| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- """
- coding : utf-8
- @Date : 2024/7/10
- @Author : Shaobo
- @Describe:
- """
- import argparse
- import torch
- import uvicorn
- from fastapi import FastAPI
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import JSONResponse
- from starlette.responses import StreamingResponse
- from protocols.openai_api import ChatCompletionRequest
- from services.chat import init_model, chat_with_codegeex, stream_chat_with_codegeex
- app = FastAPI()
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_methods=["*"],
- allow_headers=["*"],
- )
- def parse_arguments():
- parser = argparse.ArgumentParser()
- parser.add_argument("--model_name_or_path", type=str, default="THUDM/codegeex4-all-9b")
- parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
- parser.add_argument("--bf16", type=bool, default=False)
- return parser.parse_args()
- @app.post("/v1/chat/completions")
- async def chat(request: ChatCompletionRequest):
- try:
- if request.stream:
- return StreamingResponse(stream_chat_with_codegeex(request), media_type="text/event-stream")
- else:
- return JSONResponse(chat_with_codegeex(request))
- except Exception as e:
- return JSONResponse(e, status_code=500)
- if __name__ == "__main__":
- args = parse_arguments()
- init_model(args)
- uvicorn.run(app, host="127.0.0.1", port=8080)
|