donjuanplatinum 1 год назад
Родитель
Сommit
9afa8d4428

+ 11 - 22
candle_demo/Cargo.toml

@@ -1,5 +1,7 @@
-[package]
-name = "codegeex4-candle"
+[workspace]
+members = ["gui","cli", "codegeex4", "api-server"]
+resolver = "2"
+[workspace.package]
 version = "0.1.0"
 version = "0.1.0"
 edition = "2021"
 edition = "2021"
 authors = ["Donjuan Platinum <[email protected]>"]
 authors = ["Donjuan Platinum <[email protected]>"]
@@ -7,35 +9,22 @@ license = "GPL-2.0-only"
 description = "Codegeex4"
 description = "Codegeex4"
 # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 
 
-[dependencies]
-# candle-transformers = {path = "../candle/candle-transformers"}
-# candle-core = {path = "../candle/candle-core"}
-# candle-nn = {path = "../candle/candle-nn"}
-#anyhow = "1.0.86"
+[workspace.dependencies]
 hf-hub = "0.3.2"
 hf-hub = "0.3.2"
-#tokenizer = "0.1.2"
 clap = { version = "4.5.6", features = ["derive"] }
 clap = { version = "4.5.6", features = ["derive"] }
-#tracing-chrome = "0.7.2"
-#candle-examples = {path = "../candle/candle-examples"}
-#tracing-subscriber = "0.3.18"
 tokenizers = "0.19.1"
 tokenizers = "0.19.1"
 serde_json = "1.0.120"
 serde_json = "1.0.120"
 candle-core = "0.6.0"
 candle-core = "0.6.0"
-candle-transformers = "0.6.0"
+# candle-transformers = "0.6.0"
+candle-transformers = "0.6.0"	
 candle-examples = "0.6.0"
 candle-examples = "0.6.0"
 candle-nn = "0.6.0"
 candle-nn = "0.6.0"
 safetensors = "0.4.3"
 safetensors = "0.4.3"
-accelerate-src = { version = "0.3.2", optional = true}
-intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] ,optional = true}
+accelerate-src = { version = "0.3.2"}
+intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
 rand = "0.8.5"
 rand = "0.8.5"
 owo-colors = "4.0.0"
 owo-colors = "4.0.0"
