浏览代码

Allow overwriting of constants/functions, closes #149

David Peter 2 年之前
父节点
当前提交
5760135223

+ 0 - 2
examples/name_resolution_error/duplicate_variable.nbt

@@ -1,2 +0,0 @@
-let foo = 2
-let foo = 3

+ 0 - 3
examples/name_resolution_error/function_name_clash.nbt

@@ -1,3 +0,0 @@
-let x = 2
-
-fn x() = 2

+ 0 - 3
examples/name_resolution_error/variable_clashes_with_parameter.nbt

@@ -1,3 +0,0 @@
-let foo = 2
-
-fn bar(foo: Scalar) -> Scalar = foo

+ 25 - 8
numbat/src/diagnostic.rs

@@ -269,14 +269,6 @@ impl ErrorDiagnostic for TypeCheckError {
                     .diagnostic_label(LabelStyle::Primary)
                     .with_message("Incompatible types in 'assert_eq' call"),
             ]),
-            TypeCheckError::ForeignFunctionNeedsTypeAnnotations(span, _)
-            | TypeCheckError::UnknownForeignFunction(span, _)
-            | TypeCheckError::NonRationalExponent(span)
-            | TypeCheckError::OverflowInConstExpr(span)
-            | TypeCheckError::ExpectedDimensionType(span, _)
-            | TypeCheckError::ExpectedBool(span) => d.with_labels(vec![span
-                .diagnostic_label(LabelStyle::Primary)
-                .with_message(inner_error)]),
             TypeCheckError::IncompatibleTypesInAnnotation(
                 what,
                 what_span,
@@ -295,6 +287,31 @@ impl ErrorDiagnostic for TypeCheckError {
                     .diagnostic_label(LabelStyle::Primary)
                     .with_message(format!("Incompatible types in {what}")),
             ]),
+            TypeCheckError::NameAlreadyUsedBy(_, definition_span, previous_definition_span) => {
+                let mut labels = vec![];
+
+                if let Some(span) = previous_definition_span {
+                    labels.push(
+                        span.diagnostic_label(LabelStyle::Secondary)
+                            .with_message("Previously defined here"),
+                    );
+                }
+
+                labels.push(
+                    definition_span
+                        .diagnostic_label(LabelStyle::Primary)
+                        .with_message(inner_error),
+                );
+                d.with_labels(labels)
+            }
+            TypeCheckError::ForeignFunctionNeedsTypeAnnotations(span, _)
+            | TypeCheckError::UnknownForeignFunction(span, _)
+            | TypeCheckError::NonRationalExponent(span)
+            | TypeCheckError::OverflowInConstExpr(span)
+            | TypeCheckError::ExpectedDimensionType(span, _)
+            | TypeCheckError::ExpectedBool(span) => d.with_labels(vec![span
+                .diagnostic_label(LabelStyle::Primary)
+                .with_message(inner_error)]),
         }
     }
 }

+ 18 - 13
numbat/src/prefix_parser.rs

@@ -140,13 +140,20 @@ impl PrefixParser {
         }
     }
 
