|
@@ -1,24 +1,16 @@
|
|
|
-#[cfg(feature = "mkl")]
|
|
|
|
|
-extern crate intel_mkl_src;
|
|
|
|
|
|
|
+pub mod codegeex4;
|
|
|
|
|
|
|
|
-#[cfg(feature = "accelerate")]
|
|
|
|
|
-extern crate accelerate_src;
|
|
|
|
|
|
|
+pub mod args;
|
|
|
|
|
|
|
|
-use clap::Parser;
|
|
|
|
|
-use codegeex4_candle::codegeex4::*;
|
|
|
|
|
|
|
+use candle_core::{DType, Device, Tensor};
|
|
|
|
|
+use candle_transformers::generation::LogitsProcessor;
|
|
|
|
|
+use codegeex4::*;
|
|
|
use owo_colors::{self, OwoColorize};
|
|
use owo_colors::{self, OwoColorize};
|
|
|
use std::io::BufRead;
|
|
use std::io::BufRead;
|
|
|
use std::io::BufReader;
|
|
use std::io::BufReader;
|
|
|
-
|
|
|
|
|
-use candle_core as candle;
|
|
|
|
|
-use candle_core::{DType, Device, Tensor};
|
|
|
|
|
-use candle_nn::VarBuilder;
|
|
|
|
|
-use candle_transformers::generation::LogitsProcessor;
|
|
|
|
|
-use hf_hub::{Repo, RepoType};
|
|
|
|
|
-use rand::Rng;
|
|
|
|
|
use tokenizers::Tokenizer;
|
|
use tokenizers::Tokenizer;
|
|
|
|
|
|
|
|
-struct TextGeneration {
|
|
|
|
|
|
|
+pub struct TextGeneration {
|
|
|
model: Model,
|
|
model: Model,
|
|
|
device: Device,
|
|
device: Device,
|
|
|
tokenizer: Tokenizer,
|
|
tokenizer: Tokenizer,
|
|
@@ -31,7 +23,7 @@ struct TextGeneration {
|
|
|
|
|
|
|
|
impl TextGeneration {
|
|
impl TextGeneration {
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
#[allow(clippy::too_many_arguments)]
|
|
|
- fn new(
|
|
|
|
|
|
|
+ pub fn new(
|
|
|
model: Model,
|
|
model: Model,
|
|
|
tokenizer: Tokenizer,
|
|
tokenizer: Tokenizer,
|
|
|
seed: u64,
|
|
seed: u64,
|
|
@@ -56,7 +48,7 @@ impl TextGeneration {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- fn run(&mut self, sample_len: usize) -> Result<(), ()> {
|
|
|
|
|
|
|
+ pub fn run(&mut self, sample_len: usize) -> Result<(), ()> {
|
|
|
use std::io::Write;
|
|
use std::io::Write;
|
|
|
|
|
|
|
|
let stdin = std::io::stdin();
|
|
let stdin = std::io::stdin();
|
|
@@ -145,140 +137,3 @@ impl TextGeneration {
|
|
|
Ok(())
|
|
Ok(())
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
-#[derive(Parser, Debug)]
|
|
|
|
|
-#[command(author, version, about, long_about = None)]
|
|
|
|
|
-struct Args {
|
|
|
|
|
- /// Run on CPU rather than on GPU.
|
|
|
|
|
- #[arg(name = "cache", short, long, default_value = ".")]
|
|
|
|
|
- cache_path: String,
|
|
|
|
|
-
|
|
|
|
|
- #[arg(long)]
|
|
|
|
|
- cpu: bool,
|
|
|
|
|
-
|
|
|
|
|
- /// Display the token for the specified prompt.
|
|
|
|
|
- #[arg(long)]
|
|
|
|
|
- verbose_prompt: bool,
|
|
|
|
|
-
|
|
|
|
|
- #[arg(long)]
|
|
|
|
|
- prompt: String,
|
|
|
|
|
-
|
|
|
|
|
- /// The temperature used to generate samples.
|
|
|
|
|
- #[arg(long)]
|
|
|
|
|
- temperature: Option<f64>,
|
|
|
|
|
-
|
|
|
|
|
- /// Nucleus sampling probability cutoff.
|
|
|
|
|
- #[arg(long)]
|
|
|
|
|
- top_p: Option<f64>,
|
|
|
|
|
-
|
|
|
|
|
- /// The seed to use when generating random samples.
|
|
|
|
|
- #[arg(long)]
|
|
|
|
|
- seed: Option<u64>,
|
|
|
|
|
-
|
|
|
|
|
- /// The length of the sample to generate (in tokens).
|
|
|
|
|
- #[arg(long, short = 'n', default_value_t = 5000)]
|
|
|
|
|
- sample_len: usize,
|
|
|
|
|
-
|
|
|
|
|
- #[arg(long)]
|
|
|
|
|
- model_id: Option<String>,
|
|
|
|
|
-
|
|
|
|
|
- #[arg(long)]
|
|
|
|
|
- revision: Option<String>,
|
|
|
|
|
-
|
|
|
|
|
- #[arg(long)]
|
|
|
|
|
- weight_file: Option<String>,
|
|
|
|
|
-
|
|
|
|
|
- #[arg(long)]
|
|
|
|
|
- tokenizer: Option<String>,
|
|
|
|
|
-
|
|
|
|
|
- /// Penalty to be applied for repeating tokens, 1. means no penalty.
|
|
|
|
|
- #[arg(long, default_value_t = 1.1)]
|
|
|
|
|
- repeat_penalty: f32,
|
|
|
|
|
-
|
|
|
|
|
- /// The context size to consider for the repeat penalty.
|
|
|
|
|
- #[arg(long, default_value_t = 64)]
|
|
|
|
|
- repeat_last_n: usize,
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-fn main() -> Result<(), ()> {
|
|
|
|
|
- let args = Args::parse();
|
|
|
|
|
- println!(
|
|
|
|
|
- "avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
|
|
|
- candle::utils::with_avx().red(),
|
|
|
|
|
- candle::utils::with_neon().red(),
|
|
|
|
|
- candle::utils::with_simd128().red(),
|
|
|
|
|
- candle::utils::with_f16c().red(),
|
|
|
|
|
- );
|
|
|
|
|
- println!(
|
|
|
|
|
- "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
|
|
|
|
- args.temperature.unwrap_or(0.95).red(),
|
|
|
|
|
- args.repeat_penalty.red(),
|
|
|
|
|
- args.repeat_last_n.red(),
|
|
|
|
|
- );
|
|
|
|
|
-
|
|
|
|
|
- println!("cache path {}", args.cache_path.blue());
|
|
|
|
|
- println!("Prompt: [{}]", args.prompt.green());
|
|
|
|
|
- let mut seed: u64 = 0;
|
|
|
|
|
- if let Some(_seed) = args.seed {
|
|
|
|
|
- seed = _seed;
|
|
|
|
|
- } else {
|
|
|
|
|
- let mut rng = rand::thread_rng();
|
|
|
|
|
- seed = rng.gen();
|
|
|
|
|
- }
|
|
|
|
|
- println!("Using Seed {}", seed.red());
|
|
|
|
|
- let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
|
|
|
|
|
- .build()
|
|
|
|
|
- .unwrap();
|
|
|
|
|
-
|
|
|
|
|
- let model_id = match args.model_id {
|
|
|
|
|
- Some(model_id) => model_id.to_string(),
|
|
|
|
|
- None => "THUDM/codegeex4-all-9b".to_string(),
|
|
|
|
|
-p };
|
|
|
|
|
- let revision = match args.revision {
|
|
|
|
|
- Some(rev) => rev.to_string(),
|
|
|
|
|
- None => "main".to_string(),
|
|
|
|
|
- };
|
|
|
|
|
- let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
|
|
|
|
- let tokenizer_filename = match args.tokenizer {
|
|
|
|
|
- Some(file) => std::path::PathBuf::from(file),
|
|
|
|
|
- None => api
|
|
|
|
|
- .model("THUDM/codegeex4-all-9b".to_string())
|
|
|
|
|
- .get("tokenizer.json")
|
|
|
|
|
- .unwrap(),
|
|
|
|
|
- };
|
|
|
|
|
- let filenames = match args.weight_file {
|
|
|
|
|
- Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
|
|
|
|
- None => {
|
|
|
|
|
- candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json").unwrap()
|
|
|
|
|
- }
|
|
|
|
|
- };
|
|
|
|
|
- let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
|
|
|
|
|
- let start = std::time::Instant::now();
|
|
|
|
|
- let config = Config::codegeex4();
|
|
|
|
|
- let device = candle_examples::device(args.cpu).unwrap();
|
|
|
|
|
- let dtype = if device.is_cuda() {
|
|
|
|
|
- DType::BF16
|
|
|
|
|
- } else {
|
|
|
|
|
- DType::F32
|
|
|
|
|
- };
|
|
|
|
|
- println!("DType is {:?}", dtype.yellow());
|
|
|
|
|
- let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device).unwrap() };
|
|
|
|
|
- let model = Model::new(&config, vb).unwrap();
|
|
|
|
|
-
|
|
|
|
|
- println!("模型加载完毕 {:?}", start.elapsed().as_secs().green());
|
|
|
|
|
-
|
|
|
|
|
- let mut pipeline = TextGeneration::new(
|
|
|
|
|
- model,
|
|
|
|
|
- tokenizer,
|
|
|
|
|
- seed,
|
|
|
|
|
- args.temperature,
|
|
|
|
|
- args.top_p,
|
|
|
|
|
- args.repeat_penalty,
|
|
|
|
|
- args.repeat_last_n,
|
|
|
|
|
- args.verbose_prompt,
|
|
|
|
|
- &device,
|
|
|
|
|
- dtype,
|
|
|
|
|
- );
|
|
|
|
|
- pipeline.run(args.sample_len)?;
|
|
|
|
|
- Ok(())
|
|
|
|
|
-}
|
|
|