-#safetensors = {path ="../safetensors/safetensors"}
-[build-dependencies]
-bindgen_cuda = { version = "0.1.1", optional = true }
+codegeex4 = {path = "./codegeex4"}
+
 
 
 
 
-[features]
-default = []
-cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
-accelerate = ["dep:accelerate-src", "candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
-mkl = ["dep:intel-mkl-src", "candle-core/mkl", "candle-nn/mkl", "candle-transformers/mkl"]

+ 4 - 4
candle_demo/README.org

@@ -6,12 +6,12 @@ THUDM/CodeGeeX4 is a versatile model for all AI software development scenarios,
 - [[https://codegeex.cn/][HomePage]]
 - [[https://codegeex.cn/][HomePage]]
 - [[https://huggingface.co/THUDM/codegeex4-all-9b][huggingface]]  
 - [[https://huggingface.co/THUDM/codegeex4-all-9b][huggingface]]  
 - [[https://github.com/huggingface/candle/blob/main/candle-examples/examples/codegeex4-9b/README.org][Candle]]
 - [[https://github.com/huggingface/candle/blob/main/candle-examples/examples/codegeex4-9b/README.org][Candle]]
-** Running with ~cuda~
-
+  
+** Cli
 #+begin_src shell
 #+begin_src shell
-  cargo run --example codegeex4-9b --release --features cuda   --  --sample-len 300
+  cargo build --release -p codegeex4-cli
+  cargo build --release -p codegeex4-cli --features cuda # if cuda is avalibe
 #+end_src
 #+end_src
-
 ** Running with ~cpu~
 ** Running with ~cpu~
 #+begin_src shell
 #+begin_src shell
   cargo run --example codegeex4-9b --release --cpu   --  --sample-len 300
   cargo run --example codegeex4-9b --release --cpu   --  --sample-len 300

+ 37 - 0
candle_demo/api-server/Cargo.toml

@@ -0,0 +1,37 @@
+[package]
+name = "api-server"
+version.workspace = true
+edition.workspace = true
+authors.workspace = true
+license.workspace = true
+description.workspace = true
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+hf-hub = {workspace = true}
+clap = { workspace = true}
+tokenizers = {workspace = true}
+serde_json.workspace = true
+candle-core = {workspace = true}
+candle-transformers = {workspace = true}
+candle-examples = {workspace = true}
+candle-nn = {workspace = true}
+safetensors = {workspace = true}
+accelerate-src = { workspace = true, optional = true}
+intel-mkl-src = { workspace = true ,optional = true}
+rand = { workspace = true}
+owo-colors = {workspace = true}
+codegeex4 = {workspace = true}
+tokio = {version = "1.39.1", features = ["full"]}
+actix-web = "4.8.0"
+serde = { version = "1.0.204", features = ["derive"] }
+shortuuid = "0.0.1"
+short-uuid = "0.1.2"
+[build-dependencies]
+bindgen_cuda = { version = "0.1.1", optional = true }
+[features]
+default = []
+cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
+accelerate = ["dep:accelerate-src", "candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
+mkl = ["dep:intel-mkl-src", "candle-core/mkl", "candle-nn/mkl", "candle-transformers/mkl"]

+ 123 - 0
candle_demo/api-server/src/api.rs

@@ -0,0 +1,123 @@
+use actix_web::{
+    get, post,
+    web::{self, Data},
+    HttpRequest, Responder,
+};
+use owo_colors::OwoColorize;
+use serde::{Deserialize, Serialize};
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ChatMessage {
+    pub role: String,
+    pub content: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ChatCompletionRequest {
+    pub model: String,
+    pub messages: Vec<ChatMessage>,
+    pub temperature: f64,
+    pub top_p: f64,
+    pub max_tokens: usize,
+    pub stop: Vec<String>,
+    pub stream: bool,
+    pub presence_penalty: Option<f32>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct DeltaMessage {
+    pub role: String,
+    pub content: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ChatCompletionResponseStreamChoice {
+    pub index: i32,
+    pub delta: DeltaMessage,
+    pub finish_reason: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ChatCompletionStreamResponse {
+    pub id: String,
+    pub object: String,
+    pub created: i32,
+    pub model: String,
+    pub choices: Vec<ChatCompletionResponseStreamChoice>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ChatCompletionResponseChoice {
+    pub index: i32,
+    pub message: ChatMessage,
+    pub finish_reason: Option<FinishResaon>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ChatCompletionResponse {
+    pub id: String,
+    pub object: String,
+    pub created: u64,
+    pub model: String,
+    pub choices: Vec<ChatCompletionResponseChoice>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub enum FinishResaon{
+    STOP,
+    LENGTH,
+}
+use std::time::{SystemTime, UNIX_EPOCH};
+impl ChatCompletionResponse {
+    pub fn empty() -> Self {
+        let current_time = SystemTime::now();
+        Self {
+            id: format!("chatcmpl-{}", short_uuid::ShortUuid::generate()),
+            object: "chat.completion".to_string(),
+            created: current_time
+                .duration_since(UNIX_EPOCH)
+                .expect("failed to get time")
+                .as_secs()
+                .into(),
+            model: "codegeex4".to_string(),
+            choices: vec![ChatCompletionResponseChoice::empty()],
+        }
+    }
+}
+
+impl ChatCompletionResponseChoice {
+    pub fn empty() -> Self {
+        Self {
+            index: 0,
+            message: ChatMessage {
+                role: "assistant".to_string(),
+                content: "".to_string(),
+            },
+            finish_reason: None,
+        }
+    }
+}
+
+impl ChatCompletionRequest {
+    pub fn empty() -> Self {
+	Self{
+	    model: "codegeex4".to_string(),
+	    messages: vec!(ChatMessage {
+                role: "assistant".to_string(),
+                content: "".to_string(),
+            }),
+	    temperature: 0.2_f64,
+	    top_p: 0.2_f64,
+	    max_tokens: 1024_usize,
+	    stop: vec!("<|user|>".to_string(), "<|assistant|>".to_string(), "<|observation|>".to_string(), "<|endoftext|>".to_string()),
+	    stream: true,
+	    presence_penalty: None,
+	}
+    }
+}
+
+// impl DeltaMessage {
+//     pub fn new() -> Self {
+// 	role:
+//     }
+// }

+ 10 - 0
candle_demo/api-server/src/args.rs

@@ -0,0 +1,10 @@
+use clap::Parser;
+
+#[derive(Parser, Debug, Clone)]
+#[clap(version, about)]
+pub struct Args {
+    #[arg(name = "listen", short, long, default_value = "0.0.0.0:3000")]
+    pub address: String,
+    #[arg(short, long, default_value_t = 1)]
+    pub workers: usize,
+}

+ 19 - 0
candle_demo/api-server/src/main.rs

@@ -0,0 +1,19 @@
+mod api;
+mod args;
+mod server;
+mod model;
+use clap::Parser;
+use owo_colors::OwoColorize;
+
+#[tokio::main]
+async fn main() {
+    let args = args::Args::parse();
+    let server = server::Server::new(args.clone());
+    println!(
+        "{} Server Binding On {} with {} workers",
+        "[INFO]".green(),
+        &args.address.purple(),
+        &args.workers.purple()
+    );
+    server.run().await;
+}

+ 6 - 0
candle_demo/api-server/src/model.rs

@@ -0,0 +1,6 @@
+use codegeex4::codegeex4::Config;
+use crate::api::ChatCompletionRequest;
+fn stream_chat(request: ChatCompletionRequest) {
+    let default_config = codegeex4::codegeex4::Config::codegeex4();
+    
+}

+ 31 - 0
candle_demo/api-server/src/server.rs

@@ -0,0 +1,31 @@
+use crate::args::Args;
+use actix_web::{web, App, HttpResponse, HttpServer};
+use owo_colors::OwoColorize;
+
+#[derive(Debug)]
+pub struct Server {
+    config: Args,
+}
+
+impl Server {
+    pub fn new(config: Args) -> Self {
+        return Server { config };
+    }
+    pub async fn run(&self) -> () {
+        HttpServer::new(move || App::new())
+            .bind(&self.config.address)
+            .expect(&format!("{}", "Unable To Bind Server !".red()))
+            .workers(self.config.workers)
+            .run()
+            .await
+            .expect(&format!("{}", "Unable To Run the Server !".red()));
+    }
+}
+
+// use super::api::*;
+// use uuid;
+// pub async fn chat(request: ChatCompletionRequest) ->impl Responder {
+//     if request.stream == true {
+// 	return Htt
+//     }
+// }

+ 33 - 0
candle_demo/cli/Cargo.toml

@@ -0,0 +1,33 @@
+[package]
+name= "codegeex4-cli"
+version.workspace = true
+edition.workspace = true
+authors.workspace = true
+license.workspace = true
+description.workspace = true
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+hf-hub = {workspace = true}
+clap = { workspace = true}
+tokenizers = {workspace = true}
+serde_json = {workspace = true}
+candle-core = {workspace = true}
+candle-transformers = {workspace = true}
+candle-examples = {workspace = true}
+candle-nn = {workspace = true}
+safetensors = {workspace = true}
+accelerate-src = { workspace = true, optional = true}
+intel-mkl-src = { workspace = true ,optional = true}
+rand = { workspace = true}
+owo-colors = {workspace = true}
+codegeex4 = {workspace = true}
+
+[build-dependencies]
+bindgen_cuda = { version = "0.1.1", optional = true }
+[features]
+default = []
+cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
+accelerate = ["dep:accelerate-src", "candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
+mkl = ["dep:intel-mkl-src", "candle-core/mkl", "candle-nn/mkl", "candle-transformers/mkl"]

+ 99 - 0
candle_demo/cli/src/main.rs

@@ -0,0 +1,99 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use candle_core as candle;
+use candle_core::DType;
+use candle_nn::VarBuilder;
+use clap::Parser;
+use codegeex4::args::Args;
+use codegeex4::codegeex4::*;
+use codegeex4::TextGeneration;
+use hf_hub::{Repo, RepoType};
+use owo_colors::{self, OwoColorize};
+use rand::Rng;
+use tokenizers::Tokenizer;
+
+fn main() -> Result<(), ()> {
+    let args = Args::parse();
+    println!(
+        "avx: {}, neon: {}, simd128: {}, f16c: {}",
+        candle::utils::with_avx().red(),
+        candle::utils::with_neon().red(),
+        candle::utils::with_simd128().red(),
+        candle::utils::with_f16c().red(),
+    );
+    println!(
+        "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
+        args.temperature.unwrap_or(0.95).red(),
+        args.repeat_penalty.red(),
+        args.repeat_last_n.red(),
+    );
+
+    println!("cache path {}", args.cache_path.blue());
+    let mut seed: u64 = 0;
+    if let Some(_seed) = args.seed {
+        seed = _seed;
+    } else {
+        let mut rng = rand::thread_rng();
+        seed = rng.gen();
+    }
+    println!("Using Seed {}", seed.red());
+    let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
+        .build()
+        .unwrap();
+
+    let model_id = match args.model_id {
+        Some(model_id) => model_id.to_string(),
+        None => "THUDM/codegeex4-all-9b".to_string(),
+    };
+    let revision = match args.revision {
+        Some(rev) => rev.to_string(),
+        None => "main".to_string(),
+    };
+    let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
+    let tokenizer_filename = match args.tokenizer {
+        Some(file) => std::path::PathBuf::from(file),
+        None => api
+            .model("THUDM/codegeex4-all-9b".to_string())
+            .get("tokenizer.json")
+            .unwrap(),
+    };
+    let filenames = match args.weight_file {
+        Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
+        None => {
+            candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json").unwrap()
+        }
+    };
+    let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
+    let start = std::time::Instant::now();
+    let config = Config::codegeex4();
+    let device = candle_examples::device(args.cpu).unwrap();
+    let dtype = if device.is_cuda() {
+        DType::BF16
+    } else {
+        DType::F32
+    };
+    println!("DType is {:?}", dtype.yellow());
+    let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device).unwrap() };
+    let model = Model::new(&config, vb).unwrap();
+
+    println!("模型加载完毕 {:?}", start.elapsed().as_secs().green());
+
+    let mut pipeline = TextGeneration::new(
+        model,
+        tokenizer,
+        seed,
+        args.temperature,
+        args.top_p,
+        args.repeat_penalty,
+        args.repeat_last_n,
+        args.verbose_prompt,
+        &device,
+        dtype,
+    );
+    pipeline.run(args.sample_len)?;
+    Ok(())
+}

+ 34 - 0
candle_demo/codegeex4/Cargo.toml

@@ -0,0 +1,34 @@
+[package]
+name = "codegeex4"
+version.workspace = true
+edition.workspace = true
+authors.workspace = true
+license.workspace = true
+description.workspace = true
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+hf-hub = {workspace = true}
+clap = { workspace = true}
+tokenizers = {workspace = true}
+serde_json = {workspace = true}
+candle-core = {workspace = true}
+candle-transformers = {workspace = true}
+candle-examples = {workspace = true}
+candle-nn = {workspace = true}
+safetensors = {workspace = true}
+accelerate-src = { workspace = true, optional = true}
+intel-mkl-src = { workspace = true ,optional = true}
+rand = { workspace = true}
+owo-colors = {workspace = true}
+
+
+
+[build-dependencies]
+bindgen_cuda = { version = "0.1.1", optional = true }
+[features]
+default = []
+cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
+accelerate = ["dep:accelerate-src", "candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
+mkl = ["dep:intel-mkl-src", "candle-core/mkl", "candle-nn/mkl", "candle-transformers/mkl"]

+ 51 - 0
candle_demo/codegeex4/src/args.rs

@@ -0,0 +1,51 @@
+use clap::Parser;
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+pub struct Args {
+    /// Run on CPU rather than on GPU.
+    #[arg(name = "cache", short, long, default_value = ".")]
+    pub cache_path: String,
+
+    #[arg(long)]
+    pub cpu: bool,
+
+    /// Display the token for the specified prompt.
+    #[arg(long)]
+    pub verbose_prompt: bool,
+
+    /// The temperature used to generate samples.
+    #[arg(long)]
+    pub temperature: Option<f64>,
+
+    /// Nucleus sampling probability cutoff.
+    #[arg(long)]
+    pub top_p: Option<f64>,
+
+    /// The seed to use when generating random samples.
+    #[arg(long)]
+    pub seed: Option<u64>,
+
+    /// The length of the sample to generate (in tokens).
+    #[arg(long, short = 'n', default_value_t = 5000)]
+    pub sample_len: usize,
+
+    #[arg(long)]
+    pub model_id: Option<String>,
+
+    #[arg(long)]
+    pub revision: Option<String>,
+
+    #[arg(long)]
+    pub weight_file: Option<String>,
+
+    #[arg(long)]
+    pub tokenizer: Option<String>,
+
+    /// Penalty to be applied for repeating tokens, 1. means no penalty.
+    #[arg(long, default_value_t = 1.1)]
+    pub repeat_penalty: f32,
+
+    /// The context size to consider for the repeat penalty.
+    #[arg(long, default_value_t = 64)]
+    pub repeat_last_n: usize,
+}

+ 13 - 6
candle_demo/src/codegeex4.rs → candle_demo/codegeex4/src/codegeex4.rs

@@ -110,7 +110,7 @@ struct CoreAttention {
     dtype: DType,
     dtype: DType,
 }
 }
 
 
-fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32,dtype:DType) -> Result<Tensor> {
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32, dtype: DType) -> Result<Tensor> {
     let shape = mask.shape();
     let shape = mask.shape();
     let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
     let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
     let m = mask.where_cond(&on_true.to_dtype(dtype)?, on_false)?;
     let m = mask.where_cond(&on_true.to_dtype(dtype)?, on_false)?;
@@ -118,7 +118,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32,dtype:DType) -> Re
 }
 }
 
 
 impl CoreAttention {
 impl CoreAttention {
-    fn new(layer_number: usize, cfg: &Config,dtype: DType) -> Result<Self> {
+    fn new(layer_number: usize, cfg: &Config, dtype: DType) -> Result<Self> {
         let norm_factor = (cfg.kv_channels as f64).sqrt();
         let norm_factor = (cfg.kv_channels as f64).sqrt();
         let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling {
         let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling {
             let coeff = f64::max(1.0, layer_number as f64);
             let coeff = f64::max(1.0, layer_number as f64);
@@ -126,7 +126,11 @@ impl CoreAttention {
         } else {
         } else {
             (norm_factor, None)
             (norm_factor, None)
         };
         };
-        Ok(Self { coeff, norm_factor, dtype})
+        Ok(Self {
+            coeff,
+            norm_factor,
+            dtype,
+        })
     }
     }
 
 
     fn forward(
     fn forward(
@@ -159,7 +163,7 @@ impl CoreAttention {
                 &matmul_result,
                 &matmul_result,
                 &mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?,
                 &mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?,
                 f32::NEG_INFINITY,
                 f32::NEG_INFINITY,
-		self.dtype,
+                self.dtype,
             )?,
             )?,
             None => matmul_result,
             None => matmul_result,
         };
         };
@@ -175,7 +179,10 @@ impl CoreAttention {
             value_layer.reshape((value_layer.dim(0)?, output_size.0 * output_size.1, ()))?;
             value_layer.reshape((value_layer.dim(0)?, output_size.0 * output_size.1, ()))?;
         let attention_probs =
         let attention_probs =
             attention_probs.reshape((output_size.0 * output_size.1, output_size.2, ()))?;
             attention_probs.reshape((output_size.0 * output_size.1, output_size.2, ()))?;
-        let context_layer = Tensor::matmul(&attention_probs.contiguous()?, &value_layer.transpose(0, 1)?.contiguous()?)?;
+        let context_layer = Tensor::matmul(
+            &attention_probs.contiguous()?,
+            &value_layer.transpose(0, 1)?.contiguous()?,
+        )?;
         let context_layer = context_layer.reshape(output_size)?;
         let context_layer = context_layer.reshape(output_size)?;
         let context_layer = context_layer.permute((2, 0, 1, 3))?.contiguous()?;
         let context_layer = context_layer.permute((2, 0, 1, 3))?.contiguous()?;
         context_layer.flatten_from(D::Minus2)
         context_layer.flatten_from(D::Minus2)
@@ -209,7 +216,7 @@ impl SelfAttention {
             cfg.add_bias_linear || cfg.add_qkv_bias,
             cfg.add_bias_linear || cfg.add_qkv_bias,
             vb.pp("query_key_value"),
             vb.pp("query_key_value"),
         )?;
         )?;
-        let core_attention = CoreAttention::new(layer_number, cfg,vb.dtype())?;
+        let core_attention = CoreAttention::new(layer_number, cfg, vb.dtype())?;
         let dense = linear(
         let dense = linear(
             cfg.hidden_size,
             cfg.hidden_size,
             cfg.hidden_size,
             cfg.hidden_size,

+ 8 - 153
candle_demo/src/main.rs → candle_demo/codegeex4/src/lib.rs

@@ -1,24 +1,16 @@
-#[cfg(feature = "mkl")]
-extern crate intel_mkl_src;
+pub mod codegeex4;
 
 
-#[cfg(feature = "accelerate")]
-extern crate accelerate_src;
+pub mod args;
 
 
-use clap::Parser;
-use codegeex4_candle::codegeex4::*;
+use candle_core::{DType, Device, Tensor};
+use candle_transformers::generation::LogitsProcessor;
+use codegeex4::*;
 use owo_colors::{self, OwoColorize};
 use owo_colors::{self, OwoColorize};
 use std::io::BufRead;
 use std::io::BufRead;
 use std::io::BufReader;
 use std::io::BufReader;
-
-use candle_core as candle;
-use candle_core::{DType, Device, Tensor};
-use candle_nn::VarBuilder;
-use candle_transformers::generation::LogitsProcessor;
-use hf_hub::{Repo, RepoType};
-use rand::Rng;
 use tokenizers::Tokenizer;
 use tokenizers::Tokenizer;
 
 
-struct TextGeneration {
+pub struct TextGeneration {
     model: Model,
     model: Model,
     device: Device,
     device: Device,
     tokenizer: Tokenizer,
     tokenizer: Tokenizer,
@@ -31,7 +23,7 @@ struct TextGeneration {
 
 
 impl TextGeneration {
 impl TextGeneration {
     #[allow(clippy::too_many_arguments)]
     #[allow(clippy::too_many_arguments)]
-    fn new(
+    pub fn new(
         model: Model,
         model: Model,
         tokenizer: Tokenizer,
         tokenizer: Tokenizer,
         seed: u64,
         seed: u64,
@@ -56,7 +48,7 @@ impl TextGeneration {
         }
         }
     }
     }
 
 
-    fn run(&mut self, sample_len: usize) -> Result<(), ()> {
+    pub fn run(&mut self, sample_len: usize) -> Result<(), ()> {
         use std::io::Write;
         use std::io::Write;
 
 
         let stdin = std::io::stdin();
         let stdin = std::io::stdin();
@@ -145,140 +137,3 @@ impl TextGeneration {
         Ok(())
         Ok(())
     }
     }
 }
 }
-
-#[derive(Parser, Debug)]
-#[command(author, version, about, long_about = None)]
-struct Args {
-    /// Run on CPU rather than on GPU.
-    #[arg(name = "cache", short, long, default_value = ".")]
-    cache_path: String,
-
-    #[arg(long)]
-    cpu: bool,
-
-    /// Display the token for the specified prompt.
-    #[arg(long)]
-    verbose_prompt: bool,
-
-    #[arg(long)]
-    prompt: String,
-
-    /// The temperature used to generate samples.
-    #[arg(long)]
-    temperature: Option<f64>,
-
-    /// Nucleus sampling probability cutoff.
-    #[arg(long)]
-    top_p: Option<f64>,
-
-    /// The seed to use when generating random samples.
-    #[arg(long)]
-    seed: Option<u64>,
-
-    /// The length of the sample to generate (in tokens).
-    #[arg(long, short = 'n', default_value_t = 5000)]
-    sample_len: usize,
-
-    #[arg(long)]
-    model_id: Option<String>,
-
-    #[arg(long)]
-    revision: Option<String>,
-
-    #[arg(long)]
-    weight_file: Option<String>,
-
-    #[arg(long)]
-    tokenizer: Option<String>,
-
-    /// Penalty to be applied for repeating tokens, 1. means no penalty.
-    #[arg(long, default_value_t = 1.1)]
-    repeat_penalty: f32,
-
-    /// The context size to consider for the repeat penalty.
-    #[arg(long, default_value_t = 64)]
-    repeat_last_n: usize,
-}
-
-fn main() -> Result<(), ()> {
-    let args = Args::parse();
-    println!(
-        "avx: {}, neon: {}, simd128: {}, f16c: {}",
-        candle::utils::with_avx().red(),
-        candle::utils::with_neon().red(),
-        candle::utils::with_simd128().red(),
-        candle::utils::with_f16c().red(),
-    );
-    println!(
-        "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
-        args.temperature.unwrap_or(0.95).red(),
-        args.repeat_penalty.red(),
-        args.repeat_last_n.red(),
-    );
-
-    println!("cache path {}", args.cache_path.blue());
-    println!("Prompt: [{}]", args.prompt.green());
-    let mut seed: u64 = 0;
-    if let Some(_seed) = args.seed {
-        seed = _seed;
-    } else {
-        let mut rng = rand::thread_rng();
-        seed = rng.gen();
-    }
-    println!("Using Seed {}", seed.red());
-    let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
-        .build()
-        .unwrap();
-
-    let model_id = match args.model_id {
-        Some(model_id) => model_id.to_string(),
-        None => "THUDM/codegeex4-all-9b".to_string(),
-p    };
-    let revision = match args.revision {
-        Some(rev) => rev.to_string(),
-        None => "main".to_string(),
-    };
-    let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
-    let tokenizer_filename = match args.tokenizer {
-        Some(file) => std::path::PathBuf::from(file),
-        None => api
-            .model("THUDM/codegeex4-all-9b".to_string())
-            .get("tokenizer.json")
-            .unwrap(),
-    };
-    let filenames = match args.weight_file {
-        Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
-        None => {
-            candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json").unwrap()
-        }
-    };
-    let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
-    let start = std::time::Instant::now();
-    let config = Config::codegeex4();
-    let device = candle_examples::device(args.cpu).unwrap();
-    let dtype = if device.is_cuda() {
-        DType::BF16
-    } else {
-        DType::F32
-    };
-    println!("DType is {:?}", dtype.yellow());
-    let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device).unwrap() };
-    let model = Model::new(&config, vb).unwrap();
-
-    println!("模型加载完毕 {:?}", start.elapsed().as_secs().green());
-
-    let mut pipeline = TextGeneration::new(
-        model,
-        tokenizer,
-        seed,
-        args.temperature,
-        args.top_p,
-        args.repeat_penalty,
-        args.repeat_last_n,
-        args.verbose_prompt,
-        &device,
-        dtype,
-    );
-    pipeline.run(args.sample_len)?;
-    Ok(())
-}

+ 33 - 0
candle_demo/gui/Cargo.toml

@@ -0,0 +1,33 @@
+[package]
+name= "codegeex4-gui"
+version.workspace = true
+edition.workspace = true
+authors.workspace = true
+license.workspace = true
+description.workspace = true
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+hf-hub = {workspace = true}
+clap = { workspace = true}
+tokenizers = {workspace = true}
+serde_json = {workspace = true}
+candle-core = {workspace = true}
+candle-transformers = {workspace = true}
+candle-examples = {workspace = true}
+candle-nn = {workspace = true}
+safetensors = {workspace = true}
+accelerate-src = { workspace = true, optional = true}
+intel-mkl-src = { workspace = true ,optional = true}
+rand = { workspace = true}
+owo-colors = {workspace = true}
+iced = "0.12.1"
+
+[build-dependencies]
+bindgen_cuda = { version = "0.1.1", optional = true }
+[features]
+default = []
+cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
+accelerate = ["dep:accelerate-src", "candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
+mkl = ["dep:intel-mkl-src", "candle-core/mkl", "candle-nn/mkl", "candle-transformers/mkl"]

+ 1 - 0
candle_demo/gui/src/main.rs

@@ -0,0 +1 @@
+fn main() {}

+ 8 - 0
candle_demo/server/Cargo.toml

@@ -0,0 +1,8 @@
+[package]
+name = "server"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]

+ 3 - 0
candle_demo/server/src/main.rs

@@ -0,0 +1,3 @@
+fn main() {
+    println!("Hello, world!");
+}

+ 0 - 1
candle_demo/src/lib.rs

@@ -1 +0,0 @@
-pub mod codegeex4;