|
|
@@ -16,6 +16,10 @@ fn a() -> Type {
|
|
|
Type::Dimension(type_a())
|
|
|
}
|
|
|
|
|
|
+fn a_squared() -> Type {
|
|
|
+ Type::Dimension(type_a().power(2.into()))
|
|
|
+}
|
|
|
+
|
|
|
fn b() -> Type {
|
|
|
Type::Dimension(type_b())
|
|
|
}
|
|
|
@@ -56,6 +60,14 @@ fn d_squared() -> Type {
|
|
|
Type::Dimension(DType::from_type_variable(TypeVariable::new("D")).power(2.into()))
|
|
|
}
|
|
|
|
|
|
+fn d_cubed() -> Type {
|
|
|
+ Type::Dimension(DType::from_type_variable(TypeVariable::new("D")).power(3.into()))
|
|
|
+}
|
|
|
+
|
|
|
+fn d_power6() -> Type {
|
|
|
+ Type::Dimension(DType::from_type_variable(TypeVariable::new("D")).power(6.into()))
|
|
|
+}
|
|
|
+
|
|
|
fn e_type() -> Type {
|
|
|
Type::TVar(TypeVariable::new("E"))
|
|
|
}
|
|
|
@@ -96,6 +108,13 @@ macro_rules! fn_type {
|
|
|
};
|
|
|
}
|
|
|
|
|
|
+macro_rules! concrete_fn_type {
|
|
|
+ ($($param_types:expr),* => $return_type:expr) => {
|
|
|
+ Type::Fn(vec![$($param_types),*], Box::new($return_type))
|
|
|
+ };
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
#[test]
|
|
|
fn if_then_else() {
|
|
|
assert_eq!(
|
|
|
@@ -311,6 +330,10 @@ fn dimension_types_combinations() {
|
|
|
get_inferred_fn_type("fn f(x) = (x * b) + c"),
|
|
|
fn_type!(a() => c())
|
|
|
);
|
|
|
+ assert_eq!(
|
|
|
+ get_inferred_fn_type("fn f(x) = x^2 + a^2"),
|
|
|
+ fn_type!(a() => a_squared())
|
|
|
+ );
|
|
|
|
|
|
assert!(matches!(
|
|
|
get_typecheck_error("fn f(x) = (x + a) * (x + b)"),
|
|
|
@@ -318,9 +341,60 @@ fn dimension_types_combinations() {
|
|
|
));
|
|
|
}
|
|
|
|
|
|
+#[test]
|
|
|
+fn dimension_types_gauss_elimination() {
|
|
|
+ assert_eq!(
|
|
|
+ get_inferred_fn_type("fn f(x, y) = x^2 + y^3"),
|
|
|
+ fn_type!(forall d_type(); dim d_type(); d_cubed(), d_squared() => d_power6())
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ get_inferred_fn_type("fn f(x) = x^2 + x"),
|
|
|
+ fn_type!(scalar() => scalar())
|
|
|
+ );
|
|
|
+}
|
|
|
+
|
|
|
#[test]
|
|
|
fn function_types() {
|
|
|
- // TODO
|
|
|
- // λf. λx. f (f x) : (T0 -> T0) -> T0 -> T0
|
|
|
- // λf. λx. (f (1 m)) + x : Dim(T0) => (Length -> T0) -> T0 -> T0
|
|
|
+ assert_eq!(
|
|
|
+ get_inferred_fn_type("fn f(g) = g() + a"),
|
|
|
+ fn_type!(concrete_fn_type!(/* no params */ => a()) => a())
|
|
|
+ );
|
|
|
+ assert_eq!(
|
|
|
+ get_inferred_fn_type("fn f(g) = g(a) + a"),
|
|
|
+ fn_type!(concrete_fn_type!(a() => a()) => a())
|
|
|
+ );
|
|
|
+ assert_eq!(
|
|
|
+ get_inferred_fn_type("fn f(g) = g(a)"),
|
|
|
+ fn_type!(forall t(); concrete_fn_type!(a() => t()) => t())
|
|
|
+ );
|
|
|
+ assert_eq!(
|
|
|
+ get_inferred_fn_type("fn apply(g, x) = g(x)"),
|
|
|
+ fn_type!(forall s(), t(); concrete_fn_type!(t() => s()), t() => s())
|
|
|
+ );
|
|
|
+ assert_eq!(
|
|
|
+ get_inferred_fn_type("fn twice(g, x) = g(g(x))"),
|
|
|
+ fn_type!(forall t(); concrete_fn_type!(t() => t()), t() => t())
|
|
|
+ );
|
|
|
+
|
|
|
+ assert!(matches!(
|
|
|
+ get_typecheck_error("fn f(g) = if true then g() else g(1)"),
|
|
|
+ TypeCheckError::ConstraintSolverError(..)
|
|
|
+ ));
|
|
|
+}
|
|
|
+
|
|
|
+#[test]
|
|
|
+fn recursive_functions() {
|
|
|
+ assert_eq!(
|
|
|
+ get_inferred_fn_type("fn absurd() = absurd()"),
|
|
|
+ fn_type!(forall t(); /* no params */ => t())
|
|
|
+ );
|
|
|
+ assert_eq!(
|
|
|
+ get_inferred_fn_type("fn loop(x) = loop(x)"),
|
|
|
+ fn_type!(forall s(), t(); s() => t())
|
|
|
+ );
|
|
|
+ assert_eq!(
|
|
|
+ get_inferred_fn_type("fn fac(n) = if n == 0 then 1 else n * fac(n - 1)"),
|
|
|
+ fn_type!(scalar() => scalar())
|
|
|
+ );
|
|
|
}
|