Просмотр исходного кода

修改输出格式 (#1)

* Revise the output format

* Revise the output format

* Revise the output format

* Add files via upload
JasonYANG17 1 год назад
Родитель
Сommit
1ba5426235
2 измененных файлов с 69 добавлено и 5 удалено
  1. 49 1
      candle_demo/Cargo.lock
  2. 20 4
      candle_demo/src/main.rs

+ 49 - 1
candle_demo/Cargo.lock

@@ -132,6 +132,17 @@ version = "0.22.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
 
+[[package]]
+name = "bindgen_cuda"
+version = "0.1.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1f8489af5b7d17a81bffe37e0f4d6e1e4de87c87329d05447f22c35d95a1227d"
+dependencies = [
+ "glob",
+ "num_cpus",
+ "rayon",
+]
+
 [[package]]
 name = "bit-set"
 version = "0.5.3"
@@ -204,6 +215,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "d5b18de020c2729dbf7ac390325312644808b6ba9b7962f1f724e9185b1d53c7"
 dependencies = [
  "byteorder",
+ "candle-kernels",
+ "cudarc",
  "gemm",
  "half",
  "memmap2",
@@ -239,6 +252,15 @@ dependencies = [
  "tokenizers",
 ]
 
+[[package]]
+name = "candle-kernels"
+version = "0.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8bc0a71be8b2f0950b63fd602a5e10a74a4f94a5fd63059ae455e96163389488"
+dependencies = [
+ "bindgen_cuda",
+]
+
 [[package]]
 name = "candle-nn"
 version = "0.6.0"
@@ -329,7 +351,7 @@ checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70"
 name = "codegeex4-candle"
 version = "0.1.0"
 dependencies = [
- "anyhow",
+ "bindgen_cuda",
  "candle-core",
  "candle-examples",
  "candle-nn",
@@ -437,6 +459,16 @@ dependencies = [
  "memchr",
 ]
 
+[[package]]
+name = "cudarc"
+version = "0.11.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "56ee2a3fbbd981e1c7ea73cc2af136e754eb22d17436de37155227ee4dbe0cf4"
+dependencies = [
+ "half",
+ "libloading",
+]
+
 [[package]]
 name = "darling"
 version = "0.20.10"
@@ -904,6 +936,12 @@ version = "0.29.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd"
 
+[[package]]
+name = "glob"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
+
 [[package]]
 name = "h2"
 version = "0.3.26"
@@ -1172,6 +1210,16 @@ version = "0.2.155"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
 
+[[package]]
+name = "libloading"
+version = "0.8.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d"
+dependencies = [
+ "cfg-if",
+ "windows-targets 0.52.6",
+]
+
 [[package]]
 name = "libm"
 version = "0.2.8"

+ 20 - 4
candle_demo/src/main.rs

@@ -59,16 +59,22 @@ impl TextGeneration {
                 println!("{id:7} -> '{token}'");
             }
         }
+        let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
+            Some(token) => *token,
+            None => panic!("cannot find the endoftext token"),
+        };
         let mut tokens = tokens.get_ids().to_vec();
         let mut generated_tokens = 0usize;
-        let eos_token = 151329;
+        
         
         print!("{prompt}");
         std::io::stdout().flush().expect("output flush error");
         let start_gen = std::time::Instant::now();
-	println!("start_gen");
+	
+	println!("\n start_gen");
 	println!("samplelen {}",sample_len);
 	let mut count = 0;
+	let mut result = vec!();
         for index in 0..sample_len {
 	    count += 1;
 	    println!("sample count {}",count);
@@ -96,7 +102,8 @@ impl TextGeneration {
             }
 	    println!("raw generate token {}",next_token);
             let token = self.tokenizer.decode(&[next_token], true).expect("Token error");
-            print!("{token}");
+            println!("[token:{token}]");
+	    result.push(token);
             std::io::stdout().flush().unwrap();
         }
         let dt = start_gen.elapsed();
@@ -104,6 +111,10 @@ impl TextGeneration {
             "\n{generated_tokens} tokens generated ({:.2} token/s)",
             generated_tokens as f64 / dt.as_secs_f64(),
         );
+	println!("Result:");
+	for tokens in result {
+	    print!("{tokens}");
+	}
         Ok(())
     }
 }
@@ -208,7 +219,12 @@ fn main() -> Result<(),()> {
     let start = std::time::Instant::now();
     let config = Config::codegeex4();
     let device = candle_examples::device(args.cpu).unwrap();
-    let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device).unwrap() };
+    let dtype = if device.is_cuda() {
+	DType::BF16
+    } else {
+	DType::F32
+    };
+    let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device).unwrap() };
     let model = Model::new(&config, vb).unwrap();
 
     println!("loaded the model in {:?}", start.elapsed());