-    fn ensure_name_is_available(&self, name: &str, conflict_span: Span) -> Result<()> {
+    fn ensure_name_is_available(
+        &self,
+        name: &str,
+        conflict_span: Span,
+        clash_with_other_identifiers: bool,
+    ) -> Result<()> {
         if self.reserved_identifiers.contains(&name) {
             return Err(NameResolutionError::ReservedIdentifier(conflict_span));
         }
 
-        if let Some(original_span) = self.other_identifiers.get(name) {
-            return Err(self.identifier_clash_error(name, conflict_span, *original_span));
+        if clash_with_other_identifiers {
+            if let Some(original_span) = self.other_identifiers.get(name) {
+                return Err(self.identifier_clash_error(name, conflict_span, *original_span));
+            }
         }
 
         match self.parse(name) {
@@ -166,7 +173,7 @@ impl PrefixParser {
         full_name: &str,
         definition_span: Span,
     ) -> Result<()> {
-        self.ensure_name_is_available(unit_name, definition_span)?;
+        self.ensure_name_is_available(unit_name, definition_span, true)?;
 
         for (prefix_long, prefix_short, prefix) in Self::prefixes() {
             if !(prefix.is_metric() && metric || prefix.is_binary() && binary) {
@@ -177,12 +184,14 @@ impl PrefixParser {
                 self.ensure_name_is_available(
                     &format!("{}{}", prefix_long, unit_name),
                     definition_span,
+                    true,
                 )?;
             }
             if accepts_prefix.short {
                 self.ensure_name_is_available(
                     &format!("{}{}", prefix_short, unit_name),
                     definition_span,
+                    true,
                 )?;
             }
         }
@@ -201,15 +210,11 @@ impl PrefixParser {
     }
 
     pub fn add_other_identifier(&mut self, identifier: &str, definition_span: Span) -> Result<()> {
-        self.ensure_name_is_available(identifier, definition_span)?;
-
-        if let Some(original_span) = self.other_identifiers.get(identifier) {
-            Err(self.identifier_clash_error(identifier, definition_span, *original_span))
-        } else {
-            self.other_identifiers
-                .insert(identifier.into(), definition_span);
-            Ok(())
-        }
+        self.ensure_name_is_available(identifier, definition_span, false)?;
+
+        self.other_identifiers
+            .insert(identifier.into(), definition_span);
+        Ok(())
     }
 
     pub fn parse(&self, input: &str) -> PrefixParserResult {

+ 41 - 27
numbat/src/typechecker.rs

@@ -253,6 +253,9 @@ pub enum TypeCheckError {
 
     #[error("Incompatible types in comparison operator")]
     IncompatibleTypesInComparison(Span, Type, Span, Type, Span),
+
+    #[error("This name is already used by {0}")]
+    NameAlreadyUsedBy(&'static str, Span, Option<Span>),
 }
 
 type Result<T> = std::result::Result<T, TypeCheckError>;
@@ -357,15 +360,9 @@ fn evaluate_const_expr(expr: &typed_ast::Expression) -> Result<Exponent> {
     }
 }
 
-#[derive(Clone, PartialEq)]
-enum IdentifierKind {
-    Variable,
-    Other,
-}
-
 #[derive(Clone, Default)]
 pub struct TypeChecker {
-    identifiers: HashMap<String, (Type, IdentifierKind)>,
+    identifiers: HashMap<String, (Type, Option<Span>)>,
     function_signatures: HashMap<
         String,
         (
@@ -380,15 +377,14 @@ pub struct TypeChecker {
 }
 
 impl TypeChecker {
-    fn identifier_type_and_kind(&self, span: Span, name: &str) -> Result<&(Type, IdentifierKind)> {
-        self.identifiers.get(name).ok_or_else(|| {
-            let suggestion = suggestion::did_you_mean(self.identifiers.keys(), name);
-            TypeCheckError::UnknownIdentifier(span, name.into(), suggestion)
-        })
-    }
-
     fn identifier_type(&self, span: Span, name: &str) -> Result<&Type> {
-        self.identifier_type_and_kind(span, name).map(|(t, _)| t)
+        self.identifiers
+            .get(name)
+            .ok_or_else(|| {
+                let suggestion = suggestion::did_you_mean(self.identifiers.keys(), name);
+                TypeCheckError::UnknownIdentifier(span, name.into(), suggestion)
+            })
+            .map(|(type_, _)| type_)
     }
 
     pub(crate) fn check_expression(&self, ast: &ast::Expression) -> Result<typed_ast::Expression> {
@@ -797,10 +793,8 @@ impl TypeChecker {
             ast::Statement::Expression(expr) => {
                 let checked_expr = self.check_expression(expr)?;
                 for &identifier in LAST_RESULT_IDENTIFIERS {
-                    self.identifiers.insert(
-                        identifier.into(),
-                        (checked_expr.get_type(), IdentifierKind::Variable),
-                    );
+                    self.identifiers
+                        .insert(identifier.into(), (checked_expr.get_type(), None));
                 }
                 typed_ast::Statement::Expression(checked_expr)
             }
@@ -810,6 +804,16 @@ impl TypeChecker {
                 expr,
                 type_annotation,
             } => {
+                // Make sure that identifier does not clash with a function name. We do not
+                // check for clashes with unit names, as this is handled by the prefix parser.
+                if let Some(entry) = self.function_signatures.get(identifier) {
+                    return Err(TypeCheckError::NameAlreadyUsedBy(
+                        "a function",
+                        *identifier_span,
+                        Some(entry.0),
+                    ));
+                }
+
                 let expr_checked = self.check_expression(expr)?;
                 let type_deduced = expr_checked.get_type();
 
@@ -856,7 +860,7 @@ impl TypeChecker {
 
                 self.identifiers.insert(
                     identifier.clone(),
-                    (type_deduced.clone(), IdentifierKind::Variable),
+                    (type_deduced.clone(), Some(*identifier_span)),
                 );
 
                 typed_ast::Statement::DefineVariable(
@@ -869,7 +873,7 @@ impl TypeChecker {
                     type_deduced,
                 )
             }
-            ast::Statement::DefineBaseUnit(_span, unit_name, type_annotation, decorators) => {
+            ast::Statement::DefineBaseUnit(span, unit_name, type_annotation, decorators) => {
                 let type_specified = if let Some(dexpr) = type_annotation {
                     self.registry
                         .get_base_representation(dexpr)
@@ -886,10 +890,7 @@ impl TypeChecker {
                 for (name, _) in decorator::name_and_aliases(unit_name, decorators) {
                     self.identifiers.insert(
                         name.clone(),
-                        (
-                            Type::Dimension(type_specified.clone()),
-                            IdentifierKind::Other,
-                        ),
+                        (Type::Dimension(type_specified.clone()), Some(*span)),
                     );
                 }
                 typed_ast::Statement::DefineBaseUnit(
@@ -944,7 +945,10 @@ impl TypeChecker {
                 for (name, _) in decorator::name_and_aliases(identifier, decorators) {
                     self.identifiers.insert(
                         name.clone(),
-                        (Type::Dimension(type_deduced.clone()), IdentifierKind::Other),
+                        (
+                            Type::Dimension(type_deduced.clone()),
+                            Some(*identifier_span),
+                        ),
                     );
                 }
                 typed_ast::Statement::DefineDerivedUnit(
@@ -966,6 +970,16 @@ impl TypeChecker {
                 return_type_annotation_span,
                 return_type_annotation,
             } => {
+                // Make sure that function name does not clash with an identifier. We do not
+                // check for clashes with unit names, as this is handled by the prefix parser.
+                if let Some((_, span)) = self.identifiers.get(function_name) {
+                    return Err(TypeCheckError::NameAlreadyUsedBy(
+                        "a constant",
+                        *function_name_span,
+                        span.clone(),
+                    ));
+                }
+
                 let mut typechecker_fn = self.clone();
                 let is_ffi_function = body.is_none();
                 let mut type_parameters = type_parameters.clone();
@@ -1020,7 +1034,7 @@ impl TypeChecker {
 
                     typechecker_fn.identifiers.insert(
                         parameter.clone(),
-                        (parameter_type.clone(), IdentifierKind::Variable),
+                        (parameter_type.clone(), Some(*parameter_span)),
                     );
                     typed_parameters.push((
                         *parameter_span,

+ 4 - 0
numbat/src/vm.rs

@@ -355,6 +355,10 @@ impl Vm {
     }
 
     pub(crate) fn begin_function(&mut self, name: &str) {
+        // This allows us to overwrite functions
+        self.bytecode.retain(|(n, _)| n != name);
+        self.ffi_callables.retain(|ff| ff.name != name);
+
         self.bytecode.push((name.into(), vec![]));
         self.current_chunk_index = self.bytecode.len() - 1
     }

+ 31 - 10
numbat/tests/interpreter.rs

@@ -305,26 +305,35 @@ fn test_prefixes() {
 }
 
 #[test]
-fn test_error_messages() {
+fn test_parse_errors() {
     expect_failure(
         "3kg+",
         "Expected one of: number, identifier, parenthesized expression",
     );
-    expect_failure("let kg=2", "Identifier is already in use: 'kg'");
-    expect_failure("let pi=2", "Identifier is already in use: 'pi'");
-    expect_failure("let sin=2", "Identifier is already in use: 'sin'");
     expect_failure("let print=2", "Expected identifier after 'let' keyword");
-    expect_failure("fn kg(x: Scalar) = 1", "Identifier is already in use: 'kg'");
-    expect_failure("fn pi(x: Scalar) = 1", "Identifier is already in use: 'pi'");
-    expect_failure(
-        "fn sin(x: Scalar) = 1",
-        "Identifier is already in use: 'sin'",
-    );
     expect_failure(
         "fn print(x: Scalar) = 1",
         "Expected identifier after 'fn' keyword",
     );
+}
+
+#[test]
+fn test_name_clash_errors() {
+    expect_failure("let kg=2", "Identifier is already in use: 'kg'");
+    expect_failure("fn kg(x: Scalar) = 1", "Identifier is already in use: 'kg'");
+    expect_failure("fn _()=0", "Reserved identifier");
+}
+
+#[test]
+fn test_type_check_errors() {
     expect_failure("foo", "Unknown identifier 'foo'");
+
+    expect_failure("let sin=2", "This name is already used by a function");
+    expect_failure("fn pi() = 1", "This name is already used by a constant");
+}
+
+#[test]
+fn test_runtime_errors() {
     expect_failure("1/0", "Division by zero");
 }
 
@@ -368,3 +377,15 @@ fn test_string_interpolation() {
     expect_output("\"pi = {pi}!\"", "pi = 3.14159!");
     expect_output("\"1 + 2 = {1 + 2}\"", "1 + 2 = 3");
 }
+
+#[test]
+fn test_override_constants() {
+    expect_output("let x = 1\nlet x = 2\nx", "2");
+    expect_output("let pi = 4\npi", "4");
+}
+
+#[test]
+fn test_overwrite_functions() {
+    expect_output("fn f(x)=0\nfn f(x)=1\nf(2)", "1");
+    expect_output("fn sin(x)=0\nsin(1)", "0");
+}