Переглянути джерело

修改程序读取标准输入

donjuanplatinum 1 рік тому
батько
коміт
668c0208df
1 змінених файлів з 108 додано та 84 видалено
  1. 108 84
      candle_demo/src/main.rs

+ 108 - 84
candle_demo/src/main.rs

@@ -6,12 +6,16 @@ extern crate accelerate_src;
 
 
 use clap::Parser;
 use clap::Parser;
 use codegeex4_candle::codegeex4::*;
 use codegeex4_candle::codegeex4::*;
+use owo_colors::{self, OwoColorize};
+use std::io::BufRead;
+use std::io::BufReader;
 
 
 use candle_core as candle;
 use candle_core as candle;
 use candle_core::{DType, Device, Tensor};
 use candle_core::{DType, Device, Tensor};
 use candle_nn::VarBuilder;
 use candle_nn::VarBuilder;
 use candle_transformers::generation::LogitsProcessor;
 use candle_transformers::generation::LogitsProcessor;
 use hf_hub::{Repo, RepoType};
 use hf_hub::{Repo, RepoType};
+use rand::Rng;
 use tokenizers::Tokenizer;
 use tokenizers::Tokenizer;
 
 
 struct TextGeneration {
 struct TextGeneration {
@@ -52,81 +56,92 @@ impl TextGeneration {
         }
         }
     }
     }
 
 
-    fn run(&mut self, prompt: &str, sample_len: usize) -> Result<(), ()> {
+    fn run(&mut self, sample_len: usize) -> Result<(), ()> {
         use std::io::Write;
         use std::io::Write;
-        println!("starting the inference loop");
-        let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
-        if tokens.is_empty() {
-            panic!("Empty prompts are not supported in the chatglm model.")
-        }
-        if self.verbose_prompt {
-            for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
-                let token = token.replace('▁', " ").replace("<0x0A>", "\n");
-                println!("{id:7} -> '{token}'");
+
+        let stdin = std::io::stdin();
+        let reader = BufReader::new(stdin);
+        // 从标准输入读取prompt
+        for line in reader.lines() {
+            println!("[欢迎使用Codegeex4,请输入prompt]");
+            let line = line.expect("Failed to read line");
+            let tokens = self.tokenizer.encode(line, true).expect("tokens error");
+            if tokens.is_empty() {
+                panic!("Empty prompts are not supported in the chatglm model.")
             }
             }
-        }
-        let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
-            Some(token) => *token,
-            None => panic!("cannot find the endoftext token"),
-        };
-        let mut tokens = tokens.get_ids().to_vec();
-        let mut generated_tokens = 0usize;
+            if self.verbose_prompt {
+                for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
+                    let token = token.replace('▁', " ").replace("<0x0A>", "\n");
+                    println!("{id:7} -> '{token}'");
+                }
+            }
+            let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
+                Some(token) => *token,
+                None => panic!("cannot find the endoftext token"),
+            };
+            let mut tokens = tokens.get_ids().to_vec();
+            let mut generated_tokens = 0usize;
 
 
-        print!("{prompt}");
-        std::io::stdout().flush().expect("output flush error");
-        let start_gen = std::time::Instant::now();
+            std::io::stdout().flush().expect("output flush error");
+            let start_gen = std::time::Instant::now();
 
 
-        println!("\n 开始生成");
-        println!("samplelen {}", sample_len);
-        let mut count = 0;
-        let mut result = vec![];
-        for index in 0..sample_len {
-            count += 1;
-            println!("sample count {}", count);
-            let context_size = if index > 0 { 1 } else { tokens.len() };
-            let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
-            let input = Tensor::new(ctxt, &self.device)
-                .unwrap()
-                .unsqueeze(0)
-                .expect("create tensor input error");
-            let logits = self.model.forward(&input).unwrap();
-            let logits = logits.squeeze(0).unwrap().to_dtype(self.dtype).unwrap();
-            let logits = if self.repeat_penalty == 1. {
-                logits
-            } else {
-                let start_at = tokens.len().saturating_sub(self.repeat_last_n);
-                candle_transformers::utils::apply_repeat_penalty(
-                    &logits,
-                    self.repeat_penalty,
-                    &tokens[start_at..],
-                )
-                .unwrap()
-            };
+            //            println!("\n 开始生成");
+            println!("samplelen {}", sample_len.blue());
+            let mut result = vec![];
 
 
-            let next_token = self.logits_processor.sample(&logits).unwrap();
-            tokens.push(next_token);
-            generated_tokens += 1;
-            if next_token == eos_token {
-                break;
+            for index in 0..sample_len {
+                let context_size = if index > 0 { 1 } else { tokens.len() };
+                let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
+                let input = Tensor::new(ctxt, &self.device)
+                    .unwrap()
+                    .unsqueeze(0)
+                    .expect("create tensor input error");
+                let logits = self.model.forward(&input).unwrap();
+                let logits = logits.squeeze(0).unwrap().to_dtype(self.dtype).unwrap();
+                let logits = if self.repeat_penalty == 1. {
+                    logits
+                } else {
+                    let start_at = tokens.len().saturating_sub(self.repeat_last_n);
+                    candle_transformers::utils::apply_repeat_penalty(
+                        &logits,
+                        self.repeat_penalty,
+                        &tokens[start_at..],
+                    )
+                    .unwrap()
+                };
+
+                let next_token = self.logits_processor.sample(&logits).unwrap();
+                tokens.push(next_token);
+                generated_tokens += 1;
+                if next_token == eos_token {
+                    break;
+                }
+                let token = self
+                    .tokenizer
+                    .decode(&[next_token], true)
+                    .expect("Token error");
+                if self.verbose_prompt {
+                    println!(
+                        "[Index: {}] [Raw Token: {}] [Decode Token: {}]",
+                        index.blue(),
+                        next_token.green(),
+                        token.yellow()
+                    );
+                }
+                result.push(token);
+                std::io::stdout().flush().unwrap();
+            }
+            let dt = start_gen.elapsed();
+            println!(
+                "\n{generated_tokens} tokens generated ({:.2} token/s)",
+                generated_tokens as f64 / dt.as_secs_f64(),
+            );
+            println!("Result:");
+            for tokens in result {
+                print!("{tokens}");
             }
             }
-            println!("raw generate token {}", next_token);
-            let token = self
-                .tokenizer
-                .decode(&[next_token], true)
-                .expect("Token error");
-            println!("[token:{token}]");
-            result.push(token);
-            std::io::stdout().flush().unwrap();
-        }
-        let dt = start_gen.elapsed();
-        println!(
-            "\n{generated_tokens} tokens generated ({:.2} token/s)",
-            generated_tokens as f64 / dt.as_secs_f64(),
-        );
-        println!("Result:");
-        for tokens in result {
-            print!("{tokens}");
         }
         }
