|
|
@@ -1,12 +1,17 @@
|
|
|
+#[cfg(feature = "mkl")]
|
|
|
+extern crate intel_mkl_src;
|
|
|
+
|
|
|
+#[cfg(feature = "accelerate")]
|
|
|
+extern crate accelerate_src;
|
|
|
+
|
|
|
use clap::Parser;
|
|
|
use codegeex4_candle::codegeex4::*;
|
|
|
|
|
|
-
|
|
|
-use candle_core::{DType, Device, Tensor};
|
|
|
use candle_core as candle;
|
|
|
+use candle_core::{DType, Device, Tensor};
|
|
|
use candle_nn::VarBuilder;
|
|
|
use candle_transformers::generation::LogitsProcessor;
|
|
|
-use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
|
+use hf_hub::{Repo, RepoType};
|
|
|
use tokenizers::Tokenizer;
|
|
|
|
|
|
struct TextGeneration {
|
|
|
@@ -32,7 +37,7 @@ impl TextGeneration {
|
|
|
repeat_last_n: usize,
|
|
|
verbose_prompt: bool,
|
|
|
device: &Device,
|
|
|
- dtype: DType,
|
|
|
+ dtype: DType,
|
|
|
) -> Self {
|
|
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
|
Self {
|
|
|
@@ -43,11 +48,11 @@ impl TextGeneration {
|
|
|
repeat_last_n,
|
|
|
verbose_prompt,
|
|
|
device: device.clone(),
|
|
|
- dtype,
|
|
|
+ dtype,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- fn run(&mut self, prompt: &str, sample_len: usize) -> Result<(),()> {
|
|
|
+ fn run(&mut self, prompt: &str, sample_len: usize) -> Result<(), ()> {
|
|
|
use std::io::Write;
|
|
|
println!("starting the inference loop");
|
|
|
let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
|
|
|
@@ -66,22 +71,24 @@ impl TextGeneration {
|
|
|
};
|
|
|
let mut tokens = tokens.get_ids().to_vec();
|
|
|
let mut generated_tokens = 0usize;
|
|
|
-
|
|
|
-
|
|
|
+
|
|
|
print!("{prompt}");
|
|
|
std::io::stdout().flush().expect("output flush error");
|
|
|
let start_gen = std::time::Instant::now();
|
|
|
-
|
|
|
- println!("\n start_gen");
|
|
|
- println!("samplelen {}",sample_len);
|
|
|
- let mut count = 0;
|
|
|
- let mut result = vec!();
|
|
|
+
|
|
|
+ println!("\n 开始生成");
|
|
|
+ println!("samplelen {}", sample_len);
|
|
|
+ let mut count = 0;
|
|
|
+ let mut result = vec![];
|
|
|
for index in 0..sample_len {
|
|
|
- count += 1;
|
|
|
- println!("sample count {}",count);
|
|
|
+ count += 1;
|
|
|
+ println!("sample count {}", count);
|
|
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
|
|
- let input = Tensor::new(ctxt, &self.device).unwrap().unsqueeze(0).expect("create tensor input error");
|
|
|
+ let input = Tensor::new(ctxt, &self.device)
|
|
|
+ .unwrap()
|
|
|
+ .unsqueeze(0)
|
|
|
+ .expect("create tensor input error");
|
|
|
let logits = self.model.forward(&input).unwrap();
|
|
|
let logits = logits.squeeze(0).unwrap().to_dtype(self.dtype).unwrap();
|
|
|
let logits = if self.repeat_penalty == 1. {
|
|
|
@@ -92,7 +99,8 @@ impl TextGeneration {
|
|
|
&logits,
|
|
|
self.repeat_penalty,
|
|
|
&tokens[start_at..],
|
|
|
- ).unwrap()
|
|
|
+ )
|
|
|
+ .unwrap()
|
|
|
};
|
|
|
|
|
|
let next_token = self.logits_processor.sample(&logits).unwrap();
|
|
|
@@ -101,10 +109,13 @@ impl TextGeneration {
|
|
|
if next_token == eos_token {
|
|
|
break;
|
|
|
}
|
|
|
- println!("raw generate token {}",next_token);
|
|
|
- let token = self.tokenizer.decode(&[next_token], true).expect("Token error");
|
|
|
+ println!("raw generate token {}", next_token);
|
|
|
+ let token = self
|
|
|
+ .tokenizer
|
|
|
+ .decode(&[next_token], true)
|
|
|
+ .expect("Token error");
|
|
|
println!("[token:{token}]");
|
|
|
- result.push(token);
|
|
|
+ result.push(token);
|
|
|
std::io::stdout().flush().unwrap();
|
|
|
}
|
|
|
let dt = start_gen.elapsed();
|
|
|
@@ -112,10 +123,10 @@ impl TextGeneration {
|
|
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
|
generated_tokens as f64 / dt.as_secs_f64(),
|
|
|
);
|
|
|
- println!("Result:");
|
|
|
- for tokens in result {
|
|
|
- print!("{tokens}");
|
|
|
- }
|
|
|
+ println!("Result:");
|
|
|
+ for tokens in result {
|
|
|
+ print!("{tokens}");
|
|
|
+ }
|
|
|
Ok(())
|
|
|
}
|
|
|
}
|
|
|
@@ -124,14 +135,14 @@ impl TextGeneration {
|
|
|
#[command(author, version, about, long_about = None)]
|
|
|
struct Args {
|
|
|
/// Run on CPU rather than on GPU.
|
|
|
- #[arg(name="cache",short,long, default_value=".")]
|
|
|
+ #[arg(name = "cache", short, long, default_value = ".")]
|
|
|
cache_path: String,
|
|
|
-
|
|
|
+
|
|
|
#[arg(long)]
|
|
|
cpu: bool,
|
|
|
|
|
|
/// Display the token for the specified prompt.
|
|
|
- #[arg(long)]
|
|
|
+ #[arg(long,default_value_t=true)]
|
|
|
verbose_prompt: bool,
|
|
|
|
|
|
#[arg(long)]
|
|
|
@@ -174,8 +185,7 @@ struct Args {
|
|
|
repeat_last_n: usize,
|
|
|
}
|
|
|
|
|
|
-fn main() -> Result<(),()> {
|
|
|
-
|
|
|
+fn main() -> Result<(), ()> {
|
|
|
let args = Args::parse();
|
|
|
println!(
|
|
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
|
@@ -191,10 +201,11 @@ fn main() -> Result<(),()> {
|
|
|
args.repeat_last_n
|
|
|
);
|
|
|
|
|
|
- let start = std::time::Instant::now();
|
|
|
- println!("cache path {}",args.cache_path);
|
|
|
- let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into())).build().unwrap();
|
|
|
-
|
|
|
+ println!("cache path {}", args.cache_path);
|
|
|
+ let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
|
|
|
+ .build()
|
|
|
+ .unwrap();
|
|
|
+
|
|
|
let model_id = match args.model_id {
|
|
|
Some(model_id) => model_id.to_string(),
|
|
|
None => "THUDM/codegeex4-all-9b".to_string(),
|
|
|
@@ -208,27 +219,29 @@ fn main() -> Result<(),()> {
|
|
|
Some(file) => std::path::PathBuf::from(file),
|
|
|
None => api
|
|
|
.model("THUDM/codegeex4-all-9b".to_string())
|
|
|
- .get("tokenizer.json").unwrap(),
|
|
|
+ .get("tokenizer.json")
|
|
|
+ .unwrap(),
|
|
|
};
|
|
|
let filenames = match args.weight_file {
|
|
|
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
|
|
- None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json").unwrap(),
|
|
|
+ None => {
|
|
|
+ candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json").unwrap()
|
|
|
+ }
|
|
|
};
|
|
|
- println!("retrieved the files in {:?}", start.elapsed());
|
|
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
|
|
|
-
|
|
|
let start = std::time::Instant::now();
|
|
|
let config = Config::codegeex4();
|
|
|
let device = candle_examples::device(args.cpu).unwrap();
|
|
|
let dtype = if device.is_cuda() {
|
|
|
- DType::BF16
|
|
|
+ DType::BF16
|
|
|
} else {
|
|
|
- DType::F32
|
|
|
+ DType::F32
|
|
|
};
|
|
|
+ println!("dtype is {:?}", dtype);
|
|
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device).unwrap() };
|
|
|
let model = Model::new(&config, vb).unwrap();
|
|
|
|
|
|
- println!("loaded the model in {:?}", start.elapsed());
|
|
|
+ println!("模型加载完毕 {:?}", start.elapsed().as_secs());
|
|
|
|
|
|
let mut pipeline = TextGeneration::new(
|
|
|
model,
|
|
|
@@ -240,7 +253,7 @@ fn main() -> Result<(),()> {
|
|
|
args.repeat_last_n,
|
|
|
args.verbose_prompt,
|
|
|
&device,
|
|
|
- dtype,
|
|
|
+ dtype,
|
|
|
);
|
|
|
pipeline.run(&args.prompt, args.sample_len)?;
|
|
|
Ok(())
|