Browse Source

Add tests for exponentiation

David Peter 1 year ago
parent
commit
198809f578

+ 0 - 12
numbat/src/typechecker/tests/type_checking.rs

@@ -72,18 +72,6 @@ fn exponentiation_with_dimensionful_base() {
     ));
 }
 
-#[test]
-fn exponentiation_type_inference() {
-    assert_successful_typecheck("fn f(x: Scalar, y) = x^y");
-    assert_successful_typecheck("fn f(x) = x^2");
-    assert_successful_typecheck("fn f(x) = 2^x");
-
-    assert!(matches!(
-        get_typecheck_error("fn f(x, y) = x^y"),
-        TypeCheckError::ExponentiationNeedsTypeAnnotation(..)
-    ));
-}
-
 #[test]
 fn equality() {
     assert_successful_typecheck("2 a == a");

+ 25 - 0
numbat/src/typechecker/tests/type_inference.rs

@@ -52,6 +52,10 @@ fn d_dtype() -> Type {
     Type::Dimension(DType::from_type_variable(TypeVariable::new("D")))
 }
 
+fn d_squared() -> Type {
+    Type::Dimension(DType::from_type_variable(TypeVariable::new("D")).power(2.into()))
+}
+
 fn e_type() -> Type {
     Type::TVar(TypeVariable::new("E"))
 }
@@ -211,6 +215,27 @@ fn dimension_types_multiplication() {
     );
 }
 
+#[test]
+fn dimension_types_exponentiation() {
+    assert_eq!(
+        get_inferred_fn_type("fn f(x: Scalar, y) = x^y"),
+        fn_type!(scalar(), scalar() => scalar())
+    );
+    assert_eq!(
+        get_inferred_fn_type("fn f(x) = x^2"),
+        fn_type!(forall d_type(); dim d_type(); d_type() => d_squared())
+    );
+    assert_eq!(
+        get_inferred_fn_type("fn f(x) = 2^x"),
+        fn_type!(scalar() => scalar())
+    );
+
+    assert!(matches!(
+        get_typecheck_error("fn f(x, y) = x^y"),
+        TypeCheckError::ExponentiationNeedsTypeAnnotation(..)
+    ));
+}
+
 #[test]
 fn dimension_types_combinations() {
     assert_eq!(