+        self.model.reset_kv_cache(); // 清理模型kv
         Ok(())
         Ok(())
     }
     }
 }
 }
@@ -142,7 +157,7 @@ struct Args {
     cpu: bool,
     cpu: bool,
 
 
     /// Display the token for the specified prompt.
     /// Display the token for the specified prompt.
-    #[arg(long,default_value_t=true)]
+    #[arg(long)]
     verbose_prompt: bool,
     verbose_prompt: bool,
 
 
     #[arg(long)]
     #[arg(long)]
@@ -157,8 +172,8 @@ struct Args {
     top_p: Option<f64>,
     top_p: Option<f64>,
 
 
     /// The seed to use when generating random samples.
     /// The seed to use when generating random samples.
-    #[arg(long, default_value_t = 299792458)]
-    seed: u64,
+    #[arg(long)]
+    seed: Option<u64>,
 
 
     /// The length of the sample to generate (in tokens).
     /// The length of the sample to generate (in tokens).
     #[arg(long, short = 'n', default_value_t = 5000)]
     #[arg(long, short = 'n', default_value_t = 5000)]
@@ -189,19 +204,28 @@ fn main() -> Result<(), ()> {
     let args = Args::parse();
     let args = Args::parse();
     println!(
     println!(
         "avx: {}, neon: {}, simd128: {}, f16c: {}",
         "avx: {}, neon: {}, simd128: {}, f16c: {}",
-        candle::utils::with_avx(),
-        candle::utils::with_neon(),
-        candle::utils::with_simd128(),
-        candle::utils::with_f16c()
+        candle::utils::with_avx().red(),
+        candle::utils::with_neon().red(),
+        candle::utils::with_simd128().red(),
+        candle::utils::with_f16c().red(),
     );
     );
     println!(
     println!(
         "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
         "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
-        args.temperature.unwrap_or(0.95),
-        args.repeat_penalty,
-        args.repeat_last_n
+        args.temperature.unwrap_or(0.95).red(),
+        args.repeat_penalty.red(),
+        args.repeat_last_n.red(),
     );
     );
 
 
-    println!("cache path {}", args.cache_path);
+    println!("cache path {}", args.cache_path.blue());
+    println!("Prompt: [{}]", args.prompt.green());
+    let mut seed: u64 = 0;
+    if let Some(_seed) = args.seed {
+        seed = _seed;
+    } else {
+        let mut rng = rand::thread_rng();
+        seed = rng.gen();
+    }
+    println!("Using Seed {}", seed.red());
     let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
     let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
         .build()
         .build()
         .unwrap();
         .unwrap();
@@ -209,7 +233,7 @@ fn main() -> Result<(), ()> {
     let model_id = match args.model_id {
     let model_id = match args.model_id {
         Some(model_id) => model_id.to_string(),
         Some(model_id) => model_id.to_string(),
         None => "THUDM/codegeex4-all-9b".to_string(),
         None => "THUDM/codegeex4-all-9b".to_string(),
-    };
+p    };
     let revision = match args.revision {
     let revision = match args.revision {
         Some(rev) => rev.to_string(),
         Some(rev) => rev.to_string(),
         None => "main".to_string(),
         None => "main".to_string(),
@@ -237,16 +261,16 @@ fn main() -> Result<(), ()> {
     } else {
     } else {
         DType::F32
         DType::F32
     };
     };
-    println!("dtype is {:?}", dtype);
+    println!("DType is {:?}", dtype.yellow());
     let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device).unwrap() };
     let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device).unwrap() };
     let model = Model::new(&config, vb).unwrap();
     let model = Model::new(&config, vb).unwrap();
 
 
-    println!("模型加载完毕 {:?}", start.elapsed().as_secs());
+    println!("模型加载完毕 {:?}", start.elapsed().as_secs().green());
 
 
     let mut pipeline = TextGeneration::new(
     let mut pipeline = TextGeneration::new(
         model,
         model,
         tokenizer,
         tokenizer,
-        args.seed,
+        seed,
         args.temperature,
         args.temperature,
         args.top_p,
         args.top_p,
         args.repeat_penalty,
         args.repeat_penalty,
@@ -255,6 +279,6 @@ fn main() -> Result<(), ()> {
         &device,
         &device,
         dtype,
         dtype,
     );
     );
-    pipeline.run(&args.prompt, args.sample_len)?;
+    pipeline.run(args.sample_len)?;
     Ok(())
     Ok(())
 }
 }