Browse Source

Add 'local mode' tutorials

shaobo 1 year ago
parent
commit
81ff1bea82

+ 37 - 0
local_mode/README.md

@@ -0,0 +1,37 @@
+![](../resources/logo.jpeg)
+
+[English](README.md) | [中文](README_zh.md)
+
+## Local Mode
+
+The new version of the CodeGeeX plugin **supports offline mode**, allowing the use of offline deployed models to complete automatic
+completion and simple conversation functions.
+
+## Usage Tutorial
+
+### 1. Install Dependencies
+
+```bash
+cd local_mode
+pip install -r requirements.txt
+```
+
+### 2. Run the Project
+
+```bash
+python main.py --model_name_or_path THUDM/codegeex4-all-9b --device cuda --bf16 true
+>>> Running on local URL:  http://127.0.0.1:8080
+```
+
+### 3. Set API Address and Model Name
+
+As shown in the figure below, after opening the plugin with the local mode, enter the API address and model name in the settings.
+![](resources/pic1.png)
+
+### 4. Start Using
+
+Click 'Connect' to test, or click 'Ask CodeGeeX' to start using.
+
+## Demo
+
+![](resources/demo.gif)

+ 37 - 0
local_mode/README_zh.md

@@ -0,0 +1,37 @@
+![](../resources/logo.jpeg)
+
+[English](README.md) | [中文](README_zh.md)
+
+## 本地模式
+
+CodeGeeX新版插件**支持离线模式**,可使用离线部署的模型完成自动补全以及简单对话功能。
+
+## 使用教程
+
+### 1. 安装依赖项
+
+```bash
+cd local_mode
+pip install -r requirements.txt
+```
+
+### 2. 运行项目
+
+```bash
+python main.py --model_name_or_path THUDM/codegeex4-all-9b --device cuda --bf16 true
+
+>>> Running on local URL:  http://127.0.0.1:8080
+```
+
+### 3. 设置api地址和模型名称
+
+如下图所示,打开插件后进入本地模式,在设置中输入api地址和模型名称。
+![](resources/pic1.png)
+
+### 4. 开始使用
+
+点击‘连接’进行测试,或点击‘Ask CodeGeeX’即可开始使用。
+
+## Demo
+
+![](resources/demo_zh.gif)

+ 51 - 0
local_mode/main.py

@@ -0,0 +1,51 @@
+"""
+coding   : utf-8
+@Date    : 2024/7/10
+@Author  : Shaobo
+@Describe: 
+"""
+import argparse
+
+import torch
+import uvicorn
+from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import JSONResponse
+from starlette.responses import StreamingResponse
+
+from protocols.openai_api import ChatCompletionRequest
+from services.chat import init_model, chat_with_codegeex, stream_chat_with_codegeex
+
+app = FastAPI()
+
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+
+def parse_arguments():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--model_name_or_path", type=str, default="THUDM/codegeex4-all-9b")
+    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+    parser.add_argument("--bf16", type=bool, default=False)
+    return parser.parse_args()
+
+
[email protected]("/v1/chat/completions")
+async def chat(request: ChatCompletionRequest):
+    try:
+        if request.stream:
+            return StreamingResponse(stream_chat_with_codegeex(request), media_type="text/event-stream")
+        else:
+            return JSONResponse(chat_with_codegeex(request))
+    except Exception as e:
+        return JSONResponse(e, status_code=500)
+
+
+if __name__ == "__main__":
+    args = parse_arguments()
+    init_model(args)
+    uvicorn.run(app, host="127.0.0.1", port=8080)

+ 78 - 0
local_mode/models/codegeex.py

