|
|
@@ -1,4 +1,3 @@
|
|
|
-//use anyhow::{Error as E, Result};
|
|
|
use clap::Parser;
|
|
|
use codegeex4_candle::codegeex4::*;
|
|
|
|
|
|
@@ -18,6 +17,7 @@ struct TextGeneration {
|
|
|
repeat_penalty: f32,
|
|
|
repeat_last_n: usize,
|
|
|
verbose_prompt: bool,
|
|
|
+ dtype: DType,
|
|
|
}
|
|
|
|
|
|
impl TextGeneration {
|
|
|
@@ -32,6 +32,7 @@ impl TextGeneration {
|
|
|
repeat_last_n: usize,
|
|
|
verbose_prompt: bool,
|
|
|
device: &Device,
|
|
|
+ dtype: DType,
|
|
|
) -> Self {
|
|
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
|
Self {
|
|
|
@@ -42,6 +43,7 @@ impl TextGeneration {
|
|
|
repeat_last_n,
|
|
|
verbose_prompt,
|
|
|
device: device.clone(),
|
|
|
+ dtype,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -49,7 +51,6 @@ impl TextGeneration {
|
|
|
use std::io::Write;
|
|
|
println!("starting the inference loop");
|
|
|
let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
|
|
|
- println!("run starting the token 57");
|
|
|
if tokens.is_empty() {
|
|
|
panic!("Empty prompts are not supported in the chatglm model.")
|
|
|
}
|
|
|
@@ -82,7 +83,7 @@ impl TextGeneration {
|
|
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
|
|
let input = Tensor::new(ctxt, &self.device).unwrap().unsqueeze(0).expect("create tensor input error");
|
|
|
let logits = self.model.forward(&input).unwrap();
|
|
|
- let logits = logits.squeeze(0).unwrap().to_dtype(DType::F32).unwrap();
|
|
|
+ let logits = logits.squeeze(0).unwrap().to_dtype(self.dtype).unwrap();
|
|
|
let logits = if self.repeat_penalty == 1. {
|
|
|
logits
|
|
|
} else {
|
|
|
@@ -239,6 +240,7 @@ fn main() -> Result<(),()> {
|
|
|
args.repeat_last_n,
|
|
|
args.verbose_prompt,
|
|
|
&device,
|
|
|
+ dtype,
|
|
|
);
|
|
|
pipeline.run(&args.prompt, args.sample_len)?;
|
|
|
Ok(())
|