|
|
@@ -52,6 +52,10 @@ fn d_dtype() -> Type {
|
|
|
Type::Dimension(DType::from_type_variable(TypeVariable::new("D")))
|
|
|
}
|
|
|
|
|
|
+fn d_squared() -> Type {
|
|
|
+ Type::Dimension(DType::from_type_variable(TypeVariable::new("D")).power(2.into()))
|
|
|
+}
|
|
|
+
|
|
|
fn e_type() -> Type {
|
|
|
Type::TVar(TypeVariable::new("E"))
|
|
|
}
|
|
|
@@ -211,6 +215,27 @@ fn dimension_types_multiplication() {
|
|
|
);
|
|
|
}
|
|
|
|
|
|
+#[test]
|
|
|
+fn dimension_types_exponentiation() {
|
|
|
+ assert_eq!(
|
|
|
+ get_inferred_fn_type("fn f(x: Scalar, y) = x^y"),
|
|
|
+ fn_type!(scalar(), scalar() => scalar())
|
|
|
+ );
|
|
|
+ assert_eq!(
|
|
|
+ get_inferred_fn_type("fn f(x) = x^2"),
|
|
|
+ fn_type!(forall d_type(); dim d_type(); d_type() => d_squared())
|
|
|
+ );
|
|
|
+ assert_eq!(
|
|
|
+ get_inferred_fn_type("fn f(x) = 2^x"),
|
|
|
+ fn_type!(scalar() => scalar())
|
|
|
+ );
|
|
|
+
|
|
|
+ assert!(matches!(
|
|
|
+ get_typecheck_error("fn f(x, y) = x^y"),
|
|
|
+ TypeCheckError::ExponentiationNeedsTypeAnnotation(..)
|
|
|
+ ));
|
|
|
+}
|
|
|
+
|
|
|
#[test]
|
|
|
fn dimension_types_combinations() {
|
|
|
assert_eq!(
|