donjuanplatinum 1 год назад
Родитель
Сommit
84dea1e0bb
2 измененных файлов с 13 добавлено и 3 удалено
  1. 8 0
      candle_demo/README.md
  2. 5 3
      candle_demo/src/main.rs

+ 8 - 0
candle_demo/README.md

@@ -1,3 +1,11 @@
+# CPU运行
 ```
 cargo run --release -- --prompt your prompt
 ```
+
+# Cuda运行
+- 注意 需要cuda为>=12.4以上的版本
+```
+cargo build --release --features cuda
+./target/release/codegeex4-candle --prompt your prompt
+```

+ 5 - 3
candle_demo/src/main.rs

@@ -1,4 +1,3 @@
-//use anyhow::{Error as E, Result};
 use clap::Parser;
 use codegeex4_candle::codegeex4::*;
 
@@ -18,6 +17,7 @@ struct TextGeneration {
     repeat_penalty: f32,
     repeat_last_n: usize,
     verbose_prompt: bool,
+    dtype: DType,
 }
 
 impl TextGeneration {
@@ -32,6 +32,7 @@ impl TextGeneration {
         repeat_last_n: usize,
         verbose_prompt: bool,
         device: &Device,
+	dtype: DType,
     ) -> Self {
         let logits_processor = LogitsProcessor::new(seed, temp, top_p);
         Self {
@@ -42,6 +43,7 @@ impl TextGeneration {
             repeat_last_n,
             verbose_prompt,
             device: device.clone(),
+	    dtype,
         }
     }
 
@@ -49,7 +51,6 @@ impl TextGeneration {
         use std::io::Write;
         println!("starting the inference loop");
         let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
-	println!("run starting the token 57");
         if tokens.is_empty() {
             panic!("Empty prompts are not supported in the chatglm model.")
         }
@@ -82,7 +83,7 @@ impl TextGeneration {
             let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
             let input = Tensor::new(ctxt, &self.device).unwrap().unsqueeze(0).expect("create tensor input error");
             let logits = self.model.forward(&input).unwrap();
-            let logits = logits.squeeze(0).unwrap().to_dtype(DType::F32).unwrap();
+            let logits = logits.squeeze(0).unwrap().to_dtype(self.dtype).unwrap();
             let logits = if self.repeat_penalty == 1. {
                 logits
             } else {
@@ -239,6 +240,7 @@ fn main() -> Result<(),()> {
         args.repeat_last_n,
         args.verbose_prompt,
         &device,
+	dtype,
     );
     pipeline.run(&args.prompt, args.sample_len)?;
     Ok(())