浏览代码

Merge pull request #136 from sharkdp/improve-conversion-errors

Improve "incompatible dimension" errors
David Peter 2 年之前
父节点
当前提交
1ce94964d4
共有 10 个文件被更改,包括 425 次插入150 次删除
  1. 1 0
      Cargo.lock
  2. 1 0
      numbat/Cargo.toml
  3. 7 0
      numbat/src/arithmetic.rs
  4. 7 4
      numbat/src/diagnostic.rs
  5. 8 0
      numbat/src/dimension.rs
  6. 55 1
      numbat/src/product.rs
  7. 22 12
      numbat/src/registry.rs
  8. 267 79
      numbat/src/typechecker.rs
  9. 14 54
      numbat/src/unit.rs
  10. 43 0
      numbat/tests/interpreter.rs

+ 1 - 0
Cargo.lock

@@ -674,6 +674,7 @@ dependencies = [
  "strsim",
  "thiserror",
  "unicode-ident",
+ "unicode-width",
 ]
 
 [[package]]

+ 1 - 0
numbat/Cargo.toml

@@ -22,6 +22,7 @@ pretty_dtoa = "0.3"
 numbat-exchange-rates = { version = "0.1.0", path = "../numbat-exchange-rates" }
 heck = { version = "0.4.1", features = ["unicode"] }
 unicode-ident = "1.0.11"
+unicode-width = "0.1.10"
 
 [dev-dependencies]
 approx = "0.5"

+ 7 - 0
numbat/src/arithmetic.rs

