|
|
@@ -51,11 +51,11 @@ impl TextGeneration {
|
|
|
pub fn run(&mut self, sample_len: usize) -> Result<(), ()> {
|
|
|
use std::io::Write;
|
|
|
|
|
|
+ println!("[欢迎使用Codegeex4,请输入prompt]");
|
|
|
let stdin = std::io::stdin();
|
|
|
let reader = BufReader::new(stdin);
|
|
|
// 从标准输入读取prompt
|
|
|
for line in reader.lines() {
|
|
|
- println!("[欢迎使用Codegeex4,请输入prompt]");
|
|
|
let line = line.expect("Failed to read line");
|
|
|
let tokens = self.tokenizer.encode(line, true).expect("tokens error");
|
|
|
if tokens.is_empty() {
|
|
|
@@ -132,8 +132,9 @@ impl TextGeneration {
|
|
|
for tokens in result {
|
|
|
print!("{tokens}");
|
|
|
}
|
|
|
+ self.model.reset_kv_cache(); // 清理模型kv
|
|
|
}
|
|
|
- self.model.reset_kv_cache(); // 清理模型kv
|
|
|
+
|
|
|
Ok(())
|
|
|
}
|
|
|
}
|