|
|
@@ -107,17 +107,18 @@ impl RotaryEmbedding {
|
|
|
struct CoreAttention {
|
|
|
coeff: Option<f64>,
|
|
|
norm_factor: f64,
|
|
|
+ dtype: DType,
|
|
|
}
|
|
|
|
|
|
-fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
|
|
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32,dtype:DType) -> Result<Tensor> {
|
|
|
let shape = mask.shape();
|
|
|
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
|
|
- let m = mask.where_cond(&on_true.to_dtype(DType::BF16)?, on_false)?;
|
|
|
+ let m = mask.where_cond(&on_true.to_dtype(dtype)?, on_false)?;
|
|
|
Ok(m)
|
|
|
}
|
|
|
|
|
|
impl CoreAttention {
|
|
|
- fn new(layer_number: usize, cfg: &Config) -> Result<Self> {
|
|
|
+ fn new(layer_number: usize, cfg: &Config,dtype: DType) -> Result<Self> {
|
|
|
let norm_factor = (cfg.kv_channels as f64).sqrt();
|
|
|
let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling {
|
|
|
let coeff = f64::max(1.0, layer_number as f64);
|
|
|
@@ -125,7 +126,7 @@ impl CoreAttention {
|
|
|
} else {
|
|
|
(norm_factor, None)
|
|
|
};
|
|
|
- Ok(Self { coeff, norm_factor })
|
|
|
+ Ok(Self { coeff, norm_factor, dtype})
|
|
|
}
|
|
|
|
|
|
fn forward(
|
|
|
@@ -158,6 +159,7 @@ impl CoreAttention {
|
|
|
&matmul_result,
|
|
|
&mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?,
|
|
|
f32::NEG_INFINITY,
|
|
|
+ self.dtype,
|
|
|
)?,
|
|
|
None => matmul_result,
|
|
|
};
|
|
|
@@ -207,7 +209,7 @@ impl SelfAttention {
|
|
|
cfg.add_bias_linear || cfg.add_qkv_bias,
|
|
|
vb.pp("query_key_value"),
|
|
|
)?;
|
|
|
- let core_attention = CoreAttention::new(layer_number, cfg)?;
|
|
|
+ let core_attention = CoreAttention::new(layer_number, cfg,vb.dtype())?;
|
|
|
let dense = linear(
|
|
|
cfg.hidden_size,
|
|
|
cfg.hidden_size,
|