Browse Source

Better wrong-arity errors

David Peter 2 years ago
parent
commit
27f4fee657

+ 2 - 3
examples/typecheck_error/incompatible_types_in_function_argument.nbt

@@ -1,4 +1,3 @@
-fn speed(s: Length, t: Time) -> Speed = s/t
+fn speed(distance: Length, duration: Time) -> Speed = distance / duration
 
-speed(3 meter, 2 second) # okay
-speed(3 meter, 2 meter)  # type check error
+speed(3 meter, 2 meter)

+ 1 - 0
examples/typecheck_error/unsupported_const_eval_expr_conversion.nbt

@@ -0,0 +1 @@
+meter^(3 -> 4)

+ 1 - 0
examples/typecheck_error/unsupported_const_eval_expr_function_call.nbt

@@ -0,0 +1 @@
+meter^sqrt(4)

+ 1 - 0
examples/typecheck_error/unsupported_const_eval_expr_non_integer_exponent.nbt

@@ -0,0 +1 @@
+meter^(2^1.5)

+ 1 - 0
examples/typecheck_error/unsupported_const_eval_expr_unit.nbt

@@ -0,0 +1 @@
+meter^(second/second)

+ 3 - 0
examples/typecheck_error/unsupported_const_eval_expr_variable.nbt

@@ -0,0 +1,3 @@
+let x = 4
+
+meter^x

+ 10 - 10
numbat/src/bytecode_interpreter.rs

@@ -24,7 +24,7 @@ impl BytecodeInterpreter {
                 let index = self.vm.add_constant(Constant::Scalar(n.to_f64()));
                 self.vm.add_op1(Op::LoadConstant, index);
             }
