|
|
@@ -59,16 +59,22 @@ impl TextGeneration {
|
|
|
println!("{id:7} -> '{token}'");
|
|
|
}
|
|
|
}
|
|
|
+ let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
|
|
+ Some(token) => *token,
|
|
|
+ None => panic!("cannot find the endoftext token"),
|
|
|
+ };
|
|
|
let mut tokens = tokens.get_ids().to_vec();
|
|
|
let mut generated_tokens = 0usize;
|
|
|
- let eos_token = 151329;
|
|
|
+
|
|
|
|
|
|
print!("{prompt}");
|
|
|
std::io::stdout().flush().expect("output flush error");
|
|
|
let start_gen = std::time::Instant::now();
|
|
|
- println!("start_gen");
|
|
|
+
|
|
|
+ println!("\n start_gen");
|
|
|
println!("samplelen {}",sample_len);
|
|
|
let mut count = 0;
|
|
|
+ let mut result = vec!();
|
|
|
for index in 0..sample_len {
|
|
|
count += 1;
|
|
|
println!("sample count {}",count);
|
|
|
@@ -96,7 +102,8 @@ impl TextGeneration {
|
|
|
}
|
|
|
println!("raw generate token {}",next_token);
|
|
|
let token = self.tokenizer.decode(&[next_token], true).expect("Token error");
|
|
|
- print!("{token}");
|
|
|
+ println!("[token:{token}]");
|
|
|
+ result.push(token);
|
|
|
std::io::stdout().flush().unwrap();
|
|
|
}
|
|
|
let dt = start_gen.elapsed();
|
|
|
@@ -104,6 +111,10 @@ impl TextGeneration {
|
|
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
|
generated_tokens as f64 / dt.as_secs_f64(),
|
|
|
);
|
|
|
+ println!("Result:");
|
|
|
+ for tokens in result {
|
|
|
+ print!("{tokens}");
|
|
|
+ }
|
|
|
Ok(())
|
|
|
}
|
|
|
}
|
|
|
@@ -208,7 +219,12 @@ fn main() -> Result<(),()> {
|
|
|
let start = std::time::Instant::now();
|
|
|
let config = Config::codegeex4();
|
|
|
let device = candle_examples::device(args.cpu).unwrap();
|
|
|
- let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device).unwrap() };
|
|
|
+ let dtype = if device.is_cuda() {
|
|
|
+ DType::BF16
|
|
|
+ } else {
|
|
|
+ DType::F32
|
|
|
+ };
|
|
|
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device).unwrap() };
|
|
|
let model = Model::new(&config, vb).unwrap();
|
|
|
|
|
|
println!("loaded the model in {:?}", start.elapsed());
|