Browse Source

Better type checking for assert_eq

David Peter 2 years ago
parent
commit
25001b5ed1

+ 0 - 0
examples/runtime_error/assert_eq_failure2.nbt → examples/runtime_error/assert_eq_1.nbt


+ 0 - 0
examples/runtime_error/assert_eq_failure3.nbt → examples/runtime_error/assert_eq_2.nbt


+ 0 - 0
examples/runtime_error/assert_eq_failure1.nbt → examples/typecheck_error/assert_eq_1.nbt


+ 2 - 0
examples/typecheck_error/assert_eq_2.nbt

@@ -0,0 +1,2 @@
+
+assert_eq(2 meter, 2.1 meter, 0.2)

+ 1 - 0
examples/typecheck_error/assert_eq_3.nbt

@@ -0,0 +1 @@
+assert_eq(1 second, 1 meter, 0.01 meter)

+ 17 - 0
numbat/src/diagnostic.rs

@@ -226,6 +226,23 @@ impl ErrorDiagnostic for TypeCheckError {
                     "Incompatible types in 'then' and 'else' branches of conditional",
                 ),
             ]),
+            TypeCheckError::IncompatibleTypesInAssertEq(
+                procedure_span,
+                first_type,
+                first_span,
+                arg_type,
+                arg_span,
+            ) => d.with_labels(vec![
+                first_span
+                    .diagnostic_label(LabelStyle::Secondary)
+                    .with_message(first_type.to_string()),
+                arg_span
+                    .diagnostic_label(LabelStyle::Secondary)
+                    .with_message(arg_type.to_string()),
+                procedure_span
+                    .diagnostic_label(LabelStyle::Primary)
+                    .with_message("Incompatible types in 'assert_eq' call"),
+            ]),
             TypeCheckError::ForeignFunctionNeedsTypeAnnotations(span, _)
             | TypeCheckError::UnknownForeignFunction(span, _)
             | TypeCheckError::NonRationalExponent(span)

+ 27 - 0
numbat/src/typechecker.rs

@@ -238,6 +238,9 @@ pub enum TypeCheckError {
 
     #[error("Incompatible types in condition")]
     IncompatibleTypesInCondition(Span, Type, Span, Type, Span),
+
+    #[error("Argument types in assert_eq calls must match")]
+    IncompatibleTypesInAssertEq(Span, Type, Span, Type, Span),
 }
 
 type Result<T> = std::result::Result<T, TypeCheckError>;
@@ -1052,6 +1055,30 @@ impl TypeChecker {
                     .map(|e| self.check_expression(e))
                     .collect::<Result<Vec<_>>>()?;
 
+                match kind {
+                    ProcedureKind::Print => {
+                        // no argument type checks required, everything can be printed
+                    }
+                    ProcedureKind::AssertEq => {
+                        let type_first = dtype(&checked_args[0])?;
+                        for arg in &checked_args[1..] {
+                            let type_arg = dtype(&arg)?;
+                            if type_arg != type_first {
+                                return Err(TypeCheckError::IncompatibleTypesInAssertEq(
+                                    *span,
+                                    checked_args[0].get_type(),
+                                    checked_args[0].full_span(),
+                                    arg.get_type(),
+                                    arg.full_span(),
+                                ));
+                            }
+                        }
+                    }
+                    ProcedureKind::Type => {
+                        unreachable!("type() calls have a special handling above")
+                    }
+                }
+
                 typed_ast::Statement::ProcedureCall(kind.clone(), checked_args)
             }
             ast::Statement::ModuleImport(_, _) => {