Browse Source

Better incompatible-dimensions errors

David Peter 2 years ago
parent
commit
13f3a39bd2

+ 6 - 0
examples/typecheck_error/incompatible_types_in_addition.nbt

@@ -0,0 +1,6 @@
+let v: Speed = 3 m/s
+let mass: Mass = 80 kg
+let height = 10 m
+let E_pot: Energy = mass * gravity * height
+
+0.5 * mass * v + E_pot

+ 1 - 0
examples/typecheck_error/incompatible_types_in_conversion.nbt

@@ -0,0 +1 @@
+0.5 * (80 kg)^2 * 3 m/s -> J

+ 0 - 0
examples/typecheck_error/function_called_with_wrong_argument_types.nbt → examples/typecheck_error/incompatible_types_in_function_argument.nbt


+ 1 - 0
examples/typecheck_error/incompatible_types_in_modulo.nbt

@@ -0,0 +1 @@
+130 cm % 1

+ 1 - 0
examples/typecheck_error/unit_declaration.nbt

@@ -0,0 +1 @@
+unit my_unit: Length = 2 fortnight

+ 1 - 0
examples/typecheck_error/variable_declaration.nbt

@@ -0,0 +1 @@
+let x: Length = 2 second

+ 3 - 0
examples/typecheck_error/wrong_return_type.nbt

@@ -0,0 +1,3 @@
+let time: Time = 5 second
+
+fn foo(distance: Length) -> Speed = time / distance

+ 101 - 61
numbat/src/ast.rs

