donjuanplatinum 1 год назад
Родитель
Сommit
657f200f41
1 измененных файлов с 7 добавлено и 5 удалено
  1. 7 5
      candle_demo/src/codegeex4.rs

+ 7 - 5
candle_demo/src/codegeex4.rs

@@ -107,17 +107,18 @@ impl RotaryEmbedding {
 struct CoreAttention {
     coeff: Option<f64>,
     norm_factor: f64,
+    dtype: DType,
 }
 
-fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32,dtype:DType) -> Result<Tensor> {
     let shape = mask.shape();
     let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
-    let m = mask.where_cond(&on_true.to_dtype(DType::BF16)?, on_false)?;
+    let m = mask.where_cond(&on_true.to_dtype(dtype)?, on_false)?;
     Ok(m)
 }
 
 impl CoreAttention {
-    fn new(layer_number: usize, cfg: &Config) -> Result<Self> {
+    fn new(layer_number: usize, cfg: &Config,dtype: DType) -> Result<Self> {
         let norm_factor = (cfg.kv_channels as f64).sqrt();
         let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling {
             let coeff = f64::max(1.0, layer_number as f64);
@@ -125,7 +126,7 @@ impl CoreAttention {
         } else {
             (norm_factor, None)
         };
-        Ok(Self { coeff, norm_factor })
+        Ok(Self { coeff, norm_factor, dtype})
     }
 
     fn forward(
@@ -158,6 +159,7 @@ impl CoreAttention {
                 &matmul_result,
                 &mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?,
                 f32::NEG_INFINITY,
+		self.dtype,
             )?,
             None => matmul_result,
         };
@@ -207,7 +209,7 @@ impl SelfAttention {
             cfg.add_bias_linear || cfg.add_qkv_bias,
             vb.pp("query_key_value"),
         )?;
-        let core_attention = CoreAttention::new(layer_number, cfg)?;
+        let core_attention = CoreAttention::new(layer_number, cfg,vb.dtype())?;
         let dense = linear(
             cfg.hidden_size,
             cfg.hidden_size,