Browse Source

Better wrong-arity errors

David Peter 2 years ago
parent
commit
1cf4e86e82

+ 1 - 0
examples/typecheck_error/wrong_arity_fn.nbt

@@ -0,0 +1 @@
+sin(2, 3)

+ 1 - 0
examples/typecheck_error/wrong_arity_procedure.nbt

@@ -0,0 +1 @@
+print(2 meter, 3 second)

+ 4 - 3
numbat/src/ast.rs

@@ -369,7 +369,7 @@ pub enum Statement {
         type_annotation: Option<DimensionExpression>,
         decorators: Vec<Decorator>,
     },
-    ProcedureCall(ProcedureKind, Vec<Expression>),
+    ProcedureCall(Span, ProcedureKind, Vec<Expression>),
     ModuleImport(Span, ModulePath),
 }
 
@@ -554,7 +554,7 @@ impl PrettyPrint for Statement {
                     + m::space()
                     + expr.pretty_print()
             }
-            Statement::ProcedureCall(kind, args) => {
+            Statement::ProcedureCall(_, kind, args) => {
                 let identifier = match kind {
                     ProcedureKind::Print => "print",
                     ProcedureKind::AssertEq => "assert_eq",
@@ -678,7 +678,8 @@ impl ReplaceSpans for Statement {
                 type_annotation: type_annotation.clone(),
                 decorators: decorators.clone(),
             },
-            Statement::ProcedureCall(proc, args) => Statement::ProcedureCall(
+            Statement::ProcedureCall(_, proc, args) => Statement::ProcedureCall(
+                Span::dummy(),
                 proc.clone(),
                 args.iter().map(|a| a.replace_spans()).collect(),
             ),

+ 4 - 1
numbat/src/diagnostic.rs

@@ -103,10 +103,13 @@ impl ErrorDiagnostic for TypeCheckError {
                 d.with_notes(vec![format!("{self:#}")])
             }
             TypeCheckError::WrongArity {
+                callable_span,
                 callable_name: _,
                 arity: _,
                 num_args: _,
-            } => d.with_notes(vec![format!("{self:#}")]),
+            } => d.with_labels(vec![callable_span
+                .diagnostic_label(LabelStyle::Primary)
+                .with_message(format!("{self}"))]),
             TypeCheckError::TypeParameterNameClash(_) => d.with_notes(vec![format!("{self:#}")]),
             TypeCheckError::CanNotInferTypeParameters(_, _) => {
                 d.with_notes(vec![format!("{self:#}")])

+ 8 - 2
numbat/src/parser.rs

@@ -503,6 +503,7 @@ impl<'a> Parser<'a> {
             .match_any(&[TokenKind::ProcedurePrint, TokenKind::ProcedureAssertEq])
             .is_some()
         {
+            let span = self.last().unwrap().span;
             let procedure_kind = match self.last().unwrap().kind {
                 TokenKind::ProcedurePrint => ProcedureKind::Print,
                 TokenKind::ProcedureAssertEq => ProcedureKind::AssertEq,
@@ -515,7 +516,11 @@ impl<'a> Parser<'a> {
                     span: self.peek().span,
                 })
             } else {
-                Ok(Statement::ProcedureCall(procedure_kind, self.arguments()?))
+                Ok(Statement::ProcedureCall(
+                    span,
+                    procedure_kind,
+                    self.arguments()?,
+                ))
             }
         } else {
             Ok(Statement::Expression(self.expression()?))
@@ -1525,12 +1530,13 @@ mod tests {
     fn procedure_call() {
         parse_as(
             &["print(2)"],
-            Statement::ProcedureCall(ProcedureKind::Print, vec![scalar!(2.0)]),
+            Statement::ProcedureCall(Span::dummy(), ProcedureKind::Print, vec![scalar!(2.0)]),
         );
 
         parse_as(
             &["print(2, 3, 4)"],
             Statement::ProcedureCall(
+                Span::dummy(),
                 ProcedureKind::Print,
                 vec![scalar!(2.0), scalar!(3.0), scalar!(4.0)],
             ),

+ 2 - 1
numbat/src/prefix_transformer.rs

@@ -169,7 +169,8 @@ impl Transformer {
                 self.dimension_names.push(name.clone());
                 Statement::DeclareDimension(name, dexprs)
             }
-            Statement::ProcedureCall(procedure, args) => Statement::ProcedureCall(
+            Statement::ProcedureCall(span, procedure, args) => Statement::ProcedureCall(
+                span,
                 procedure,
                 args.into_iter()
                     .map(|arg| self.transform_expression(arg))

+ 10 - 7
numbat/src/typechecker.rs

@@ -50,6 +50,7 @@ pub enum TypeCheckError {
     #[error("Function or procedure '{callable_name}' called with {num_args} arguments(s), but needs {}..{}", arity.start(), arity.end())]
     // TODO: better formatting of the arity range (e.g. in case it includes just one number)
     WrongArity {
+        callable_span: Span,
         callable_name: String,
         arity: ArityRange,
         num_args: usize,
@@ -255,6 +256,7 @@ impl TypeChecker {
 
                 if !arity_range.contains(&args.len()) {
                     return Err(TypeCheckError::WrongArity {
+                        callable_span: span,
                         callable_name: function_name.clone(),
                         arity: arity_range,
                         num_args: args.len(),
@@ -596,10 +598,11 @@ impl TypeChecker {
                 }
                 typed_ast::Statement::DeclareDimension(name)
             }
-            ast::Statement::ProcedureCall(kind, args) => {
+            ast::Statement::ProcedureCall(span, kind, args) => {
                 let procedure = ffi::procedures().get(&kind).unwrap();
                 if !procedure.arity.contains(&args.len()) {
                     return Err(TypeCheckError::WrongArity {
+                        callable_span: span,
                         callable_name: procedure.name.clone(),
                         arity: procedure.arity.clone(),
                         num_args: args.len(),
@@ -916,7 +919,7 @@ mod tests {
                 fn f() = 1
                 f(1)
             "),
-            TypeCheckError::WrongArity{callable_name, arity, num_args: 1} if arity == (0..=0) && callable_name == "f"
+            TypeCheckError::WrongArity{callable_span:_, callable_name, arity, num_args: 1} if arity == (0..=0) && callable_name == "f"
         ));
 
         assert!(matches!(
@@ -924,7 +927,7 @@ mod tests {
                 fn f(x: Scalar) = x
                 f()
             "),
-            TypeCheckError::WrongArity{callable_name, arity, num_args: 0} if arity == (1..=1) && callable_name == "f"
+            TypeCheckError::WrongArity{callable_span:_, callable_name, arity, num_args: 0} if arity == (1..=1) && callable_name == "f"
         ));
 
         assert!(matches!(
@@ -932,7 +935,7 @@ mod tests {
                 fn f(x: Scalar) = x
                 f(2, 3)
             "),
-            TypeCheckError::WrongArity{callable_name, arity, num_args: 2} if arity == (1..=1) && callable_name == "f"
+            TypeCheckError::WrongArity{callable_span:_, callable_name, arity, num_args: 2} if arity == (1..=1) && callable_name == "f"
         ));
 
         assert!(matches!(
@@ -940,7 +943,7 @@ mod tests {
                 fn mean<D>(xs: D…) -> D
                 mean()
             "),
-            TypeCheckError::WrongArity{callable_name, arity, num_args: 0} if arity == (1..=usize::MAX) && callable_name == "mean"
+            TypeCheckError::WrongArity{callable_span:_, callable_name, arity, num_args: 0} if arity == (1..=usize::MAX) && callable_name == "mean"
         ));
     }
 
@@ -977,13 +980,13 @@ mod tests {
     fn arity_checks_in_procedure_calls() {
         assert!(matches!(
             get_typecheck_error("assert_eq(1)"),
-            TypeCheckError::WrongArity{callable_name, arity, num_args: 1} if arity == (2..=3) && callable_name == "assert_eq"
+            TypeCheckError::WrongArity{callable_span:_, callable_name, arity, num_args: 1} if arity == (2..=3) && callable_name == "assert_eq"
         ));
         assert_successful_typecheck("assert_eq(1,2)");
         assert_successful_typecheck("assert_eq(1,2,3)");
         assert!(matches!(
             get_typecheck_error("assert_eq(1,2,3,4)"),
-            TypeCheckError::WrongArity{callable_name, arity, num_args: 4} if arity == (2..=3) && callable_name == "assert_eq"
+            TypeCheckError::WrongArity{callable_span:_, callable_name, arity, num_args: 4} if arity == (2..=3) && callable_name == "assert_eq"
         ));
     }
 }