|
|
@@ -68,7 +68,7 @@ impl RotaryEmbedding {
|
|
|
let inv_freq_len = inv_freq.len();
|
|
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
|
|
let t = Tensor::arange(0u32, cfg.seq_length as u32, dev)?
|
|
|
- .to_dtype(dtype)?
|
|
|
+ .to_dtype(dtype).expect("unalbe to dytpe in Rotray Embedding new")
|
|
|
.reshape((cfg.seq_length, 1))?;
|
|
|
let freqs = t.matmul(&inv_freq)?;
|
|
|
let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?;
|