Browse Source

fix the span of statements

Tamo 1 week ago
parent
commit
e815b385bd

+ 6 - 8
numbat/src/bytecode_interpreter.rs

@@ -433,6 +433,7 @@ impl BytecodeInterpreter {
             }
             Statement::DefineDerivedUnit(
                 unit_name,
+                full_span,
                 expr,
                 decorators,
                 annotation,
@@ -473,8 +474,7 @@ impl BytecodeInterpreter {
                     Op::SetUnitConstant,
                     unit_information_idx,
                     constant_idx,
-                    // TODO: TAMO: We should have the full span of the stmt here
-                    expr.full_span(),
+                    *full_span,
                 );
 
                 // TODO: code duplication with DeclareBaseUnit branch above
@@ -483,7 +483,7 @@ impl BytecodeInterpreter {
                         .insert(name.into(), constant_idx);
                 }
             }
-            Statement::ProcedureCall(ProcedureKind::Type, args) => {
+            Statement::ProcedureCall(ProcedureKind::Type, full_span, args) => {
                 assert_eq!(args.len(), 1);
                 let arg = &args[0];
 
@@ -491,10 +491,9 @@ impl BytecodeInterpreter {
                 let idx = self.vm.add_string(
                     m::dimmed("=") + m::whitespace(" ") + arg.get_type_scheme().pretty_print(), // TODO
                 );
-                // TODO: TAMO: We should have the full span of the stmt here
-                self.vm.add_op1(Op::PrintString, idx, arg.full_span());
+                self.vm.add_op1(Op::PrintString, idx, *full_span);
             }
-            Statement::ProcedureCall(kind, args) => {
+            Statement::ProcedureCall(kind, full_span, args) => {
                 // Put all arguments on top of the stack
                 for arg in args {
                     self.compile_expression(arg);
@@ -512,8 +511,7 @@ impl BytecodeInterpreter {
                     callable_idx,
                     args.len() as u16,
                     spans_idx,
-                    // TODO: TAMO: We should have the full span of the stmt here
-                    args[0].full_span(),
+                    *full_span,
                 );
                 // TODO: check overflow
             }

+ 4 - 4
numbat/src/traversal.rs

@@ -97,11 +97,11 @@ impl ForAllTypeSchemes for Statement<'_> {
             Statement::DefineBaseUnit(_, _, _annotation, type_) => {
                 f(type_);
             }
-            Statement::DefineDerivedUnit(_, expr, _, _annotation, type_, _) => {
+            Statement::DefineDerivedUnit(_, _, expr, _, _annotation, type_, _) => {
                 expr.for_all_type_schemes(f);
                 f(type_);
             }
-            Statement::ProcedureCall(_, args) => {
+            Statement::ProcedureCall(_, _, args) => {
                 for arg in args {
                     arg.for_all_type_schemes(f);
                 }
@@ -132,8 +132,8 @@ impl ForAllExpressions for Statement<'_> {
             }
             Statement::DefineDimension(_, _) => {}
             Statement::DefineBaseUnit(_, _, _, _) => {}
-            Statement::DefineDerivedUnit(_, expr, _, _, _, _) => expr.for_all_expressions(f),
-            Statement::ProcedureCall(_, args) => {
+            Statement::DefineDerivedUnit(_, _, expr, _, _, _, _) => expr.for_all_expressions(f),
+            Statement::ProcedureCall(_, _, args) => {
                 for arg in args {
                     arg.for_all_expressions(f);
                 }

+ 13 - 3
numbat/src/typechecker/mod.rs

@@ -1331,6 +1331,7 @@ impl TypeChecker {
                 }
                 typed_ast::Statement::DefineDerivedUnit(
                     identifier,
+                    identifier_span.extend(&expr.full_span()),
                     expr_checked,
                     decorators.clone(),
                     type_annotation.clone(),
@@ -1680,7 +1681,11 @@ impl TypeChecker {
                     .map(|e| self.elaborate_expression(e))
                     .collect::<Result<Vec<_>>>()?;
 
-                typed_ast::Statement::ProcedureCall(kind.clone(), checked_args)
+                typed_ast::Statement::ProcedureCall(
+                    kind.clone(),
+                    span.extend(&args[0].full_span()),
+                    checked_args,
+                )
             }
             ast::Statement::ProcedureCall(span, kind, args) => {
                 let procedure = ffi::procedures().get(kind).unwrap();
@@ -1749,7 +1754,12 @@ impl TypeChecker {
                     }
                 }
 
-                typed_ast::Statement::ProcedureCall(kind.clone(), checked_args)
+                typed_ast::Statement::ProcedureCall(
+                    kind.clone(),
+                    args.last()
+                        .map_or(*span, |last| span.extend(&last.full_span())),
+                    checked_args,
+                )
             }
             ast::Statement::ModuleImport(_, _) => {
                 unreachable!("Modules should have been inlined by now")
@@ -1839,7 +1849,7 @@ impl TypeChecker {
             TypeCheckError::SubstitutionError(elaborated_statement.pretty_print().to_string(), e)
         })?;
 
-        if let typed_ast::Statement::DefineDerivedUnit(_, expr, _, _annotation, type_, _) =
+        if let typed_ast::Statement::DefineDerivedUnit(_, _, expr, _, _annotation, type_, _) =
             &elaborated_statement
         {
             if !type_.unsafe_as_concrete().is_closed() {

+ 2 - 2
numbat/src/typechecker/substitutions.rs

@@ -230,11 +230,11 @@ impl ApplySubstitution for Statement<'_> {
             }
             Statement::DefineDimension(_, _) => Ok(()),
             Statement::DefineBaseUnit(_, _, _annotation, type_) => type_.apply(s),
-            Statement::DefineDerivedUnit(_, e, _, _annotation, type_, _) => {
+            Statement::DefineDerivedUnit(_, _, e, _, _annotation, type_, _) => {
                 e.apply(s)?;
                 type_.apply(s)
             }
-            Statement::ProcedureCall(_, args) => {
+            Statement::ProcedureCall(_, _, args) => {
                 for arg in args {
                     arg.apply(s)?;
                 }

+ 6 - 4
numbat/src/typed_ast.rs

@@ -629,13 +629,14 @@ pub enum Statement<'a> {
     ),
     DefineDerivedUnit(
         &'a str,
+        Span,
         Expression<'a>,
         Vec<Decorator<'a>>,
         Option<TypeAnnotation>,
         TypeScheme,
         Markup,
     ),
-    ProcedureCall(crate::ast::ProcedureKind, Vec<Expression<'a>>),
+    ProcedureCall(crate::ast::ProcedureKind, Span, Vec<Expression<'a>>),
     DefineStruct(StructInfo),
 }
 
@@ -723,11 +724,11 @@ impl Statement<'_> {
             }
             Statement::DefineDimension(_, _) => {}
             Statement::DefineBaseUnit(_, _, _, _) => {}
-            Statement::DefineDerivedUnit(_, _, _, type_annotation, type_, readable_type) => {
+            Statement::DefineDerivedUnit(_, _, _, _, type_annotation, type_, readable_type) => {
                 *readable_type =
                     Self::create_readable_type(registry, type_, type_annotation, false);
             }
-            Statement::ProcedureCall(_, _) => {}
+            Statement::ProcedureCall(_, _, _) => {}
             Statement::DefineStruct(_) => {}
         }
     }
@@ -1080,6 +1081,7 @@ impl PrettyPrint for Statement<'_> {
             }
             Statement::DefineDerivedUnit(
                 identifier,
+                _,
                 expr,
                 decorators,
                 _annotation,
@@ -1098,7 +1100,7 @@ impl PrettyPrint for Statement<'_> {
                     + m::space()
                     + expr.pretty_print()
             }
-            Statement::ProcedureCall(kind, args) => {
+            Statement::ProcedureCall(kind, _, args) => {
                 let identifier = match kind {
                     ProcedureKind::Print => "print",
                     ProcedureKind::Assert => "assert",