Browse Source

More inference tests

David Peter 1 year ago
parent
commit
d2f480a6aa

+ 7 - 2
numbat/src/typechecker/constraints.rs

@@ -251,7 +251,10 @@ impl Constraint {
                 )))
             }
             Constraint::Equal(Type::Dimension(dtype_x), t)
-                if dtype_x.deconstruct_as_single_type_variable().is_some() =>
+                if dtype_x
+                    .deconstruct_as_single_type_variable()
+                    .map(|tv| !dtype_x.contains(&tv, false))
+                    .unwrap_or(false) =>
             {
                 let x = dtype_x.deconstruct_as_single_type_variable().unwrap();
                 debug!(
@@ -264,7 +267,9 @@ impl Constraint {
                     t.clone(),
                 )))
             }
-            Constraint::Equal(t @ Type::Fn(params1, return1), s @ Type::Fn(params2, return2)) => {
+            Constraint::Equal(t @ Type::Fn(params1, return1), s @ Type::Fn(params2, return2))
+                if params1.len() == params2.len() =>
+            {
                 debug!(
                     "  (4) SOLVING: {t} ~ {s} with new constraints for all parameters and return types",
                     t = t,

+ 77 - 3
numbat/src/typechecker/tests/type_inference.rs

@@ -16,6 +16,10 @@ fn a() -> Type {
     Type::Dimension(type_a())
 }
 
+fn a_squared() -> Type {
+    Type::Dimension(type_a().power(2.into()))
+}
+
 fn b() -> Type {
     Type::Dimension(type_b())
 }
@@ -56,6 +60,14 @@ fn d_squared() -> Type {
     Type::Dimension(DType::from_type_variable(TypeVariable::new("D")).power(2.into()))
 }
 
+fn d_cubed() -> Type {
+    Type::Dimension(DType::from_type_variable(TypeVariable::new("D")).power(3.into()))
+}
+
+fn d_power6() -> Type {
+    Type::Dimension(DType::from_type_variable(TypeVariable::new("D")).power(6.into()))
+}
+
 fn e_type() -> Type {
     Type::TVar(TypeVariable::new("E"))
 }
@@ -96,6 +108,13 @@ macro_rules! fn_type {
     };
 }
 
+macro_rules! concrete_fn_type {
+    ($($param_types:expr),* => $return_type:expr) => {
+        Type::Fn(vec![$($param_types),*], Box::new($return_type))
+    };
+
+}
+
 #[test]
 fn if_then_else() {
     assert_eq!(
@@ -311,6 +330,10 @@ fn dimension_types_combinations() {
         get_inferred_fn_type("fn f(x) = (x * b) + c"),
         fn_type!(a() => c())
     );
+    assert_eq!(
+        get_inferred_fn_type("fn f(x) = x^2 + a^2"),
+        fn_type!(a() => a_squared())
+    );
 
     assert!(matches!(
         get_typecheck_error("fn f(x) = (x + a) * (x + b)"),
@@ -318,9 +341,60 @@ fn dimension_types_combinations() {
     ));
 }
 
+#[test]
+fn dimension_types_gauss_elimination() {
+    assert_eq!(
+        get_inferred_fn_type("fn f(x, y) = x^2 + y^3"),
+        fn_type!(forall d_type(); dim d_type(); d_cubed(), d_squared() => d_power6())
+    );
+
+    assert_eq!(
+        get_inferred_fn_type("fn f(x) = x^2 + x"),
+        fn_type!(scalar() => scalar())
+    );
+}
+
 #[test]
 fn function_types() {
-    // TODO
-    // λf. λx. f (f x)  :  (T0 -> T0) -> T0 -> T0
-    // λf. λx. (f (1 m)) + x  :  Dim(T0) => (Length -> T0) -> T0 -> T0
+    assert_eq!(
+        get_inferred_fn_type("fn f(g) = g() + a"),
+        fn_type!(concrete_fn_type!(/* no params */ => a()) => a())
+    );
+    assert_eq!(
+        get_inferred_fn_type("fn f(g) = g(a) + a"),
+        fn_type!(concrete_fn_type!(a() => a()) => a())
+    );
+    assert_eq!(
+        get_inferred_fn_type("fn f(g) = g(a)"),
+        fn_type!(forall t(); concrete_fn_type!(a() => t()) => t())
+    );
+    assert_eq!(
+        get_inferred_fn_type("fn apply(g, x) = g(x)"),
+        fn_type!(forall s(), t(); concrete_fn_type!(t() => s()), t() => s())
+    );
+    assert_eq!(
+        get_inferred_fn_type("fn twice(g, x) = g(g(x))"),
+        fn_type!(forall t(); concrete_fn_type!(t() => t()), t() => t())
+    );
+
+    assert!(matches!(
+        get_typecheck_error("fn f(g) = if true then g() else g(1)"),
+        TypeCheckError::ConstraintSolverError(..)
+    ));
+}
+
+#[test]
+fn recursive_functions() {
+    assert_eq!(
+        get_inferred_fn_type("fn absurd() = absurd()"),
+        fn_type!(forall t(); /* no params */ => t())
+    );
+    assert_eq!(
+        get_inferred_fn_type("fn loop(x) = loop(x)"),
+        fn_type!(forall s(), t(); s() => t())
+    );
+    assert_eq!(
+        get_inferred_fn_type("fn fac(n) = if n == 0 then 1 else n * fac(n - 1)"),
+        fn_type!(scalar() => scalar())
+    );
 }