Browse Source

Instantiation tests

David Peter 1 year ago
parent
commit
c3acc33155

+ 1 - 0
numbat/src/typechecker/constraints.rs

@@ -376,6 +376,7 @@ impl Constraint {
     fn get_dtype_constraint_type_variable(&self) -> Option<TypeVariable> {
         match self {
             Constraint::IsDType(Type::TVar(tvar)) => Some(tvar.clone()),
+            Constraint::IsDType(Type::TPar(name)) => Some(TypeVariable::new(name.clone())),
             _ => None,
         }
     }

+ 9 - 0
numbat/src/typechecker/mod.rs

@@ -1297,6 +1297,15 @@ impl TypeChecker {
                         type_parameter.clone(),
                         bound.clone(),
                     ));
+
+                    match bound {
+                        Some(TypeParameterBound::Dim) => {
+                            typechecker_fn
+                                .add_dtype_constraint(&Type::TPar(type_parameter.clone()))
+                                .ok();
+                        }
+                        None => {}
+                    }
                 }
 
                 let mut typed_parameters = vec![];

+ 1 - 0
numbat/src/typechecker/tests/mod.rs

@@ -34,6 +34,7 @@ const TEST_PRELUDE: &str = "
     fn head<T>(x: List<T>) -> T
 
     fn id<T>(x: T) -> T = x
+    fn id_for_dim<T: Dim>(x: T) -> T = x
     ";
 
 fn type_a() -> DType {

+ 14 - 0
numbat/src/typechecker/tests/type_checking.rs

@@ -730,3 +730,17 @@ fn name_resolution() {
             ",
     );
 }
+
+#[test]
+fn instantiation() {
+    assert_successful_typecheck("id(1)");
+    assert_successful_typecheck("id(1 a) / id(1 b)");
+    assert_successful_typecheck("if id(true) then id(1) else id(2)");
+
+    assert_successful_typecheck("id_for_dim(1)");
+    assert_successful_typecheck("id(1 a) / id(1 b)");
+    assert!(matches!(
+        get_typecheck_error("id_for_dim(true)"),
+        TypeCheckError::ConstraintSolverError(..)
+    ));
+}

+ 0 - 7
numbat/src/typechecker/tests/type_inference.rs

@@ -234,10 +234,3 @@ fn function_types() {
     // λf. λx. f (f x)  :  (T0 -> T0) -> T0 -> T0
     // λf. λx. (f (1 m)) + x  :  Dim(T0) => (Length -> T0) -> T0 -> T0
 }
-
-#[test]
-fn instantiation() {
-    // TODO
-    // make sure that e.g. `id` can be used twice in the same expression,
-    // but with different types
-}