Просмотр исходного кода

Merge remote-tracking branch 'origin/main'

shaobo 1 год назад
Родитель
Сommit
53dbb39321

+ 13 - 0
README.md

@@ -87,6 +87,19 @@ python -m vllm.entrypoints.openai.api_server \
      --trust_remote_code
 ```
 
+### Rust-candle
+Codegeex4 now suport Candle framwork [Repo](https://github.com/huggingface/candle/blob/main/candle-examples/examples/codegeex4-9b/README.org)
+#### Cli
+Use Rust to launch [codegeex4-all-9b](https://huggingface.co/THUDM/codegeex4-all-9b):
+``` shell
+	cd candle_demo
+	cargo build -p codegeex4-cli --release --features cuda # for Cuda
+	cargo build -p codegeex4-cli --release # for cpu
+	./target/release/codegeex4-cli --sample-len 512
+```
+
+
+
 ## Tutorials
 CodeGeeX4-ALL-9B provides three user guides to help users quickly understand and use the model:
 

+ 10 - 0
README_zh.md

@@ -89,6 +89,16 @@ python -m vllm.entrypoints.openai.api_server \
      --trust_remote_code
 ```
 
+### Rust-candle
+Codegeex4现已支持Candle框架 [Repo](https://github.com/huggingface/candle/blob/main/candle-examples/examples/codegeex4-9b/README.org)
+
+Use Rust to launch [codegeex4-all-9b](https://huggingface.co/THUDM/codegeex4-all-9b):
+``` shell
+	cd candle_demo
+	cargo build -p codegeex4-cli --release --features cuda # for Cuda
+	cargo build -p codegeex4-cli --release # for cpu
+	./target/release/codegeex4-cli --sample-len 512
+```
 ## 用户指南
 我们为 CodeGeeX4-ALL-9B 提供了用户指南,帮助用户快速了解和使用该模型:
 

+ 2 - 0
candle_demo/.gitignore

@@ -0,0 +1,2 @@
+target
+Cargo.lock

+ 30 - 0
candle_demo/Cargo.toml

@@ -0,0 +1,30 @@
+[workspace]
+members = ["cli", "codegeex4", "api-server"]
+resolver = "2"
+[workspace.package]
+version = "0.1.0"
+edition = "2021"
+authors = ["Donjuan Platinum <[email protected]>"]
+license = "GPL-2.0-only"
+description = "Codegeex4"
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[workspace.dependencies]
+hf-hub = "0.3.2"
+clap = { version = "4.5.6", features = ["derive"] }
+tokenizers = "0.19.1"
+serde_json = "1.0.120"
+candle-core = "0.6.0"
+# candle-transformers = "0.6.0"
+candle-transformers = "0.6.0"	
+candle-examples = "0.6.0"
+candle-nn = "0.6.0"
+safetensors = "0.4.3"
+accelerate-src = { version = "0.3.2"}
+intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
+rand = "0.8.5"
+owo-colors = "4.0.0"
+codegeex4 = {path = "./codegeex4"}
+
+
+

+ 88 - 0
candle_demo/README.org

@@ -0,0 +1,88 @@
+* candle-codegeex4_9b
+THUDM/CodeGeeX4 is a versatile model for all AI software development scenarios, including code completion, code interpreter, web search, function calling, repository-level Q&A and much more.
+[[../resources/candle_example.png][file:../resources/candle_example.png]]
+
+- [[https://github.com/THUDM/CodeGeeX4][Github]]
+- [[https://codegeex.cn/][HomePage]]
+- [[https://huggingface.co/THUDM/codegeex4-all-9b][huggingface]]  
+- [[https://github.com/huggingface/candle/blob/main/candle-examples/examples/codegeex4-9b/README.org][Candle]]
+
+- 目前openai-api正在开发中  
+** Cli
+#+begin_src shell
+  cargo build --release -p codegeex4-cli # Cpu
+  cargo build --release -p codegeex4-cli --features cuda # if cuda is avalibe
+  ./target/release/codegeex4-cli --sample-len 500
+#+end_src
+** Output_Example
+#+begin_src shell
+  avx: false, neon: false, simd128: false, f16c: false
+  temp: 0.95 repeat-penalty: 1.10 repeat-last-n: 64
+  cache path /root/autodl-tmp
+  Prompt: [please write a FFT in rust]
+  Using Seed 11511762269791786684
+  DType is BF16
+  transofrmer layers create
+  模型加载完毕 4
+  starting the inference loop
+
+   开始生成
+  samplelen 500
+
+  500 tokens generated (34.60 token/s)
+  Result:
+
+  Sure, I can help you with that. Here's an example of a Fast Fourier Transform (FFT) implementation in Rust:
+
+  ```rust
+  use num_complex::Complex;
+
+  fn fft(input: &[Complex<f64> > ] ) -> Vec<Complex<f64> > > {
+      let n = input.len();
+    
+      if n == 1 {
+	  return vec![input[0]]];
+      }
+    
+      let mut even = vec![];
+      let mut odd = vec![];
+    
+      for i in 0..n {
+
+	      if i % 2 == 0 {
+	      even.push(input[i]);
+	  } else {
+	      odd.push(input[i]);
+	  }
+      }
+    
+      let even_fft = fft(&even);
+      let odd_fft = fft(&odd);
+    
+      let mut output = vec![];
+    
+      for k in 0..n/2 {
+	  let t = Complex::new(0.0, -2.0 * std::f64::consts::PI * (k as f64) / (n as f64))) ).exp();
+        
+	  output.push(even_fft[k] + odd_fft[k] * t]);
+	  output.push(even_fft[k] - odd_fft[k] * t]);
+      }
+    
+      return output;
+  }
+  ```
+
+  This implementation uses the Cooley-Tukey algorithm to perform the FFT. The function takes an array of complex numbers and returns an array of complex numbers which is the result of the FFT.
+#+end_src
+
+
+*  Citation
+#+begin_src
+  @inproceedings{zheng2023codegeex,
+  title={CodeGeeX: A Pre-Trained Model for Code Generation with Multilingual Benchmarking on HumanEval-X},
+  author={Qinkai Zheng and Xiao Xia and Xu Zou and Yuxiao Dong and Shan Wang and Yufei Xue and Zihan Wang and Lei Shen and Andi Wang and Yang Li and Teng Su and Zhilin Yang and Jie Tang},
+  booktitle={Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
+  pages={5673--5684},
+  year={2023}
+}
+#+end_src

+ 37 - 0
candle_demo/api-server/Cargo.toml

@@ -0,0 +1,37 @@
+[package]
+name = "api-server"
+version.workspace = true
+edition.workspace = true
+authors.workspace = true
+license.workspace = true
+description.workspace = true
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+hf-hub = {workspace = true}
+clap = { workspace = true}
+tokenizers = {workspace = true}
+serde_json.workspace = true
+candle-core = {workspace = true}
+candle-transformers = {workspace = true}
+candle-examples = {workspace = true}
+candle-nn = {workspace = true}
+safetensors = {workspace = true}
+accelerate-src = { workspace = true, optional = true}
+intel-mkl-src = { workspace = true ,optional = true}
+rand = { workspace = true}
+owo-colors = {workspace = true}
+codegeex4 = {workspace = true}
+tokio = {version = "1.39.1", features = ["full"]}
+actix-web = "4.8.0"
+serde = { version = "1.0.204", features = ["derive"] }
+shortuuid = "0.0.1"
+short-uuid = "0.1.2"
+[build-dependencies]
+bindgen_cuda = { version = "0.1.1", optional = true }
+[features]
+default = []
+cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
+accelerate = ["dep:accelerate-src", "candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
+mkl = ["dep:intel-mkl-src", "candle-core/mkl", "candle-nn/mkl", "candle-transformers/mkl"]

+ 123 - 0
candle_demo/api-server/src/api.rs

@@ -0,0 +1,123 @@
+use actix_web::{
+    get, post,
+    web::{self, Data},
+    HttpRequest, Responder,
+};
+use owo_colors::OwoColorize;
+use serde::{Deserialize, Serialize};
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ChatMessage {
+    pub role: String,
+    pub content: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ChatCompletionRequest {
+    pub model: String,
+    pub messages: Vec<ChatMessage>,
+    pub temperature: f64,
+    pub top_p: f64,
+    pub max_tokens: usize,
+    pub stop: Vec<String>,
+    pub stream: bool,
+    pub presence_penalty: Option<f32>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct DeltaMessage {
+    pub role: String,
+    pub content: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ChatCompletionResponseStreamChoice {
+    pub index: i32,
+    pub delta: DeltaMessage,
+    pub finish_reason: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ChatCompletionStreamResponse {
+    pub id: String,
+    pub object: String,
+    pub created: i32,
+    pub model: String,
+    pub choices: Vec<ChatCompletionResponseStreamChoice>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ChatCompletionResponseChoice {
+    pub index: i32,
+    pub message: ChatMessage,
+    pub finish_reason: Option<FinishResaon>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ChatCompletionResponse {
+    pub id: String,
+    pub object: String,
+    pub created: u64,
+    pub model: String,
+    pub choices: Vec<ChatCompletionResponseChoice>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub enum FinishResaon{
+    STOP,
+    LENGTH,
+}
+use std::time::{SystemTime, UNIX_EPOCH};
+impl ChatCompletionResponse {
+    pub fn empty() -> Self {
+        let current_time = SystemTime::now();
+        Self {
+            id: format!("chatcmpl-{}", short_uuid::ShortUuid::generate()),
+            object: "chat.completion".to_string(),
+            created: current_time
+                .duration_since(UNIX_EPOCH)
+                .expect("failed to get time")
+                .as_secs()
+                .into(),
+            model: "codegeex4".to_string(),
+            choices: vec![ChatCompletionResponseChoice::empty()],
+        }
+    }
+}
+
+impl ChatCompletionResponseChoice {
+    pub fn empty() -> Self {
+        Self {
+            index: 0,
+            message: ChatMessage {
+                role: "assistant".to_string(),
+                content: "".to_string(),
+            },
+            finish_reason: None,
+        }
+    }
+}
+
+impl ChatCompletionRequest {
+    pub fn empty() -> Self {
+	Self{
+	    model: "codegeex4".to_string(),
+	    messages: vec!(ChatMessage {
+                role: "assistant".to_string(),
+                content: "".to_string(),
+            }),
+	    temperature: 0.2_f64,
+	    top_p: 0.2_f64,
+	    max_tokens: 1024_usize,
+	    stop: vec!("<|user|>".to_string(), "<|assistant|>".to_string(), "<|observation|>".to_string(), "<|endoftext|>".to_string()),
+	    stream: true,
+	    presence_penalty: None,
+	}
+    }
+}
+
+// impl DeltaMessage {
+//     pub fn new() -> Self {
+// 	role:
+//     }
+// }

+ 10 - 0
candle_demo/api-server/src/args.rs

@@ -0,0 +1,10 @@
+use clap::Parser;
+
+#[derive(Parser, Debug, Clone)]
+#[clap(version, about)]
+pub struct Args {
+    #[arg(name = "listen", short, long, default_value = "0.0.0.0:3000")]
+    pub address: String,
+    #[arg(short, long, default_value_t = 1)]
+    pub workers: usize,
+}

+ 19 - 0
candle_demo/api-server/src/main.rs

@@ -0,0 +1,19 @@
+mod api;
+mod args;
+mod server;
+mod model;
+use clap::Parser;
+use owo_colors::OwoColorize;
+
+#[tokio::main]
+async fn main() {
+    let args = args::Args::parse();
+    let server = server::Server::new(args.clone());
+    println!(
+        "{} Server Binding On {} with {} workers",
+        "[INFO]".green(),
+        &args.address.purple(),
+        &args.workers.purple()
+    );
+    server.run().await;
+}

+ 6 - 0
candle_demo/api-server/src/model.rs

@@ -0,0 +1,6 @@
+use codegeex4::codegeex4::Config;
+use crate::api::ChatCompletionRequest;
+fn stream_chat(request: ChatCompletionRequest) {
+    let default_config = codegeex4::codegeex4::Config::codegeex4();
+    
+}

+ 31 - 0
candle_demo/api-server/src/server.rs

@@ -0,0 +1,31 @@
+use crate::args::Args;
+use actix_web::{web, App, HttpResponse, HttpServer};
+use owo_colors::OwoColorize;
+
+#[derive(Debug)]
+pub struct Server {
+    config: Args,
+}
+
+impl Server {
+    pub fn new(config: Args) -> Self {
+        return Server { config };
+    }
+    pub async fn run(&self) -> () {
+        HttpServer::new(move || App::new())
+            .bind(&self.config.address)
+            .expect(&format!("{}", "Unable To Bind Server !".red()))
+            .workers(self.config.workers)
+            .run()
+            .await
+            .expect(&format!("{}", "Unable To Run the Server !".red()));
+    }
+}
+
+// use super::api::*;
+// use uuid;
+// pub async fn chat(request: ChatCompletionRequest) ->impl Responder {
+//     if request.stream == true {
+// 	return Htt
+//     }
+// }

+ 33 - 0
candle_demo/cli/Cargo.toml

@@ -0,0 +1,33 @@
+[package]
+name= "codegeex4-cli"
+version.workspace = true
+edition.workspace = true
+authors.workspace = true
+license.workspace = true
+description.workspace = true
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+hf-hub = {workspace = true}
+clap = { workspace = true}
+tokenizers = {workspace = true}
+serde_json = {workspace = true}
+candle-core = {workspace = true}
+candle-transformers = {workspace = true}
+candle-examples = {workspace = true}
+candle-nn = {workspace = true}
+safetensors = {workspace = true}
+accelerate-src = { workspace = true, optional = true}
+intel-mkl-src = { workspace = true ,optional = true}
+rand = { workspace = true}
+owo-colors = {workspace = true}
+codegeex4 = {workspace = true}
+
+[build-dependencies]
+bindgen_cuda = { version = "0.1.1", optional = true }
+[features]
+default = []
+cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
+accelerate = ["dep:accelerate-src", "candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
+mkl = ["dep:intel-mkl-src", "candle-core/mkl", "candle-nn/mkl", "candle-transformers/mkl"]

+ 99 - 0
candle_demo/cli/src/main.rs

@@ -0,0 +1,99 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use candle_core as candle;
+use candle_core::DType;
+use candle_nn::VarBuilder;
+use clap::Parser;
+use codegeex4::args::Args;
+use codegeex4::codegeex4::*;
+use codegeex4::TextGeneration;
+use hf_hub::{Repo, RepoType};
+use owo_colors::{self, OwoColorize};
+use rand::Rng;
+use tokenizers::Tokenizer;
+
+fn main() -> Result<(), ()> {
+    let args = Args::parse();
+    println!(
+        "avx: {}, neon: {}, simd128: {}, f16c: {}",
+        candle::utils::with_avx().red(),
+        candle::utils::with_neon().red(),
+        candle::utils::with_simd128().red(),
+        candle::utils::with_f16c().red(),
+    );
+    println!(
+        "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
+        args.temperature.unwrap_or(0.95).red(),
+        args.repeat_penalty.red(),
+        args.repeat_last_n.red(),
+    );
+
+    println!("cache path {}", args.cache_path.blue());
+    let mut seed: u64 = 0;
+    if let Some(_seed) = args.seed {
+        seed = _seed;
+    } else {
+        let mut rng = rand::thread_rng();
+        seed = rng.gen();
+    }
+    println!("Using Seed {}", seed.red());
+    let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
+        .build()
+        .unwrap();
+
+    let model_id = match args.model_id {
+        Some(model_id) => model_id.to_string(),
+        None => "THUDM/codegeex4-all-9b".to_string(),
+    };
+    let revision = match args.revision {
+        Some(rev) => rev.to_string(),
+        None => "main".to_string(),
+    };
+    let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
+    let tokenizer_filename = match args.tokenizer {
+        Some(file) => std::path::PathBuf::from(file),
+        None => api
+            .model("THUDM/codegeex4-all-9b".to_string())
+            .get("tokenizer.json")
+            .unwrap(),
+    };
+    let filenames = match args.weight_file {
+        Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
+        None => {
+            candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json").unwrap()
+        }
+    };
+    let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
+    let start = std::time::Instant::now();
+    let config = Config::codegeex4();
+    let device = candle_examples::device(args.cpu).unwrap();
+    let dtype = if device.is_cuda() {
+        DType::BF16
+    } else {
+        DType::F32
+    };
+    println!("DType is {:?}", dtype.yellow());
+    let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device).unwrap() };
+    let model = Model::new(&config, vb).unwrap();
+
+    println!("模型加载完毕 {:?}", start.elapsed().as_secs().green());
+
+    let mut pipeline = TextGeneration::new(
+        model,
+        tokenizer,
+        seed,
+        args.temperature,
+        args.top_p,
+        args.repeat_penalty,
+        args.repeat_last_n,
+        args.verbose_prompt,
+        &device,
+        dtype,
+    );
+    pipeline.run(args.sample_len)?;
+    Ok(())
+}

+ 34 - 0
candle_demo/codegeex4/Cargo.toml

@@ -0,0 +1,34 @@
+[package]
+name = "codegeex4"
+version.workspace = true
+edition.workspace = true
+authors.workspace = true
+license.workspace = true
+description.workspace = true
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+hf-hub = {workspace = true}
+clap = { workspace = true}
+tokenizers = {workspace = true}
+serde_json = {workspace = true}
+candle-core = {workspace = true}
+candle-transformers = {workspace = true}
+candle-examples = {workspace = true}
+candle-nn = {workspace = true}
+safetensors = {workspace = true}
+accelerate-src = { workspace = true, optional = true}
+intel-mkl-src = { workspace = true ,optional = true}
+rand = { workspace = true}
+owo-colors = {workspace = true}
+
+
+
+[build-dependencies]
+bindgen_cuda = { version = "0.1.1", optional = true }
+[features]
+default = []
+cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
+accelerate = ["dep:accelerate-src", "candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
+mkl = ["dep:intel-mkl-src", "candle-core/mkl", "candle-nn/mkl", "candle-transformers/mkl"]

+ 51 - 0
candle_demo/codegeex4/src/args.rs

@@ -0,0 +1,51 @@
+use clap::Parser;
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+pub struct Args {
+    /// Run on CPU rather than on GPU.
+    #[arg(name = "cache", short, long, default_value = ".")]
+    pub cache_path: String,
+
+    #[arg(long)]
+    pub cpu: bool,
+
+    /// Display the token for the specified prompt.
+    #[arg(long)]
+    pub verbose_prompt: bool,
+
+    /// The temperature used to generate samples.
+    #[arg(long)]
+    pub temperature: Option<f64>,
+
+    /// Nucleus sampling probability cutoff.
+    #[arg(long)]
+    pub top_p: Option<f64>,
+
+    /// The seed to use when generating random samples.
+    #[arg(long)]
+    pub seed: Option<u64>,
+
+    /// The length of the sample to generate (in tokens).
+    #[arg(long, short = 'n', default_value_t = 5000)]
+    pub sample_len: usize,
+
+    #[arg(long)]
+    pub model_id: Option<String>,
+
+    #[arg(long)]
+    pub revision: Option<String>,
+
+    #[arg(long)]
+    pub weight_file: Option<String>,
+
+    #[arg(long)]
+    pub tokenizer: Option<String>,
+
+    /// Penalty to be applied for repeating tokens, 1. means no penalty.
+    #[arg(long, default_value_t = 1.1)]
+    pub repeat_penalty: f32,
+
+    /// The context size to consider for the repeat penalty.
+    #[arg(long, default_value_t = 64)]
+    pub repeat_last_n: usize,
+}

+ 601 - 0
candle_demo/codegeex4/src/codegeex4.rs

@@ -0,0 +1,601 @@
+use candle_core as candle;
+use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
+use candle_nn::VarBuilder;
+use candle_transformers::models::with_tracing::{linear_b as linear, Linear};
+
+#[derive(Debug, Clone)]
+pub struct Config {
+    pub num_layers: usize,
+    pub padded_vocab_size: usize,
+    pub hidden_size: usize,
+    pub ffn_hidden_size: usize,
+    pub kv_channels: usize,
+    pub num_attention_heads: usize,
+    pub seq_length: usize,
+    pub layernorm_epsilon: f64,
+    pub rmsnorm: bool,
+    pub apply_residual_connection_post_layernorm: bool,
+    pub post_layer_norm: bool,
+    pub add_bias_linear: bool,
+    pub add_qkv_bias: bool,
+    pub bias_dropout_fusion: bool,
+    pub multi_query_attention: bool,
+    pub multi_query_group_num: usize,
+    pub apply_query_key_layer_scaling: bool,
+    pub attention_softmax_in_fp32: bool,
+    pub fp32_residual_connection: bool,
+}
+
+impl Config {
+    pub fn codegeex4() -> Self {
+        Self {
+            num_layers: 40,
+            padded_vocab_size: 151552,
+            hidden_size: 4096,
+            ffn_hidden_size: 13696,
+            kv_channels: 128,
+            num_attention_heads: 32,
+            seq_length: 131072,
+            layernorm_epsilon: 1e-5,
+            rmsnorm: true,
+            apply_residual_connection_post_layernorm: false,
+            post_layer_norm: true,
+            add_bias_linear: false,
+            add_qkv_bias: true,
+            bias_dropout_fusion: true,
+            multi_query_attention: true,
+            multi_query_group_num: 2,
+            apply_query_key_layer_scaling: true,
+            attention_softmax_in_fp32: true,
+            fp32_residual_connection: false,
+        }
+    }
+}
+
+#[derive(Debug, Clone)]
+struct RotaryEmbedding {
+    cache: Tensor,
+}
+
+impl RotaryEmbedding {
+    fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result<Self> {
+        let rotary_dim = cfg.kv_channels;
+        let n_elem = rotary_dim / 2;
+        let inv_freq: Vec<_> = (0..n_elem)
+            .step_by(2)
+            .map(|i| 1f32 / 10_000f64.powf(i as f64 / n_elem as f64) as f32)
+            .collect();
+        let inv_freq_len = inv_freq.len();
+        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
+        let t = Tensor::arange(0u32, cfg.seq_length as u32, dev)?
+            .to_dtype(dtype)
+            .expect("unalbe to dytpe in Rotray Embedding new")
+            .reshape((cfg.seq_length, 1))?;
+        let freqs = t.matmul(&inv_freq)?;
+        let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?;
+        Ok(Self { cache })
+    }
+
+    fn apply(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
+        let (seqlen, _b, np, _hn) = xs.dims4()?;
+        let cache = self.cache.narrow(0, seqlen_offset, seqlen)?;
+        let rot_dim = cache.dim(D::Minus2)? * 2;
+        let (xs, xs_pass) = (
+            xs.narrow(D::Minus1, 0, rot_dim)?,
+            xs.narrow(D::Minus1, rot_dim, rot_dim)?,
+        );
+        let xshaped = xs.reshape((seqlen, (), np, rot_dim / 2, 2))?;
+        let cache = cache.reshape((seqlen, (), 1, rot_dim / 2, 2))?;
+        let (xshaped0, xshaped1) = (
+            xshaped.i((.., .., .., .., 0))?,
+            xshaped.i((.., .., .., .., 1))?,
+        );
+        let (cache0, cache1) = (cache.i((.., .., .., .., 0))?, cache.i((.., .., .., .., 1))?);
+        let xs_out = Tensor::stack(
+            &[
+                (xshaped0.broadcast_mul(&cache0)? - xshaped1.broadcast_mul(&cache1)?)?,
+                (xshaped1.broadcast_mul(&cache0)? + xshaped0.broadcast_mul(&cache1)?)?,
+            ],
+            D::Minus1,
+        )?;
+        let xs_out = xs_out.flatten_from(3)?;
+        Tensor::cat(&[xs_out, xs_pass], D::Minus1)
+    }
+}
+
+#[derive(Debug, Clone)]
+struct CoreAttention {
+    coeff: Option<f64>,
+    norm_factor: f64,
+    dtype: DType,
+}
+
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32, dtype: DType) -> Result<Tensor> {
+    let shape = mask.shape();
+    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
+    let m = mask.where_cond(&on_true.to_dtype(dtype)?, on_false)?;
+    Ok(m)
+}
+
+impl CoreAttention {
+    fn new(layer_number: usize, cfg: &Config, dtype: DType) -> Result<Self> {
+        let norm_factor = (cfg.kv_channels as f64).sqrt();
+        let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling {
+            let coeff = f64::max(1.0, layer_number as f64);
+            (norm_factor * coeff, Some(coeff))
+        } else {
+            (norm_factor, None)
+        };
+        Ok(Self {
+            coeff,
+            norm_factor,
+            dtype,
+        })
+    }
+
+    fn forward(
+        &self,
+        query_layer: &Tensor,
+        key_layer: &Tensor,
+        value_layer: &Tensor,
+        attention_mask: &Option<Tensor>,
+    ) -> Result<Tensor> {
+        let output_size = (
+            query_layer.dim(1)?, // b
+            query_layer.dim(2)?, // np
+            query_layer.dim(0)?, // sq
+            key_layer.dim(0)?,   // sk
+        );
+        let query_layer =
+            query_layer.reshape((output_size.2, output_size.0 * output_size.1, ()))?;
+        let key_layer = key_layer.reshape((output_size.3, output_size.0 * output_size.1, ()))?;
+        let matmul_result = Tensor::matmul(
+            &query_layer.transpose(0, 1)?.contiguous()?,
+            &key_layer.transpose(0, 1)?.transpose(1, 2)?.contiguous()?,
+        )?;
+        let matmul_result = (matmul_result / self.norm_factor)?.reshape(output_size)?;
+        let matmul_result = match self.coeff {
+            None => matmul_result,
+            Some(coeff) => (matmul_result * coeff)?,
+        };
+        let attention_scores = match attention_mask {
+            Some(mask) => masked_fill(
+                &matmul_result,
+                &mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?,
+                f32::NEG_INFINITY,
+                self.dtype,
+            )?,
+            None => matmul_result,
+        };
+        let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
+
+        let output_size = (
+            value_layer.dim(1)?,
+            value_layer.dim(2)?,
+            query_layer.dim(0)?,
+            value_layer.dim(3)?,
+        );
+        let value_layer =
+            value_layer.reshape((value_layer.dim(0)?, output_size.0 * output_size.1, ()))?;
+        let attention_probs =
+            attention_probs.reshape((output_size.0 * output_size.1, output_size.2, ()))?;
+        let context_layer = Tensor::matmul(
+            &attention_probs.contiguous()?,
+            &value_layer.transpose(0, 1)?.contiguous()?,
+        )?;
+        let context_layer = context_layer.reshape(output_size)?;
+        let context_layer = context_layer.permute((2, 0, 1, 3))?.contiguous()?;
+        context_layer.flatten_from(D::Minus2)
+    }
+}
+
+#[derive(Debug, Clone)]
+struct SelfAttention {
+    query_key_value: Linear,
+    core_attention: CoreAttention,
+    dense: Linear,
+    multi_query_attention: bool,
+    num_attention_heads_per_partition: usize,
+    num_multi_query_groups_per_partition: usize,
+    hidden_size_per_attention_head: usize,
+    kv_cache: Option<(Tensor, Tensor)>,
+}
+
+impl SelfAttention {
+    fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+        let projection_size = cfg.kv_channels * cfg.num_attention_heads;
+        let hidden_size_per_attention_head = projection_size / cfg.num_attention_heads;
+        let qkv_hidden_size = if cfg.multi_query_attention {
+            projection_size + 2 * hidden_size_per_attention_head * cfg.multi_query_group_num
+        } else {
+            3 * projection_size
+        };
+        let query_key_value = linear(
+            cfg.hidden_size,
+            qkv_hidden_size,
+            cfg.add_bias_linear || cfg.add_qkv_bias,
+            vb.pp("query_key_value"),
+        )?;
+        let core_attention = CoreAttention::new(layer_number, cfg, vb.dtype())?;
+        let dense = linear(
+            cfg.hidden_size,
+            cfg.hidden_size,
+            cfg.add_bias_linear,
+            vb.pp("dense"),
+        )?;
+        Ok(Self {
+            query_key_value,
+            core_attention,
+            dense,
+            multi_query_attention: cfg.multi_query_attention,
+            num_attention_heads_per_partition: cfg.num_attention_heads,
+            num_multi_query_groups_per_partition: cfg.multi_query_group_num,
+            hidden_size_per_attention_head: cfg.kv_channels,
+            kv_cache: None,
+        })
+    }
+
+    fn reset_kv_cache(&mut self) {
+        self.kv_cache = None
+    }
+
+    fn forward(
+        &mut self,
+        xs: &Tensor,
+        attention_mask: &Option<Tensor>,
+        rotary_emb: &RotaryEmbedding,
+    ) -> Result<Tensor> {
+        let mixed_x_layer = xs.apply(&self.query_key_value)?;
+        if !self.multi_query_attention {
+            candle::bail!("only multi_query_attention=true is supported")
+        }
+        let hpa = self.hidden_size_per_attention_head;
+        let query_layer =
+            mixed_x_layer.narrow(D::Minus1, 0, self.num_attention_heads_per_partition * hpa)?;
+        let key_layer = mixed_x_layer.narrow(
+            D::Minus1,
+            self.num_attention_heads_per_partition * hpa,
+            self.num_multi_query_groups_per_partition * hpa,
+        )?;
+        let value_layer = mixed_x_layer.narrow(
+            D::Minus1,
+            self.num_attention_heads_per_partition * hpa
+                + self.num_multi_query_groups_per_partition * hpa,
+            self.num_multi_query_groups_per_partition * hpa,
+        )?;
+        let query_layer = query_layer.reshape((
+            query_layer.dim(0)?,
+            query_layer.dim(1)?,
+            self.num_attention_heads_per_partition,
+            hpa,
+        ))?;
+        let key_layer = key_layer.reshape((
+            key_layer.dim(0)?,
+            key_layer.dim(1)?,
+            self.num_multi_query_groups_per_partition,
+            hpa,
+        ))?;
+        let value_layer = value_layer.reshape((
+            value_layer.dim(0)?,
+            value_layer.dim(1)?,
+            self.num_multi_query_groups_per_partition,
+            hpa,
+        ))?;
+
+        // Rotary embeddings.
+        let seqlen_offset = match &self.kv_cache {
+            None => 0,
+            Some((prev_k, _)) => prev_k.dim(0)?,
+        };
+        let query_layer = rotary_emb.apply(&query_layer, seqlen_offset)?;
+        let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?;
+
+        // KV cache.
+        let (key_layer, value_layer) = match &self.kv_cache {
+            None => (key_layer, value_layer),
+            Some((prev_k, prev_v)) => {
+                let k = Tensor::cat(&[prev_k, &key_layer], 0)?;
+                let v = Tensor::cat(&[prev_v, &value_layer], 0)?;
+                (k, v)
+            }
+        };
+        self.kv_cache = Some((key_layer.clone(), value_layer.clone()));
+
+        // Repeat KV.
+        let ratio =
+            self.num_attention_heads_per_partition / self.num_multi_query_groups_per_partition;
+        let key_layer = {
+            let (d0, d1, d2, d3) = key_layer.dims4()?;
+            key_layer
+                .unsqueeze(D::Minus2)?
+                .expand((d0, d1, d2, ratio, d3))?
+                .reshape((
+                    d0,
+                    d1,
+                    self.num_attention_heads_per_partition,
+                    self.hidden_size_per_attention_head,
+                ))?
+        };
+        let value_layer = {
+            let (d0, d1, d2, d3) = value_layer.dims4()?;
+            value_layer
+                .unsqueeze(D::Minus2)?
+                .expand((d0, d1, d2, ratio, d3))?
+                .reshape((
+                    d0,
+                    d1,
+                    self.num_attention_heads_per_partition,
+                    self.hidden_size_per_attention_head,
+                ))?
+        };
+
+        let context_layer =
+            self.core_attention
+                .forward(&query_layer, &key_layer, &value_layer, attention_mask)?;
+        let output = context_layer.apply(&self.dense)?;
+        Ok(output)
+    }
+}
+
+#[allow(clippy::upper_case_acronyms)]
+#[derive(Debug, Clone)]
+struct MLP {
+    dense_h_to_4h: Linear,
+    dense_4h_to_h: Linear,
+}
+
+impl MLP {
+    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+        let dense_h_to_4h = linear(
+            cfg.hidden_size,
+            cfg.ffn_hidden_size * 2,
+            cfg.add_bias_linear,
+            vb.pp("dense_h_to_4h"),
+        )?;
+        let dense_4h_to_h = linear(
+            cfg.ffn_hidden_size,
+            cfg.hidden_size,
+            cfg.add_bias_linear,
+            vb.pp("dense_4h_to_h"),
+        )?;
+        Ok(Self {
+            dense_4h_to_h,
+            dense_h_to_4h,
+        })
+    }
+}
+
+impl Module for MLP {
+    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+        xs.apply(&self.dense_h_to_4h)?
+            .apply(&candle_nn::Activation::Swiglu)?
+            .apply(&self.dense_4h_to_h)
+    }
+}
+
+#[derive(Debug, Clone)]
+struct Block {
+    input_layernorm: candle_nn::LayerNorm,
+    self_attention: SelfAttention,
+    post_attention_layernorm: candle_nn::LayerNorm,
+    mlp: MLP,
+    apply_residual_connection_post_layernorm: bool,
+}
+
+impl Block {
+    fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+        let input_layernorm = if cfg.rmsnorm {
+            candle_nn::rms_norm(
+                cfg.hidden_size,
+                cfg.layernorm_epsilon,
+                vb.pp("input_layernorm"),
+            )?
+            .into_inner()
+        } else {
+            candle_nn::layer_norm(
+                cfg.hidden_size,
+                cfg.layernorm_epsilon,
+                vb.pp("input_layernorm"),
+            )?
+        };
+        let post_attention_layernorm = if cfg.rmsnorm {
+            candle_nn::rms_norm(
+                cfg.hidden_size,
+                cfg.layernorm_epsilon,
+                vb.pp("post_attention_layernorm"),
+            )?
+            .into_inner()
+        } else {
+            candle_nn::layer_norm(
+                cfg.hidden_size,
+                cfg.layernorm_epsilon,
+                vb.pp("post_attention_layernorm"),
+            )?
+        };
+        let self_attention = SelfAttention::new(layer_number, cfg, vb.pp("self_attention"))?;
+        let mlp = MLP::new(cfg, vb.pp("mlp"))?;
+        Ok(Self {
+            input_layernorm,
+            self_attention,
+            post_attention_layernorm,
+            mlp,
+            apply_residual_connection_post_layernorm: cfg.apply_residual_connection_post_layernorm,
+        })
+    }
+
+    fn reset_kv_cache(&mut self) {
+        self.self_attention.reset_kv_cache()
+    }
+
+    fn forward(
+        &mut self,
+        xs: &Tensor,
+        attention_mask: &Option<Tensor>,
+        rotary_emb: &RotaryEmbedding,
+    ) -> Result<Tensor> {
+        let layernorm_output = xs.apply(&self.input_layernorm)?;
+        let attention_output =
+            self.self_attention
+                .forward(&layernorm_output, attention_mask, rotary_emb)?;
+        let residual = if self.apply_residual_connection_post_layernorm {
+            &layernorm_output
+        } else {
+            xs
+        };
+        let layernorm_input = (residual + attention_output)?;
+        let layernorm_output = layernorm_input.apply(&self.post_attention_layernorm)?;
+        let mlp_output = layernorm_output.apply(&self.mlp)?;
+        let residual = if self.apply_residual_connection_post_layernorm {
+            &layernorm_output
+        } else {
+            &layernorm_input
+        };
+        mlp_output + residual
+    }
+}
+
+#[derive(Debug, Clone)]
+struct Transformer {
+    layers: Vec<Block>,
+    final_layernorm: Option<candle_nn::LayerNorm>,
+    rotary_emb: RotaryEmbedding,
+}
+
+impl Transformer {
+    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+        let vb_l = vb.pp("layers");
+        let mut layers = Vec::with_capacity(cfg.num_layers);
+        println!("transofrmer layers create");
+        let mut count = 0;
+        for layer_index in 0..cfg.num_layers {
+            count += 1;
+            println!("for layer index in {} total is {} ", count, cfg.num_layers);
+            let block = Block::new(layer_index + 1, cfg, vb_l.pp(layer_index))?;
+            layers.push(block)
+        }
+        let final_layernorm = if cfg.post_layer_norm {
+            let ln = if cfg.rmsnorm {
+                candle_nn::rms_norm(
+                    cfg.hidden_size,
+                    cfg.layernorm_epsilon,
+                    vb.pp("final_layernorm"),
+                )?
+                .into_inner()
+            } else {
+                candle_nn::layer_norm(
+                    cfg.hidden_size,
+                    cfg.layernorm_epsilon,
+                    vb.pp("final_layernorm"),
+                )?
+            };
+            Some(ln)
+        } else {
+            None
+        };
+        let rotary_emb = RotaryEmbedding::new(cfg, vb.dtype(), vb.device())?;
+        Ok(Self {
+            layers,
+            final_layernorm,
+            rotary_emb,
+        })
+    }
+
+    fn reset_kv_cache(&mut self) {
+        for block in self.layers.iter_mut() {
+            block.reset_kv_cache()
+        }
+    }
+
+    fn forward(&mut self, xs: &Tensor, attention_mask: &Option<Tensor>) -> Result<Tensor> {
+        let mut xs = xs.clone();
+        for block in self.layers.iter_mut() {
+            xs = block.forward(&xs, attention_mask, &self.rotary_emb)?
+        }
+        match self.final_layernorm.as_ref() {
+            None => Ok(xs),
+            Some(ln) => xs.apply(ln),
+        }
+    }
+}
+
+#[derive(Debug, Clone)]
+struct Embedding {
+    word_embeddings: candle_nn::Embedding,
+    fp32_residual_connection: bool,
+}
+
+impl Embedding {
+    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+        let word_embeddings = candle_nn::embedding(
+            cfg.padded_vocab_size,
+            cfg.hidden_size,
+            vb.pp("word_embeddings"),
+        )?;
+        Ok(Self {
+            word_embeddings,
+            fp32_residual_connection: cfg.fp32_residual_connection,
+        })
+    }
+}
+
+impl Module for Embedding {
+    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+        let xs = self.word_embeddings.forward(xs)?.transpose(0, 1)?; // b,s,h -> s,b,h
+        if self.fp32_residual_connection {
+            xs.to_dtype(candle::DType::F32)
+        } else {
+            xs.contiguous()
+        }
+    }
+}
+
+#[derive(Debug, Clone)]
+pub struct Model {
+    embedding: Embedding,
+    encoder: Transformer,
+    output_layer: Linear,
+}
+
+fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
+    let mask: Vec<_> = (0..size)
+        .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
+        .collect();
+    Tensor::from_slice(&mask, (size, size), device)
+}
+
+impl Model {
+    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+        let vb = vb.pp("transformer");
+        let embedding = Embedding::new(cfg, vb.pp("embedding"))?;
+        let encoder = Transformer::new(cfg, vb.pp("encoder"))?;
+        let output_layer = linear(
+            cfg.hidden_size,
+            cfg.padded_vocab_size,
+            false,
+            vb.pp("output_layer"),
+        )?;
+
+        Ok(Self {
+            embedding,
+            encoder,
+            output_layer,
+        })
+    }
+
+    pub fn reset_kv_cache(&mut self) {
+        self.encoder.reset_kv_cache()
+    }
+
+    pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
+        let (_b_size, seq_len) = xs.dims2()?;
+        let input_embeds = xs.apply(&self.embedding)?;
+        let attention_mask = if seq_len <= 1 {
+            None
+        } else {
+            Some(get_mask(seq_len, xs.device())?)
+        };
+        let xs = self.encoder.forward(&input_embeds, &attention_mask)?;
+        let lm_logits = xs.i(seq_len - 1)?.apply(&self.output_layer)?;
+        Ok(lm_logits)
+    }
+}

+ 140 - 0
candle_demo/codegeex4/src/lib.rs

@@ -0,0 +1,140 @@
+pub mod codegeex4;
+
+pub mod args;
+
+use candle_core::{DType, Device, Tensor};
+use candle_transformers::generation::LogitsProcessor;
+use codegeex4::*;
+use owo_colors::{self, OwoColorize};
+use std::io::BufRead;
+use std::io::BufReader;
+use tokenizers::Tokenizer;
+
+pub struct TextGeneration {
+    model: Model,
+    device: Device,
+    tokenizer: Tokenizer,
+    logits_processor: LogitsProcessor,
+    repeat_penalty: f32,
+    repeat_last_n: usize,
+    verbose_prompt: bool,
+    dtype: DType,
+}
+
+impl TextGeneration {
+    #[allow(clippy::too_many_arguments)]
+    pub fn new(
+        model: Model,
+        tokenizer: Tokenizer,
+        seed: u64,
+        temp: Option<f64>,
+        top_p: Option<f64>,
+        repeat_penalty: f32,
+        repeat_last_n: usize,
+        verbose_prompt: bool,
+        device: &Device,
+        dtype: DType,
+    ) -> Self {
+        let logits_processor = LogitsProcessor::new(seed, temp, top_p);
+        Self {
+            model,
+            tokenizer,
+            logits_processor,
+            repeat_penalty,
+            repeat_last_n,
+            verbose_prompt,
+            device: device.clone(),
+            dtype,
+        }
+    }
+
+    pub fn run(&mut self, sample_len: usize) -> Result<(), ()> {
+        use std::io::Write;
+
+        println!("[欢迎使用Codegeex4,请输入prompt]");
+        let stdin = std::io::stdin();
+        let reader = BufReader::new(stdin);
+        // 从标准输入读取prompt
+        for line in reader.lines() {
+            let line = line.expect("Failed to read line");
+            let tokens = self.tokenizer.encode(line, true).expect("tokens error");
+            if tokens.is_empty() {
+                panic!("Empty prompts are not supported in the chatglm model.")
+            }
+            if self.verbose_prompt {
+                for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
+                    let token = token.replace('▁', " ").replace("<0x0A>", "\n");
+                    println!("{id:7} -> '{token}'");
+                }
+            }
+            let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
+                Some(token) => *token,
+                None => panic!("cannot find the endoftext token"),
+            };
+            let mut tokens = tokens.get_ids().to_vec();
+            let mut generated_tokens = 0usize;
+
+            std::io::stdout().flush().expect("output flush error");
+            let start_gen = std::time::Instant::now();
+
+            //            println!("\n 开始生成");
+            println!("samplelen {}", sample_len.blue());
+            let mut result = vec![];
+
+            for index in 0..sample_len {
+                let context_size = if index > 0 { 1 } else { tokens.len() };
+                let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
+                let input = Tensor::new(ctxt, &self.device)
+                    .unwrap()
+                    .unsqueeze(0)
+                    .expect("create tensor input error");
+                let logits = self.model.forward(&input).unwrap();
+                let logits = logits.squeeze(0).unwrap().to_dtype(self.dtype).unwrap();
+                let logits = if self.repeat_penalty == 1. {
+                    logits
+                } else {
+                    let start_at = tokens.len().saturating_sub(self.repeat_last_n);
+                    candle_transformers::utils::apply_repeat_penalty(
+                        &logits,
+                        self.repeat_penalty,
+                        &tokens[start_at..],
+                    )
+                    .unwrap()
+                };
+
+                let next_token = self.logits_processor.sample(&logits).unwrap();
+                tokens.push(next_token);
+                generated_tokens += 1;
+                if next_token == eos_token {
+                    break;
+                }
+                let token = self
+                    .tokenizer
+                    .decode(&[next_token], true)
+                    .expect("Token error");
+                if self.verbose_prompt {
+                    println!(
+                        "[Index: {}] [Raw Token: {}] [Decode Token: {}]",
+                        index.blue(),
+                        next_token.green(),
+                        token.yellow()
+                    );
+                }
+                result.push(token);
+                std::io::stdout().flush().unwrap();
+            }
+            let dt = start_gen.elapsed();
+            println!(
+                "\n{generated_tokens} tokens generated ({:.2} token/s)",
+                generated_tokens as f64 / dt.as_secs_f64(),
+            );
+            println!("Result:");
+            for tokens in result {
+                print!("{tokens}");
+            }
+            self.model.reset_kv_cache(); // 清理模型kv
+        }
+
+        Ok(())
+    }
+}

BIN
metric/.DS_Store


BIN
metric/pics/.DS_Store


BIN
metric/pics/Bigcodebench.png


+ 20 - 29
repodemo/chainlit.md

@@ -1,50 +1,41 @@
-# CodeGeeX
-
-# Welcome to My Chat Demo Application
-
-This is a simple demonstration application.
-
-## Instructions
-
-1. Enter your question.
-2. Wait for a response.
-3. Enjoy the conversation!
+![](../resources/logo.jpeg)
+## Welcome to Chat Demo Application
+![](https://github.com/user-attachments/assets/f2cb6c13-a715-4adf-bf3a-b9ca5ee165df)
+This is a simple demo application designed to showcase multi-turn conversations and project Q&A functionalities.
 
 ## Features
 
-- Supports multi-turn conversations.
-- Supports online Q&A.
-- Supports uploading local zip packages for project Q&A and modifications.
-- Supports inputting GitHub project links for project Q&A and modifications.
-
+- Supports multi-turn conversations
+- Supports online Q&A
+- Supports uploading local zip files for project Q&A and modifications
+- Supports inputting GitHub project links for Q&A and modifications
+![](https://github.com/user-attachments/assets/ff6f6e32-457c-4733-815b-b639e4197899)
 ## Installation
 
 1. Clone the repository locally.
-2. Start the model. You can deploy the model using vllm or ollama, provide the OpenAI request format, and set the deployed `api_base` and `api_key`. Alternatively, visit [CodeGeeX API](https://open.bigmodel.cn/dev/api#codegeex-4) to get the API key.
+2. Start the model. You can deploy the model via vllm or ollama, provide the OpenAI request format, set the deployed `api_base` and `api_key`, or access the [CodeGeeX API](https://open.bigmodel.cn/dev/api#codegeex-4) to get an API key. Fill in the corresponding information in the .env file.
+![](https://github.com/user-attachments/assets/6aabc3e4-a930-4853-b511-68b9389fa42f)
 
 ```shell
-#use open.bigmodel.cn api
-openai_api_key = "<|apikey|>"
+# Using open.bigmodel.cn API
+openai_api_key = ""
 openai_api_base = "https://open.bigmodel.cn/api/paas/v4/"
 model_name = "codegeex-4"
-#use vllm
+# Using vllm
 openai_api_key = "EMPTY"
 openai_api_base = "http://xxxx:xxxx/v1"
 model_name = "codegeex4-all-9b"
 ```
 
-3. Fill in the corresponding model information and `bing_search_api` (if you want to experience online search) in the `.env` file.
-4. Install dependencies: `pip install -r requirements.txt`.
-5. Run the application: `chainlit run run.py --port 8899`.
+3. Fill in the corresponding model information and `bing_search_api` (if you want to experience online queries) in the .env file. Turn on the online query switch on the left side of the input box during the chat, which is off by default.
+![](https://github.com/user-attachments/assets/e9d9b620-cfc7-4c2d-bedc-a01d41f79e29)
+4. Install dependencies: `pip install -r requirements.txt`
+5. Run the application: `chainlit run run.py --port 8899`
 
-## Note
+## Notes
 
 Please ensure your network environment can access the CodeGeeX API.
 
-## Disclaimer
-
-This application is for educational and research purposes only and should not be used for any commercial purposes. The developer is not responsible for any loss or damage caused by the use of this application.
-
-## Acknowledgements
+## Acknowledgments
 
 Thank you for using our application. If you have any questions or suggestions, please feel free to contact us. We look forward to your feedback and are committed to providing you with better service.

+ 20 - 26
repodemo/chainlit_zh-CN.md

@@ -1,51 +1,45 @@
-# CodeGeeX
+![](../resources/logo.jpeg)
+## 欢迎使用Chat Demo应用
+![](https://github.com/user-attachments/assets/f2cb6c13-a715-4adf-bf3a-b9ca5ee165df)
+这是一个简单的演示应用程序,用于展示多轮对话和项目问答功能。
 
-# 欢迎使用我的chat demo应用
-
-这是一个简单的演示应用程序。
-
-## 使用说明
-
-1. 输入您的问题
-2. 等待回复
-3. 享受对话!
 
 ## 功能
 
--  支持多轮对话
--  支持联网问答
--  支持上传本地zip压缩包项目,可以进行项目问答和对项目进行修改
--  支持输入GitHub链接项目,可以进行项目问答和对项目进行修改。
+- 支持多轮对话
+- 支持联网问答
+- 支持上传本地zip压缩包项目进行问答和修改
+- 支持输入GitHub链接项目进行问答和修改
+![](https://github.com/user-attachments/assets/ff6f6e32-457c-4733-815b-b639e4197899)
 
 ## 安装
 
 1. 克隆仓库到本地
-2. 启动模型,可以通过vllm或者ollama部署模型,提供openai的请求格式,设置部署的api_base和api_key,或者访问[CodeGeeX API](https://open.bigmodel.cn/dev/api#codegeex-4)获取apikey.
+2. 启动模型,可以通过vllm或者ollama部署模型,提供openai的请求格式,设置部署的api_base和api_key,或者访问[CodeGeeX API](https://open.bigmodel.cn/dev/api#codegeex-4)获取apikey。在.env文件中填写对应的信息
+![](https://github.com/user-attachments/assets/6aabc3e4-a930-4853-b511-68b9389fa42f)
 
 ```shell
-#use open.bigmodel.cn api
-openai_api_key = "<|apikey|>"
+# 使用open.bigmodel.cn API
+openai_api_key = ""
 openai_api_base = "https://open.bigmodel.cn/api/paas/v4/"
 model_name = "codegeex-4"
-#use vllm
+# 使用vllm
 openai_api_key = "EMPTY"
 openai_api_base = "http://xxxx:xxxx/v1"
 model_name = "codegeex4-all-9b"
 ```
 
-3. 到.env文件里填写对应模型信息和bing_search_api(如果需要体验联网查询)
+3. 在.env文件中填写对应模型信息和bing_search_api(如果需要体验联网查询),并且在聊天的时候在输入框左侧打开
+联网查询开关,默认关闭。
+![](https://github.com/user-attachments/assets/e9d9b620-cfc7-4c2d-bedc-a01d41f79e29)
 4. 安装依赖:`pip install -r requirements.txt`
-5. 运行应用:`chainlit run run.py --port 8899` 
+5. 运行应用:`chainlit run run.py --port 8899`
 
-
-## 注意
+## 注意事项
 
 请确保您的网络环境可以访问CodeGeeX的API。
 
-## 免责声明
-
-本应用仅供学习和研究使用,不得用于任何商业用途。开发者不对因使用本应用而导致的任何损失或损害负责。
 
 ## 感谢
 
-感谢您使用我们的应用。如果您有任何问题或建议,请随时联系我们。我们期待您的反馈,并致力于为您提供更好的服务。
+感谢您使用我们的应用。如果您有任何问题或建议,请随时联系我们。我们期待您的反馈,并致力于为您提供更好的服务。

+ 21 - 29
repodemo/readme.md

@@ -1,50 +1,42 @@
-# CodeGeeX
-
-# Welcome to My Chat Demo Application
-
-This is a simple demonstration application.
-
-## Instructions
-
-1. Enter your question.
-2. Wait for a response.
-3. Enjoy the conversation!
+![](../resources/logo.jpeg)
+[English](./readme.md) | [中文](./readme_zh.md)
+## Welcome to Chat Demo Application
+![](https://github.com/user-attachments/assets/f2cb6c13-a715-4adf-bf3a-b9ca5ee165df)
+This is a simple demo application designed to showcase multi-turn conversations and project Q&A functionalities.
 
 ## Features
 
-- Supports multi-turn conversations.
-- Supports online Q&A.
-- Supports uploading local zip packages for project Q&A and modifications.
-- Supports inputting GitHub project links for project Q&A and modifications.
-
+- Supports multi-turn conversations
+- Supports online Q&A
+- Supports uploading local zip files for project Q&A and modifications
+- Supports inputting GitHub project links for Q&A and modifications
+![](https://github.com/user-attachments/assets/ff6f6e32-457c-4733-815b-b639e4197899)
 ## Installation
 
 1. Clone the repository locally.
-2. Start the model. You can deploy the model using vllm or ollama, provide the OpenAI request format, and set the deployed `api_base` and `api_key`. Alternatively, visit [CodeGeeX API](https://open.bigmodel.cn/dev/api#codegeex-4) to get the API key.
+2. Start the model. You can deploy the model via vllm or ollama, provide the OpenAI request format, set the deployed `api_base` and `api_key`, or access the [CodeGeeX API](https://open.bigmodel.cn/dev/api#codegeex-4) to get an API key. Fill in the corresponding information in the .env file.
+![](https://github.com/user-attachments/assets/6aabc3e4-a930-4853-b511-68b9389fa42f)
 
 ```shell
-#use open.bigmodel.cn api
-openai_api_key = "<|apikey|>"
+# Using open.bigmodel.cn API
+openai_api_key = ""
 openai_api_base = "https://open.bigmodel.cn/api/paas/v4/"
 model_name = "codegeex-4"
-#use vllm
+# Using vllm
 openai_api_key = "EMPTY"
 openai_api_base = "http://xxxx:xxxx/v1"
 model_name = "codegeex4-all-9b"
 ```
 
-3. Fill in the corresponding model information and `bing_search_api` (if you want to experience online search) in the `.env` file.
-4. Install dependencies: `pip install -r requirements.txt`.
-5. Run the application: `chainlit run run.py --port 8899`.
+3. Fill in the corresponding model information and `bing_search_api` (if you want to experience online queries) in the .env file. Turn on the online query switch on the left side of the input box during the chat, which is off by default.
+![](https://github.com/user-attachments/assets/e9d9b620-cfc7-4c2d-bedc-a01d41f79e29)
+4. Install dependencies: `pip install -r requirements.txt`
+5. Run the application: `chainlit run run.py --port 8899`
 
-## Note
+## Notes
 
 Please ensure your network environment can access the CodeGeeX API.
 
-## Disclaimer
-
-This application is for educational and research purposes only and should not be used for any commercial purposes. The developer is not responsible for any loss or damage caused by the use of this application.
-
-## Acknowledgements
+## Acknowledgments
 
 Thank you for using our application. If you have any questions or suggestions, please feel free to contact us. We look forward to your feedback and are committed to providing you with better service.

+ 45 - 0
repodemo/readme_zh.md

@@ -0,0 +1,45 @@
+![](../resources/logo.jpeg)
+[English](./readme.md) | [中文](./readme_zh.md)
+## 欢迎使用Chat Demo应用
+![](https://github.com/user-attachments/assets/f2cb6c13-a715-4adf-bf3a-b9ca5ee165df)
+这是一个简单的演示应用程序,用于展示多轮对话和项目问答功能。
+
+
+## 功能
+
+- 支持多轮对话
+- 支持联网问答
+- 支持上传本地zip压缩包项目进行问答和修改
+- 支持输入GitHub链接项目进行问答和修改
+![](https://github.com/user-attachments/assets/ff6f6e32-457c-4733-815b-b639e4197899)
+## 安装
+
+1. 克隆仓库到本地
+2. 启动模型,可以通过vllm或者ollama部署模型,提供openai的请求格式,设置部署的api_base和api_key,或者访问[CodeGeeX API](https://open.bigmodel.cn/dev/api#codegeex-4)获取apikey。在.env文件中填写对应的信息
+![](https://github.com/user-attachments/assets/6aabc3e4-a930-4853-b511-68b9389fa42f)
+
+```shell
+# 使用open.bigmodel.cn API
+openai_api_key = ""
+openai_api_base = "https://open.bigmodel.cn/api/paas/v4/"
+model_name = "codegeex-4"
+# 使用vllm
+openai_api_key = "EMPTY"
+openai_api_base = "http://xxxx:xxxx/v1"
+model_name = "codegeex4-all-9b"
+```
+
+3. 在.env文件中填写对应模型信息和bing_search_api(如果需要体验联网查询),并且在聊天的时候在输入框左侧打开
+联网查询开关,默认关闭。
+![](https://github.com/user-attachments/assets/e9d9b620-cfc7-4c2d-bedc-a01d41f79e29)
+4. 安装依赖:`pip install -r requirements.txt`
+5. 运行应用:`chainlit run run.py --port 8899`
+
+## 注意事项
+
+请确保您的网络环境可以访问CodeGeeX的API。
+
+
+## 感谢
+
+感谢您使用我们的应用。如果您有任何问题或建议,请随时联系我们。我们期待您的反馈,并致力于为您提供更好的服务。

+ 2 - 1
repodemo/requirements.txt

@@ -1,4 +1,5 @@
 chainlit==1.1.305
 beautifulsoup4
 python-dotenv
-gitpython
+gitpython
+openai==1.35.4

+ 3 - 3
repodemo/run.py

@@ -49,8 +49,8 @@ def tools_choose_agent(input_text):
 async def chat_profile():
     return [
         cl.ChatProfile(
-            name="联网聊天",
-            markdown_description="聊天demo:支持多轮对话。支持联网回答用户问题。默认联网,如不联网在输入框左边关闭联网功能。",
+            name="chat聊天",
+            markdown_description="聊天demo:支持多轮对话。支持联网回答用户问题(需要在输入框左边打开联网开关)。默认联网,如不联网在输入框左边关闭联网功能。",
             starters=[
                 cl.Starter(
                     label="请你用python写一个快速排序。",
@@ -107,7 +107,7 @@ async def start():
             Switch(
                 id="is_online",
                 label="CodeGeeX4 - is_online",
-                initial=True
+                initial=False
             ),
         ]
     ).send()

BIN
resources/candle_example.png