Browse Source

Re-enable some tests

David Peter 1 year ago
parent
commit
a01bfb1751

+ 40 - 40
numbat-cli/tests/integration.rs

@@ -17,51 +17,51 @@ fn numbat() -> Command {
     cmd
 }
 
-// #[test]
-// fn pass_expression_on_command_line() {
-//     numbat()
-//         .arg("--expression")
-//         .arg("2 meter + 3 meter")
-//         .assert()
-//         .success()
-//         .stdout(predicates::str::contains("5 m"));
+#[test]
+fn pass_expression_on_command_line() {
+    numbat()
+        .arg("--expression")
+        .arg("2 meter + 3 meter")
+        .assert()
+        .success()
+        .stdout(predicates::str::contains("5 m"));
 
-//     numbat()
-//         .arg("-e")
-//         .arg("let x = 2")
-//         .arg("-e")
-//         .arg("x^3")
-//         .assert()
-//         .success()
-//         .stdout(predicates::str::contains("8"));
+    numbat()
+        .arg("-e")
+        .arg("let x = 2")
+        .arg("-e")
+        .arg("x^3")
+        .assert()
+        .success()
+        .stdout(predicates::str::contains("8"));
 
-//     numbat()
-//         .arg("--expression")
-//         .arg("2 +/ 3")
-//         .assert()
-//         .stderr(predicates::str::contains("while parsing"));
+    numbat()
+        .arg("--expression")
+        .arg("2 +/ 3")
+        .assert()
+        .stderr(predicates::str::contains("while parsing"));
 
-//     numbat()
-//         .arg("--expression")
-//         .arg("2 meter + 3 second")
-//         .assert()
-//         .failure()
-//         .stderr(predicates::str::contains("while type checking"));
+    numbat()
+        .arg("--expression")
+        .arg("2 meter + 3 second")
+        .assert()
+        .failure()
+        .stderr(predicates::str::contains("while type checking"));
 
-//     numbat()
-//         .arg("--expression")
-//         .arg("1/0")
-//         .assert()
-//         .failure()
-//         .stderr(predicates::str::contains("runtime error"));
+    // numbat()
+    //     .arg("--expression")
+    //     .arg("1/0")
+    //     .assert()
+    //     .failure()
+    //     .stderr(predicates::str::contains("runtime error"));
 
-//     numbat()
-//         .arg("--expression")
-//         .arg("type(2 m/s)")
-//         .assert()
-//         .success()
-//         .stdout(predicates::str::contains("Length / Time"));
-// }
+    numbat()
+        .arg("--expression")
+        .arg("type(2 m/s)")
+        .assert()
+        .success()
+        .stdout(predicates::str::contains("Length / Time"));
+}
 
 #[test]
 fn read_code_from_file() {

+ 3 - 3
numbat/src/typechecker/constraints.rs

@@ -179,7 +179,7 @@ pub enum TrivialResultion {
 }
 
 impl TrivialResultion {
-    pub fn is_violated(self) -> bool {
+    pub fn is_trivially_violated(self) -> bool {
         matches!(self, TrivialResultion::Violated)
     }
 
@@ -302,7 +302,7 @@ impl Constraint {
             //     ]))
             // }
             Constraint::Equal(_, _) => None,
-            Constraint::IsDType(Type::Dimension(inner)) => {
+            Constraint::IsDType(Type::Dimension(_inner)) => {
                 Some(Satisfied::trivially()) // TODO: this is not correct, see below
 
                 // let new_constraints = inner
@@ -341,7 +341,7 @@ impl Constraint {
             //             Type::Dimension(result),
             //         )))
             //     }
-            _ => None,
+            // _ => None,
             // },
         }
     }

+ 2 - 2
numbat/src/typechecker/error.rs

@@ -19,10 +19,10 @@ pub enum TypeCheckError {
     IncompatibleDimensions(IncompatibleDimensionsError),
 
     #[error("Exponents need to be dimensionless (got {1}).")]
-    NonScalarExponent(Span, DType),
+    NonScalarExponent(Span, Type),
 
     #[error("Argument of factorial needs to be dimensionless (got {1}).")]
-    NonScalarFactorialArgument(Span, DType),
+    NonScalarFactorialArgument(Span, Type),
 
     #[error("Unsupported expression in const-evaluation of exponent: {1}.")]
     UnsupportedConstEvalExpression(Span, &'static str),

+ 107 - 66
numbat/src/typechecker/mod.rs

@@ -24,7 +24,7 @@ use crate::typed_ast::{self, DType, Expression, StructInfo, Type};
 use crate::{decorator, ffi, suggestion};
 
 use const_evaluation::evaluate_const_expr;
-use constraints::{Constraint, ConstraintSet};
+use constraints::{Constraint, ConstraintSet, TrivialResultion};
 use itertools::Itertools;
 use name_generator::NameGenerator;
 use num_traits::Zero;
@@ -75,6 +75,15 @@ impl TypeChecker {
         Type::TVar(self.name_generator.fresh_type_variable())
     }
 
+    fn add_equal_constraint(&mut self, lhs: &Type, rhs: &Type) -> TrivialResultion {
+        self.constraints
+            .add(Constraint::Equal(lhs.clone(), rhs.clone()))
+    }
+
+    fn add_dtype_constraint(&mut self, type_: &Type) -> TrivialResultion {
+        self.constraints.add(Constraint::IsDType(type_.clone()))
+    }
+
     fn type_from_annotation(&self, annotation: &TypeAnnotation) -> Result<Type> {
         match annotation {
             TypeAnnotation::TypeExpression(dexpr) => {
@@ -163,7 +172,7 @@ impl TypeChecker {
     }
 
     fn proper_function_call(
-        &self,
+        &mut self,
         span: &Span,
         full_span: &Span,
         function_name: &str,
@@ -320,13 +329,16 @@ impl TypeChecker {
                     }
                 }
                 (parameter_type, argument_type) => {
-                    if &argument_type != parameter_type {
-                        // return Err(TypeCheckError::IncompatibleTypesInFunctionCall(
-                        //     Some(*parameter_span),
-                        //     parameter_type.clone(),
-                        //     arguments[idx].full_span(),
-                        //     argument_type.clone(),
-                        // ));
+                    if self
+                        .add_equal_constraint(parameter_type, &argument_type)
+                        .is_trivially_violated()
+                    {
+                        return Err(TypeCheckError::IncompatibleTypesInFunctionCall(
+                            Some(*parameter_span),
+                            parameter_type.clone(),
+                            arguments[idx].full_span(),
+                            argument_type.clone(),
+                        ));
                     }
                 }
             }
@@ -431,27 +443,56 @@ impl TypeChecker {
             ast::Expression::UnaryOperator { op, expr, span_op } => {
                 let checked_expr = self.elaborate_expression(expr)?;
                 let type_ = checked_expr.get_type();
-                match (&type_, op) {
-                    (Type::Dimension(dtype), ast::UnaryOperator::Factorial) => {
-                        if !dtype.is_scalar() {
+                // match (&type_, op) {
+                //     (Type::Dimension(dtype), ast::UnaryOperator::Factorial) => {
+                //         if !dtype.is_scalar() {
+                //             return Err(TypeCheckError::NonScalarFactorialArgument(
+                //                 expr.full_span(),
+                //                 dtype.clone(),
+                //             ));
+                //         }
+                //     }
+                //     (Type::Dimension(_), ast::UnaryOperator::Negate) => (),
+                //     (Type::Boolean, ast::UnaryOperator::LogicalNeg) => (),
+                //     (_, ast::UnaryOperator::LogicalNeg) => {
+                //         return Err(TypeCheckError::ExpectedBool(expr.full_span()))
+                //     }
+                //     _ => {
+                //         // return Err(TypeCheckError::ExpectedDimensionType(
+                //         //     checked_expr.full_span(),
+                //         //     type_.clone(),
+                //         // ));
+                //     }
+                // };
+                match op {
+                    ast::UnaryOperator::Factorial => {
+                        if self
+                            .add_equal_constraint(&type_, &Type::scalar())
+                            .is_trivially_violated()
+                        {
                             return Err(TypeCheckError::NonScalarFactorialArgument(
                                 expr.full_span(),
-                                dtype.clone(),
+                                type_,
                             ));
                         }
                     }
-                    (Type::Dimension(_), ast::UnaryOperator::Negate) => (),
-                    (Type::Boolean, ast::UnaryOperator::LogicalNeg) => (),
-                    (_, ast::UnaryOperator::LogicalNeg) => {
-                        return Err(TypeCheckError::ExpectedBool(expr.full_span()))
+                    ast::UnaryOperator::Negate => {
+                        if self.add_dtype_constraint(&type_).is_trivially_violated() {
+                            return Err(TypeCheckError::ExpectedDimensionType(
+                                expr.full_span(),
+                                type_,
+                            ));
+                        }
                     }
-                    _ => {
-                        // return Err(TypeCheckError::ExpectedDimensionType(
-                        //     checked_expr.full_span(),
-                        //     type_.clone(),
-                        // ));
+                    ast::UnaryOperator::LogicalNeg => {
+                        if self
+                            .add_equal_constraint(&type_, &Type::Boolean)
+                            .is_trivially_violated()
+                        {
+                            return Err(TypeCheckError::ExpectedBool(expr.full_span()));
+                        }
                     }
-                };
+                }
 
                 typed_ast::Expression::UnaryOperator(*span_op, *op, Box::new(checked_expr), type_)
             }
@@ -483,13 +524,16 @@ impl TypeChecker {
                         });
                     }
 
-                    if parameter_types[0] != lhs_type {
-                        // return Err(TypeCheckError::IncompatibleTypesInFunctionCall(
-                        //     None,
-                        //     parameter_types[0].clone(),
-                        //     lhs.full_span(),
-                        //     lhs_type,
-                        // ));
+                    if self
+                        .add_equal_constraint(&lhs_type, &parameter_types[0])
+                        .is_trivially_violated()
+                    {
+                        return Err(TypeCheckError::IncompatibleTypesInFunctionCall(
+                            None,
+                            parameter_types[0].clone(),
+                            lhs.full_span(),
+                            lhs_type,
+                        ));
                     }
 
                     typed_ast::Expression::CallableCall(
@@ -559,9 +603,8 @@ impl TypeChecker {
                         let rhs_type = rhs_checked.get_type();
 
                         if self
-                            .constraints
-                            .add(Constraint::Equal(lhs_type.clone(), rhs_type))
-                            .is_violated()
+                            .add_equal_constraint(&lhs_type, &rhs_type)
+                            .is_trivially_violated()
                         {
                             let lhs_dtype = dtype(&lhs_checked)?;
                             let rhs_dtype = dtype(&rhs_checked)?;
@@ -629,7 +672,7 @@ impl TypeChecker {
                             if !exponent_type.is_scalar() {
                                 return Err(TypeCheckError::NonScalarExponent(
                                     rhs.full_span(),
-                                    exponent_type,
+                                    Type::Dimension(exponent_type), // TODO
                                 ));
                             }
 
@@ -706,11 +749,13 @@ impl TypeChecker {
                 // that evaluates to a function "pointer".
 
                 if let Some((name, signature)) = self.get_proper_function_reference(callable) {
+                    let name = name.clone(); // TODO: this is just a hack for now to get around a borrowing issue. not fixed properly since this will probably be removed anyways
+                    let signature = signature.clone(); // TODO: same
                     self.proper_function_call(
                         span,
                         full_span,
                         &name,
-                        signature,
+                        &signature,
                         arguments_checked,
                         argument_types,
                     )?
@@ -784,9 +829,8 @@ impl TypeChecker {
                 let condition = self.elaborate_expression(condition)?;
 
                 if self
-                    .constraints
-                    .add(Constraint::Equal(condition.get_type(), Type::Boolean))
-                    .is_violated()
+                    .add_equal_constraint(&condition.get_type(), &Type::Boolean)
+                    .is_trivially_violated()
                 {
                     return Err(TypeCheckError::ExpectedBool(condition.full_span()));
                 }
@@ -798,9 +842,8 @@ impl TypeChecker {
                 let else_type = else_.get_type();
 
                 if self
-                    .constraints
-                    .add(Constraint::Equal(then_type.clone(), else_type.clone()))
-                    .is_violated()
+                    .add_equal_constraint(&then_type, &else_type)
+                    .is_trivially_violated()
                 {
                     return Err(TypeCheckError::IncompatibleTypesInCondition(
                         *span,
@@ -936,23 +979,18 @@ impl TypeChecker {
 
                 if !element_types.is_empty() {
                     let type_of_first_element = element_types[0].clone();
-                    self.constraints
-                        .add(Constraint::Equal(
-                            result_element_type.clone(),
-                            type_of_first_element.clone(),
-                        ))
+                    self.add_equal_constraint(&result_element_type, &type_of_first_element)
                         .ok(); // This can never be satisfied trivially, so ignore the result
 
                     for (subsequent_element, type_of_subsequent_element) in
                         elements_checked.iter().zip(element_types.iter()).skip(1)
                     {
                         if self
-                            .constraints
-                            .add(Constraint::Equal(
-                                type_of_subsequent_element.clone(),
-                                type_of_first_element.clone(),
-                            ))
-                            .is_violated()
+                            .add_equal_constraint(
+                                &type_of_subsequent_element,
+                                &type_of_first_element,
+                            )
+                            .is_trivially_violated()
                         {
                             return Err(TypeCheckError::IncompatibleTypesInList(
                                 elements_checked[0].full_span(),
@@ -1018,9 +1056,8 @@ impl TypeChecker {
                         }
                         (deduced, annotated) => {
                             if self
-                                .constraints
-                                .add(Constraint::Equal(deduced.clone(), annotated.clone()))
-                                .is_violated()
+                                .add_equal_constraint(&deduced, &annotated)
+                                .is_trivially_violated()
                             {
                                 return Err(TypeCheckError::IncompatibleTypesInAnnotation(
                                     "definition".into(),
@@ -1322,16 +1359,20 @@ impl TypeChecker {
                                 Type::Dimension(dtype_deduced)
                             }
                             (type_deduced, type_specified) => {
-                                // if type_deduced != type_specified {
-                                //     return Err(TypeCheckError::IncompatibleTypesInAnnotation(
-                                //         "function definition".into(),
-                                //         *function_name_span,
-                                //         type_specified,
-                                //         return_type_annotation_span.unwrap(),
-                                //         type_deduced,
-                                //         body.as_ref().map(|b| b.full_span()).unwrap(),
-                                //     ));
-                                // }
+                                if self
+                                    .add_equal_constraint(&type_deduced, &type_specified)
+                                    .is_trivially_violated()
+                                {
+                                    return Err(TypeCheckError::IncompatibleTypesInAnnotation(
+                                        "function definition".into(),
+                                        *function_name_span,
+                                        type_specified,
+                                        return_type_annotation_span.unwrap(),
+                                        type_deduced,
+                                        body.as_ref().map(|b| b.full_span()).unwrap(),
+                                    ));
+                                }
+
                                 type_specified
                             }
                         }
@@ -1539,7 +1580,7 @@ impl TypeChecker {
         }
 
         // Solve constraints
-        let (substitution, dtype_variables) = self
+        let (substitution, _dtype_variables) = self
             .constraints
             .solve()
             .map_err(TypeCheckError::ConstraintSolverError)?;

+ 0 - 4
numbat/src/typechecker/name_generator.rs

@@ -6,10 +6,6 @@ pub struct NameGenerator {
 }
 
 impl NameGenerator {
-    pub fn new() -> NameGenerator {
-        NameGenerator { counter: 0 }
-    }
-
     pub fn fresh_type_variable(&mut self) -> TypeVariable {
         let name = format!("T{}", self.counter);
         self.counter += 1;

+ 8 - 8
numbat/src/typechecker/substitutions.rs

@@ -20,13 +20,13 @@ impl Substitution {
         self.0.iter().find(|(var, _)| var == v).map(|(_, t)| t)
     }
 
-    pub fn pretty_print(&self) -> String {
-        self.0
-            .iter()
-            .map(|(v, t)| format!("  {} := {}", v.name(), t))
-            .collect::<Vec<String>>()
-            .join("\n")
-    }
+    // pub fn pretty_print(&self) -> String {
+    //     self.0
+    //         .iter()
+    //         .map(|(v, t)| format!("  {} := {}", v.name(), t))
+    //         .collect::<Vec<String>>()
+    //         .join("\n")
+    // }
 
     pub fn extend(&mut self, other: Substitution) {
         for (_, t) in &mut self.0 {
@@ -77,7 +77,7 @@ impl ApplySubstitution for Type {
 }
 
 impl ApplySubstitution for DType {
-    fn apply(&mut self, s: &Substitution) -> Result<(), SubstitutionError> {
+    fn apply(&mut self, _s: &Substitution) -> Result<(), SubstitutionError> {
         // TODO
         Ok(())
     }

+ 88 - 88
numbat/src/typechecker/tests.rs

@@ -89,14 +89,14 @@ fn power_operator_with_scalar_base() {
     assert_successful_typecheck("2^2");
     assert_successful_typecheck("2^(2^2)");
 
-    // assert!(matches!(
-    //     get_typecheck_error("2^a"),
-    //     TypeCheckError::NonScalarExponent(_, t) if t == type_a()
-    // ));
-    // assert!(matches!(
-    //     get_typecheck_error("2^(c/b)"),
-    //     TypeCheckError::NonScalarExponent(_, t) if t == type_a()
-    // ));
+    assert!(matches!(
+        get_typecheck_error("2^a"),
+        TypeCheckError::NonScalarExponent(_, t) if t == Type::Dimension(type_a())
+    ));
+    assert!(matches!(
+        get_typecheck_error("2^(c/b)"),
+        TypeCheckError::NonScalarExponent(_, t) if t == Type::Dimension(type_a())
+    ));
 }
 
 #[test]
@@ -110,7 +110,7 @@ fn power_operator_with_dimensionful_base() {
 
     assert!(matches!(
         get_typecheck_error("a^b"),
-        TypeCheckError::NonScalarExponent(_, t) if t == type_b()
+        TypeCheckError::NonScalarExponent(_, t) if t == Type::Dimension(type_b())
     ));
 
     // TODO: if we add ("constexpr") constants later, it would be great to support those in exponents.
@@ -152,26 +152,26 @@ fn variable_definitions() {
     assert_successful_typecheck("let x: Bool = true");
     assert_successful_typecheck("let x: String = \"hello\"");
 
-    // assert!(matches!(
-    //     get_typecheck_error("let x: A = b"),
-    //     TypeCheckError::IncompatibleDimensions(IncompatibleDimensionsError {expected_type, actual_type, ..}) if expected_type == type_a() && actual_type == type_b()
-    // ));
-    // assert!(matches!(
-    //     get_typecheck_error("let x: A = true"),
-    //     TypeCheckError::IncompatibleTypesInAnnotation(_, _, annotated_type, _, actual_type, _) if annotated_type == Type::Dimension(type_a()) && actual_type == Type::Boolean
-    // ));
-    // assert!(matches!(
-    //     get_typecheck_error("let x: A = \"foo\""),
-    //     TypeCheckError::IncompatibleTypesInAnnotation(_, _, annotated_type, _, actual_type, _) if annotated_type == Type::Dimension(type_a()) && actual_type == Type::String
-    // ));
-    // assert!(matches!(
-    //     get_typecheck_error("let x: Bool = a"),
-    //     TypeCheckError::IncompatibleTypesInAnnotation(_, _, annotated_type, _, actual_type, _) if annotated_type == Type::Boolean && actual_type == Type::Dimension(type_a())
-    // ));
-    // assert!(matches!(
-    //     get_typecheck_error("let x: String = true"),
-    //     TypeCheckError::IncompatibleTypesInAnnotation(_, _, annotated_type, _, actual_type, _) if annotated_type == Type::String && actual_type == Type::Boolean
-    // ));
+    assert!(matches!(
+        get_typecheck_error("let x: A = b"),
+        TypeCheckError::IncompatibleDimensions(IncompatibleDimensionsError {expected_type, actual_type, ..}) if expected_type == type_a() && actual_type == type_b()
+    ));
+    assert!(matches!(
+        get_typecheck_error("let x: A = true"),
+        TypeCheckError::IncompatibleTypesInAnnotation(_, _, annotated_type, _, actual_type, _) if annotated_type == Type::Dimension(type_a()) && actual_type == Type::Boolean
+    ));
+    assert!(matches!(
+        get_typecheck_error("let x: A = \"foo\""),
+        TypeCheckError::IncompatibleTypesInAnnotation(_, _, annotated_type, _, actual_type, _) if annotated_type == Type::Dimension(type_a()) && actual_type == Type::String
+    ));
+    assert!(matches!(
+        get_typecheck_error("let x: Bool = a"),
+        TypeCheckError::IncompatibleTypesInAnnotation(_, _, annotated_type, _, actual_type, _) if annotated_type == Type::Boolean && actual_type == Type::Dimension(type_a())
+    ));
+    assert!(matches!(
+        get_typecheck_error("let x: String = true"),
+        TypeCheckError::IncompatibleTypesInAnnotation(_, _, annotated_type, _, actual_type, _) if annotated_type == Type::String && actual_type == Type::Boolean
+    ));
 }
 
 #[test]
@@ -179,10 +179,10 @@ fn unit_definitions() {
     assert_successful_typecheck("unit my_c: C = a * b");
     assert_successful_typecheck("unit foo: A*B^2 = a b^2");
 
-    // assert!(matches!(
-    //     get_typecheck_error("unit my_c: C = a"),
-    //     TypeCheckError::IncompatibleDimensions(IncompatibleDimensionsError {expected_type, actual_type, ..}) if expected_type == type_c() && actual_type == type_a()
-    // ));
+    assert!(matches!(
+        get_typecheck_error("unit my_c: C = a"),
+        TypeCheckError::IncompatibleDimensions(IncompatibleDimensionsError {expected_type, actual_type, ..}) if expected_type == type_c() && actual_type == type_a()
+    ));
 }
 
 #[test]
@@ -193,16 +193,16 @@ fn function_definitions() {
 
     assert_successful_typecheck("fn f(x: A) = x");
 
-    // assert!(matches!(
-    //     get_typecheck_error("fn f(x: A, y: B) -> C = x / y"),
-    //     TypeCheckError::IncompatibleDimensions(IncompatibleDimensionsError {expected_type, actual_type, ..}) if expected_type == type_c() && actual_type == type_a() / type_b()
-    // ));
+    assert!(matches!(
+        get_typecheck_error("fn f(x: A, y: B) -> C = x / y"),
+        TypeCheckError::IncompatibleDimensions(IncompatibleDimensionsError {expected_type, actual_type, ..}) if expected_type == type_c() && actual_type == type_a() / type_b()
+    ));
 
-    // assert!(matches!(
-    //     get_typecheck_error("fn f(x: A) -> A = a\n\
-    //                          f(b)"),
-    //     TypeCheckError::IncompatibleDimensions(IncompatibleDimensionsError {expected_type, actual_type, ..}) if expected_type == type_a() && actual_type == type_b()
-    // ));
+    assert!(matches!(
+        get_typecheck_error("fn f(x: A) -> A = a\n\
+                             f(b)"),
+        TypeCheckError::IncompatibleDimensions(IncompatibleDimensionsError {expected_type, actual_type, ..}) if expected_type == type_a() && actual_type == type_b()
+    ));
 }
 
 #[test]
@@ -212,10 +212,10 @@ fn recursive_functions() {
         "fn factorial(n: Scalar) -> Scalar = if n < 0 then 1 else factorial(n - 1) * n",
     );
 
-    // assert!(matches!(
-    //     get_typecheck_error("fn f(x: Scalar) -> A = if x < 0 then f(-x) else 2 b"),
-    //     TypeCheckError::IncompatibleTypesInCondition(_, lhs, _, rhs, _) if lhs == Type::Dimension(type_a()) && rhs == Type::Dimension(type_b())
-    // ));
+    assert!(matches!(
+        get_typecheck_error("fn f(x: Scalar) -> A = if x < 0 then f(-x) else 2 b"),
+        TypeCheckError::IncompatibleTypesInCondition(_, lhs, _, rhs, _) if lhs == Type::Dimension(type_a()) && rhs == Type::Dimension(type_b())
+    ));
 }
 
 #[test]
@@ -242,12 +242,12 @@ fn generics_basic() {
             ",
     );
 
-    // assert!(matches!(
-    //     get_typecheck_error("fn f<T1, T2>(x: T1, y: T2) -> T2/T1 = x/y"),
-    //     TypeCheckError::IncompatibleDimensions(IncompatibleDimensionsError {expected_type, actual_type, ..})
-    //         if expected_type == base_type("T2") / base_type("T1") &&
-    //         actual_type == base_type("T1") / base_type("T2")
-    // ));
+    assert!(matches!(
+        get_typecheck_error("fn f<T1, T2>(x: T1, y: T2) -> T2/T1 = x/y"),
+        TypeCheckError::IncompatibleDimensions(IncompatibleDimensionsError {expected_type, actual_type, ..})
+            if expected_type == base_type("T2") / base_type("T1") &&
+            actual_type == base_type("T1") / base_type("T2")
+    ));
 }
 
 // #[test]
@@ -427,50 +427,50 @@ fn conditionals() {
     assert_successful_typecheck("if true then 1 else 2");
     assert_successful_typecheck("if true then true else false");
 
-    // assert!(matches!(
-    //     get_typecheck_error("if 1 then 2 else 3"),
-    //     TypeCheckError::ExpectedBool(_)
-    // ));
+    assert!(matches!(
+        get_typecheck_error("if 1 then 2 else 3"),
+        TypeCheckError::ExpectedBool(_)
+    ));
 
-    // assert!(matches!(
-    //     get_typecheck_error("if true then a else b"),
-    //     TypeCheckError::IncompatibleTypesInCondition(_, t1, _, t2, _) if t1 == Type::Dimension(base_type("A")) && t2 == Type::Dimension(base_type("B"))
-    // ));
+    assert!(matches!(
+        get_typecheck_error("if true then a else b"),
+        TypeCheckError::IncompatibleTypesInCondition(_, t1, _, t2, _) if t1 == Type::Dimension(base_type("A")) && t2 == Type::Dimension(base_type("B"))
+    ));
 
-    // assert!(matches!(
-    //     get_typecheck_error("if true then true else a"),
-    //     TypeCheckError::IncompatibleTypesInCondition(_, t1, _, t2, _) if t1 == Type::Boolean && t2 == Type::Dimension(base_type("A"))
-    // ));
+    assert!(matches!(
+        get_typecheck_error("if true then true else a"),
+        TypeCheckError::IncompatibleTypesInCondition(_, t1, _, t2, _) if t1 == Type::Boolean && t2 == Type::Dimension(base_type("A"))
+    ));
 }
 
 #[test]
 fn non_dtype_return_types() {
-    // assert!(matches!(
-    //     get_typecheck_error("fn f() -> String = 1"),
-    //     TypeCheckError::IncompatibleTypesInAnnotation(..)
-    // ));
-    // assert!(matches!(
-    //     get_typecheck_error("fn f() -> Scalar = \"test\""),
-    //     TypeCheckError::IncompatibleTypesInAnnotation(..)
-    // ));
+    assert!(matches!(
+        get_typecheck_error("fn f() -> String = 1"),
+        TypeCheckError::IncompatibleTypesInAnnotation(..)
+    ));
+    assert!(matches!(
+        get_typecheck_error("fn f() -> Scalar = \"test\""),
+        TypeCheckError::IncompatibleTypesInAnnotation(..)
+    ));
 
-    // assert!(matches!(
-    //     get_typecheck_error("fn f() -> Bool = 1"),
-    //     TypeCheckError::IncompatibleTypesInAnnotation(..)
-    // ));
-    // assert!(matches!(
-    //     get_typecheck_error("fn f() -> Scalar = true"),
-    //     TypeCheckError::IncompatibleTypesInAnnotation(..)
-    // ));
+    assert!(matches!(
+        get_typecheck_error("fn f() -> Bool = 1"),
+        TypeCheckError::IncompatibleTypesInAnnotation(..)
+    ));
+    assert!(matches!(
+        get_typecheck_error("fn f() -> Scalar = true"),
+        TypeCheckError::IncompatibleTypesInAnnotation(..)
+    ));
 
-    // assert!(matches!(
-    //     get_typecheck_error("fn f() -> String = true"),
-    //     TypeCheckError::IncompatibleTypesInAnnotation(..)
-    // ));
-    // assert!(matches!(
-    //     get_typecheck_error("fn f() -> Bool = \"test\""),
-    //     TypeCheckError::IncompatibleTypesInAnnotation(..)
-    // ));
+    assert!(matches!(
+        get_typecheck_error("fn f() -> String = true"),
+        TypeCheckError::IncompatibleTypesInAnnotation(..)
+    ));
+    assert!(matches!(
+        get_typecheck_error("fn f() -> Bool = \"test\""),
+        TypeCheckError::IncompatibleTypesInAnnotation(..)
+    ));
 }
 
 #[test]

+ 2 - 2
numbat/src/typed_ast.rs

@@ -181,11 +181,11 @@ impl Type {
         }
     }
 
-    pub(crate) fn instantiate(&self, type_variables: &[TypeVariable]) -> Type {
+    pub(crate) fn instantiate(&self, _type_variables: &[TypeVariable]) -> Type {
         todo!()
     }
 
-    pub(crate) fn contains(&self, x: &TypeVariable) -> bool {
+    pub(crate) fn contains(&self, _x: &TypeVariable) -> bool {
         false // TODO!
     }
 }

+ 86 - 86
numbat/tests/interpreter.rs

@@ -283,108 +283,108 @@ fn test_math() {
     )
 }
 
-// #[test]
-// fn test_incompatible_dimension_errors() {
-//     assert_snapshot!(
-//         get_error_message("kg m / s^2 + kg m^2"),
-//         @r###"
-//      left hand side: Length  × Mass × Time⁻²    [= Force]
-//     right hand side: Length² × Mass             [= MomentOfInertia]
-//     "###
-//     );
+#[test]
+fn test_incompatible_dimension_errors() {
+    assert_snapshot!(
+        get_error_message("kg m / s^2 + kg m^2"),
+        @r###"
+     left hand side: Length  × Mass × Time⁻²    [= Force]
+    right hand side: Length² × Mass             [= MomentOfInertia]
+    "###
+    );
 
-//     assert_snapshot!(
-//         get_error_message("1 + m"),
-//         @r###"
-//      left hand side: Scalar    [= Angle, Scalar, SolidAngle]
-//     right hand side: Length
+    assert_snapshot!(
+        get_error_message("1 + m"),
+        @r###"
+     left hand side: Scalar    [= Angle, Scalar, SolidAngle]
+    right hand side: Length
 
-//     Suggested fix: divide the expression on the right hand side by a `Length` factor
-//     "###
-//     );
+    Suggested fix: divide the expression on the right hand side by a `Length` factor
+    "###
+    );
 
-//     assert_snapshot!(
-//         get_error_message("m / s + K A"),
-//         @r###"
-//      left hand side: Length / Time            [= Velocity]
-//     right hand side: Current × Temperature
-//     "###
-//     );
+    assert_snapshot!(
+        get_error_message("m / s + K A"),
+        @r###"
+     left hand side: Length / Time            [= Velocity]
+    right hand side: Current × Temperature
+    "###
+    );
 
-//     assert_snapshot!(
-//         get_error_message("m + 1 / m"),
-//         @r###"
-//      left hand side: Length
-//     right hand side: Length⁻¹    [= Wavenumber]
+    assert_snapshot!(
+        get_error_message("m + 1 / m"),
+        @r###"
+     left hand side: Length
+    right hand side: Length⁻¹    [= Wavenumber]
 
-//     Suggested fix: invert the expression on the right hand side
-//     "###
-//     );
+    Suggested fix: invert the expression on the right hand side
+    "###
+    );
 
-//     assert_snapshot!(
-//         get_error_message("kW -> J"),
-//         @r###"
-//      left hand side: Length² × Mass × Time⁻³    [= Power]
-//     right hand side: Length² × Mass × Time⁻²    [= Energy, Torque]
+    assert_snapshot!(
+        get_error_message("kW -> J"),
+        @r###"
+     left hand side: Length² × Mass × Time⁻³    [= Power]
+    right hand side: Length² × Mass × Time⁻²    [= Energy, Torque]
 
-//     Suggested fix: divide the expression on the right hand side by a `Time` factor
-//     "###
-//     );
+    Suggested fix: divide the expression on the right hand side by a `Time` factor
+    "###
+    );
 
-//     assert_snapshot!(
-//         get_error_message("sin(1 meter)"),
-//         @r###"
-//     parameter type: Scalar    [= Angle, Scalar, SolidAngle]
-//      argument type: Length
+    assert_snapshot!(
+        get_error_message("sin(1 meter)"),
+        @r###"
+    parameter type: Scalar    [= Angle, Scalar, SolidAngle]
+     argument type: Length
 
-//     Suggested fix: divide the function argument by a `Length` factor
-//     "###
-//     );
+    Suggested fix: divide the function argument by a `Length` factor
+    "###
+    );
 
-//     assert_snapshot!(
-//         get_error_message("let x: Acceleration = 4 m / s"),
-//         @r###"
-//     specified dimension: Length × Time⁻²    [= Acceleration]
-//        actual dimension: Length × Time⁻¹    [= Velocity]
+    assert_snapshot!(
+        get_error_message("let x: Acceleration = 4 m / s"),
+        @r###"
+    specified dimension: Length × Time⁻²    [= Acceleration]
+       actual dimension: Length × Time⁻¹    [= Velocity]
 
-//     Suggested fix: divide the right hand side expression by a `Time` factor
-//     "###
-//     );
+    Suggested fix: divide the right hand side expression by a `Time` factor
+    "###
+    );
 
-//     assert_snapshot!(
-//         get_error_message("unit x: Acceleration = 4 m / s"),
-//         @r###"
-//     specified dimension: Length × Time⁻²    [= Acceleration]
-//        actual dimension: Length × Time⁻¹    [= Velocity]
+    assert_snapshot!(
+        get_error_message("unit x: Acceleration = 4 m / s"),
+        @r###"
+    specified dimension: Length × Time⁻²    [= Acceleration]
+       actual dimension: Length × Time⁻¹    [= Velocity]
 
-//     Suggested fix: divide the right hand side expression by a `Time` factor
-//     "###
-//     );
+    Suggested fix: divide the right hand side expression by a `Time` factor
+    "###
+    );
 
-//     assert_snapshot!(
-//         get_error_message("fn acceleration(length: Length, time: Time) -> Acceleration = length / time"),
-//         @r###"
-//     specified return type: Length × Time⁻²    [= Acceleration]
-//        actual return type: Length × Time⁻¹    [= Velocity]
+    assert_snapshot!(
+        get_error_message("fn acceleration(length: Length, time: Time) -> Acceleration = length / time"),
+        @r###"
+    specified return type: Length × Time⁻²    [= Acceleration]
+       actual return type: Length × Time⁻¹    [= Velocity]
 
-//     Suggested fix: divide the expression in the function body by a `Time` factor
-//     "###
-//     );
-// }
+    Suggested fix: divide the expression in the function body by a `Time` factor
+    "###
+    );
+}
 
-// #[test]
-// fn test_temperature_conversions() {
-//     expect_output("from_celsius(11.5)", "284.65 K");
-//     expect_output("from_fahrenheit(89.3)", "304.983 K");
-//     expect_output("0 K -> celsius", "-273.15");
-//     expect_output("fahrenheit(30 K)", "-405.67");
-//     expect_output("from_celsius(100) -> celsius", "100");
-//     expect_output("from_fahrenheit(100) -> fahrenheit", "100.0");
-//     expect_output("from_celsius(123 K -> celsius)", "123 K");
-//     expect_output("from_fahrenheit(123 K -> fahrenheit)", "123 K");
-
-//     expect_output("-40 -> from_fahrenheit -> celsius", "-40");
-// }
+#[test]
+fn test_temperature_conversions() {
+    expect_output("from_celsius(11.5)", "284.65 K");
+    expect_output("from_fahrenheit(89.3)", "304.983 K");
+    // expect_output("0 K -> celsius", "-273.15");
+    expect_output("fahrenheit(30 K)", "-405.67");
+    expect_output("from_celsius(100) -> celsius", "100");
+    expect_output("from_fahrenheit(100) -> fahrenheit", "100.0");
+    expect_output("from_celsius(123 K -> celsius)", "123 K");
+    expect_output("from_fahrenheit(123 K -> fahrenheit)", "123 K");
+
+    expect_output("-40 -> from_fahrenheit -> celsius", "-40");
+}
 
 #[test]
 fn test_other_functions() {