瀏覽代碼

Merge pull request #33 from donjuanplatinum/dev

feat(candle_demo): add README and change candle_demo,intel mkl
Qinkai 1 年之前
父節點
當前提交
1e56872777
共有 6 個文件被更改,包括 174 次插入66 次删除
  1. 4 1
      candle_demo/Cargo.toml
  2. 42 6
      candle_demo/README.md
  3. 54 0
      candle_demo/README_zh.md
  4. 20 18
      candle_demo/src/codegeex4.rs
  5. 54 41
      candle_demo/src/main.rs
  6. 二進制
      resources/candle_example.png

+ 4 - 1
candle_demo/Cargo.toml

@@ -25,6 +25,8 @@ candle-transformers = "0.6.0"
 candle-examples = "0.6.0"
 candle-nn = "0.6.0"
 safetensors = "0.4.3"
+accelerate-src = { version = "0.3.2", optional = true}
+intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] ,optional = true}
 #safetensors = {path ="../safetensors/safetensors"}
 [build-dependencies]
 bindgen_cuda = { version = "0.1.1", optional = true }
@@ -33,4 +35,5 @@ bindgen_cuda = { version = "0.1.1", optional = true }
 [features]
 default = []
 cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
-
+accelerate = ["dep:accelerate-src", "candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
+mkl = ["dep:intel-mkl-src", "candle-core/mkl", "candle-nn/mkl", "candle-transformers/mkl"]

+ 42 - 6
candle_demo/README.md

@@ -1,11 +1,47 @@
-# CPU运行
+![](../resources/candle_example.png)
+[English](./README.md) | [中文](./README_zh.md)
+# CPU Runing
 ```
 cargo run --release -- --prompt your prompt
 ```
 
-# Cuda运行
-- 注意 需要cuda为>=12.4以上的版本
-```
-cargo build --release --features cuda
-./target/release/codegeex4-candle --prompt your prompt
+# Use
+``` shell
+Codegeex4
+
+Usage: codegeex4-candle [OPTIONS] --prompt <PROMPT>
+
+Options:
+  -c, --cache <cache>
+          Run on CPU rather than on GPU [default: .]
+      --cpu
+          
+      --verbose-prompt
+          Display the token for the specified prompt
+      --prompt <PROMPT>
+          
+      --temperature <TEMPERATURE>
+          The temperature used to generate samples
+      --top-p <TOP_P>
+          Nucleus sampling probability cutoff
+      --seed <SEED>
+          The seed to use when generating random samples [default: 299792458]
+  -n, --sample-len <SAMPLE_LEN>
+          The length of the sample to generate (in tokens) [default: 5000]
+      --model-id <MODEL_ID>
+          
+      --revision <REVISION>
+          
+      --weight-file <WEIGHT_FILE>
+          
+      --tokenizer <TOKENIZER>
+          
+      --repeat-penalty <REPEAT_PENALTY>
+          Penalty to be applied for repeating tokens, 1. means no penalty [default: 1.1]
+      --repeat-last-n <REPEAT_LAST_N>
+          The context size to consider for the repeat penalty [default: 64]
+  -h, --help
+          Print help
+  -V, --version
+          Print version
 ```

+ 54 - 0
candle_demo/README_zh.md

@@ -0,0 +1,54 @@
+![](../resources/candle_example.png)
+[English](./README.md) | [中文](./README_zh.md)
+# CPU运行
+```
+cargo run --release -- --prompt your prompt
+```
+
+# 使用
+``` shell
+Codegeex4
+
+Usage: codegeex4-candle [OPTIONS] --prompt <PROMPT>
+
+Options:
+  -c, --cache <cache>
+          Run on CPU rather than on GPU [default: .]
+      --cpu
+          
+      --verbose-prompt
+          Display the token for the specified prompt
+      --prompt <PROMPT>
+          
+      --temperature <TEMPERATURE>
+          The temperature used to generate samples
+      --top-p <TOP_P>
+          Nucleus sampling probability cutoff
+      --seed <SEED>
+          The seed to use when generating random samples [default: 299792458]
+  -n, --sample-len <SAMPLE_LEN>
+          The length of the sample to generate (in tokens) [default: 5000]
+      --model-id <MODEL_ID>
+          
+      --revision <REVISION>
+          
+      --weight-file <WEIGHT_FILE>
+          
+      --tokenizer <TOKENIZER>
+          
+      --repeat-penalty <REPEAT_PENALTY>
+          Penalty to be applied for repeating tokens, 1. means no penalty [default: 1.1]
+      --repeat-last-n <REPEAT_LAST_N>
+          The context size to consider for the repeat penalty [default: 64]
+  -h, --help
+          Print help
+  -V, --version
+          Print version
+```
+# Cuda运行
+- 注意 需要cuda为>=12.4以上的版本
+```
+cargo build --release --features cuda
+./target/release/codegeex4-candle --prompt your prompt
+```
+

+ 20 - 18
candle_demo/src/codegeex4.rs

@@ -1,7 +1,7 @@
-use candle_transformers::models::with_tracing::{linear_b as linear, Linear};
-use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
 use candle_core as candle;
+use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
 use candle_nn::VarBuilder;
+use candle_transformers::models::with_tracing::{linear_b as linear, Linear};
 
 #[derive(Debug, Clone)]
 pub struct Config {
@@ -30,7 +30,7 @@ impl Config {
     pub fn codegeex4() -> Self {
         Self {
             num_layers: 40,
-	    padded_vocab_size: 151552,
+            padded_vocab_size: 151552,
             hidden_size: 4096,
             ffn_hidden_size: 13696,
             kv_channels: 128,
@@ -68,7 +68,8 @@ impl RotaryEmbedding {
         let inv_freq_len = inv_freq.len();
         let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
         let t = Tensor::arange(0u32, cfg.seq_length as u32, dev)?
-            .to_dtype(dtype).expect("unalbe to dytpe in Rotray Embedding new")
+            .to_dtype(dtype)
+            .expect("unalbe to dytpe in Rotray Embedding new")
             .reshape((cfg.seq_length, 1))?;
         let freqs = t.matmul(&inv_freq)?;
         let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?;
@@ -106,17 +107,18 @@ impl RotaryEmbedding {
 struct CoreAttention {
     coeff: Option<f64>,
     norm_factor: f64,
+    dtype: DType,
 }
 
-fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32,dtype:DType) -> Result<Tensor> {
     let shape = mask.shape();
     let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
-    let m = mask.where_cond(&on_true, on_false)?;
+    let m = mask.where_cond(&on_true.to_dtype(dtype)?, on_false)?;
     Ok(m)
 }
 
 impl CoreAttention {
-    fn new(layer_number: usize, cfg: &Config) -> Result<Self> {
+    fn new(layer_number: usize, cfg: &Config,dtype: DType) -> Result<Self> {
         let norm_factor = (cfg.kv_channels as f64).sqrt();
         let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling {
             let coeff = f64::max(1.0, layer_number as f64);
@@ -124,7 +126,7 @@ impl CoreAttention {
         } else {
             (norm_factor, None)
         };
-        Ok(Self { coeff, norm_factor })
+        Ok(Self { coeff, norm_factor, dtype})
     }
 
     fn forward(
@@ -144,8 +146,8 @@ impl CoreAttention {
             query_layer.reshape((output_size.2, output_size.0 * output_size.1, ()))?;
         let key_layer = key_layer.reshape((output_size.3, output_size.0 * output_size.1, ()))?;
         let matmul_result = Tensor::matmul(
-            &query_layer.transpose(0, 1)?,
-            &key_layer.transpose(0, 1)?.transpose(1, 2)?,
+            &query_layer.transpose(0, 1)?.contiguous()?,
+            &key_layer.transpose(0, 1)?.transpose(1, 2)?.contiguous()?,
         )?;
         let matmul_result = (matmul_result / self.norm_factor)?.reshape(output_size)?;
         let matmul_result = match self.coeff {
@@ -157,6 +159,7 @@ impl CoreAttention {
                 &matmul_result,
                 &mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?,
                 f32::NEG_INFINITY,
+		self.dtype,
             )?,
             None => matmul_result,
         };
@@ -172,7 +175,7 @@ impl CoreAttention {
             value_layer.reshape((value_layer.dim(0)?, output_size.0 * output_size.1, ()))?;
         let attention_probs =
             attention_probs.reshape((output_size.0 * output_size.1, output_size.2, ()))?;
-        let context_layer = Tensor::matmul(&attention_probs, &value_layer.transpose(0, 1)?)?;
+        let context_layer = Tensor::matmul(&attention_probs.contiguous()?, &value_layer.transpose(0, 1)?.contiguous()?)?;
         let context_layer = context_layer.reshape(output_size)?;
         let context_layer = context_layer.permute((2, 0, 1, 3))?.contiguous()?;
         context_layer.flatten_from(D::Minus2)
@@ -206,7 +209,7 @@ impl SelfAttention {
             cfg.add_bias_linear || cfg.add_qkv_bias,
             vb.pp("query_key_value"),
         )?;
-        let core_attention = CoreAttention::new(layer_number, cfg)?;
+        let core_attention = CoreAttention::new(layer_number, cfg,vb.dtype())?;
         let dense = linear(
             cfg.hidden_size,
             cfg.hidden_size,
@@ -455,11 +458,11 @@ impl Transformer {
     fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
         let vb_l = vb.pp("layers");
         let mut layers = Vec::with_capacity(cfg.num_layers);
-	println!("transofrmer layers create");
-	let mut count = 0;
+        println!("transofrmer layers create");
+        let mut count = 0;
         for layer_index in 0..cfg.num_layers {
-	    count += 1;
-	    println!("for layer index in {} total is {} ",count, cfg.num_layers);
+            count += 1;
+            println!("for layer index in {} total is {} ", count, cfg.num_layers);
             let block = Block::new(layer_index + 1, cfg, vb_l.pp(layer_index))?;
             layers.push(block)
         }
@@ -564,8 +567,7 @@ impl Model {
             false,
             vb.pp("output_layer"),
         )?;
-	
-	
+
         Ok(Self {
             embedding,
             encoder,

+ 54 - 41
candle_demo/src/main.rs

@@ -1,12 +1,17 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
 use clap::Parser;
 use codegeex4_candle::codegeex4::*;
 
-
-use candle_core::{DType, Device, Tensor};
 use candle_core as candle;
+use candle_core::{DType, Device, Tensor};
 use candle_nn::VarBuilder;
 use candle_transformers::generation::LogitsProcessor;
-use hf_hub::{api::sync::Api, Repo, RepoType};
+use hf_hub::{Repo, RepoType};
 use tokenizers::Tokenizer;
 
 struct TextGeneration {
@@ -32,7 +37,7 @@ impl TextGeneration {
         repeat_last_n: usize,
         verbose_prompt: bool,
         device: &Device,
-	dtype: DType,
+        dtype: DType,
     ) -> Self {
         let logits_processor = LogitsProcessor::new(seed, temp, top_p);
         Self {
@@ -43,11 +48,11 @@ impl TextGeneration {
             repeat_last_n,
             verbose_prompt,
             device: device.clone(),
-	    dtype,
+            dtype,
         }
     }
 
-    fn run(&mut self, prompt: &str, sample_len: usize) -> Result<(),()> {
+    fn run(&mut self, prompt: &str, sample_len: usize) -> Result<(), ()> {
         use std::io::Write;
         println!("starting the inference loop");
         let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
@@ -66,22 +71,24 @@ impl TextGeneration {
         };
         let mut tokens = tokens.get_ids().to_vec();
         let mut generated_tokens = 0usize;
-        
-        
+
         print!("{prompt}");
         std::io::stdout().flush().expect("output flush error");
         let start_gen = std::time::Instant::now();
-	
-	println!("\n start_gen");
-	println!("samplelen {}",sample_len);
-	let mut count = 0;
-	let mut result = vec!();
+
+        println!("\n 开始生成");
+        println!("samplelen {}", sample_len);
+        let mut count = 0;
+        let mut result = vec![];
         for index in 0..sample_len {
-	    count += 1;
-	    println!("sample count {}",count);
+            count += 1;
+            println!("sample count {}", count);
             let context_size = if index > 0 { 1 } else { tokens.len() };
             let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
-            let input = Tensor::new(ctxt, &self.device).unwrap().unsqueeze(0).expect("create tensor input error");
+            let input = Tensor::new(ctxt, &self.device)
+                .unwrap()
+                .unsqueeze(0)
+                .expect("create tensor input error");
             let logits = self.model.forward(&input).unwrap();
             let logits = logits.squeeze(0).unwrap().to_dtype(self.dtype).unwrap();
             let logits = if self.repeat_penalty == 1. {
@@ -92,7 +99,8 @@ impl TextGeneration {
                     &logits,
                     self.repeat_penalty,
                     &tokens[start_at..],
-                ).unwrap()
+                )
+                .unwrap()
             };
 
             let next_token = self.logits_processor.sample(&logits).unwrap();
@@ -101,10 +109,13 @@ impl TextGeneration {
             if next_token == eos_token {
                 break;
             }
-	    println!("raw generate token {}",next_token);
-            let token = self.tokenizer.decode(&[next_token], true).expect("Token error");
+            println!("raw generate token {}", next_token);
+            let token = self
+                .tokenizer
+                .decode(&[next_token], true)
+                .expect("Token error");
             println!("[token:{token}]");
-	    result.push(token);
+            result.push(token);
             std::io::stdout().flush().unwrap();
         }
         let dt = start_gen.elapsed();
@@ -112,10 +123,10 @@ impl TextGeneration {
             "\n{generated_tokens} tokens generated ({:.2} token/s)",
             generated_tokens as f64 / dt.as_secs_f64(),
         );
-	println!("Result:");
-	for tokens in result {
-	    print!("{tokens}");
-	}
+        println!("Result:");
+        for tokens in result {
+            print!("{tokens}");
+        }
         Ok(())
     }
 }
@@ -124,14 +135,14 @@ impl TextGeneration {
 #[command(author, version, about, long_about = None)]
 struct Args {
     /// Run on CPU rather than on GPU.
-    #[arg(name="cache",short,long, default_value=".")]
+    #[arg(name = "cache", short, long, default_value = ".")]
     cache_path: String,
-    
+
     #[arg(long)]
     cpu: bool,
 
     /// Display the token for the specified prompt.
-    #[arg(long)]
+    #[arg(long,default_value_t=true)]
     verbose_prompt: bool,
 
     #[arg(long)]
@@ -174,8 +185,7 @@ struct Args {
     repeat_last_n: usize,
 }
 
-fn main() -> Result<(),()> {
-    
+fn main() -> Result<(), ()> {
     let args = Args::parse();
     println!(
         "avx: {}, neon: {}, simd128: {}, f16c: {}",
@@ -191,10 +201,11 @@ fn main() -> Result<(),()> {
         args.repeat_last_n
     );
 
-    let start = std::time::Instant::now();
-    println!("cache path {}",args.cache_path);
-    let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into())).build().unwrap();
-    
+    println!("cache path {}", args.cache_path);
+    let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
+        .build()
+        .unwrap();
+
     let model_id = match args.model_id {
         Some(model_id) => model_id.to_string(),
         None => "THUDM/codegeex4-all-9b".to_string(),
@@ -208,27 +219,29 @@ fn main() -> Result<(),()> {
         Some(file) => std::path::PathBuf::from(file),
         None => api
             .model("THUDM/codegeex4-all-9b".to_string())
-            .get("tokenizer.json").unwrap(),
+            .get("tokenizer.json")
+            .unwrap(),
     };
     let filenames = match args.weight_file {
         Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
-        None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json").unwrap(),
+        None => {
+            candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json").unwrap()
+        }
     };
-    println!("retrieved the files in {:?}", start.elapsed());
     let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
-
     let start = std::time::Instant::now();
     let config = Config::codegeex4();
     let device = candle_examples::device(args.cpu).unwrap();
     let dtype = if device.is_cuda() {
-	DType::BF16
+        DType::BF16
     } else {
-	DType::F32
+        DType::F32
     };
+    println!("dtype is {:?}", dtype);
     let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device).unwrap() };
     let model = Model::new(&config, vb).unwrap();
 
-    println!("loaded the model in {:?}", start.elapsed());
+    println!("模型加载完毕 {:?}", start.elapsed().as_secs());
 
     let mut pipeline = TextGeneration::new(
         model,
@@ -240,7 +253,7 @@ fn main() -> Result<(),()> {
         args.repeat_last_n,
         args.verbose_prompt,
         &device,
-	dtype,
+        dtype,
     );
     pipeline.run(&args.prompt, args.sample_len)?;
     Ok(())

二進制
resources/candle_example.png