@@ -6,6 +6,13 @@ pub type Exponent = Rational;
 
 pub trait Power {
     fn power(self, e: Exponent) -> Self;
+
+    fn invert(self) -> Self
+    where
+        Self: Sized,
+    {
+        self.power(Exponent::from_integer(-1))
+    }
 }
 
 pub fn pretty_exponent(e: &Exponent) -> String {

+ 7 - 4
numbat/src/diagnostic.rs

@@ -1,8 +1,11 @@
 use codespan_reporting::diagnostic::LabelStyle;
 
 use crate::{
-    interpreter::RuntimeError, parser::ParseError, resolver::ResolverError,
-    typechecker::TypeCheckError, NameResolutionError,
+    interpreter::RuntimeError,
+    parser::ParseError,
+    resolver::ResolverError,
+    typechecker::{IncompatibleDimensionsError, TypeCheckError},
+    NameResolutionError,
 };
 
 pub type Diagnostic = codespan_reporting::diagnostic::Diagnostic<usize>;
@@ -81,7 +84,7 @@ impl ErrorDiagnostic for TypeCheckError {
             TypeCheckError::UnknownCallable(span, _) => d.with_labels(vec![span
                 .diagnostic_label(LabelStyle::Primary)
                 .with_message("unknown callable")]),
-            TypeCheckError::IncompatibleDimensions {
+            TypeCheckError::IncompatibleDimensions(IncompatibleDimensionsError {
                 operation,
                 span_operation,
                 span_actual,
@@ -89,7 +92,7 @@ impl ErrorDiagnostic for TypeCheckError {
                 span_expected,
                 expected_type,
                 ..
-            } => {
+            }) => {
                 let labels = vec![
                     span_operation
                         .diagnostic_label(LabelStyle::Secondary)

+ 8 - 0
numbat/src/dimension.rs

@@ -39,6 +39,14 @@ impl DimensionRegistry {
         self.registry.get_base_representation_for_name(name)
     }
 
+    pub fn get_derived_entry_names_for(
+        &self,
+        base_representation: &BaseRepresentation,
+    ) -> Vec<String> {
+        self.registry
+            .get_derived_entry_names_for(base_representation)
+    }
+
     pub fn add_base_dimension(&mut self, name: &str) -> Result<BaseRepresentation> {
         self.registry.add_base_entry(name, ())?;
         Ok(self

+ 55 - 1
numbat/src/product.rs

@@ -1,8 +1,12 @@
-use std::ops::{Div, Mul};
+use std::{
+    fmt::Display,
+    ops::{Div, Mul},
+};
 
 use crate::arithmetic::{Exponent, Power};
 use itertools::Itertools;
 use num_rational::Ratio;
+use num_traits::Signed;
 
 pub trait Canonicalize {
     type MergeKey: PartialEq;
@@ -17,6 +21,56 @@ pub struct Product<Factor, const CANONICALIZE: bool = false> {
     factors: Vec<Factor>,
 }
 
+impl<Factor: Power + Clone + Canonicalize + Ord + Display, const CANONICALIZE: bool>
+    Product<Factor, CANONICALIZE>
+{
+    pub fn as_string<GetExponent>(
+        &self,
+        get_exponent: GetExponent,
+        times_separator: &'static str,
+        over_separator: &'static str,
+    ) -> String
+    where
+        GetExponent: Fn(&Factor) -> Exponent,
+    {
+        let to_string = |fs: &[Factor]| -> String {
+            let mut result = String::new();
+            for factor in fs.iter() {
+                result.push_str(&factor.to_string());
+                result.push_str(times_separator);
+            }
+            result.trim_end_matches(times_separator).into()
+        };
+
+        let factors_positive: Vec<_> = self
+            .iter()
+            .filter(|f| get_exponent(*f).is_positive())
+            .cloned()
+            .collect();
+        let factors_negative: Vec<_> = self
+            .iter()
+            .filter(|f| !get_exponent(*f).is_positive())
+            .cloned()
+            .collect();
+
+        match (&factors_positive[..], &factors_negative[..]) {
+            (&[], &[]) => "".into(),
+            (&[], negative) => to_string(negative),
+            (positive, &[]) => to_string(positive),
+            (positive, [single_negative]) => format!(
+                "{}{over_separator}{}",
+                to_string(positive),
+                to_string(&[single_negative.clone().invert()])
+            ),
+            (positive, negative) => format!(
+                "{}{over_separator}({})",
+                to_string(positive),
+                to_string(&negative.iter().map(|f| f.clone().invert()).collect_vec())
+            ),
+        }
+    }
+}
+
 impl<Factor: Clone + Ord + Canonicalize, const CANONICALIZE: bool> Product<Factor, CANONICALIZE> {
     pub fn unity() -> Self {
         Self::from_factors([])

+ 22 - 12
numbat/src/registry.rs

@@ -1,5 +1,6 @@
 use std::{collections::HashMap, fmt::Display};
 
+use itertools::Itertools;
 use num_traits::Zero;
 use thiserror::Error;
 
@@ -28,6 +29,12 @@ pub struct BaseIndex(isize);
 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
 pub struct BaseRepresentationFactor(pub BaseEntry, pub Exponent);
 
+impl Display for BaseRepresentationFactor {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "{}{}", self.0, pretty_exponent(&self.1))
+    }
+}
+
 impl Canonicalize for BaseRepresentationFactor {
     type MergeKey = BaseEntry;
 
@@ -56,20 +63,11 @@ pub type BaseRepresentation = Product<BaseRepresentationFactor, true>;
 
 impl Display for BaseRepresentation {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        let num_factors = self.iter().count();
-
-        if num_factors == 0 {
-            write!(f, "Scalar")?;
+        if self.iter().count() == 0 {
+            f.write_str("Scalar")
         } else {
-            for (n, BaseRepresentationFactor(name, exp)) in self.iter().enumerate() {
-                write!(f, "{}{}", name, pretty_exponent(exp))?;
-
-                if n != self.iter().count() - 1 {
-                    write!(f, " × ")?;
-                }
-            }
+            f.write_str(&self.as_string(|f| f.1, " × ", " / "))
         }
-        Ok(())
     }
 }
 
@@ -98,6 +96,18 @@ impl<Metadata> Registry<Metadata> {
         Ok(())
     }
 
+    pub fn get_derived_entry_names_for(
+        &self,
+        base_representation: &BaseRepresentation,
+    ) -> Vec<String> {
+        self.derived_entries
+            .iter()
+            .filter(|(_, br)| *br == base_representation)
+            .map(|(name, _)| name.clone())
+            .sorted_unstable()
+            .collect()
+    }
+
     pub fn is_base_entry(&self, name: &str) -> bool {
         self.base_entries.iter().any(|(n, _)| n == name)
     }

+ 267 - 79
numbat/src/typechecker.rs

@@ -1,6 +1,10 @@
-use std::collections::{HashMap, HashSet};
+use std::{
+    collections::{HashMap, HashSet},
+    error::Error,
+    fmt,
+};
 
-use crate::arithmetic::{Exponent, Power, Rational};
+use crate::arithmetic::{pretty_exponent, Exponent, Power, Rational};
 use crate::dimension::DimensionRegistry;
 use crate::ffi::ArityRange;
 use crate::name_resolution::LAST_RESULT_IDENTIFIERS;
@@ -10,8 +14,159 @@ use crate::typed_ast::{self, Type};
 use crate::{ast, decorator, ffi, suggestion};
 
 use ast::DimensionExpression;
+use itertools::Itertools;
 use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive, Zero};
 use thiserror::Error;
+use unicode_width::UnicodeWidthStr;
+
+#[derive(Debug, PartialEq, Eq)]
+pub struct IncompatibleDimensionsError {
+    pub span_operation: Span,
+    pub operation: String,
+    pub span_expected: Span,
+    pub expected_name: &'static str,
+    pub expected_type: BaseRepresentation,
+    pub expected_dimensions: Vec<String>,
+    pub span_actual: Span,
+    pub actual_name: &'static str,
+    pub actual_type: BaseRepresentation,
+    pub actual_dimensions: Vec<String>,
+}
+
+fn pad(a: &str, b: &str) -> (String, String) {
+    let max_length = a.width().max(b.width());
+
+    (
+        format!("{a: <width$}", width = max_length),
+        format!("{b: <width$}", width = max_length),
+    )
+}
+
+impl fmt::Display for IncompatibleDimensionsError {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        let have_common_factors = self
+            .expected_type
+            .iter()
+            .any(|f| self.actual_type.iter().map(|f| &f.0).contains(&f.0));
+
+        let (mut expected_result_string, mut actual_result_string) = if !have_common_factors
+            || (self.expected_type.iter().count() == 1 && self.actual_type.iter().count() == 1)
+        {
+            pad(
+                &self.expected_type.to_string(),
+                &self.actual_type.to_string(),
+            )
+        } else {
+            let format_factor =
+                |name: &str, exponent: &Exponent| format!(" × {name}{}", pretty_exponent(exponent));
+
+            let mut shared_factors = HashMap::<&String, (Exponent, Exponent)>::new();
+            let mut expected_factors = HashMap::<&String, Exponent>::new();
+            let mut actual_factors = HashMap::<&String, Exponent>::new();
+
+            for BaseRepresentationFactor(name, expected_exponent) in self.expected_type.iter() {
+                if let Some(BaseRepresentationFactor(_, actual_exponent)) =
+                    self.actual_type.iter().find(|f| *name == f.0)
+                {
+                    shared_factors.insert(&name, (*expected_exponent, *actual_exponent));
+                } else {
+                    expected_factors.insert(&name, *expected_exponent);
+                }
+            }
+
+            for BaseRepresentationFactor(name, exponent) in self.actual_type.iter() {
+                if !shared_factors.contains_key(&name) {
+                    actual_factors.insert(&name, *exponent);
+                }
+            }
+
+            let mut expected_result_string = String::new();
+            let mut actual_result_string = String::new();
+
+            for (name, (exp1, exp2)) in shared_factors
+                .iter()
+                .sorted_unstable_by_key(|entry| entry.0)
+            {
+                let (str1, str2) = pad(&format_factor(name, exp1), &format_factor(name, exp2));
+
+                expected_result_string.push_str(&str1);
+                actual_result_string.push_str(&str2);
+            }
+
+            let mut expected_factors_string = String::new();
+
+            for (name, exp) in expected_factors
+                .iter()
+                .sorted_unstable_by_key(|entry| entry.0)
+            {
+                expected_factors_string.push_str(&format_factor(name, exp));
+            }
+
+            let mut actual_factors_string = String::new();
+
+            for (name, exp) in actual_factors
+                .iter()
+                .sorted_unstable_by_key(|entry| entry.0)
+            {
+                actual_factors_string.push_str(&format_factor(name, exp));
+            }
+
+            expected_result_string.push_str(&format!(
+                "{expected_factors_string: <width$}",
+                width = expected_factors_string.width() + actual_factors_string.width()
+            ));
+            actual_result_string.push_str(&" ".repeat(expected_factors_string.width()));
+            actual_result_string.push_str(&actual_factors_string);
+
+            (expected_result_string, actual_result_string)
+        };
+
+        if !self.expected_dimensions.is_empty() {
+            expected_result_string
+                .push_str(&format!("    [= {}]", self.expected_dimensions.join(", ")));
+        }
+
+        if !self.actual_dimensions.is_empty() {
+            actual_result_string
+                .push_str(&format!("    [= {}]", self.actual_dimensions.join(", ")));
+        }
+
+        write!(
+            f,
+            "{}: {}",
+            self.expected_name,
+            expected_result_string.trim_start_matches(" × ").trim_end(),
+        )?;
+
+        write!(
+            f,
+            "\n{}: {}",
+            self.actual_name,
+            actual_result_string.trim_start_matches(" × ").trim_end(),
+        )?;
+
+        let mut missing_type = self.actual_type.clone() / self.expected_type.clone();
+
+        let suggestion_name =
+            if missing_type.iter().fold(Exponent::zero(), |a, b| a + b.1) >= Exponent::zero() {
+                self.expected_name
+            } else {
+                missing_type = missing_type.invert();
+                self.actual_name
+            };
+
+        write!(
+            f,
+            "\n\nSuggested fix: multiply {} by {}",
+            // Remove leading whitespace padding.
+            // TODO: don't pass in names with whitespace, pad them in programatically in this method instead.
+            suggestion_name.trim_start(),
+            missing_type,
+        )
+    }
+}
+
+impl Error for IncompatibleDimensionsError {}
 
 #[derive(Debug, Error, PartialEq, Eq)]
 pub enum TypeCheckError {
@@ -21,17 +176,8 @@ pub enum TypeCheckError {
     #[error("Unknown callable '{1}'.")]
     UnknownCallable(Span, String),
 
-    #[error("{expected_name}: {expected_type}\n{actual_name}: {actual_type}")]
-    IncompatibleDimensions {
-        span_operation: Span,
-        operation: String,
-        span_expected: Span,
-        expected_name: &'static str,
-        expected_type: BaseRepresentation,
-        span_actual: Span,
-        actual_name: &'static str,
-        actual_type: BaseRepresentation,
-    },
+    #[error(transparent)]
+    IncompatibleDimensions(IncompatibleDimensionsError),
 
     #[error("Exponents need to be dimensionless (got {1}).")]
     NonScalarExponent(Span, BaseRepresentation),
@@ -232,23 +378,33 @@ impl TypeChecker {
                             span_op: *span_op,
                         }
                         .full_span();
-                        Err(TypeCheckError::IncompatibleDimensions {
-                            span_operation: span_op.unwrap_or(full_span),
-                            operation: match op {
-                                typed_ast::BinaryOperator::Add => "addition".into(),
-                                typed_ast::BinaryOperator::Sub => "subtraction".into(),
-                                typed_ast::BinaryOperator::Mul => "multiplication".into(),
-                                typed_ast::BinaryOperator::Div => "division".into(),
-                                typed_ast::BinaryOperator::Power => "exponentiation".into(),
-                                typed_ast::BinaryOperator::ConvertTo => "unit conversion".into(),
+                        Err(TypeCheckError::IncompatibleDimensions(
+                            IncompatibleDimensionsError {
+                                span_operation: span_op.unwrap_or(full_span),
+                                operation: match op {
+                                    typed_ast::BinaryOperator::Add => "addition".into(),
+                                    typed_ast::BinaryOperator::Sub => "subtraction".into(),
+                                    typed_ast::BinaryOperator::Mul => "multiplication".into(),
+                                    typed_ast::BinaryOperator::Div => "division".into(),
+                                    typed_ast::BinaryOperator::Power => "exponentiation".into(),
+                                    typed_ast::BinaryOperator::ConvertTo => {
+                                        "unit conversion".into()
+                                    }
+                                },
+                                span_expected: lhs.full_span(),
+                                expected_name: " left hand side",
+                                expected_dimensions: self
+                                    .registry
+                                    .get_derived_entry_names_for(&lhs_type),
+                                expected_type: lhs_type,
+                                span_actual: rhs.full_span(),
+                                actual_name: "right hand side",
+                                actual_dimensions: self
+                                    .registry
+                                    .get_derived_entry_names_for(&rhs_type),
+                                actual_type: rhs_type,
                             },
-                            span_expected: lhs.full_span(),
-                            expected_name: " left hand side",
-                            expected_type: lhs_type,
-                            span_actual: rhs.full_span(),
-                            actual_name: "right hand side",
-                            actual_type: rhs_type,
-                        })
+                        ))
                     } else {
                         Ok(lhs_type)
                     }
@@ -401,20 +557,28 @@ impl TypeChecker {
                     }
 
                     if parameter_type != argument_type {
-                        return Err(TypeCheckError::IncompatibleDimensions {
-                            span_operation: *span,
-                            operation: format!(
-                                "argument {num} of function call to '{name}'",
-                                num = idx + 1,
-                                name = function_name
-                            ),
-                            span_expected: parameter_types[idx].0,
-                            expected_name: "parameter type",
-                            expected_type: parameter_type.clone(),
-                            span_actual: args[idx].full_span(),
-                            actual_name: " argument type",
-                            actual_type: argument_type,
-                        });
+                        return Err(TypeCheckError::IncompatibleDimensions(
+                            IncompatibleDimensionsError {
+                                span_operation: *span,
+                                operation: format!(
+                                    "argument {num} of function call to '{name}'",
+                                    num = idx + 1,
+                                    name = function_name
+                                ),
+                                span_expected: parameter_types[idx].0,
+                                expected_name: "parameter type",
+                                expected_dimensions: self
+                                    .registry
+                                    .get_derived_entry_names_for(&parameter_type),
+                                expected_type: parameter_type,
+                                span_actual: args[idx].full_span(),
+                                actual_name: " argument type",
+                                actual_dimensions: self
+                                    .registry
+                                    .get_derived_entry_names_for(&argument_type),
+                                actual_type: argument_type,
+                            },
+                        ));
                     }
                 }
 
@@ -478,16 +642,24 @@ impl TypeChecker {
                         .get_base_representation(dexpr)
                         .map_err(TypeCheckError::RegistryError)?;
                     if type_deduced != type_specified {
-                        return Err(TypeCheckError::IncompatibleDimensions {
-                            span_operation: *identifier_span,
-                            operation: "variable definition".into(),
-                            span_expected: dexpr.full_span(),
-                            expected_name: "specified dimension",
-                            expected_type: type_specified,
-                            span_actual: expr.full_span(),
-                            actual_name: "   actual dimension",
-                            actual_type: type_deduced,
-                        });
+                        return Err(TypeCheckError::IncompatibleDimensions(
+                            IncompatibleDimensionsError {
+                                span_operation: *identifier_span,
+                                operation: "variable definition".into(),
+                                span_expected: dexpr.full_span(),
+                                expected_name: "specified dimension",
+                                expected_dimensions: self
+                                    .registry
+                                    .get_derived_entry_names_for(&type_specified),
+                                expected_type: type_specified,
+                                span_actual: expr.full_span(),
+                                actual_name: "   actual dimension",
+                                actual_dimensions: self
+                                    .registry
+                                    .get_derived_entry_names_for(&type_deduced),
+                                actual_type: type_deduced,
+                            },
+                        ));
                     }
                 }
                 self.identifiers
@@ -537,16 +709,24 @@ impl TypeChecker {
                         .get_base_representation(dexpr)
                         .map_err(TypeCheckError::RegistryError)?;
                     if type_deduced != type_specified {
-                        return Err(TypeCheckError::IncompatibleDimensions {
-                            span_operation: *identifier_span,
-                            operation: "unit definition".into(),
-                            span_expected: type_annotation_span.unwrap(),
-                            expected_name: "specified dimension",
-                            expected_type: type_specified,
-                            span_actual: expr.full_span(),
-                            actual_name: "   actual dimension",
-                            actual_type: type_deduced,
-                        });
+                        return Err(TypeCheckError::IncompatibleDimensions(
+                            IncompatibleDimensionsError {
+                                span_operation: *identifier_span,
+                                operation: "unit definition".into(),
+                                span_expected: type_annotation_span.unwrap(),
+                                expected_name: "specified dimension",
+                                expected_dimensions: self
+                                    .registry
+                                    .get_derived_entry_names_for(&type_specified),
+                                expected_type: type_specified,
+                                span_actual: expr.full_span(),
+                                actual_name: "   actual dimension",
+                                actual_dimensions: self
+                                    .registry
+                                    .get_derived_entry_names_for(&type_deduced),
+                                actual_type: type_deduced,
+                            },
+                        ));
                     }
                 }
                 for (name, _) in decorator::name_and_aliases(&identifier, &decorators) {
@@ -650,16 +830,24 @@ impl TypeChecker {
                     let return_type_deduced = expr.get_type();
                     if let Some(return_type_specified) = return_type_specified {
                         if return_type_deduced != return_type_specified {
-                            return Err(TypeCheckError::IncompatibleDimensions {
-                                span_operation: *function_name_span,
-                                operation: "function return type".into(),
-                                span_expected: return_type_span.unwrap(),
-                                expected_name: "specified return type",
-                                expected_type: return_type_specified,
-                                span_actual: body.as_ref().map(|b| b.full_span()).unwrap(),
-                                actual_name: "   actual return type",
-                                actual_type: return_type_deduced,
-                            });
+                            return Err(TypeCheckError::IncompatibleDimensions(
+                                IncompatibleDimensionsError {
+                                    span_operation: *function_name_span,
+                                    operation: "function return type".into(),
+                                    span_expected: return_type_span.unwrap(),
+                                    expected_name: "specified return type",
+                                    expected_dimensions: self
+                                        .registry
+                                        .get_derived_entry_names_for(&return_type_specified),
+                                    expected_type: return_type_specified,
+                                    span_actual: body.as_ref().map(|b| b.full_span()).unwrap(),
+                                    actual_name: "   actual return type",
+                                    actual_dimensions: self
+                                        .registry
+                                        .get_derived_entry_names_for(&return_type_deduced),
+                                    actual_type: return_type_deduced,
+                                },
+                            ));
                         }
                     }
                     return_type_deduced
@@ -846,7 +1034,7 @@ mod tests {
 
         assert!(matches!(
             get_typecheck_error("a + b"),
-            TypeCheckError::IncompatibleDimensions{expected_type, actual_type, ..} if expected_type == type_a() && actual_type == type_b()
+            TypeCheckError::IncompatibleDimensions(IncompatibleDimensionsError {expected_type, actual_type, ..}) if expected_type == type_a() && actual_type == type_b()
         ));
     }
 
@@ -906,7 +1094,7 @@ mod tests {
 
         assert!(matches!(
             get_typecheck_error("let x: A = b"),
-            TypeCheckError::IncompatibleDimensions{expected_type, actual_type, ..} if expected_type == type_a() && actual_type == type_b()
+            TypeCheckError::IncompatibleDimensions(IncompatibleDimensionsError {expected_type, actual_type, ..}) if expected_type == type_a() && actual_type == type_b()
         ));
     }
 
@@ -917,7 +1105,7 @@ mod tests {
 
         assert!(matches!(
             get_typecheck_error("unit my_c: C = a"),
-            TypeCheckError::IncompatibleDimensions{expected_type, actual_type, ..} if expected_type == type_c() && actual_type == type_a()
+            TypeCheckError::IncompatibleDimensions(IncompatibleDimensionsError {expected_type, actual_type, ..}) if expected_type == type_c() && actual_type == type_a()
         ));
     }
 
@@ -931,13 +1119,13 @@ mod tests {
 
         assert!(matches!(
             get_typecheck_error("fn f(x: A, y: B) -> C = x / y"),
-            TypeCheckError::IncompatibleDimensions{expected_type, actual_type, ..} if expected_type == type_c() && actual_type == type_a() / type_b()
+            TypeCheckError::IncompatibleDimensions(IncompatibleDimensionsError {expected_type, actual_type, ..}) if expected_type == type_c() && actual_type == type_a() / type_b()
         ));
 
         assert!(matches!(
             get_typecheck_error("fn f(x: A) -> A = a\n\
                                  f(b)"),
-            TypeCheckError::IncompatibleDimensions{expected_type, actual_type, ..} if expected_type == type_a() && actual_type == type_b()
+            TypeCheckError::IncompatibleDimensions(IncompatibleDimensionsError {expected_type, actual_type, ..}) if expected_type == type_a() && actual_type == type_b()
         ));
     }
 
@@ -967,7 +1155,7 @@ mod tests {
 
         assert!(matches!(
             get_typecheck_error("fn f<T1, T2>(x: T1, y: T2) -> T2/T1 = x/y"),
-            TypeCheckError::IncompatibleDimensions{expected_type, actual_type, ..}
+            TypeCheckError::IncompatibleDimensions(IncompatibleDimensionsError {expected_type, actual_type, ..})
                 if expected_type == base_type("T2") / base_type("T1") &&
                 actual_type == base_type("T1") / base_type("T2")
         ));

+ 14 - 54
numbat/src/unit.rs

@@ -1,6 +1,6 @@
 use std::fmt::Display;
 
-use num_traits::{Signed, ToPrimitive, Zero};
+use num_traits::{ToPrimitive, Zero};
 
 use crate::{
     arithmetic::{pretty_exponent, Exponent, Power, Rational},
@@ -144,6 +144,18 @@ impl Power for UnitFactor {
     }
 }
 
+impl Display for UnitFactor {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(
+            f,
+            "{}{}{}",
+            self.prefix.as_string_short(),
+            self.unit_id.canonical_name,
+            pretty_exponent(&self.exponent)
+        )
+    }
+}
+
 pub type Unit = Product<UnitFactor, false>;
 
 impl Unit {
@@ -318,59 +330,7 @@ impl Unit {
 
 impl Display for Unit {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        let to_string = |fs: &[UnitFactor]| -> String {
-            let mut result = String::new();
-            for &UnitFactor {
-                prefix,
-                unit_id: ref base_unit,
-                exponent,
-            } in fs.iter()
-            {
-                result.push_str(&prefix.as_string_short());
-                result.push_str(&base_unit.canonical_name);
-                result.push_str(&pretty_exponent(&exponent));
-                result.push('·');
-            }
-            result.trim_end_matches('·').into()
-        };
-
-        let flip_exponents = |fs: &[UnitFactor]| -> Vec<UnitFactor> {
-            fs.iter()
-                .map(|f| UnitFactor {
-                    exponent: -f.exponent,
-                    ..f.clone()
-                })
-                .collect()
-        };
-
-        let factors_positive: Vec<_> = self
-            .iter()
-            .filter(|f| f.exponent.is_positive())
-            .cloned()
-            .collect();
-        let factors_negative: Vec<_> = self
-            .iter()
-            .filter(|f| !f.exponent.is_positive())
-            .cloned()
-            .collect();
-
-        let result: String = match (&factors_positive[..], &factors_negative[..]) {
-            (&[], &[]) => "".into(),
-            (&[], negative) => to_string(negative),
-            (positive, &[]) => to_string(positive),
-            (positive, [single_negative]) => format!(
-                "{}/{}",
-                to_string(positive),
-                to_string(&flip_exponents(&[single_negative.clone()]))
-            ),
-            (positive, negative) => format!(
-                "{}/({})",
-                to_string(positive),
-                to_string(&flip_exponents(negative))
-            ),
-        };
-
-        write!(f, "{}", result)
+        f.write_str(&self.as_string(|f| f.exponent, "·", "/"))
     }
 }
 

+ 43 - 0
numbat/tests/interpreter.rs

@@ -33,6 +33,15 @@ fn expect_failure(code: &str, msg_part: &str) {
     }
 }
 
+fn expect_exact_failure(code: &str, expected: &str) {
+    let mut ctx = get_test_context();
+    if let Err(e) = ctx.interpret(code, CodeSource::Text) {
+        assert_eq!(e.to_string(), expected);
+    } else {
+        panic!();
+    }
+}
+
 #[test]
 fn test_factorial() {
     expect_output("0!", "1");
@@ -159,6 +168,40 @@ fn test_math() {
     )
 }
 
+#[test]
+fn test_incompatible_dimension_errors() {
+    expect_exact_failure(
+        "kg m / s^2 + kg m^2",
+        " left hand side: Length  × Mass × Time⁻²    [= Force]\n\
+         right hand side: Length² × Mass             [= MomentOfInertia]\n\n\
+         Suggested fix: multiply left hand side by Length × Time²",
+    );
+    expect_exact_failure(
+        "1 + m",
+        " left hand side: Scalar    [= Angle, Scalar, SolidAngle]\n\
+         right hand side: Length\n\n\
+         Suggested fix: multiply left hand side by Length",
+    );
+    expect_exact_failure(
+        "m / s + K A",
+        " left hand side: Length / Time            [= Speed]\n\
+         right hand side: Current × Temperature\n\n\
+         Suggested fix: multiply left hand side by Current × Temperature × Time / Length",
+    );
+    expect_exact_failure(
+        "m + 1 / m",
+        " left hand side: Length\n\
+         right hand side: Length⁻¹\n\n\
+         Suggested fix: multiply right hand side by Length²",
+    );
+    expect_exact_failure(
+        "kW -> J",
+        " left hand side: Length² × Mass × Time⁻³    [= Power]\n\
+         right hand side: Length² × Mass × Time⁻²    [= Energy, Torque]\n\n\
+         Suggested fix: multiply left hand side by Time",
+    );
+}
+
 #[test]
 fn test_temperature_conversions() {
     expect_output("from_celsius(11.5)", "284.65 K");