@@ -340,29 +340,35 @@ pub enum ProcedureKind {
 #[derive(Debug, Clone, PartialEq)]
 pub enum Statement {
     Expression(Expression),
-    DeclareVariable(Span, String, Expression, Option<DimensionExpression>),
-    DeclareFunction(
-        Span,
-        /// Function name
-        String,
-        /// Introduced type parameters
-        Vec<String>,
-        /// Arguments, optionally with type annotations. The boolean argument specifies whether or not the parameter is variadic
-        Vec<(String, Option<DimensionExpression>, bool)>,
+    DeclareVariable {
+        identifier_span: Span,
+        identifier: String,
+        expr: Expression,
+        type_annotation_span: Option<Span>,
+        type_annotation: Option<DimensionExpression>,
+    },
+    DeclareFunction {
+        function_name_span: Span,
+        function_name: String,
+        type_parameters: Vec<String>,
+        /// Parameters, optionally with type annotations. The boolean argument specifies whether or not the parameter is variadic
+        parameters: Vec<(String, Option<DimensionExpression>, bool)>,
         /// Function body. If it is absent, the function is implemented via FFI
-        Option<Expression>,
+        body: Option<Expression>,
+        return_type_span: Option<Span>,
         /// Optional annotated return type
-        Option<DimensionExpression>,
-    ),
+        return_type_annotation: Option<DimensionExpression>,
+    },
     DeclareDimension(String, Vec<DimensionExpression>),
     DeclareBaseUnit(Span, String, DimensionExpression, Vec<Decorator>),
-    DeclareDerivedUnit(
-        Span,
-        String,
-        Expression,
-        Option<DimensionExpression>,
-        Vec<Decorator>,
-    ),
+    DeclareDerivedUnit {
+        identifier_span: Span,
+        identifier: String,
+        expr: Expression,
+        type_annotation_span: Option<Span>,
+        type_annotation: Option<DimensionExpression>,
+        decorators: Vec<Decorator>,
+    },
     ProcedureCall(ProcedureKind, Vec<Expression>),
     ModuleImport(Span, ModulePath),
 }
@@ -422,11 +428,17 @@ fn decorator_markup(decorators: &Vec<Decorator>) -> Markup {
 impl PrettyPrint for Statement {
     fn pretty_print(&self) -> Markup {
         match self {
-            Statement::DeclareVariable(_span, identifier, expr, dexpr) => {
+            Statement::DeclareVariable {
+                identifier_span: _,
+                identifier,
+                expr,
+                type_annotation_span: _,
+                type_annotation,
+            } => {
                 m::keyword("let")
                     + m::space()
                     + m::identifier(identifier)
-                    + dexpr
+                    + type_annotation
                         .as_ref()
                         .map(|d| m::operator(":") + m::space() + d.pretty_print())
                         .unwrap_or_default()
@@ -435,20 +447,21 @@ impl PrettyPrint for Statement {
                     + m::space()
                     + expr.pretty_print()
             }
-            Statement::DeclareFunction(
-                _span,
-                identifier,
-                type_variables,
-                parameters,
+            Statement::DeclareFunction {
+                function_name_span: _,
+                function_name,
+                type_parameters,
+                parameters: arguments,
                 body,
-                dexpr,
-            ) => {
-                let markup_type_variables = if type_variables.is_empty() {
+                return_type_span: _,
+                return_type_annotation,
+            } => {
+                let markup_type_parameters = if type_parameters.is_empty() {
                     Markup::default()
                 } else {
                     m::operator("<")
                         + Itertools::intersperse(
-                            type_variables.iter().map(m::type_identifier),
+                            type_parameters.iter().map(m::type_identifier),
                             m::operator(", "),
                         )
                         .sum()
@@ -456,7 +469,7 @@ impl PrettyPrint for Statement {
                 };
 
                 let markup_parameters = Itertools::intersperse(
-                    parameters.iter().map(|(name, dexpr, is_variadic)| {
+                    arguments.iter().map(|(name, dexpr, is_variadic)| {
                         m::identifier(name)
                             + dexpr
                                 .as_ref()
@@ -476,15 +489,15 @@ impl PrettyPrint for Statement {
                 )
                 .sum();
 
-                let markup_return_type = dexpr
+                let markup_return_type = return_type_annotation
                     .as_ref()
                     .map(|d| m::space() + m::operator("->") + m::space() + d.pretty_print())
                     .unwrap_or_default();
 
                 m::keyword("fn")
                     + m::space()
-                    + m::identifier(identifier)
-                    + markup_type_variables
+                    + m::identifier(function_name)
+                    + markup_type_parameters
                     + m::operator("(")
                     + markup_parameters
                     + m::operator(")")
@@ -520,12 +533,19 @@ impl PrettyPrint for Statement {
                     + m::space()
                     + dexpr.pretty_print()
             }
-            Statement::DeclareDerivedUnit(_span, identifier, expr, dexpr, decorators) => {
+            Statement::DeclareDerivedUnit {
+                identifier_span: _,
+                identifier,
+                expr,
+                type_annotation_span: _,
+                type_annotation,
+                decorators,
+            } => {
                 decorator_markup(decorators)
                     + m::keyword("unit")
                     + m::space()
                     + m::unit(identifier)
-                    + dexpr
+                    + type_annotation
                         .as_ref()
                         .map(|d| m::operator(":") + m::space() + d.pretty_print())
                         .unwrap_or_default()
@@ -606,22 +626,36 @@ impl ReplaceSpans for Statement {
     fn replace_spans(&self) -> Self {
         match self {
             Statement::Expression(expr) => Statement::Expression(expr.replace_spans()),
-            Statement::DeclareVariable(_, name, expr, type_) => Statement::DeclareVariable(
-                Span::dummy(),
-                name.clone(),
-                expr.replace_spans(),
-                type_.clone(),
-            ),
-            Statement::DeclareFunction(_span, name, type_params, args, body, type_) => {
-                Statement::DeclareFunction(
-                    Span::dummy(),
-                    name.clone(),
-                    type_params.clone(),
-                    args.clone(),
-                    body.clone().map(|b| b.replace_spans()),
-                    type_.clone(),
-                )
-            }
+            Statement::DeclareVariable {
+                identifier_span: _,
+                identifier,
+                expr,
+                type_annotation_span,
+                type_annotation,
+            } => Statement::DeclareVariable {
+                identifier_span: Span::dummy(),
+                identifier: identifier.clone(),
+                expr: expr.replace_spans(),
+                type_annotation_span: type_annotation_span.map(|_| Span::dummy()),
+                type_annotation: type_annotation.clone(),
+            },
+            Statement::DeclareFunction {
+                function_name_span: _,
+                function_name,
+                type_parameters,
+                parameters,
+                body,
+                return_type_span,
+                return_type_annotation,
+            } => Statement::DeclareFunction {
+                function_name_span: Span::dummy(),
+                function_name: function_name.clone(),
+                type_parameters: type_parameters.clone(),
+                parameters: parameters.clone(),
+                body: body.clone().map(|b| b.replace_spans()),
+                return_type_span: return_type_span.map(|_| Span::dummy()),
+                return_type_annotation: return_type_annotation.clone(),
+            },
             s @ Statement::DeclareDimension(_, _) => s.clone(),
             Statement::DeclareBaseUnit(_, name, type_, decorators) => Statement::DeclareBaseUnit(
                 Span::dummy(),
@@ -629,15 +663,21 @@ impl ReplaceSpans for Statement {
                 type_.clone(),
                 decorators.clone(),
             ),
-            Statement::DeclareDerivedUnit(_, name, expr, type_, decorators) => {
-                Statement::DeclareDerivedUnit(
-                    Span::dummy(),
-                    name.clone(),
-                    expr.replace_spans(),
-                    type_.clone(),
-                    decorators.clone(),
-                )
-            }
+            Statement::DeclareDerivedUnit {
+                identifier_span: _,
+                identifier,
+                expr,
+                type_annotation_span,
+                type_annotation,
+                decorators,
+            } => Statement::DeclareDerivedUnit {
+                identifier_span: Span::dummy(),
+                identifier: identifier.clone(),
+                expr: expr.replace_spans(),
+                type_annotation_span: type_annotation_span.map(|_| Span::dummy()),
+                type_annotation: type_annotation.clone(),
+                decorators: decorators.clone(),
+            },
             Statement::ProcedureCall(proc, args) => Statement::ProcedureCall(
                 proc.clone(),
                 args.iter().map(|a| a.replace_spans()).collect(),

+ 64 - 27
numbat/src/diagnostic.rs

@@ -8,11 +8,11 @@ use crate::{
 pub type Diagnostic = codespan_reporting::diagnostic::Diagnostic<usize>;
 
 pub trait ErrorDiagnostic {
-    fn diagnostic(self) -> Diagnostic;
+    fn diagnostic(&self) -> Diagnostic;
 }
 
 impl ErrorDiagnostic for ParseError {
-    fn diagnostic(self) -> Diagnostic {
+    fn diagnostic(&self) -> Diagnostic {
         Diagnostic::error()
             .with_message("while parsing")
             .with_labels(vec![self
@@ -23,7 +23,7 @@ impl ErrorDiagnostic for ParseError {
 }
 
 impl ErrorDiagnostic for ResolverError {
-    fn diagnostic(self) -> Diagnostic {
+    fn diagnostic(&self) -> Diagnostic {
         match self {
             ResolverError::UnknownModule(span, _) => Diagnostic::error()
                 .with_message("while resolving imports in")
@@ -36,7 +36,7 @@ impl ErrorDiagnostic for ResolverError {
 }
 
 impl ErrorDiagnostic for NameResolutionError {
-    fn diagnostic(self) -> Diagnostic {
+    fn diagnostic(&self) -> Diagnostic {
         match self {
             NameResolutionError::IdentifierClash {
                 conflicting_identifier: _,
@@ -51,45 +51,82 @@ impl ErrorDiagnostic for NameResolutionError {
 }
 
 impl ErrorDiagnostic for TypeCheckError {
-    fn diagnostic(self) -> Diagnostic {
-        let d = Diagnostic::error()
-            .with_message("while type checking")
-            .with_notes(vec![format!("{self:#}")]);
+    fn diagnostic(&self) -> Diagnostic {
+        let d = Diagnostic::error().with_message("while type checking");
 
         match self {
             TypeCheckError::UnknownIdentifier(span, _) => {
                 d.with_labels(vec![span.diagnostic_label(LabelStyle::Primary)])
             }
-            TypeCheckError::UnknownFunction(_) => d,
-            TypeCheckError::IncompatibleDimensions(span, _, _, _, _, _) => {
-                if let Some(span) = span {
-                    d.with_labels(vec![span.diagnostic_label(LabelStyle::Primary)])
-                } else {
-                    d
+            TypeCheckError::UnknownFunction(_) => d.with_notes(vec![format!("{self:#}")]),
+            TypeCheckError::IncompatibleDimensions {
+                operation,
+                span_operation,
+                span_actual,
+                actual_type,
+                span_expected,
+                expected_type,
+                ..
+            } => {
+                let mut labels = vec![span_operation
+                    .diagnostic_label(LabelStyle::Secondary)
+                    .with_message(format!("incompatible dimensions in {}", operation))];
+                if let Some(span_actual) = span_actual {
+                    labels.push(
+                        span_actual
+                            .diagnostic_label(LabelStyle::Primary)
+                            .with_message(format!("{actual_type}")),
+                    );
                 }
+                if let Some(span_expected) = span_expected {
+                    labels.push(
+                        span_expected
+                            .diagnostic_label(LabelStyle::Primary)
+                            .with_message(format!("{expected_type}")),
+                    );
+                }
+                d.with_labels(labels).with_notes(vec![format!("{self:#}")])
+            }
+            TypeCheckError::NonScalarExponent(span, type_) => d
+                .with_labels(vec![span
+                    .diagnostic_label(LabelStyle::Primary)
+                    .with_message(format!("{type_}"))])
+                .with_notes(vec![format!("{self:#}")]),
+            TypeCheckError::UnsupportedConstEvalExpression(_) => {
+                d.with_notes(vec![format!("{self:#}")])
+            }
+            TypeCheckError::DivisionByZeroInConstEvalExpression => {
+                d.with_notes(vec![format!("{self:#}")])
+            }
+            TypeCheckError::RegistryError(_) => d.with_notes(vec![format!("{self:#}")]),
+            TypeCheckError::IncompatibleAlternativeDimensionExpression(_) => {
+                d.with_notes(vec![format!("{self:#}")])
             }
-            TypeCheckError::NonScalarExponent(_) => d,
-            TypeCheckError::UnsupportedConstEvalExpression(_) => d,
-            TypeCheckError::DivisionByZeroInConstEvalExpression => d,
-            TypeCheckError::RegistryError(_) => d,
-            TypeCheckError::IncompatibleAlternativeDimensionExpression(_) => d,
             TypeCheckError::WrongArity {
                 callable_name: _,
                 arity: _,
                 num_args: _,
-            } => d,
-            TypeCheckError::TypeParameterNameClash(_) => d,
-            TypeCheckError::CanNotInferTypeParameters(_, _) => d,
-            TypeCheckError::MultipleUnresolvedTypeParameters => d,
-            TypeCheckError::ForeignFunctionNeedsReturnTypeAnnotation(_) => d,
-            TypeCheckError::UnknownForeignFunction(_) => d,
-            TypeCheckError::ParameterTypesCanNotBeDeduced => d,
+            } => d.with_notes(vec![format!("{self:#}")]),
+            TypeCheckError::TypeParameterNameClash(_) => d.with_notes(vec![format!("{self:#}")]),
+            TypeCheckError::CanNotInferTypeParameters(_, _) => {
+                d.with_notes(vec![format!("{self:#}")])
+            }
+            TypeCheckError::MultipleUnresolvedTypeParameters => {
+                d.with_notes(vec![format!("{self:#}")])
+            }
+            TypeCheckError::ForeignFunctionNeedsReturnTypeAnnotation(_) => {
+                d.with_notes(vec![format!("{self:#}")])
+            }
+            TypeCheckError::UnknownForeignFunction(_) => d.with_notes(vec![format!("{self:#}")]),
+            TypeCheckError::ParameterTypesCanNotBeDeduced => {
+                d.with_notes(vec![format!("{self:#}")])
+            }
         }
     }
 }
 
 impl ErrorDiagnostic for RuntimeError {
-    fn diagnostic(self) -> Diagnostic {
+    fn diagnostic(&self) -> Diagnostic {
         Diagnostic::error()
             .with_message("runtime error")
             .with_notes(vec![format!("{self:#}")])

+ 118 - 92
numbat/src/parser.rs

@@ -237,10 +237,12 @@ impl<'a> Parser<'a> {
             if let Some(identifier) = self.match_exact(TokenKind::Identifier) {
                 let identifier_span = self.last().unwrap().span;
 
-                let dexpr = if self.match_exact(TokenKind::Colon).is_some() {
-                    Some(self.dimension_expression()?)
+                let (type_annotation_span, dexpr) = if self.match_exact(TokenKind::Colon).is_some()
+                {
+                    let type_annotation = self.dimension_expression()?;
+                    (Some(self.last().unwrap().span), Some(type_annotation))
                 } else {
-                    None
+                    (None, None)
                 };
 
                 if self.match_exact(TokenKind::Equal).is_none() {
@@ -251,12 +253,13 @@ impl<'a> Parser<'a> {
                 } else {
                     let expr = self.expression()?;
 
-                    Ok(Statement::DeclareVariable(
+                    Ok(Statement::DeclareVariable {
                         identifier_span,
-                        identifier.lexeme.clone(),
+                        identifier: identifier.lexeme.clone(),
                         expr,
-                        dexpr,
-                    ))
+                        type_annotation_span,
+                        type_annotation: dexpr,
+                    })
                 }
             } else {
                 Err(ParseError {
@@ -266,12 +269,12 @@ impl<'a> Parser<'a> {
             }
         } else if self.match_exact(TokenKind::Fn).is_some() {
             if let Some(fn_name) = self.match_exact(TokenKind::Identifier) {
-                let identifier_span = self.last().unwrap().span;
+                let function_name_span = self.last().unwrap().span;
                 let mut type_parameters = vec![];
                 if self.match_exact(TokenKind::LeftAngleBracket).is_some() {
                     while self.match_exact(TokenKind::RightAngleBracket).is_none() {
-                        if let Some(param_name) = self.match_exact(TokenKind::Identifier) {
-                            type_parameters.push(param_name.lexeme.to_string());
+                        if let Some(type_parameter_name) = self.match_exact(TokenKind::Identifier) {
+                            type_parameters.push(type_parameter_name.lexeme.to_string());
 
                             if self.match_exact(TokenKind::Comma).is_none()
                                 && self.peek().kind != TokenKind::RightAngleBracket
@@ -334,12 +337,17 @@ impl<'a> Parser<'a> {
                     }
                 }
 
-                let optional_return_type_dexpr = if self.match_exact(TokenKind::Arrow).is_some() {
-                    // Parse return type
-                    Some(self.dimension_expression()?)
-                } else {
-                    None
-                };
+                let (return_type_span, return_type_annotation) =
+                    if self.match_exact(TokenKind::Arrow).is_some() {
+                        // Parse return type
+                        let return_type_annotation = self.dimension_expression()?;
+                        (
+                            Some(self.last().unwrap().span),
+                            Some(return_type_annotation),
+                        )
+                    } else {
+                        (None, None)
+                    };
 
                 let fn_is_variadic = parameters.iter().any(|p| p.2);
                 if fn_is_variadic && parameters.len() > 1 {
@@ -362,14 +370,15 @@ impl<'a> Parser<'a> {
                     });
                 }
 
-                Ok(Statement::DeclareFunction(
-                    identifier_span,
-                    fn_name.lexeme.clone(),
-                    type_parameters,
+                Ok(Statement::DeclareFunction {
+                    function_name_span,
+                    function_name: fn_name.lexeme.clone(),
+                    type_parameters: type_parameters,
                     parameters,
                     body,
-                    optional_return_type_dexpr,
-                ))
+                    return_type_span,
+                    return_type_annotation,
+                })
             } else {
                 Err(ParseError {
                     kind: ParseErrorKind::ExpectedIdentifierAfterFn,
@@ -429,10 +438,12 @@ impl<'a> Parser<'a> {
         } else if self.match_exact(TokenKind::Unit).is_some() {
             if let Some(identifier) = self.match_exact(TokenKind::Identifier) {
                 let identifier_span = self.last().unwrap().span;
-                let dexpr = if self.match_exact(TokenKind::Colon).is_some() {
-                    Some(self.dimension_expression()?)
+                let (type_annotation_span, dexpr) = if self.match_exact(TokenKind::Colon).is_some()
+                {
+                    let type_annotation = self.dimension_expression()?;
+                    (Some(self.last().unwrap().span), Some(type_annotation))
                 } else {
-                    None
+                    (None, None)
                 };
 
                 let unit_name = identifier.lexeme.clone();
@@ -442,13 +453,14 @@ impl<'a> Parser<'a> {
 
                 if self.match_exact(TokenKind::Equal).is_some() {
                     let expr = self.expression()?;
-                    Ok(Statement::DeclareDerivedUnit(
+                    Ok(Statement::DeclareDerivedUnit {
                         identifier_span,
-                        unit_name,
+                        identifier: unit_name,
                         expr,
-                        dexpr,
+                        type_annotation_span,
+                        type_annotation: dexpr,
                         decorators,
-                    ))
+                    })
                 } else if let Some(dexpr) = dexpr {
                     Ok(Statement::DeclareBaseUnit(
                         identifier_span,
@@ -1254,17 +1266,24 @@ mod tests {
     fn variable_declaration() {
         parse_as(
             &["let foo = 1", "let foo=1"],
-            Statement::DeclareVariable(Span::dummy(), "foo".into(), scalar!(1.0), None),
+            Statement::DeclareVariable {
+                identifier_span: Span::dummy(),
+                identifier: "foo".into(),
+                expr: scalar!(1.0),
+                type_annotation_span: None,
+                type_annotation: None,
+            },
         );
 
         parse_as(
             &["let x: Length = 1 * meter"],
-            Statement::DeclareVariable(
-                Span::dummy(),
-                "x".into(),
-                binop!(scalar!(1.0), Mul, identifier!("meter")),
-                Some(DimensionExpression::Dimension("Length".into())),
-            ),
+            Statement::DeclareVariable {
+                identifier_span: Span::dummy(),
+                identifier: "x".into(),
+                expr: binop!(scalar!(1.0), Mul, identifier!("meter")),
+                type_annotation_span: Some(Span::dummy()),
+                type_annotation: Some(DimensionExpression::Dimension("Length".into())),
+            },
         );
 
         should_fail_with(
@@ -1346,63 +1365,67 @@ mod tests {
     fn function_declaration() {
         parse_as(
             &["fn foo() = 1"],
-            Statement::DeclareFunction(
-                Span::dummy(),
-                "foo".into(),
-                vec![],
-                vec![],
-                Some(scalar!(1.0)),
-                None,
-            ),
+            Statement::DeclareFunction {
+                function_name_span: Span::dummy(),
+                function_name: "foo".into(),
+                type_parameters: vec![],
+                parameters: vec![],
+                body: Some(scalar!(1.0)),
+                return_type_span: None,
+                return_type_annotation: None,
+            },
         );
 
         parse_as(
             &["fn foo() -> Scalar = 1"],
-            Statement::DeclareFunction(
-                Span::dummy(),
-                "foo".into(),
-                vec![],
-                vec![],
-                Some(scalar!(1.0)),
-                Some(DimensionExpression::Dimension("Scalar".into())),
-            ),
+            Statement::DeclareFunction {
+                function_name_span: Span::dummy(),
+                function_name: "foo".into(),
+                type_parameters: vec![],
+                parameters: vec![],
+                body: Some(scalar!(1.0)),
+                return_type_span: Some(Span::dummy()),
+                return_type_annotation: Some(DimensionExpression::Dimension("Scalar".into())),
+            },
         );
 
         parse_as(
             &["fn foo(x) = 1"],
-            Statement::DeclareFunction(
-                Span::dummy(),
-                "foo".into(),
-                vec![],
-                vec![("x".into(), None, false)],
-                Some(scalar!(1.0)),
-                None,
-            ),
+            Statement::DeclareFunction {
+                function_name_span: Span::dummy(),
+                function_name: "foo".into(),
+                type_parameters: vec![],
+                parameters: vec![("x".into(), None, false)],
+                body: Some(scalar!(1.0)),
+                return_type_span: None,
+                return_type_annotation: None,
+            },
         );
 
         parse_as(
             &["fn foo(x, y, z) = 1"],
-            Statement::DeclareFunction(
-                Span::dummy(),
-                "foo".into(),
-                vec![],
-                vec![
+            Statement::DeclareFunction {
+                function_name_span: Span::dummy(),
+                function_name: "foo".into(),
+                type_parameters: vec![],
+                parameters: vec![
                     ("x".into(), None, false),
                     ("y".into(), None, false),
                     ("z".into(), None, false),
                 ],
-                Some(scalar!(1.0)),
-                None,
-            ),
+                body: Some(scalar!(1.0)),
+                return_type_span: None,
+                return_type_annotation: None,
+            },
         );
 
         parse_as(
             &["fn foo(x: Length, y: Time, z: Length^3 · Time^2) -> Scalar = 1"],
-            Statement::DeclareFunction(
-                Span::dummy(),
-                "foo".into(),
-                vec![],
-                vec![
+            Statement::DeclareFunction {
+                function_name_span: Span::dummy(),
+                function_name: "foo".into(),
+                type_parameters: vec![],
+                parameters: vec![
                     (
                         "x".into(),
                         Some(DimensionExpression::Dimension("Length".into())),
@@ -1428,41 +1451,44 @@ mod tests {
                         false,
                     ),
                 ],
-                Some(scalar!(1.0)),
-                Some(DimensionExpression::Dimension("Scalar".into())),
-            ),
+                body: Some(scalar!(1.0)),
+                return_type_span: Some(Span::dummy()),
+                return_type_annotation: Some(DimensionExpression::Dimension("Scalar".into())),
+            },
         );
 
         parse_as(
             &["fn foo<X>(x: X) = 1"],
-            Statement::DeclareFunction(
-                Span::dummy(),
-                "foo".into(),
-                vec!["X".into()],
-                vec![(
+            Statement::DeclareFunction {
+                function_name_span: Span::dummy(),
+                function_name: "foo".into(),
+                type_parameters: vec!["X".into()],
+                parameters: vec![(
                     "x".into(),
                     Some(DimensionExpression::Dimension("X".into())),
                     false,
                 )],
-                Some(scalar!(1.0)),
-                None,
-            ),
+                body: Some(scalar!(1.0)),
+                return_type_span: None,
+                return_type_annotation: None,
+            },
         );
 
         parse_as(
             &["fn foo<D>(x: D…) -> D"],
-            Statement::DeclareFunction(
-                Span::dummy(),
-                "foo".into(),
-                vec!["D".into()],
-                vec![(
+            Statement::DeclareFunction {
+                function_name_span: Span::dummy(),
+                function_name: "foo".into(),
+                type_parameters: vec!["D".into()],
+                parameters: vec![(
                     "x".into(),
                     Some(DimensionExpression::Dimension("D".into())),
                     true,
                 )],
-                None,
-                Some(DimensionExpression::Dimension("D".into())),
-            ),
+                body: None,
+                return_type_span: Some(Span::dummy()),
+                return_type_annotation: Some(DimensionExpression::Dimension("D".into())),
+            },
         );
 
         should_fail_with(

+ 54 - 23
numbat/src/prefix_transformer.rs

@@ -107,32 +107,63 @@ impl Transformer {
                 self.register_name_and_aliases(&name, &decorators, span)?;
                 Statement::DeclareBaseUnit(span, name, dexpr, decorators)
             }
-            Statement::DeclareDerivedUnit(span, name, expr, dexpr, decorators) => {
-                self.register_name_and_aliases(&name, &decorators, span)?;
-                Statement::DeclareDerivedUnit(
-                    span,
-                    name,
-                    self.transform_expression(expr),
-                    dexpr,
+            Statement::DeclareDerivedUnit {
+                identifier_span,
+                identifier,
+                expr,
+                type_annotation_span,
+                type_annotation,
+                decorators,
+            } => {
+                self.register_name_and_aliases(&identifier, &decorators, identifier_span)?;
+                Statement::DeclareDerivedUnit {
+                    identifier_span,
+                    identifier,
+                    expr: self.transform_expression(expr),
+                    type_annotation_span,
+                    type_annotation,
                     decorators,
-                )
+                }
             }
-            Statement::DeclareVariable(span, name, expr, dexpr) => {
-                self.variable_names.push(name.clone());
-                self.prefix_parser.add_other_identifier(&name, span)?;
-                Statement::DeclareVariable(span, name, self.transform_expression(expr), dexpr)
+            Statement::DeclareVariable {
+                identifier_span,
+                identifier,
+                expr,
+                type_annotation_span,
+                type_annotation,
+            } => {
+                self.variable_names.push(identifier.clone());
+                self.prefix_parser
+                    .add_other_identifier(&identifier, identifier_span)?;
+                Statement::DeclareVariable {
+                    identifier_span,
+                    identifier,
+                    expr: self.transform_expression(expr),
+                    type_annotation_span,
+                    type_annotation,
+                }
             }
-            Statement::DeclareFunction(span, name, type_params, args, body, return_type) => {
-                self.function_names.push(name.clone());
-                self.prefix_parser.add_other_identifier(&name, span)?;
-                Statement::DeclareFunction(
-                    span,
-                    name,
-                    type_params,
-                    args,
-                    body.map(|expr| self.transform_expression(expr)),
-                    return_type,
-                )
+            Statement::DeclareFunction {
+                function_name_span,
+                function_name,
+                type_parameters,
+                parameters,
+                body,
+                return_type_span,
+                return_type_annotation,
+            } => {
+                self.function_names.push(function_name.clone());
+                self.prefix_parser
+                    .add_other_identifier(&function_name, function_name_span)?;
+                Statement::DeclareFunction {
+                    function_name_span,
+                    function_name,
+                    type_parameters,
+                    parameters,
+                    body: body.map(|expr| self.transform_expression(expr)),
+                    return_type_span,
+                    return_type_annotation,
+                }
             }
             Statement::DeclareDimension(name, dexprs) => {
                 self.dimension_names.push(name.clone());

+ 7 - 6
numbat/src/resolver.rs

@@ -203,12 +203,13 @@ mod tests {
         assert_eq!(
             &program_inlined.replace_spans(),
             &[
-                Statement::DeclareVariable(
-                    Span::dummy(),
-                    "a".into(),
-                    Expression::Scalar(Span::dummy(), Number::from_f64(1.0)),
-                    None
-                ),
+                Statement::DeclareVariable {
+                    identifier_span: Span::dummy(),
+                    identifier: "a".into(),
+                    expr: Expression::Scalar(Span::dummy(), Number::from_f64(1.0)),
+                    type_annotation_span: None,
+                    type_annotation: None
+                },
                 Statement::Expression(Expression::Identifier(Span::dummy(), "a".into()))
             ]
         );

+ 3 - 3
numbat/src/span.rs

@@ -1,6 +1,6 @@
 use codespan_reporting::diagnostic::{Label, LabelStyle};
 
-#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
 pub struct SourceCodePositition {
     pub byte: usize,
     pub index: usize,
@@ -38,8 +38,8 @@ impl Span {
     pub fn extend(&self, other: &Span) -> Span {
         assert_eq!(self.code_source_index, other.code_source_index);
         Span {
-            start: self.start,
-            end: other.end,
+            start: std::cmp::min(self.start, other.start),
+            end: std::cmp::max(self.end, other.end),
             code_source_index: self.code_source_index,
         }
     }

+ 151 - 102
numbat/src/typechecker.rs

@@ -20,18 +20,20 @@ pub enum TypeCheckError {
     #[error("Unknown function '{0}'.")]
     UnknownFunction(String),
 
-    #[error("Incompatible dimensions in {1}:\n    {2}: {3}\n    {4}: {5}")]
-    IncompatibleDimensions(
-        Option<Span>,
-        String,
-        &'static str,
-        BaseRepresentation,
-        &'static str,
-        BaseRepresentation,
-    ),
-
-    #[error("Got dimension {0}, but exponent must be dimensionless.")]
-    NonScalarExponent(BaseRepresentation),
+    #[error("{expected_name}: {expected_type}\n{actual_name}: {actual_type}")]
+    IncompatibleDimensions {
+        span_operation: Span,
+        operation: String,
+        span_expected: Option<Span>,
+        expected_name: &'static str,
+        expected_type: BaseRepresentation,
+        span_actual: Option<Span>,
+        actual_name: &'static str,
+        actual_type: BaseRepresentation,
+    },
+
+    #[error("Exponents need to be dimensionless (got {1}).")]
+    NonScalarExponent(Span, BaseRepresentation),
 
     #[error("Unsupported expression in const-evaluation of exponent: {0}.")]
     UnsupportedConstEvalExpression(&'static str),
@@ -163,21 +165,38 @@ impl TypeChecker {
                 rhs,
                 span_op,
             } => {
-                let lhs = self.check_expression(*lhs)?;
-                let rhs = self.check_expression(*rhs)?;
+                let lhs_checked = self.check_expression((*lhs).clone())?;
+                let rhs_checked = self.check_expression((*rhs).clone())?;
 
                 let get_type_and_assert_equality = || {
-                    let lhs_type = lhs.get_type();
-                    let rhs_type = rhs.get_type();
+                    let lhs_type = lhs_checked.get_type();
+                    let rhs_type = rhs_checked.get_type();
                     if lhs_type != rhs_type {
-                        Err(TypeCheckError::IncompatibleDimensions(
-                            span_op,
-                            "binary operator".into(),
-                            " left hand side",
-                            lhs_type,
-                            "right hand side",
-                            rhs_type,
-                        ))
+                        Err(TypeCheckError::IncompatibleDimensions {
+                            span_operation: span_op.unwrap_or(
+                                ast::Expression::BinaryOperator {
+                                    op,
+                                    lhs: lhs.clone(),
+                                    rhs: rhs.clone(),
+                                    span_op,
+                                }
+                                .full_span(),
+                            ),
+                            operation: match op {
+                                typed_ast::BinaryOperator::Add => "addition".into(),
+                                typed_ast::BinaryOperator::Sub => "subtraction".into(),
+                                typed_ast::BinaryOperator::Mul => "multiplication".into(),
+                                typed_ast::BinaryOperator::Div => "division".into(),
+                                typed_ast::BinaryOperator::Power => "exponentiation".into(),
+                                typed_ast::BinaryOperator::ConvertTo => "unit conversion".into(),
+                            },
+                            span_expected: Some(lhs.full_span()),
+                            expected_name: " left hand side",
+                            expected_type: lhs_type,
+                            span_actual: Some(rhs.full_span()),
+                            actual_name: "right hand side",
+                            actual_type: rhs_type,
+                        })
                     } else {
                         Ok(lhs_type)
                     }
@@ -186,31 +205,43 @@ impl TypeChecker {
                 let type_ = match op {
                     typed_ast::BinaryOperator::Add => get_type_and_assert_equality()?,
                     typed_ast::BinaryOperator::Sub => get_type_and_assert_equality()?,
-                    typed_ast::BinaryOperator::Mul => lhs.get_type() * rhs.get_type(),
-                    typed_ast::BinaryOperator::Div => lhs.get_type() / rhs.get_type(),
+                    typed_ast::BinaryOperator::Mul => {
+                        lhs_checked.get_type() * rhs_checked.get_type()
+                    }
+                    typed_ast::BinaryOperator::Div => {
+                        lhs_checked.get_type() / rhs_checked.get_type()
+                    }
                     typed_ast::BinaryOperator::Power => {
-                        let exponent_type = rhs.get_type();
+                        let exponent_type = rhs_checked.get_type();
                         if exponent_type != Type::unity() {
-                            return Err(TypeCheckError::NonScalarExponent(exponent_type));
+                            return Err(TypeCheckError::NonScalarExponent(
+                                rhs.full_span(),
+                                exponent_type,
+                            ));
                         }
 
-                        let base_type = lhs.get_type();
+                        let base_type = lhs_checked.get_type();
                         if base_type == Type::unity() {
                             // Skip evaluating the exponent if the lhs is a scalar. This allows
                             // for arbitrary (decimal) exponents, if the base is a scalar.
 
                             base_type
                         } else {
-                            let exponent = evaluate_const_expr(&rhs)?;
+                            let exponent = evaluate_const_expr(&rhs_checked)?;
                             base_type.power(exponent)
                         }
                     }
                     typed_ast::BinaryOperator::ConvertTo => get_type_and_assert_equality()?,
                 };
 
-                typed_ast::Expression::BinaryOperator(op, Box::new(lhs), Box::new(rhs), type_)
+                typed_ast::Expression::BinaryOperator(
+                    op,
+                    Box::new(lhs_checked),
+                    Box::new(rhs_checked),
+                    type_,
+                )
             }
-            ast::Expression::FunctionCall(_, function_name, args) => {
+            ast::Expression::FunctionCall(span, function_name, args) => {
                 let (type_parameters, parameter_types, is_variadic, return_type) = self
                     .function_signatures
                     .get(&function_name)
@@ -231,8 +262,8 @@ impl TypeChecker {
                 }
 
                 let arguments_checked = args
-                    .into_iter()
-                    .map(|a| self.check_expression(a))
+                    .iter()
+                    .map(|a| self.check_expression(a.clone()))
                     .collect::<Result<Vec<_>>>()?;
                 let argument_types = arguments_checked.iter().map(|e| e.get_type());
 
@@ -304,18 +335,20 @@ impl TypeChecker {
                     }
 
                     if parameter_type != argument_type {
-                        return Err(TypeCheckError::IncompatibleDimensions(
-                            None, // TODO
-                            format!(
+                        return Err(TypeCheckError::IncompatibleDimensions {
+                            span_operation: span,
+                            operation: format!(
                                 "argument {num} of function call to '{name}'",
                                 num = idx + 1,
                                 name = function_name
                             ),
-                            "parameter type",
-                            parameter_type.clone(),
-                            " argument type",
-                            argument_type,
-                        ));
+                            span_expected: None, // TODO
+                            expected_name: "parameter type",
+                            expected_type: parameter_type.clone(),
+                            span_actual: Some(args[idx].full_span()),
+                            actual_name: " argument type",
+                            actual_type: argument_type,
+                        });
                     }
                 }
 
@@ -352,28 +385,37 @@ impl TypeChecker {
                 }
                 typed_ast::Statement::Expression(checked_expr)
             }
-            ast::Statement::DeclareVariable(_span, name, expr, optional_dexpr) => {
-                let expr = self.check_expression(expr)?;
-                let type_deduced = expr.get_type();
+            ast::Statement::DeclareVariable {
+                identifier_span,
+                identifier,
+                expr,
+                type_annotation_span,
+                type_annotation,
+            } => {
+                let expr_checked = self.check_expression(expr.clone())?;
+                let type_deduced = expr_checked.get_type();
 
-                if let Some(ref dexpr) = optional_dexpr {
+                if let Some(ref dexpr) = type_annotation {
                     let type_specified = self
                         .registry
                         .get_base_representation(dexpr)
                         .map_err(TypeCheckError::RegistryError)?;
                     if type_deduced != type_specified {
-                        return Err(TypeCheckError::IncompatibleDimensions(
-                            None, // TODO
-                            "variable declaration".into(),
-                            "specified dimension",
-                            type_specified,
-                            "   actual dimension",
-                            type_deduced,
-                        ));
+                        return Err(TypeCheckError::IncompatibleDimensions {
+                            span_operation: identifier_span,
+                            operation: "variable declaration".into(),
+                            span_expected: type_annotation_span,
+                            expected_name: "specified dimension",
+                            expected_type: type_specified,
+                            span_actual: Some(expr.full_span()),
+                            actual_name: "   actual dimension",
+                            actual_type: type_deduced,
+                        });
                     }
                 }
-                self.identifiers.insert(name.clone(), type_deduced.clone());
-                typed_ast::Statement::DeclareVariable(name, expr, type_deduced)
+                self.identifiers
+                    .insert(identifier.clone(), type_deduced.clone());
+                typed_ast::Statement::DeclareVariable(identifier, expr_checked, type_deduced)
             }
             ast::Statement::DeclareBaseUnit(_span, unit_name, dexpr, decorators) => {
                 let type_specified = self
@@ -386,47 +428,51 @@ impl TypeChecker {
                 }
                 typed_ast::Statement::DeclareBaseUnit(unit_name, decorators, type_specified)
             }
-            ast::Statement::DeclareDerivedUnit(
-                _span,
-                unit_name,
+            ast::Statement::DeclareDerivedUnit {
+                identifier_span,
+                identifier,
                 expr,
-                optional_dexpr,
+                type_annotation_span,
+                type_annotation,
                 decorators,
-            ) => {
+            } => {
                 // TODO: this is the *exact same code* that we have above for
                 // variable declarations => deduplicate this somehow
-                let expr = self.check_expression(expr)?;
-                let type_deduced = expr.get_type();
+                let expr_checked = self.check_expression(expr.clone())?;
+                let type_deduced = expr_checked.get_type();
 
-                if let Some(ref dexpr) = optional_dexpr {
+                if let Some(ref dexpr) = type_annotation {
                     let type_specified = self
                         .registry
                         .get_base_representation(dexpr)
                         .map_err(TypeCheckError::RegistryError)?;
                     if type_deduced != type_specified {
-                        return Err(TypeCheckError::IncompatibleDimensions(
-                            None, // TODO
-                            "derived unit declaration".into(),
-                            "specified dimension",
-                            type_specified,
-                            "   actual dimension",
-                            type_deduced,
-                        ));
+                        return Err(TypeCheckError::IncompatibleDimensions {
+                            span_operation: identifier_span,
+                            operation: "derived unit declaration".into(),
+                            span_expected: type_annotation_span,
+                            expected_name: "specified dimension",
+                            expected_type: type_specified,
+                            span_actual: Some(expr.full_span()),
+                            actual_name: "   actual dimension",
+                            actual_type: type_deduced,
+                        });
                     }
                 }
-                for (name, _) in decorator::name_and_aliases(&unit_name, &decorators) {
+                for (name, _) in decorator::name_and_aliases(&identifier, &decorators) {
                     self.identifiers.insert(name.clone(), type_deduced.clone());
                 }
-                typed_ast::Statement::DeclareDerivedUnit(unit_name, expr, decorators)
+                typed_ast::Statement::DeclareDerivedUnit(identifier, expr_checked, decorators)
             }
-            ast::Statement::DeclareFunction(
-                _span,
+            ast::Statement::DeclareFunction {
+                function_name_span,
                 function_name,
                 type_parameters,
                 parameters,
                 body,
-                optional_return_type_dexpr,
-            ) => {
+                return_type_span,
+                return_type_annotation,
+            } => {
                 let mut typechecker_fn = self.clone();
 
                 for type_parameter in &type_parameters {
@@ -458,7 +504,7 @@ impl TypeChecker {
                     is_variadic |= p_is_variadic;
                 }
 
-                let return_type_specified = optional_return_type_dexpr
+                let return_type_specified = return_type_annotation
                     .map(|ref return_type_dexpr| {
                         typechecker_fn
                             .registry
@@ -467,22 +513,25 @@ impl TypeChecker {
                     })
                     .transpose()?;
 
-                let body = body
+                let body_checked = body
+                    .clone()
                     .map(|expr| typechecker_fn.check_expression(expr))
                     .transpose()?;
 
-                let return_type = if let Some(ref expr) = body {
+                let return_type = if let Some(ref expr) = body_checked {
                     let return_type_deduced = expr.get_type();
                     if let Some(return_type_specified) = return_type_specified {
                         if return_type_deduced != return_type_specified {
-                            return Err(TypeCheckError::IncompatibleDimensions(
-                                None, // TODO
-                                "function return type".into(),
-                                "specified return type",
-                                return_type_specified,
-                                "   actual return type",
-                                return_type_deduced,
-                            ));
+                            return Err(TypeCheckError::IncompatibleDimensions {
+                                span_operation: function_name_span,
+                                operation: "function return type".into(),
+                                span_expected: return_type_span,
+                                expected_name: "specified return type",
+                                expected_type: return_type_specified,
+                                span_actual: Some(body.unwrap().full_span()),
+                                actual_name: "   actual return type",
+                                actual_type: return_type_deduced,
+                            });
                         }
                     }
                     return_type_deduced
@@ -512,7 +561,7 @@ impl TypeChecker {
                 typed_ast::Statement::DeclareFunction(
                     function_name,
                     typed_parameters,
-                    body,
+                    body_checked,
                     return_type,
                 )
             }
@@ -655,7 +704,7 @@ mod tests {
 
         assert!(matches!(
             get_typecheck_error("a + b"),
-            TypeCheckError::IncompatibleDimensions(_, _, _, t1, _, t2) if t1 == type_a() && t2 == type_b()
+            TypeCheckError::IncompatibleDimensions{expected_type, actual_type, ..} if expected_type == type_a() && actual_type == type_b()
         ));
     }
 
@@ -666,11 +715,11 @@ mod tests {
 
         assert!(matches!(
             get_typecheck_error("2^a"),
-            TypeCheckError::NonScalarExponent(t) if t == type_a()
+            TypeCheckError::NonScalarExponent(_, t) if t == type_a()
         ));
         assert!(matches!(
             get_typecheck_error("2^(c/b)"),
-            TypeCheckError::NonScalarExponent(t) if t == type_a()
+            TypeCheckError::NonScalarExponent(_, t) if t == type_a()
         ));
     }
 
@@ -685,7 +734,7 @@ mod tests {
 
         assert!(matches!(
             get_typecheck_error("a^b"),
-            TypeCheckError::NonScalarExponent(t) if t == type_b()
+            TypeCheckError::NonScalarExponent(_, t) if t == type_b()
         ));
 
         // TODO: if we add ("constexpr") constants later, it would be great to support those in exponents.
@@ -715,7 +764,7 @@ mod tests {
 
         assert!(matches!(
             get_typecheck_error("let x: A = b"),
-            TypeCheckError::IncompatibleDimensions(_, _, _, t1, _, t2) if t1 == type_a() && t2 == type_b()
+            TypeCheckError::IncompatibleDimensions{expected_type, actual_type, ..} if expected_type == type_a() && actual_type == type_b()
         ));
     }
 
@@ -726,7 +775,7 @@ mod tests {
 
         assert!(matches!(
             get_typecheck_error("unit my_c: C = a"),
-            TypeCheckError::IncompatibleDimensions(_, _, _, t1, _, t2) if t1 == type_c() && t2 == type_a()
+            TypeCheckError::IncompatibleDimensions{expected_type, actual_type, ..} if expected_type == type_c() && actual_type == type_a()
         ));
     }
 
@@ -740,13 +789,13 @@ mod tests {
 
         assert!(matches!(
             get_typecheck_error("fn f(x: A, y: B) -> C = x / y"),
-            TypeCheckError::IncompatibleDimensions(_, _, _, t1, _, t2) if t1 == type_c() && t2 == type_a() / type_b()
+            TypeCheckError::IncompatibleDimensions{expected_type, actual_type, ..} if expected_type == type_c() && actual_type == type_a() / type_b()
         ));
 
         assert!(matches!(
             get_typecheck_error("fn f(x: A) -> A = a\n\
                                  f(b)"),
-            TypeCheckError::IncompatibleDimensions(_, _, _, t1, _, t2) if t1 == type_a() && t2 == type_b()
+            TypeCheckError::IncompatibleDimensions{expected_type, actual_type, ..} if expected_type == type_a() && actual_type == type_b()
         ));
     }
 
@@ -776,9 +825,9 @@ mod tests {
 
         assert!(matches!(
             get_typecheck_error("fn f<T1, T2>(x: T1, y: T2) -> T2/T1 = x/y"),
-            TypeCheckError::IncompatibleDimensions(_, _, _, t1, _, t2)
-                if t1 == base_type("T2") / base_type("T1") &&
-                   t2 == base_type("T1") / base_type("T2")
+            TypeCheckError::IncompatibleDimensions{expected_type, actual_type, ..}
+                if expected_type == base_type("T2") / base_type("T1") &&
+                actual_type == base_type("T1") / base_type("T2")
         ));
     }
 
@@ -904,7 +953,7 @@ mod tests {
                 mean(1 a, 1 b)
             "
             ),
-            TypeCheckError::IncompatibleDimensions(..)
+            TypeCheckError::IncompatibleDimensions { .. }
         ));
     }