Browse Source

add README and change candle_demo

donjuanplatinum 1 year ago
parent
commit
e1a09f26f5
3 changed files with 139 additions and 46 deletions
  1. 40 6
      candle_demo/README.md
  2. 52 0
      candle_demo/README_zh.md
  3. 47 40
      candle_demo/src/main.rs

+ 40 - 6
candle_demo/README.md

@@ -1,11 +1,45 @@
-# CPU运行
+# CPU Runing
 ```
 cargo run --release -- --prompt your prompt
 ```
 
-# Cuda运行
-- 注意 需要cuda为>=12.4以上的版本
-```
-cargo build --release --features cuda
-./target/release/codegeex4-candle --prompt your prompt
+# Use
+``` shell
+Codegeex4
+
+Usage: codegeex4-candle [OPTIONS] --prompt <PROMPT>
+
+Options:
+  -c, --cache <cache>
+          Run on CPU rather than on GPU [default: .]
+      --cpu
+          
+      --verbose-prompt
+          Display the token for the specified prompt
+      --prompt <PROMPT>
+          
+      --temperature <TEMPERATURE>
+          The temperature used to generate samples
+      --top-p <TOP_P>
+          Nucleus sampling probability cutoff
+      --seed <SEED>
+          The seed to use when generating random samples [default: 299792458]
+  -n, --sample-len <SAMPLE_LEN>
+          The length of the sample to generate (in tokens) [default: 5000]
+      --model-id <MODEL_ID>
+          
+      --revision <REVISION>
+          
+      --weight-file <WEIGHT_FILE>
+          
+      --tokenizer <TOKENIZER>
+          
+      --repeat-penalty <REPEAT_PENALTY>
+          Penalty to be applied for repeating tokens, 1. means no penalty [default: 1.1]
+      --repeat-last-n <REPEAT_LAST_N>
+          The context size to consider for the repeat penalty [default: 64]
+  -h, --help
+          Print help
+  -V, --version
+          Print version
 ```

+ 52 - 0
candle_demo/README_zh.md

@@ -0,0 +1,52 @@
+# CPU运行
+```
+cargo run --release -- --prompt your prompt
+```
+
+# 使用
+``` shell
+Codegeex4
+
+Usage: codegeex4-candle [OPTIONS] --prompt <PROMPT>
+
+Options:
+  -c, --cache <cache>
+          Run on CPU rather than on GPU [default: .]
+      --cpu
+          
+      --verbose-prompt
+          Display the token for the specified prompt
+      --prompt <PROMPT>
+          
+      --temperature <TEMPERATURE>
+          The temperature used to generate samples
+      --top-p <TOP_P>
+          Nucleus sampling probability cutoff
+      --seed <SEED>
+          The seed to use when generating random samples [default: 299792458]
+  -n, --sample-len <SAMPLE_LEN>
+          The length of the sample to generate (in tokens) [default: 5000]
+      --model-id <MODEL_ID>
+          
+      --revision <REVISION>
+          
+      --weight-file <WEIGHT_FILE>
+          
+      --tokenizer <TOKENIZER>
+          
+      --repeat-penalty <REPEAT_PENALTY>
+          Penalty to be applied for repeating tokens, 1. means no penalty [default: 1.1]
+      --repeat-last-n <REPEAT_LAST_N>
+          The context size to consider for the repeat penalty [default: 64]
+  -h, --help
+          Print help
+  -V, --version
+          Print version
+```
+# Cuda运行
+- 注意 需要cuda为>=12.4以上的版本
+```
+cargo build --release --features cuda
+./target/release/codegeex4-candle --prompt your prompt
+```
+

+ 47 - 40
candle_demo/src/main.rs

@@ -1,12 +1,11 @@
 use clap::Parser;
 use codegeex4_candle::codegeex4::*;
 
-
-use candle_core::{DType, Device, Tensor};
 use candle_core as candle;
+use candle_core::{DType, Device, Tensor};
 use candle_nn::VarBuilder;
 use candle_transformers::generation::LogitsProcessor;
-use hf_hub::{api::sync::Api, Repo, RepoType};
+use hf_hub::{Repo, RepoType};
 use tokenizers::Tokenizer;
 
 struct TextGeneration {
@@ -32,7 +31,7 @@ impl TextGeneration {
         repeat_last_n: usize,
         verbose_prompt: bool,
         device: &Device,
-	dtype: DType,
+        dtype: DType,
     ) -> Self {
         let logits_processor = LogitsProcessor::new(seed, temp, top_p);
         Self {
@@ -43,11 +42,11 @@ impl TextGeneration {
             repeat_last_n,
             verbose_prompt,
             device: device.clone(),
-	    dtype,
+            dtype,
         }
     }
 
-    fn run(&mut self, prompt: &str, sample_len: usize) -> Result<(),()> {
+    fn run(&mut self, prompt: &str, sample_len: usize) -> Result<(), ()> {
         use std::io::Write;
         println!("starting the inference loop");
         let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
@@ -66,22 +65,24 @@ impl TextGeneration {
         };
         let mut tokens = tokens.get_ids().to_vec();
         let mut generated_tokens = 0usize;
-        
-        
+
         print!("{prompt}");
         std::io::stdout().flush().expect("output flush error");
         let start_gen = std::time::Instant::now();
-	
-	println!("\n start_gen");
-	println!("samplelen {}",sample_len);
-	let mut count = 0;
-	let mut result = vec!();
+
+        println!("\n 开始生成");
+        println!("samplelen {}", sample_len);
+        let mut count = 0;
+        let mut result = vec![];
         for index in 0..sample_len {
-	    count += 1;
-	    println!("sample count {}",count);
+            count += 1;
+            println!("sample count {}", count);
             let context_size = if index > 0 { 1 } else { tokens.len() };
             let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
-            let input = Tensor::new(ctxt, &self.device).unwrap().unsqueeze(0).expect("create tensor input error");
+            let input = Tensor::new(ctxt, &self.device)
+                .unwrap()
+                .unsqueeze(0)
+                .expect("create tensor input error");
             let logits = self.model.forward(&input).unwrap();
             let logits = logits.squeeze(0).unwrap().to_dtype(self.dtype).unwrap();
             let logits = if self.repeat_penalty == 1. {
@@ -92,7 +93,8 @@ impl TextGeneration {
                     &logits,
                     self.repeat_penalty,
                     &tokens[start_at..],
-                ).unwrap()
+                )
+                .unwrap()
             };
 
             let next_token = self.logits_processor.sample(&logits).unwrap();
@@ -101,10 +103,13 @@ impl TextGeneration {
             if next_token == eos_token {
                 break;
             }
-	    println!("raw generate token {}",next_token);
-            let token = self.tokenizer.decode(&[next_token], true).expect("Token error");
+            println!("raw generate token {}", next_token);
+            let token = self
+                .tokenizer
+                .decode(&[next_token], true)
+                .expect("Token error");
             println!("[token:{token}]");
-	    result.push(token);
+            result.push(token);
             std::io::stdout().flush().unwrap();
         }
         let dt = start_gen.elapsed();
@@ -112,10 +117,10 @@ impl TextGeneration {
             "\n{generated_tokens} tokens generated ({:.2} token/s)",
             generated_tokens as f64 / dt.as_secs_f64(),
         );
-	println!("Result:");
-	for tokens in result {
-	    print!("{tokens}");
-	}
+        println!("Result:");
+        for tokens in result {
+            print!("{tokens}");
+        }
         Ok(())
     }
 }
@@ -124,9 +129,9 @@ impl TextGeneration {
 #[command(author, version, about, long_about = None)]
 struct Args {
     /// Run on CPU rather than on GPU.
-    #[arg(name="cache",short,long, default_value=".")]
+    #[arg(name = "cache", short, long, default_value = ".")]
     cache_path: String,
-    
+
     #[arg(long)]
     cpu: bool,
 
@@ -174,8 +179,7 @@ struct Args {
     repeat_last_n: usize,
 }
 
-fn main() -> Result<(),()> {
-    
+fn main() -> Result<(), ()> {
     let args = Args::parse();
     println!(
         "avx: {}, neon: {}, simd128: {}, f16c: {}",
@@ -191,10 +195,11 @@ fn main() -> Result<(),()> {
         args.repeat_last_n
     );
 
-    let start = std::time::Instant::now();
-    println!("cache path {}",args.cache_path);
-    let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into())).build().unwrap();
-    
+    println!("cache path {}", args.cache_path);
+    let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
+        .build()
+        .unwrap();
+
     let model_id = match args.model_id {
         Some(model_id) => model_id.to_string(),
         None => "THUDM/codegeex4-all-9b".to_string(),
@@ -208,27 +213,29 @@ fn main() -> Result<(),()> {
         Some(file) => std::path::PathBuf::from(file),
         None => api
             .model("THUDM/codegeex4-all-9b".to_string())
-            .get("tokenizer.json").unwrap(),
+            .get("tokenizer.json")
+            .unwrap(),
     };
     let filenames = match args.weight_file {
         Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
-        None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json").unwrap(),
+        None => {
+            candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json").unwrap()
+        }
     };
-    println!("retrieved the files in {:?}", start.elapsed());
     let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
-
     let start = std::time::Instant::now();
     let config = Config::codegeex4();
     let device = candle_examples::device(args.cpu).unwrap();
     let dtype = if device.is_cuda() {
-	DType::BF16
+        DType::BF16
     } else {
-	DType::F32
+        DType::F32
     };
+    println!("dtype is {:?}", dtype);
     let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device).unwrap() };
     let model = Model::new(&config, vb).unwrap();
 
-    println!("loaded the model in {:?}", start.elapsed());
+    println!("模型加载完毕 {:?}", start.elapsed().as_secs());
 
     let mut pipeline = TextGeneration::new(
         model,
@@ -240,7 +247,7 @@ fn main() -> Result<(),()> {
         args.repeat_last_n,
         args.verbose_prompt,
         &device,
-	dtype,
+        dtype,
     );
     pipeline.run(&args.prompt, args.sample_len)?;
     Ok(())