فهرست منبع

Add span to scalars

David Peter 2 سال پیش
والد
کامیت
e80ba9effe
6فایلهای تغییر یافته به همراه65 افزوده شده و 40 حذف شده
  1. 15 15
      numbat/src/ast.rs
  2. 7 1
      numbat/src/diagnostic.rs
  3. 26 13
      numbat/src/parser.rs
  4. 1 1
      numbat/src/prefix_transformer.rs
  5. 1 1
      numbat/src/resolver.rs
  6. 15 9
      numbat/src/typechecker.rs

+ 15 - 15
numbat/src/ast.rs

@@ -33,7 +33,7 @@ impl PrettyPrint for BinaryOperator {
 
 #[derive(Debug, Clone, PartialEq)]
 pub enum Expression {
-    Scalar(Number),
+    Scalar(Span, Number),
     Identifier(Span, String),
     UnitIdentifier(Span, Prefix, String, String),
     Negate(Box<Expression>),
@@ -49,7 +49,7 @@ pub enum Expression {
 #[cfg(test)]
 macro_rules! scalar {
     ( $num:expr ) => {{
-        Expression::Scalar(Number::from_f64($num))
+        Expression::Scalar(Span::dummy(), Number::from_f64($num))
     }};
 }
 
@@ -94,11 +94,11 @@ fn pretty_scalar(Number(n): Number) -> Markup {
 
 fn with_parens(expr: &Expression) -> Markup {
     match expr {
-        Expression::Scalar(_)
-        | Expression::Identifier(_, _)
-        | Expression::UnitIdentifier(_, _, _, _)
-        | Expression::FunctionCall(_, _) => expr.pretty_print(),
-        Expression::Negate(_) | Expression::BinaryOperator { .. } => {
+        Expression::Scalar(..)
+        | Expression::Identifier(..)
+        | Expression::UnitIdentifier(..)
+        | Expression::FunctionCall(..) => expr.pretty_print(),
+        Expression::Negate(..) | Expression::BinaryOperator { .. } => {
             m::operator("(") + expr.pretty_print() + m::operator(")")
         }
     }
@@ -112,8 +112,8 @@ fn with_parens_liberal(expr: &Expression) -> Markup {
             lhs,
             rhs,
             span_op: _,
-        } if matches!(**lhs, Expression::Scalar(_))
-            && matches!(**rhs, Expression::UnitIdentifier(_, _, _, _)) =>
+        } if matches!(**lhs, Expression::Scalar(..))
+            && matches!(**rhs, Expression::UnitIdentifier(..)) =>
         {
             expr.pretty_print()
         }
@@ -128,13 +128,13 @@ fn pretty_print_binop(op: &BinaryOperator, lhs: &Expression, rhs: &Expression) -
             lhs.pretty_print() + op.pretty_print() + rhs.pretty_print()
         }
         BinaryOperator::Mul => match (lhs, rhs) {
-            (Expression::Scalar(s), Expression::UnitIdentifier(_, prefix, _name, full_name)) => {
+            (Expression::Scalar(_, s), Expression::UnitIdentifier(_, prefix, _name, full_name)) => {
                 // Fuse multiplication of a scalar and a unit to a quantity
                 pretty_scalar(*s)
                     + m::space()
                     + m::unit(format!("{}{}", prefix.as_string_long(), full_name))
             }
-            (Expression::Scalar(s), Expression::Identifier(_, name)) => {
+            (Expression::Scalar(_, s), Expression::Identifier(_, name)) => {
                 // Fuse multiplication of a scalar and identifier
                 pretty_scalar(*s) + m::space() + m::identifier(name)
             }
@@ -235,10 +235,10 @@ fn pretty_print_binop(op: &BinaryOperator, lhs: &Expression, rhs: &Expression) -
 
             add_parens_if_needed(lhs) + op.pretty_print() + add_parens_if_needed(rhs)
         }
-        BinaryOperator::Power if matches!(rhs, Expression::Scalar(n) if n.to_f64() == 2.0) => {
+        BinaryOperator::Power if matches!(rhs, Expression::Scalar(_, n) if n.to_f64() == 2.0) => {
             with_parens(lhs) + m::operator("²")
         }
-        BinaryOperator::Power if matches!(rhs, Expression::Scalar(n) if n.to_f64() == 3.0) => {
+        BinaryOperator::Power if matches!(rhs, Expression::Scalar(_, n) if n.to_f64() == 3.0) => {
             with_parens(lhs) + m::operator("³")
         }
         _ => with_parens(lhs) + op.pretty_print() + with_parens(rhs),
@@ -250,7 +250,7 @@ impl PrettyPrint for Expression {
         use Expression::*;
 
         match self {
-            Scalar(n) => pretty_scalar(*n),
+            Scalar(_, n) => pretty_scalar(*n),
             Identifier(_, name) => m::identifier(name),
             UnitIdentifier(_, prefix, _name, full_name) => {
                 m::unit(format!("{}{}", prefix.as_string_long(), full_name))
@@ -547,7 +547,7 @@ pub trait ReplaceSpans {
 impl ReplaceSpans for Expression {
     fn replace_spans(&self) -> Self {
         match self {
-            e @ Expression::Scalar(_) => e.clone(),
+            Expression::Scalar(_, name) => Expression::Scalar(Span::dummy(), name.clone()),
             Expression::Identifier(_, name) => Expression::Identifier(Span::dummy(), name.clone()),
             Expression::UnitIdentifier(_, prefix, name, full_name) => Expression::UnitIdentifier(
                 Span::dummy(),

+ 7 - 1
numbat/src/diagnostic.rs

@@ -61,7 +61,13 @@ impl ErrorDiagnostic for TypeCheckError {
                 d.with_labels(vec![span.diagnostic_label(LabelStyle::Primary)])
             }
             TypeCheckError::UnknownFunction(_) => d,
-            TypeCheckError::IncompatibleDimensions(_, _, _, _, _) => d,
+            TypeCheckError::IncompatibleDimensions(span, _, _, _, _, _) => {
+                if let Some(span) = span {
+                    d.with_labels(vec![span.diagnostic_label(LabelStyle::Primary)])
+                } else {
+                    d
+                }
+            }
             TypeCheckError::NonScalarExponent(_) => d,
             TypeCheckError::UnsupportedConstEvalExpression(_) => d,
             TypeCheckError::DivisionByZeroInConstEvalExpression => d,

+ 26 - 13
numbat/src/parser.rs

@@ -699,7 +699,10 @@ impl<'a> Parser<'a> {
             expr = Expression::BinaryOperator {
                 op: BinaryOperator::Power,
                 lhs: Box::new(expr),
-                rhs: Box::new(Expression::Scalar(Number::from_f64(exp as f64))),
+                rhs: Box::new(Expression::Scalar(
+                    exponent.span,
+                    Number::from_f64(exp as f64),
+                )),
                 span_op: None,
             };
         }
@@ -743,21 +746,31 @@ impl<'a> Parser<'a> {
         // This function needs to be kept in sync with `next_token_could_start_primary` below.
 
         if let Some(num) = self.match_exact(TokenKind::Number) {
-            Ok(Expression::Scalar(Number::from_f64(
-                num.lexeme.parse::<f64>().unwrap(),
-            )))
+            Ok(Expression::Scalar(
+                self.last().unwrap().span,
+                Number::from_f64(num.lexeme.parse::<f64>().unwrap()),
+            ))
         } else if let Some(hex_int) = self.match_exact(TokenKind::IntegerWithBase(16)) {
-            Ok(Expression::Scalar(Number::from_f64(
-                i128::from_str_radix(&hex_int.lexeme[2..], 16).unwrap() as f64, // TODO: i128 limits our precision here
-            )))
+            Ok(Expression::Scalar(
+                self.last().unwrap().span,
+                Number::from_f64(
+                    i128::from_str_radix(&hex_int.lexeme[2..], 16).unwrap() as f64, // TODO: i128 limits our precision here
+                ),
+            ))
         } else if let Some(oct_int) = self.match_exact(TokenKind::IntegerWithBase(8)) {
-            Ok(Expression::Scalar(Number::from_f64(
-                i128::from_str_radix(&oct_int.lexeme[2..], 8).unwrap() as f64, // TODO: i128 limits our precision here
-            )))
+            Ok(Expression::Scalar(
+                self.last().unwrap().span,
+                Number::from_f64(
+                    i128::from_str_radix(&oct_int.lexeme[2..], 8).unwrap() as f64, // TODO: i128 limits our precision here
+                ),
+            ))
         } else if let Some(bin_int) = self.match_exact(TokenKind::IntegerWithBase(2)) {
-            Ok(Expression::Scalar(Number::from_f64(
-                i128::from_str_radix(&bin_int.lexeme[2..], 2).unwrap() as f64, // TODO: i128 limits our precision here
-            )))
+            Ok(Expression::Scalar(
+                self.last().unwrap().span,
+                Number::from_f64(
+                    i128::from_str_radix(&bin_int.lexeme[2..], 2).unwrap() as f64, // TODO: i128 limits our precision here
+                ),
+            ))
         } else if let Some(identifier) = self.match_exact(TokenKind::Identifier) {
             let span = self.last().unwrap().span;
             Ok(Expression::Identifier(span, identifier.lexeme.clone()))

+ 1 - 1
numbat/src/prefix_transformer.rs

@@ -32,7 +32,7 @@ impl Transformer {
 
     fn transform_expression(&self, expression: Expression) -> Expression {
         match expression {
-            expr @ Expression::Scalar(_) => expr,
+            expr @ Expression::Scalar(..) => expr,
             Expression::Identifier(span, identifier) => {
                 if let PrefixParserResult::UnitIdentifier(prefix, unit_name, full_name) =
                     self.prefix_parser.parse(&identifier)

+ 1 - 1
numbat/src/resolver.rs

@@ -206,7 +206,7 @@ mod tests {
                 Statement::DeclareVariable(
                     Span::dummy(),
                     "a".into(),
-                    Expression::Scalar(Number::from_f64(1.0)),
+                    Expression::Scalar(Span::dummy(), Number::from_f64(1.0)),
                     None
                 ),
                 Statement::Expression(Expression::Identifier(Span::dummy(), "a".into()))

+ 15 - 9
numbat/src/typechecker.rs

@@ -20,8 +20,9 @@ pub enum TypeCheckError {
     #[error("Unknown function '{0}'.")]
     UnknownFunction(String),
 
-    #[error("Incompatible dimensions in {0}:\n    {1}: {2}\n    {3}: {4}")]
+    #[error("Incompatible dimensions in {1}:\n    {2}: {3}\n    {4}: {5}")]
     IncompatibleDimensions(
+        Option<Span>,
         String,
         &'static str,
         BaseRepresentation,
@@ -140,7 +141,7 @@ impl TypeChecker {
 
     pub(crate) fn check_expression(&self, ast: ast::Expression) -> Result<typed_ast::Expression> {
         Ok(match ast {
-            ast::Expression::Scalar(n) => typed_ast::Expression::Scalar(n),
+            ast::Expression::Scalar(_, n) => typed_ast::Expression::Scalar(n),
             ast::Expression::Identifier(span, name) => {
                 let type_ = self.type_for_identifier(span, &name)?.clone();
 
@@ -170,6 +171,7 @@ impl TypeChecker {
                     let rhs_type = rhs.get_type();
                     if lhs_type != rhs_type {
                         Err(TypeCheckError::IncompatibleDimensions(
+                            span_op,
                             "binary operator".into(),
                             " left hand side",
                             lhs_type,
@@ -303,6 +305,7 @@ impl TypeChecker {
 
                     if parameter_type != argument_type {
                         return Err(TypeCheckError::IncompatibleDimensions(
+                            None, // TODO
                             format!(
                                 "argument {num} of function call to '{name}'",
                                 num = idx + 1,
@@ -360,6 +363,7 @@ impl TypeChecker {
                         .map_err(TypeCheckError::RegistryError)?;
                     if type_deduced != type_specified {
                         return Err(TypeCheckError::IncompatibleDimensions(
+                            None, // TODO
                             "variable declaration".into(),
                             "specified dimension",
                             type_specified,
@@ -401,6 +405,7 @@ impl TypeChecker {
                         .map_err(TypeCheckError::RegistryError)?;
                     if type_deduced != type_specified {
                         return Err(TypeCheckError::IncompatibleDimensions(
+                            None, // TODO
                             "derived unit declaration".into(),
                             "specified dimension",
                             type_specified,
@@ -471,6 +476,7 @@ impl TypeChecker {
                     if let Some(return_type_specified) = return_type_specified {
                         if return_type_deduced != return_type_specified {
                             return Err(TypeCheckError::IncompatibleDimensions(
+                                None, // TODO
                                 "function return type".into(),
                                 "specified return type",
                                 return_type_specified,
@@ -649,7 +655,7 @@ mod tests {
 
         assert!(matches!(
             get_typecheck_error("a + b"),
-            TypeCheckError::IncompatibleDimensions(_, _, t1, _, t2) if t1 == type_a() && t2 == type_b()
+            TypeCheckError::IncompatibleDimensions(_, _, _, t1, _, t2) if t1 == type_a() && t2 == type_b()
         ));
     }
 
@@ -709,7 +715,7 @@ mod tests {
 
         assert!(matches!(
             get_typecheck_error("let x: A = b"),
-            TypeCheckError::IncompatibleDimensions(_, _, t1, _, t2) if t1 == type_a() && t2 == type_b()
+            TypeCheckError::IncompatibleDimensions(_, _, _, t1, _, t2) if t1 == type_a() && t2 == type_b()
         ));
     }
 
@@ -720,7 +726,7 @@ mod tests {
 
         assert!(matches!(
             get_typecheck_error("unit my_c: C = a"),
-            TypeCheckError::IncompatibleDimensions(_, _, t1, _, t2) if t1 == type_c() && t2 == type_a()
+            TypeCheckError::IncompatibleDimensions(_, _, _, t1, _, t2) if t1 == type_c() && t2 == type_a()
         ));
     }
 
@@ -734,13 +740,13 @@ mod tests {
 
         assert!(matches!(
             get_typecheck_error("fn f(x: A, y: B) -> C = x / y"),
-            TypeCheckError::IncompatibleDimensions(_, _, t1, _, t2) if t1 == type_c() && t2 == type_a() / type_b()
+            TypeCheckError::IncompatibleDimensions(_, _, _, t1, _, t2) if t1 == type_c() && t2 == type_a() / type_b()
         ));
 
         assert!(matches!(
             get_typecheck_error("fn f(x: A) -> A = a\n\
                                  f(b)"),
-            TypeCheckError::IncompatibleDimensions(_, _, t1, _, t2) if t1 == type_a() && t2 == type_b()
+            TypeCheckError::IncompatibleDimensions(_, _, _, t1, _, t2) if t1 == type_a() && t2 == type_b()
         ));
     }
 
@@ -770,7 +776,7 @@ mod tests {
 
         assert!(matches!(
             get_typecheck_error("fn f<T1, T2>(x: T1, y: T2) -> T2/T1 = x/y"),
-            TypeCheckError::IncompatibleDimensions(_, _, t1, _, t2)
+            TypeCheckError::IncompatibleDimensions(_, _, _, t1, _, t2)
                 if t1 == base_type("T2") / base_type("T1") &&
                    t2 == base_type("T1") / base_type("T2")
         ));
@@ -898,7 +904,7 @@ mod tests {
                 mean(1 a, 1 b)
             "
             ),
-            TypeCheckError::IncompatibleDimensions(_, _, _, _, _)
+            TypeCheckError::IncompatibleDimensions(..)
         ));
     }