|
@@ -6,12 +6,16 @@ extern crate accelerate_src;
|
|
|
|
|
|
|
|
use clap::Parser;
|
|
use clap::Parser;
|
|
|
use codegeex4_candle::codegeex4::*;
|
|
use codegeex4_candle::codegeex4::*;
|
|
|
|
|
+use owo_colors::{self, OwoColorize};
|
|
|
|
|
+use std::io::BufRead;
|
|
|
|
|
+use std::io::BufReader;
|
|
|
|
|
|
|
|
use candle_core as candle;
|
|
use candle_core as candle;
|
|
|
use candle_core::{DType, Device, Tensor};
|
|
use candle_core::{DType, Device, Tensor};
|
|
|
use candle_nn::VarBuilder;
|
|
use candle_nn::VarBuilder;
|
|
|
use candle_transformers::generation::LogitsProcessor;
|
|
use candle_transformers::generation::LogitsProcessor;
|
|
|
use hf_hub::{Repo, RepoType};
|
|
use hf_hub::{Repo, RepoType};
|
|
|
|
|
+use rand::Rng;
|
|
|
use tokenizers::Tokenizer;
|
|
use tokenizers::Tokenizer;
|
|
|
|
|
|
|
|
struct TextGeneration {
|
|
struct TextGeneration {
|
|
@@ -52,81 +56,92 @@ impl TextGeneration {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- fn run(&mut self, prompt: &str, sample_len: usize) -> Result<(), ()> {
|
|
|
|
|
|
|
+ fn run(&mut self, sample_len: usize) -> Result<(), ()> {
|
|
|
use std::io::Write;
|
|
use std::io::Write;
|
|
|
- println!("starting the inference loop");
|
|
|
|
|
- let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
|
|
|
|
|
- if tokens.is_empty() {
|
|
|
|
|
- panic!("Empty prompts are not supported in the chatglm model.")
|
|
|
|
|
- }
|
|
|
|
|
- if self.verbose_prompt {
|
|
|
|
|
- for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
|
|
|
|
- let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
|
|
|
|
- println!("{id:7} -> '{token}'");
|
|
|
|
|
|
|
+
|
|
|
|
|
+ let stdin = std::io::stdin();
|
|
|
|
|
+ let reader = BufReader::new(stdin);
|
|
|
|
|
+ // 从标准输入读取prompt
|
|
|
|
|
+ for line in reader.lines() {
|
|
|
|
|
+ println!("[欢迎使用Codegeex4,请输入prompt]");
|
|
|
|
|
+ let line = line.expect("Failed to read line");
|
|
|
|
|
+ let tokens = self.tokenizer.encode(line, true).expect("tokens error");
|
|
|
|
|
+ if tokens.is_empty() {
|
|
|
|
|
+ panic!("Empty prompts are not supported in the chatglm model.")
|
|
|
}
|
|
}
|
|
|
- }
|
|
|
|
|
- let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
|
|
|
|
- Some(token) => *token,
|
|
|
|
|
- None => panic!("cannot find the endoftext token"),
|
|
|
|
|
- };
|
|
|
|
|
- let mut tokens = tokens.get_ids().to_vec();
|
|
|
|
|
- let mut generated_tokens = 0usize;
|
|
|
|
|
|
|
+ if self.verbose_prompt {
|
|
|
|
|
+ for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
|
|
|
|
+ let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
|
|
|
|
+ println!("{id:7} -> '{token}'");
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
|
|
|
|
+ Some(token) => *token,
|
|
|
|
|
+ None => panic!("cannot find the endoftext token"),
|
|
|
|
|
+ };
|
|
|
|
|
+ let mut tokens = tokens.get_ids().to_vec();
|
|
|
|
|
+ let mut generated_tokens = 0usize;
|
|
|
|
|
|
|
|
- print!("{prompt}");
|
|
|
|
|
- std::io::stdout().flush().expect("output flush error");
|
|
|
|
|
- let start_gen = std::time::Instant::now();
|
|
|
|
|
|
|
+ std::io::stdout().flush().expect("output flush error");
|
|
|
|
|
+ let start_gen = std::time::Instant::now();
|
|
|
|
|
|
|
|
- println!("\n 开始生成");
|
|
|
|
|
- println!("samplelen {}", sample_len);
|
|
|
|
|
- let mut count = 0;
|
|
|
|
|
- let mut result = vec![];
|
|
|
|
|
- for index in 0..sample_len {
|
|
|
|
|
- count += 1;
|
|
|
|
|
- println!("sample count {}", count);
|
|
|
|
|
- let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
|
|
|
- let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
|
|
|
|
- let input = Tensor::new(ctxt, &self.device)
|
|
|
|
|
- .unwrap()
|
|
|
|
|
- .unsqueeze(0)
|
|
|
|
|
- .expect("create tensor input error");
|
|
|
|
|
- let logits = self.model.forward(&input).unwrap();
|
|
|
|
|
- let logits = logits.squeeze(0).unwrap().to_dtype(self.dtype).unwrap();
|
|
|
|
|
- let logits = if self.repeat_penalty == 1. {
|
|
|
|
|
- logits
|
|
|
|
|
- } else {
|
|
|
|
|
- let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
|
|
|
- candle_transformers::utils::apply_repeat_penalty(
|
|
|
|
|
- &logits,
|
|
|
|
|
- self.repeat_penalty,
|
|
|
|
|
- &tokens[start_at..],
|
|
|
|
|
- )
|
|
|
|
|
- .unwrap()
|
|
|
|
|
- };
|
|
|
|
|
|
|
+ // println!("\n 开始生成");
|
|
|
|
|
+ println!("samplelen {}", sample_len.blue());
|
|
|
|
|
+ let mut result = vec![];
|
|
|
|
|
|
|
|
- let next_token = self.logits_processor.sample(&logits).unwrap();
|
|
|
|
|
- tokens.push(next_token);
|
|
|
|
|
- generated_tokens += 1;
|
|
|
|
|
- if next_token == eos_token {
|
|
|
|
|
- break;
|
|
|
|
|
|
|
+ for index in 0..sample_len {
|
|
|
|
|
+ let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
|
|
|
+ let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
|
|
|
|
+ let input = Tensor::new(ctxt, &self.device)
|
|
|
|
|
+ .unwrap()
|
|
|
|
|
+ .unsqueeze(0)
|
|
|
|
|
+ .expect("create tensor input error");
|
|
|
|
|
+ let logits = self.model.forward(&input).unwrap();
|
|
|
|
|
+ let logits = logits.squeeze(0).unwrap().to_dtype(self.dtype).unwrap();
|
|
|
|
|
+ let logits = if self.repeat_penalty == 1. {
|
|
|
|
|
+ logits
|
|
|
|
|
+ } else {
|
|
|
|
|
+ let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
|
|
|
+ candle_transformers::utils::apply_repeat_penalty(
|
|
|
|
|
+ &logits,
|
|
|
|
|
+ self.repeat_penalty,
|
|
|
|
|
+ &tokens[start_at..],
|
|
|
|
|
+ )
|
|
|
|
|
+ .unwrap()
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ let next_token = self.logits_processor.sample(&logits).unwrap();
|
|
|
|
|
+ tokens.push(next_token);
|
|
|
|
|
+ generated_tokens += 1;
|
|
|
|
|
+ if next_token == eos_token {
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ let token = self
|
|
|
|
|
+ .tokenizer
|
|
|
|
|
+ .decode(&[next_token], true)
|
|
|
|
|
+ .expect("Token error");
|
|
|
|
|
+ if self.verbose_prompt {
|
|
|
|
|
+ println!(
|
|
|
|
|
+ "[Index: {}] [Raw Token: {}] [Decode Token: {}]",
|
|
|
|
|
+ index.blue(),
|
|
|
|
|
+ next_token.green(),
|
|
|
|
|
+ token.yellow()
|
|
|
|
|
+ );
|
|
|
|
|
+ }
|
|
|
|
|
+ result.push(token);
|
|
|
|
|
+ std::io::stdout().flush().unwrap();
|
|
|
|
|
+ }
|
|
|
|
|
+ let dt = start_gen.elapsed();
|
|
|
|
|
+ println!(
|
|
|
|
|
+ "\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
|
|
|
+ generated_tokens as f64 / dt.as_secs_f64(),
|
|
|
|
|
+ );
|
|
|
|
|
+ println!("Result:");
|
|
|
|
|
+ for tokens in result {
|
|
|
|
|
+ print!("{tokens}");
|
|
|
}
|
|
}
|
|
|
- println!("raw generate token {}", next_token);
|
|
|
|
|
- let token = self
|
|
|
|
|
- .tokenizer
|
|
|
|
|
- .decode(&[next_token], true)
|
|
|
|
|
- .expect("Token error");
|
|
|
|
|
- println!("[token:{token}]");
|
|
|
|
|
- result.push(token);
|
|
|
|
|
- std::io::stdout().flush().unwrap();
|
|
|
|
|
- }
|
|
|
|
|
- let dt = start_gen.elapsed();
|
|
|
|
|
- println!(
|
|
|
|
|
- "\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
|
|
|
- generated_tokens as f64 / dt.as_secs_f64(),
|
|
|
|
|
- );
|
|
|
|
|
- println!("Result:");
|
|
|
|
|
- for tokens in result {
|
|
|
|
|
- print!("{tokens}");
|
|
|
|
|
}
|
|
}
|
|
|
|
|
+ self.model.reset_kv_cache(); // 清理模型kv
|
|
|
Ok(())
|
|
Ok(())
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -142,7 +157,7 @@ struct Args {
|
|
|
cpu: bool,
|
|
cpu: bool,
|
|
|
|
|
|
|
|
/// Display the token for the specified prompt.
|
|
/// Display the token for the specified prompt.
|
|
|
- #[arg(long,default_value_t=true)]
|
|
|
|
|
|
|
+ #[arg(long)]
|
|
|
verbose_prompt: bool,
|
|
verbose_prompt: bool,
|
|
|
|
|
|
|
|
#[arg(long)]
|
|
#[arg(long)]
|
|
@@ -157,8 +172,8 @@ struct Args {
|
|
|
top_p: Option<f64>,
|
|
top_p: Option<f64>,
|
|
|
|
|
|
|
|
/// The seed to use when generating random samples.
|
|
/// The seed to use when generating random samples.
|
|
|
- #[arg(long, default_value_t = 299792458)]
|
|
|
|
|
- seed: u64,
|
|
|
|
|
|
|
+ #[arg(long)]
|
|
|
|
|
+ seed: Option<u64>,
|
|
|
|
|
|
|
|
/// The length of the sample to generate (in tokens).
|
|
/// The length of the sample to generate (in tokens).
|
|
|
#[arg(long, short = 'n', default_value_t = 5000)]
|
|
#[arg(long, short = 'n', default_value_t = 5000)]
|
|
@@ -189,19 +204,28 @@ fn main() -> Result<(), ()> {
|
|
|
let args = Args::parse();
|
|
let args = Args::parse();
|
|
|
println!(
|
|
println!(
|
|
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
|
- candle::utils::with_avx(),
|
|
|
|
|
- candle::utils::with_neon(),
|
|
|
|
|
- candle::utils::with_simd128(),
|
|
|
|
|
- candle::utils::with_f16c()
|
|
|
|
|
|
|
+ candle::utils::with_avx().red(),
|
|
|
|
|
+ candle::utils::with_neon().red(),
|
|
|
|
|
+ candle::utils::with_simd128().red(),
|
|
|
|
|
+ candle::utils::with_f16c().red(),
|
|
|
);
|
|
);
|
|
|
println!(
|
|
println!(
|
|
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
|
|
- args.temperature.unwrap_or(0.95),
|
|
|
|
|
- args.repeat_penalty,
|
|
|
|
|
- args.repeat_last_n
|
|
|
|
|
|
|
+ args.temperature.unwrap_or(0.95).red(),
|
|
|
|
|
+ args.repeat_penalty.red(),
|
|
|
|
|
+ args.repeat_last_n.red(),
|
|
|
);
|
|
);
|
|
|
|
|
|
|
|
- println!("cache path {}", args.cache_path);
|
|
|
|
|
|
|
+ println!("cache path {}", args.cache_path.blue());
|
|
|
|
|
+ println!("Prompt: [{}]", args.prompt.green());
|
|
|
|
|
+ let mut seed: u64 = 0;
|
|
|
|
|
+ if let Some(_seed) = args.seed {
|
|
|
|
|
+ seed = _seed;
|
|
|
|
|
+ } else {
|
|
|
|
|
+ let mut rng = rand::thread_rng();
|
|
|
|
|
+ seed = rng.gen();
|
|
|
|
|
+ }
|
|
|
|
|
+ println!("Using Seed {}", seed.red());
|
|
|
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
|
|
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
|
|
|
.build()
|
|
.build()
|
|
|
.unwrap();
|
|
.unwrap();
|
|
@@ -209,7 +233,7 @@ fn main() -> Result<(), ()> {
|
|
|
let model_id = match args.model_id {
|
|
let model_id = match args.model_id {
|
|
|
Some(model_id) => model_id.to_string(),
|
|
Some(model_id) => model_id.to_string(),
|
|
|
None => "THUDM/codegeex4-all-9b".to_string(),
|
|
None => "THUDM/codegeex4-all-9b".to_string(),
|
|
|
- };
|
|
|
|
|
|
|
+p };
|
|
|
let revision = match args.revision {
|
|
let revision = match args.revision {
|
|
|
Some(rev) => rev.to_string(),
|
|
Some(rev) => rev.to_string(),
|
|
|
None => "main".to_string(),
|
|
None => "main".to_string(),
|
|
@@ -237,16 +261,16 @@ fn main() -> Result<(), ()> {
|
|
|
} else {
|
|
} else {
|
|
|
DType::F32
|
|
DType::F32
|
|
|
};
|
|
};
|
|
|
- println!("dtype is {:?}", dtype);
|
|
|
|
|
|
|
+ println!("DType is {:?}", dtype.yellow());
|
|
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device).unwrap() };
|
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device).unwrap() };
|
|
|
let model = Model::new(&config, vb).unwrap();
|
|
let model = Model::new(&config, vb).unwrap();
|
|
|
|
|
|
|
|
- println!("模型加载完毕 {:?}", start.elapsed().as_secs());
|
|
|
|
|
|
|
+ println!("模型加载完毕 {:?}", start.elapsed().as_secs().green());
|
|
|
|
|
|
|
|
let mut pipeline = TextGeneration::new(
|
|
let mut pipeline = TextGeneration::new(
|
|
|
model,
|
|
model,
|
|
|
tokenizer,
|
|
tokenizer,
|
|
|
- args.seed,
|
|
|
|
|
|
|
+ seed,
|
|
|
args.temperature,
|
|
args.temperature,
|
|
|
args.top_p,
|
|
args.top_p,
|
|
|
args.repeat_penalty,
|
|
args.repeat_penalty,
|
|
@@ -255,6 +279,6 @@ fn main() -> Result<(), ()> {
|
|
|
&device,
|
|
&device,
|
|
|
dtype,
|
|
dtype,
|
|
|
);
|
|
);
|
|
|
- pipeline.run(&args.prompt, args.sample_len)?;
|
|
|
|
|
|
|
+ pipeline.run(args.sample_len)?;
|
|
|
Ok(())
|
|
Ok(())
|
|
|
}
|
|
}
|