-            Expression::Identifier(identifier, _type) => {
+            Expression::Identifier(_span, identifier, _type) => {
                 if let Some(position) = self.local_variables.iter().position(|n| n == identifier) {
                     self.vm.add_op1(Op::GetLocal, position as u16); // TODO: check overflow
                 } else {
@@ -32,7 +32,7 @@ impl BytecodeInterpreter {
                     self.vm.add_op1(Op::GetVariable, identifier_idx);
                 }
             }
-            Expression::UnitIdentifier(prefix, unit_name, _full_name, _type) => {
+            Expression::UnitIdentifier(_span, prefix, unit_name, _full_name, _type) => {
                 if let Some(index) = self
                     .unit_name_to_constant_index
                     .get(&(*prefix, unit_name.clone()))
@@ -55,7 +55,7 @@ impl BytecodeInterpreter {
                     }
                 }
             }
-            Expression::Negate(rhs, _type) => {
+            Expression::Negate(_span, rhs, _type) => {
                 self.compile_expression(rhs)?;
                 self.vm.add_op(Op::Negate);
             }
@@ -73,7 +73,7 @@ impl BytecodeInterpreter {
                 };
                 self.vm.add_op(op);
             }
-            Expression::FunctionCall(name, args, _type) => {
+            Expression::FunctionCall(_span, name, args, _type) => {
                 // Put all arguments on top of the stack
                 for arg in args {
                     self.compile_expression(arg)?;
@@ -97,13 +97,13 @@ impl BytecodeInterpreter {
         self.compile_expression(expr)?;
 
         match expr {
-            Expression::Scalar(_)
-            | Expression::Identifier(_, _)
-            | Expression::UnitIdentifier(_, _, _, _)
-            | Expression::FunctionCall(_, _, _)
-            | Expression::Negate(_, _)
+            Expression::Scalar(..)
+            | Expression::Identifier(..)
+            | Expression::UnitIdentifier(..)
+            | Expression::FunctionCall(..)
+            | Expression::Negate(..)
             | Expression::BinaryOperator(_, BinaryOperator::ConvertTo, _, _, _) => {}
-            Expression::BinaryOperator(_, _, _, _, _) => {
+            Expression::BinaryOperator(..) => {
                 self.vm.add_op(Op::FullSimplify);
             }
         }

+ 33 - 16
numbat/src/diagnostic.rs

@@ -59,7 +59,7 @@ impl ErrorDiagnostic for TypeCheckError {
             TypeCheckError::UnknownIdentifier(span, _) => d.with_labels(vec![span
                 .diagnostic_label(LabelStyle::Primary)
                 .with_message("unknown identifier")]),
-            TypeCheckError::UnknownFunction(span, _) => d.with_labels(vec![span
+            TypeCheckError::UnknownCallable(span, _) => d.with_labels(vec![span
                 .diagnostic_label(LabelStyle::Primary)
                 .with_message("unknown callable")]),
             TypeCheckError::IncompatibleDimensions {
@@ -115,23 +115,40 @@ impl ErrorDiagnostic for TypeCheckError {
             TypeCheckError::WrongArity {
                 callable_span,
                 callable_name: _,
+                callable_definition_span,
                 arity,
                 num_args,
-            } => d.with_labels(vec![callable_span
-                .diagnostic_label(LabelStyle::Primary)
-                .with_message(format!(
-                    "Function or procedure called with {num}, but takes {range}",
-                    num = if *num_args == 1 {
-                        "one argument".into()
-                    } else {
-                        format!("{num_args} arguments")
-                    },
-                    range = if arity.start() == arity.end() {
-                        format!("{}", arity.start())
-                    } else {
-                        format!("{} to {}", arity.start(), arity.end())
-                    }
-                ))]),
+            } => {
+                let mut labels = vec![callable_span
+                    .diagnostic_label(LabelStyle::Primary)
+                    .with_message(format!(
+                        "{what}was called with {num}, but takes {range}",
+                        what = if callable_definition_span.is_some() {
+                            ""
+                        } else {
+                            "procedure "
+                        },
+                        num = if *num_args == 1 {
+                            "one argument".into()
+                        } else {
+                            format!("{num_args} arguments")
+                        },
+                        range = if arity.start() == arity.end() {
+                            format!("{}", arity.start())
+                        } else {
+                            format!("{} to {}", arity.start(), arity.end())
+                        }
+                    ))];
+                if let Some(span) = callable_definition_span {
+                    labels.insert(
+                        0,
+                        span.diagnostic_label(LabelStyle::Secondary)
+                            .with_message("The function defined here …"),
+                    );
+                }
+
+                d.with_labels(labels)
+            }
             TypeCheckError::TypeParameterNameClash(_) => d.with_notes(vec![inner_error]),
             TypeCheckError::CanNotInferTypeParameters(_, _) => d.with_notes(vec![inner_error]),
             TypeCheckError::MultipleUnresolvedTypeParameters => d.with_notes(vec![inner_error]),

+ 1 - 1
numbat/src/span.rs

@@ -52,7 +52,7 @@ impl Span {
         )
     }
 
-    // TODO: make this #[cfg(test)]
+    #[cfg(test)]
     pub fn dummy() -> Span {
         Self {
             start: SourceCodePositition::start(),

+ 42 - 27
numbat/src/typechecker.rs

@@ -17,8 +17,8 @@ pub enum TypeCheckError {
     #[error("Unknown identifier '{1}'.")]
     UnknownIdentifier(Span, String),
 
-    #[error("Unknown function '{1}'.")]
-    UnknownFunction(Span, String),
+    #[error("Unknown callable '{1}'.")]
+    UnknownCallable(Span, String),
 
     #[error("{expected_name}: {expected_type}\n{actual_name}: {actual_type}")]
     IncompatibleDimensions {
@@ -51,6 +51,7 @@ pub enum TypeCheckError {
     WrongArity {
         callable_span: Span,
         callable_name: String,
+        callable_definition_span: Option<Span>,
         arity: ArityRange,
         num_args: usize,
     },
@@ -86,7 +87,7 @@ fn to_rational_exponent(exponent_f64: f64) -> Exponent {
 fn evaluate_const_expr(expr: &typed_ast::Expression) -> Result<Exponent> {
     match expr {
         typed_ast::Expression::Scalar(n) => Ok(to_rational_exponent(n.to_f64())),
-        typed_ast::Expression::Negate(ref expr, _) => Ok(-evaluate_const_expr(expr)?),
+        typed_ast::Expression::Negate(_, ref expr, _) => Ok(-evaluate_const_expr(expr)?),
         typed_ast::Expression::BinaryOperator(span, op, lhs_expr, rhs_expr, _) => {
             let lhs = evaluate_const_expr(lhs_expr)?;
             let rhs = evaluate_const_expr(rhs_expr)?;
@@ -112,18 +113,18 @@ fn evaluate_const_expr(expr: &typed_ast::Expression) -> Result<Exponent> {
                     }
                 }
                 typed_ast::BinaryOperator::ConvertTo => Err(
-                    TypeCheckError::UnsupportedConstEvalExpression(Span::dummy(), "conversion"), // TODO
+                    TypeCheckError::UnsupportedConstEvalExpression(*span, "conversion"),
                 ),
             }
         }
-        typed_ast::Expression::Identifier(_, _) => Err(
-            TypeCheckError::UnsupportedConstEvalExpression(Span::dummy(), "identifier"), // TODO
+        typed_ast::Expression::Identifier(span, _, _) => Err(
+            TypeCheckError::UnsupportedConstEvalExpression(*span, "variable"),
         ),
-        typed_ast::Expression::UnitIdentifier(_, _, _, _) => Err(
-            TypeCheckError::UnsupportedConstEvalExpression(Span::dummy(), "identifier"), // TODO
+        typed_ast::Expression::UnitIdentifier(span, _, _, _, _) => Err(
+            TypeCheckError::UnsupportedConstEvalExpression(*span, "unit identifier"),
         ),
-        typed_ast::Expression::FunctionCall(_, _, _) => Err(
-            TypeCheckError::UnsupportedConstEvalExpression(Span::dummy(), "function call"), // TODO
+        typed_ast::Expression::FunctionCall(span, _, _, _) => Err(
+            TypeCheckError::UnsupportedConstEvalExpression(*span, "function call"),
         ),
     }
 }
@@ -131,7 +132,7 @@ fn evaluate_const_expr(expr: &typed_ast::Expression) -> Result<Exponent> {
 #[derive(Clone, Default)]
 pub struct TypeChecker {
     identifiers: HashMap<String, Type>,
-    function_signatures: HashMap<String, (Vec<String>, Vec<(Span, Type)>, bool, Type)>,
+    function_signatures: HashMap<String, (Span, Vec<String>, Vec<(Span, Type)>, bool, Type)>,
     registry: DimensionRegistry,
 }
 
@@ -148,17 +149,17 @@ impl TypeChecker {
             ast::Expression::Identifier(span, name) => {
                 let type_ = self.type_for_identifier(span, &name)?.clone();
 
-                typed_ast::Expression::Identifier(name, type_)
+                typed_ast::Expression::Identifier(span, name, type_)
             }
             ast::Expression::UnitIdentifier(span, prefix, name, full_name) => {
                 let type_ = self.type_for_identifier(span, &name)?.clone();
 
-                typed_ast::Expression::UnitIdentifier(prefix, name, full_name, type_)
+                typed_ast::Expression::UnitIdentifier(span, prefix, name, full_name, type_)
             }
-            ast::Expression::Negate(_, expr) => {
+            ast::Expression::Negate(span, expr) => {
                 let checked_expr = self.check_expression(*expr)?;
                 let type_ = checked_expr.get_type();
-                typed_ast::Expression::Negate(Box::new(checked_expr), type_)
+                typed_ast::Expression::Negate(span, Box::new(checked_expr), type_)
             }
             ast::Expression::BinaryOperator {
                 op,
@@ -244,10 +245,16 @@ impl TypeChecker {
                 )
             }
             ast::Expression::FunctionCall(span, function_name, args) => {
-                let (type_parameters, parameter_types, is_variadic, return_type) = self
+                let (
+                    callable_definition_span,
+                    type_parameters,
+                    parameter_types,
+                    is_variadic,
+                    return_type,
+                ) = self
                     .function_signatures
                     .get(&function_name)
-                    .ok_or_else(|| TypeCheckError::UnknownFunction(span, function_name.clone()))?;
+                    .ok_or_else(|| TypeCheckError::UnknownCallable(span, function_name.clone()))?;
 
                 let arity_range = if *is_variadic {
                     1..=usize::MAX
@@ -259,6 +266,7 @@ impl TypeChecker {
                     return Err(TypeCheckError::WrongArity {
                         callable_span: span,
                         callable_name: function_name.clone(),
+                        callable_definition_span: Some(*callable_definition_span),
                         arity: arity_range,
                         num_args: args.len(),
                     });
@@ -373,7 +381,12 @@ impl TypeChecker {
 
                 let return_type = substitute(&substitutions, return_type);
 
-                typed_ast::Expression::FunctionCall(function_name, arguments_checked, return_type)
+                typed_ast::Expression::FunctionCall(
+                    span,
+                    function_name,
+                    arguments_checked,
+                    return_type,
+                )
             }
         })
     }
@@ -452,7 +465,7 @@ impl TypeChecker {
                     if type_deduced != type_specified {
                         return Err(TypeCheckError::IncompatibleDimensions {
                             span_operation: identifier_span,
-                            operation: "derived unit declaration".into(),
+                            operation: "unit definition".into(),
                             span_expected: type_annotation_span.unwrap(),
                             expected_name: "specified dimension",
                             expected_type: type_specified,
@@ -566,6 +579,7 @@ impl TypeChecker {
                 self.function_signatures.insert(
                     function_name.clone(),
                     (
+                        function_name_span,
                         type_parameters,
                         parameter_types,
                         is_variadic,
@@ -621,6 +635,7 @@ impl TypeChecker {
                     return Err(TypeCheckError::WrongArity {
                         callable_span: span,
                         callable_name: procedure.name.clone(),
+                        callable_definition_span: None,
                         arity: procedure.arity.clone(),
                         num_args: args.len(),
                     });
@@ -761,7 +776,7 @@ mod tests {
         assert!(matches!(
             get_typecheck_error("let x=2
                                  a^x"),
-            TypeCheckError::UnsupportedConstEvalExpression(_, desc) if desc == "identifier"
+            TypeCheckError::UnsupportedConstEvalExpression(_, desc) if desc == "variable"
         ));
 
         assert!(matches!(
@@ -914,7 +929,7 @@ mod tests {
     fn unknown_function() {
         assert!(matches!(
             get_typecheck_error("foo(2)"),
-            TypeCheckError::UnknownFunction(_, name) if name == "foo"
+            TypeCheckError::UnknownCallable(_, name) if name == "foo"
         ));
     }
 
@@ -936,7 +951,7 @@ mod tests {
                 fn f() = 1
                 f(1)
             "),
-            TypeCheckError::WrongArity{callable_span:_, callable_name, arity, num_args: 1} if arity == (0..=0) && callable_name == "f"
+            TypeCheckError::WrongArity{callable_span:_, callable_name, callable_definition_span: _, arity, num_args: 1} if arity == (0..=0) && callable_name == "f"
         ));
 
         assert!(matches!(
@@ -944,7 +959,7 @@ mod tests {
                 fn f(x: Scalar) = x
                 f()
             "),
-            TypeCheckError::WrongArity{callable_span:_, callable_name, arity, num_args: 0} if arity == (1..=1) && callable_name == "f"
+            TypeCheckError::WrongArity{callable_span:_, callable_name, callable_definition_span: _,  arity, num_args: 0} if arity == (1..=1) && callable_name == "f"
         ));
 
         assert!(matches!(
@@ -952,7 +967,7 @@ mod tests {
                 fn f(x: Scalar) = x
                 f(2, 3)
             "),
-            TypeCheckError::WrongArity{callable_span:_, callable_name, arity, num_args: 2} if arity == (1..=1) && callable_name == "f"
+            TypeCheckError::WrongArity{callable_span:_, callable_name, callable_definition_span: _,  arity, num_args: 2} if arity == (1..=1) && callable_name == "f"
         ));
 
         assert!(matches!(
@@ -960,7 +975,7 @@ mod tests {
                 fn mean<D>(xs: D…) -> D
                 mean()
             "),
-            TypeCheckError::WrongArity{callable_span:_, callable_name, arity, num_args: 0} if arity == (1..=usize::MAX) && callable_name == "mean"
+            TypeCheckError::WrongArity{callable_span:_, callable_name, callable_definition_span: _,  arity, num_args: 0} if arity == (1..=usize::MAX) && callable_name == "mean"
         ));
     }
 
@@ -997,13 +1012,13 @@ mod tests {
     fn arity_checks_in_procedure_calls() {
         assert!(matches!(
             get_typecheck_error("assert_eq(1)"),
-            TypeCheckError::WrongArity{callable_span:_, callable_name, arity, num_args: 1} if arity == (2..=3) && callable_name == "assert_eq"
+            TypeCheckError::WrongArity{callable_span:_, callable_name, callable_definition_span: _,  arity, num_args: 1} if arity == (2..=3) && callable_name == "assert_eq"
         ));
         assert_successful_typecheck("assert_eq(1,2)");
         assert_successful_typecheck("assert_eq(1,2,3)");
         assert!(matches!(
             get_typecheck_error("assert_eq(1,2,3,4)"),
-            TypeCheckError::WrongArity{callable_span:_, callable_name, arity, num_args: 4} if arity == (2..=3) && callable_name == "assert_eq"
+            TypeCheckError::WrongArity{callable_span:_, callable_name, callable_definition_span: _,  arity, num_args: 4} if arity == (2..=3) && callable_name == "assert_eq"
         ));
     }
 }

+ 8 - 8
numbat/src/typed_ast.rs

@@ -8,11 +8,11 @@ pub type Type = BaseRepresentation;
 #[derive(Debug, Clone, PartialEq)]
 pub enum Expression {
     Scalar(Number),
-    Identifier(String, Type),
-    UnitIdentifier(Prefix, String, String, Type),
-    Negate(Box<Expression>, Type),
+    Identifier(Span, String, Type),
+    UnitIdentifier(Span, Prefix, String, String, Type),
+    Negate(Span, Box<Expression>, Type),
     BinaryOperator(Span, BinaryOperator, Box<Expression>, Box<Expression>, Type),
-    FunctionCall(String, Vec<Expression>, Type),
+    FunctionCall(Span, String, Vec<Expression>, Type),
 }
 
 #[derive(Debug, Clone, PartialEq)]
@@ -35,11 +35,11 @@ impl Expression {
     pub(crate) fn get_type(&self) -> Type {
         match self {
             Expression::Scalar(_) => Type::unity(),
-            Expression::Identifier(_, type_) => type_.clone(),
-            Expression::UnitIdentifier(_, _, _, _type) => _type.clone(),
-            Expression::Negate(_, type_) => type_.clone(),
+            Expression::Identifier(_, _, type_) => type_.clone(),
+            Expression::UnitIdentifier(_, _, _, _, _type) => _type.clone(),
+            Expression::Negate(_, _, type_) => type_.clone(),
             Expression::BinaryOperator(_, _, _, _, type_) => type_.clone(),
-            Expression::FunctionCall(_, _, type_) => type_.clone(),
+            Expression::FunctionCall(_, _, _, type_) => type_.clone(),
         }
     }
 }