main.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. """
  2. coding : utf-8
  3. @Date : 2024/7/10
  4. @Author : Shaobo
  5. @Describe:
  6. """
  7. import argparse
  8. import torch
  9. import uvicorn
  10. from fastapi import FastAPI
  11. from fastapi.middleware.cors import CORSMiddleware
  12. from fastapi.responses import JSONResponse
  13. from starlette.responses import StreamingResponse
  14. from protocols.openai_api import ChatCompletionRequest
  15. from services.chat import init_model, chat_with_codegeex, stream_chat_with_codegeex
  16. app = FastAPI()
  17. app.add_middleware(
  18. CORSMiddleware,
  19. allow_origins=["*"],
  20. allow_methods=["*"],
  21. allow_headers=["*"],
  22. )
  23. def parse_arguments():
  24. parser = argparse.ArgumentParser()
  25. parser.add_argument("--model_name_or_path", type=str, default="THUDM/codegeex4-all-9b")
  26. parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
  27. parser.add_argument("--bf16", type=bool, default=False)
  28. return parser.parse_args()
  29. @app.post("/v1/chat/completions")
  30. async def chat(request: ChatCompletionRequest):
  31. try:
  32. if request.stream:
  33. return StreamingResponse(stream_chat_with_codegeex(request), media_type="text/event-stream")
  34. else:
  35. return JSONResponse(chat_with_codegeex(request))
  36. except Exception as e:
  37. return JSONResponse(e, status_code=500)
  38. if __name__ == "__main__":
  39. args = parse_arguments()
  40. init_model(args)
  41. uvicorn.run(app, host="127.0.0.1", port=8080)