@@ -0,0 +1,78 @@
+"""
+coding   : utf-8
+@Date    : 2024/7/10
+@Author  : Shaobo
+@Describe: 
+"""
+
+import torch
+from protocols.openai_api import ChatCompletionRequest, ChatCompletionStreamResponse, ChatCompletionResponse
+from sseclient import Event
+from transformers import AutoTokenizer, AutoModel
+
+SYS_PROMPT = "你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"
+
+
+class CodegeexChatModel:
+    def __init__(self, args):
+        self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
+        if args.bf16:
+            self.model = AutoModel.from_pretrained(
+                args.model_name_or_path,
+                trust_remote_code=True,
+                torch_dtype=torch.bfloat16,
+            ).to(args.device).eval()
+        else:
+            self.model = AutoModel.from_pretrained(
+                args.model_name_or_path,
+                trust_remote_code=True
+            ).to(args.device).eval()
+        print("Model is initialized.")
+
+    def stream_chat(self, request: ChatCompletionRequest):
+        try:
+            length = 0
+            for i, (response, _) in enumerate(self.model.stream_chat(
+                    self.tokenizer,
+                    query=request.messages[-1].content,
+                    history=[msg.model_dump() for msg in request.messages[:-1]],
+                    max_new_tokens=request.max_tokens,
+                    temperature=request.temperature,
+                    top_p=request.top_p,
+                    repetition_penalty=request.presence_penalty
+            )):
+                resp = ChatCompletionStreamResponse()
+                resp.choices[0].index = i
+                resp.choices[0].delta.content = response[length:]
+                event = Event(id=resp.id, data=resp.json(), event='message')
+                yield event.dump()
+                length = len(response)
+            resp = ChatCompletionStreamResponse()
+            resp.choices[0].finish_reason = 'stop'
+            event = Event(id=resp.id, data=resp.json(), event='message')
+            yield event.dump()
+        except Exception as e:
+            resp = ChatCompletionStreamResponse()
+            resp.choices[0].finish_reason = 'stop'
+            event = Event(id=resp.id, data=f"请求报错,错误原因:{e}", event='message')
+            yield event.dump()
+
+    def chat(self, request: ChatCompletionRequest):
+        try:
+            response, _ = self.model.chat(
+                self.tokenizer,
+                query=request.messages[0].content,
+                history=[msg.model_dump() for msg in request.messages[:-1]],
+                max_new_tokens=request.max_tokens,
+                temperature=request.temperature,
+                top_p=request.top_p,
+                repetition_penalty=request.presence_penalty
+            )
+            resp = ChatCompletionResponse()
+            resp.choices[0].message.content = response
+            resp.choices[0].finish_reason = 'stop'
+            # event = Event(id=resp.id, data=resp.json(), event='message')
+            # return event.dump()
+            return resp.model_dump()
+        except Exception as e:
+            return f"请求报错,错误原因:{e}"

+ 61 - 0
local_mode/protocols/openai_api.py

@@ -0,0 +1,61 @@
+"""
+coding   : utf-8
+@Date    : 2024/7/11
+@Author  : Shaobo
+@Describe: 
+"""
+import time
+from typing import Literal
+
+import shortuuid
+from pydantic import BaseModel
+
+
+class ChatMessage(BaseModel):
+    role: str
+    content: str
+
+
+class ChatCompletionRequest(BaseModel):
+    model: str = "codegeex4"
+    messages: list[ChatMessage]
+    temperature: float = 0.2
+    top_p: float = 1.0
+    max_tokens: int = 1024
+    stop: list[str] = ['<|user|>', '<|assistant|>', '<|observation|>', '<|endoftext|>']
+    stream: bool = True
+    presence_penalty: float = None
+
+
+class DeltaMessage(BaseModel):
+    role: str
+    content: str
+
+
+class ChatCompletionResponseStreamChoice(BaseModel):
+    index: int = 0
+    delta: DeltaMessage = DeltaMessage(role='assistant', content='')
+    finish_reason: Literal["stop", "length"] = None
+
+
+class ChatCompletionStreamResponse(BaseModel):
+    id: str = f"chatcmpl-{shortuuid.random()}"
+    object: str = "chat.completion.chunk"
+    created: int = int(time.time())
+    model: str = "codegeex4"
+    choices: list[ChatCompletionResponseStreamChoice] = [ChatCompletionResponseStreamChoice()]
+
+
+class ChatCompletionResponseChoice(BaseModel):
+    index: int = 0
+    message: ChatMessage = ChatMessage(role="assistant", content="")
+    finish_reason: Literal["stop", "length"] = None
+
+
+class ChatCompletionResponse(BaseModel):
+    id: str = f"chatcmpl-{shortuuid.random()}"
+    object: str = "chat.completion"
+    created: int = int(time.time())
+    model: str = "codegeex4"
+    choices: list[ChatCompletionResponseChoice] = [ChatCompletionResponseChoice()]
+    # usage: UsageInfo

+ 13 - 0
local_mode/requirements.txt

@@ -0,0 +1,13 @@
+accelerate==0.31.0
+fastapi==0.111.0
+openai==1.35.12
+pydantic==2.8.2
+regex==2024.5.15
+requests==2.32.3
+shortuuid==1.0.13
+sseclient==0.0.27
+starlette==0.37.2
+tiktoken==0.7.0
+torch==2.3.1
+transformers==4.39.0
+uvicorn==0.30.1

BIN
local_mode/resources/demo.gif


BIN
local_mode/resources/demo_zh.gif


BIN
local_mode/resources/pic1.png


+ 22 - 0
local_mode/services/chat.py

@@ -0,0 +1,22 @@
+"""
+coding   : utf-8
+@Date    : 2024/7/10
+@Author  : Shaobo
+@Describe: 
+"""
+from models.codegeex import CodegeexChatModel
+
+model: CodegeexChatModel
+
+
+def stream_chat_with_codegeex(request):
+    yield from model.stream_chat(request)
+
+
+def chat_with_codegeex(request):
+    return model.chat(request)
+
+
+def init_model(args):
+    global model
+    model = CodegeexChatModel(args)