Browse Source

update the DefineFunction type to include a local variables array

Tamo 1 year ago
parent
commit
9177e4023b

+ 33 - 20
numbat/src/ast.rs

@@ -393,16 +393,19 @@ pub enum TypeParameterBound {
     Dim,
 }
 
+#[derive(Debug, Clone, PartialEq)]
+pub struct DefineVariable {
+    pub identifier_span: Span,
+    pub identifier: String,
+    pub expr: Expression,
+    pub type_annotation: Option<TypeAnnotation>,
+    pub decorators: Vec<Decorator>,
+}
+
 #[derive(Debug, Clone, PartialEq)]
 pub enum Statement {
     Expression(Expression),
-    DefineVariable {
-        identifier_span: Span,
-        identifier: String,
-        expr: Expression,
-        type_annotation: Option<TypeAnnotation>,
-        decorators: Vec<Decorator>,
-    },
+    DefineVariable(DefineVariable),
     DefineFunction {
         function_name_span: Span,
         function_name: String,
@@ -411,6 +414,8 @@ pub enum Statement {
         parameters: Vec<(Span, String, Option<TypeAnnotation>)>,
         /// Function body. If it is absent, the function is implemented via FFI
         body: Option<Expression>,
+        /// Local variables
+        local_variables: Vec<DefineVariable>,
         /// Optional annotated return type
         return_type_annotation: Option<TypeAnnotation>,
         decorators: Vec<Decorator>,
@@ -573,30 +578,34 @@ impl ReplaceSpans for Expression {
     }
 }
 
+#[cfg(test)]
+impl ReplaceSpans for DefineVariable {
+    fn replace_spans(&self) -> Self {
+        Self {
+            identifier_span: Span::dummy(),
+            identifier: self.identifier.clone(),
+            expr: self.expr.replace_spans(),
+            type_annotation: self.type_annotation.as_ref().map(|t| t.replace_spans()),
+            decorators: self.decorators.clone(),
+        }
+    }
+}
+
 #[cfg(test)]
 impl ReplaceSpans for Statement {
     fn replace_spans(&self) -> Self {
         match self {
             Statement::Expression(expr) => Statement::Expression(expr.replace_spans()),
-            Statement::DefineVariable {
-                identifier_span: _,
-                identifier,
-                expr,
-                type_annotation,
-                decorators,
-            } => Statement::DefineVariable {
-                identifier_span: Span::dummy(),
-                identifier: identifier.clone(),
-                expr: expr.replace_spans(),
-                type_annotation: type_annotation.as_ref().map(|t| t.replace_spans()),
-                decorators: decorators.clone(),
-            },
+            Statement::DefineVariable(variable) => {
+                Statement::DefineVariable(variable.replace_spans())
+            }
             Statement::DefineFunction {
                 function_name_span: _,
                 function_name,
                 type_parameters,
                 parameters,
                 body,
+                local_variables,
                 return_type_annotation,
                 decorators,
             } => Statement::DefineFunction {
@@ -617,6 +626,10 @@ impl ReplaceSpans for Statement {
                     })
                     .collect(),
                 body: body.clone().map(|b| b.replace_spans()),
+                local_variables: local_variables
+                    .iter()
+                    .map(DefineVariable::replace_spans)
+                    .collect(),
                 return_type_annotation: return_type_annotation.as_ref().map(|t| t.replace_spans()),
                 decorators: decorators.clone(),
             },

+ 39 - 32
numbat/src/bytecode_interpreter.rs

@@ -12,7 +12,9 @@ use crate::name_resolution::LAST_RESULT_IDENTIFIERS;
 use crate::prefix::Prefix;
 use crate::prefix_parser::AcceptsPrefix;
 use crate::pretty_print::PrettyPrint;
-use crate::typed_ast::{BinaryOperator, Expression, Statement, StringPart, UnaryOperator};
+use crate::typed_ast::{
+    BinaryOperator, DefineVariable, Expression, Statement, StringPart, UnaryOperator,
+};
 use crate::unit::{CanonicalName, Unit};
 use crate::unit_registry::{UnitMetadata, UnitRegistry};
 use crate::value::FunctionReference;
@@ -300,6 +302,35 @@ impl BytecodeInterpreter {
         Ok(())
     }
 
+    fn compile_define_variable(&mut self, define_variable: &DefineVariable) -> Result<()> {
+        let DefineVariable(identifier, decorators, expr, _annotation, _type, _readable_type) =
+            define_variable;
+        let current_depth = self.current_depth();
+
+        // For variables, we ignore the prefix info and only use the names
+        let aliases = crate::decorator::name_and_aliases(identifier, decorators)
+            .map(|(name, _)| name)
+            .cloned()
+            .collect::<Vec<_>>();
+        let metadata = LocalMetadata {
+            name: crate::decorator::name(decorators),
+            url: crate::decorator::url(decorators),
+            description: crate::decorator::description(decorators),
+            aliases: aliases.clone(),
+        };
+
+        for alias_name in aliases {
+            self.compile_expression_with_simplify(expr)?;
+
+            self.locals[current_depth].push(Local {
+                identifier: alias_name.clone(),
+                depth: 0,
+                metadata: metadata.clone(),
+            });
+        }
+        Ok(())
+    }
+
     fn compile_statement(
         &mut self,
         stmt: &Statement,
@@ -310,37 +341,8 @@ impl BytecodeInterpreter {
                 self.compile_expression_with_simplify(expr)?;
                 self.vm.add_op(Op::Return);
             }
-            Statement::DefineVariable(
-                identifier,
-                decorators,
-                expr,
-                _annotation,
-                _type,
-                _readable_type,
-            ) => {
-                let current_depth = self.current_depth();
-
-                // For variables, we ignore the prefix info and only use the names
-                let aliases = crate::decorator::name_and_aliases(identifier, decorators)
-                    .map(|(name, _)| name)
-                    .cloned()
-                    .collect::<Vec<_>>();
-                let metadata = LocalMetadata {
-                    name: crate::decorator::name(decorators),
-                    url: crate::decorator::url(decorators),
-                    description: crate::decorator::description(decorators),
-                    aliases: aliases.clone(),
-                };
-
-                for alias_name in aliases {
-                    self.compile_expression_with_simplify(expr)?;
-
-                    self.locals[current_depth].push(Local {
-                        identifier: alias_name.clone(),
-                        depth: 0,
-                        metadata: metadata.clone(),
-                    });
-                }
+            Statement::DefineVariable(define_variable) => {
+                self.compile_define_variable(define_variable)?
             }
             Statement::DefineFunction(
                 name,
@@ -348,6 +350,7 @@ impl BytecodeInterpreter {
                 _type_parameters,
                 parameters,
                 Some(expr),
+                local_variables,
                 _return_type,
                 _return_type_annotation,
                 _readable_return_type,
@@ -364,6 +367,9 @@ impl BytecodeInterpreter {
                         metadata: LocalMetadata::default(),
                     });
                 }
+                for local_variables in local_variables {
+                    self.compile_define_variable(local_variables)?;
+                }
 
                 self.compile_expression_with_simplify(expr)?;
                 self.vm.add_op(Op::Return);
@@ -380,6 +386,7 @@ impl BytecodeInterpreter {
                 _type_parameters,
                 parameters,
                 None,
+                _local_variables,
                 _return_type,
                 _return_type_annotation,
                 _readable_return_type,

+ 1 - 0
numbat/src/keywords.rs

@@ -6,6 +6,7 @@ pub const KEYWORDS: &[&str] = &[
     "to ",
     "let ",
     "fn ",
+    "where ",
     "dimension ",
     "unit ",
     "use ",

+ 21 - 10
numbat/src/parser.rs

@@ -64,8 +64,8 @@
 
 use crate::arithmetic::{Exponent, Rational};
 use crate::ast::{
-    BinaryOperator, Expression, ProcedureKind, Statement, StringPart, TypeAnnotation,
-    TypeExpression, TypeParameterBound, UnaryOperator,
+    BinaryOperator, DefineVariable, Expression, ProcedureKind, Statement, StringPart,
+    TypeAnnotation, TypeExpression, TypeParameterBound, UnaryOperator,
 };
 use crate::decorator::{self, Decorator};
 use crate::number::Number;
@@ -443,13 +443,13 @@ impl<'a> Parser<'a> {
                 let mut decorators = vec![];
                 std::mem::swap(&mut decorators, &mut self.decorator_stack);
 
-                Ok(Statement::DefineVariable {
+                Ok(Statement::DefineVariable(DefineVariable {
                     identifier_span,
                     identifier: identifier.lexeme.clone(),
                     expr,
                     type_annotation,
                     decorators,
-                })
+                }))
             }
         } else {
             Err(ParseError {
@@ -584,6 +584,7 @@ impl<'a> Parser<'a> {
                 type_parameters,
                 parameters,
                 body,
+                local_variables: todo!(),
                 return_type_annotation,
                 decorators,
             })
@@ -2282,18 +2283,18 @@ mod tests {
     fn variable_definition() {
         parse_as(
             &["let foo = 1", "let foo=1"],
-            Statement::DefineVariable {
+            Statement::DefineVariable(DefineVariable {
                 identifier_span: Span::dummy(),
                 identifier: "foo".into(),
                 expr: scalar!(1.0),
                 type_annotation: None,
                 decorators: Vec::new(),
-            },
+            }),
         );
 
         parse_as(
             &["let x: Length = 1 * meter"],
-            Statement::DefineVariable {
+            Statement::DefineVariable(DefineVariable {
                 identifier_span: Span::dummy(),
                 identifier: "x".into(),
                 expr: binop!(scalar!(1.0), Mul, identifier!("meter")),
@@ -2301,13 +2302,13 @@ mod tests {
                     TypeExpression::TypeIdentifier(Span::dummy(), "Length".into()),
                 )),
                 decorators: Vec::new(),
-            },
+            }),
         );
 
         // same as above, but with some decorators
         parse_as(
             &["@name(\"myvar\") @aliases(foo, bar) let x: Length = 1 * meter"],
-            Statement::DefineVariable {
+            Statement::DefineVariable(DefineVariable {
                 identifier_span: Span::dummy(),
                 identifier: "x".into(),
                 expr: binop!(scalar!(1.0), Mul, identifier!("meter")),
@@ -2318,7 +2319,7 @@ mod tests {
                     decorator::Decorator::Name("myvar".into()),
                     decorator::Decorator::Aliases(vec![("foo".into(), None), ("bar".into(), None)]),
                 ],
-            },
+            }),
         );
 
         should_fail_with(
@@ -2471,6 +2472,7 @@ mod tests {
                 type_parameters: vec![],
                 parameters: vec![],
                 body: Some(scalar!(1.0)),
+                local_variables: vec![],
                 return_type_annotation: None,
                 decorators: vec![],
             },
@@ -2484,6 +2486,7 @@ mod tests {
                 type_parameters: vec![],
                 parameters: vec![],
                 body: Some(scalar!(1.0)),
+                local_variables: vec![],
                 return_type_annotation: Some(TypeAnnotation::TypeExpression(
                     TypeExpression::TypeIdentifier(Span::dummy(), "Scalar".into()),
                 )),
@@ -2499,6 +2502,7 @@ mod tests {
                 type_parameters: vec![],
                 parameters: vec![(Span::dummy(), "x".into(), None)],
                 body: Some(scalar!(1.0)),
+                local_variables: vec![],
                 return_type_annotation: None,
                 decorators: vec![],
             },
@@ -2512,6 +2516,7 @@ mod tests {
                 type_parameters: vec![],
                 parameters: vec![(Span::dummy(), "x".into(), None)],
                 body: Some(scalar!(1.0)),
+                local_variables: vec![],
                 return_type_annotation: None,
                 decorators: vec![],
             },
@@ -2531,6 +2536,7 @@ mod tests {
                     (Span::dummy(), "y".into(), None),
                 ],
                 body: Some(scalar!(1.0)),
+                local_variables: vec![],
                 return_type_annotation: None,
                 decorators: vec![],
             },
@@ -2548,6 +2554,7 @@ mod tests {
                     (Span::dummy(), "z".into(), None),
                 ],
                 body: Some(scalar!(1.0)),
+                local_variables: vec![],
                 return_type_annotation: None,
                 decorators: vec![],
             },
@@ -2601,6 +2608,7 @@ mod tests {
                     ),
                 ],
                 body: Some(scalar!(1.0)),
+                local_variables: vec![],
                 return_type_annotation: Some(TypeAnnotation::TypeExpression(
                     TypeExpression::TypeIdentifier(Span::dummy(), "Scalar".into()),
                 )),
@@ -2622,6 +2630,7 @@ mod tests {
                     )),
                 )],
                 body: Some(scalar!(1.0)),
+                local_variables: vec![],
                 return_type_annotation: None,
                 decorators: vec![],
             },
@@ -2641,6 +2650,7 @@ mod tests {
                     )),
                 )],
                 body: Some(scalar!(1.0)),
+                local_variables: vec![],
                 return_type_annotation: None,
                 decorators: vec![],
             },
@@ -2654,6 +2664,7 @@ mod tests {
                 type_parameters: vec![],
                 parameters: vec![(Span::dummy(), "x".into(), None)],
                 body: Some(scalar!(1.0)),
+                local_variables: vec![],
                 return_type_annotation: None,
                 decorators: vec![
                     decorator::Decorator::Name("Some function".into()),

+ 34 - 20
numbat/src/prefix_transformer.rs

@@ -1,5 +1,5 @@
 use crate::{
-    ast::{Expression, Statement, StringPart},
+    ast::{DefineVariable, Expression, Statement, StringPart},
     decorator::{self, Decorator},
     name_resolution::NameResolutionError,
     prefix_parser::{PrefixParser, PrefixParserResult},
@@ -159,6 +159,32 @@ impl Transformer {
         Ok(())
     }
 
+    fn transform_define_variable(
+        &mut self,
+        define_variable: DefineVariable,
+    ) -> Result<DefineVariable> {
+        let DefineVariable {
+            identifier_span,
+            identifier,
+            expr,
+            type_annotation,
+            decorators,
+        } = define_variable;
+
+        for (name, _) in decorator::name_and_aliases(&identifier, &decorators) {
+            self.variable_names.push(name.clone());
+        }
+        self.prefix_parser
+            .add_other_identifier(&identifier, identifier_span)?;
+        Ok(DefineVariable {
+            identifier_span,
+            identifier,
+            expr: self.transform_expression(expr),
+            type_annotation,
+            decorators,
+        })
+    }
+
     fn transform_statement(&mut self, statement: Statement) -> Result<Statement> {
         Ok(match statement {
             Statement::Expression(expr) => Statement::Expression(self.transform_expression(expr)),
@@ -184,25 +210,8 @@ impl Transformer {
                     decorators,
                 }
             }
-            Statement::DefineVariable {
-                identifier_span,
-                identifier,
-                expr,
-                type_annotation,
-                decorators,
-            } => {
-                for (name, _) in decorator::name_and_aliases(&identifier, &decorators) {
-                    self.variable_names.push(name.clone());
-                }
-                self.prefix_parser
-                    .add_other_identifier(&identifier, identifier_span)?;
-                Statement::DefineVariable {
-                    identifier_span,
-                    identifier,
-                    expr: self.transform_expression(expr),
-                    type_annotation,
-                    decorators,
-                }
+            Statement::DefineVariable(define_variable) => {
+                Statement::DefineVariable(self.transform_define_variable(define_variable)?)
             }
             Statement::DefineFunction {
                 function_name_span,
@@ -210,6 +219,7 @@ impl Transformer {
                 type_parameters,
                 parameters,
                 body,
+                local_variables,
                 return_type_annotation,
                 decorators,
             } => {
@@ -238,6 +248,10 @@ impl Transformer {
                     type_parameters,
                     parameters,
                     body: body.map(|expr| self.transform_expression(expr)),
+                    local_variables: local_variables
+                        .into_iter()
+                        .map(|def| self.transform_define_variable(def))
+                        .collect::<Result<_>>()?,
                     return_type_annotation,
                     decorators,
                 }

+ 12 - 9
numbat/src/resolver.rs

@@ -145,7 +145,10 @@ impl Resolver {
 
 #[cfg(test)]
 mod tests {
-    use crate::{ast::Expression, number::Number};
+    use crate::{
+        ast::{DefineVariable, Expression},
+        number::Number,
+    };
 
     use super::*;
 
@@ -189,13 +192,13 @@ mod tests {
         assert_eq!(
             &program_inlined.replace_spans(),
             &[
-                Statement::DefineVariable {
+                Statement::DefineVariable(DefineVariable {
                     identifier_span: Span::dummy(),
                     identifier: "a".into(),
                     expr: Expression::Scalar(Span::dummy(), Number::from_f64(1.0)),
                     type_annotation: None,
                     decorators: Vec::new(),
-                },
+                }),
                 Statement::Expression(Expression::Identifier(Span::dummy(), "a".into()))
             ]
         );
@@ -219,13 +222,13 @@ mod tests {
         assert_eq!(
             &program_inlined.replace_spans(),
             &[
-                Statement::DefineVariable {
+                Statement::DefineVariable(DefineVariable {
                     identifier_span: Span::dummy(),
                     identifier: "a".into(),
                     expr: Expression::Scalar(Span::dummy(), Number::from_f64(1.0)),
                     type_annotation: None,
                     decorators: Vec::new(),
-                },
+                }),
                 Statement::Expression(Expression::Identifier(Span::dummy(), "a".into()))
             ]
         );
@@ -248,20 +251,20 @@ mod tests {
         assert_eq!(
             &program_inlined.replace_spans(),
             &[
-                Statement::DefineVariable {
+                Statement::DefineVariable(DefineVariable {
                     identifier_span: Span::dummy(),
                     identifier: "y".into(),
                     expr: Expression::Scalar(Span::dummy(), Number::from_f64(1.0)),
                     type_annotation: None,
                     decorators: Vec::new(),
-                },
-                Statement::DefineVariable {
+                }),
+                Statement::DefineVariable(DefineVariable {
                     identifier_span: Span::dummy(),
                     identifier: "x".into(),
                     expr: Expression::Identifier(Span::dummy(), "y".into()),
                     type_annotation: None,
                     decorators: Vec::new(),
-                },
+                }),
             ]
         );
     }

+ 2 - 0
numbat/src/tokenizer.rs

@@ -89,6 +89,7 @@ pub enum TokenKind {
     To,
     Let,
     Fn, // 'fn'
+    Where,
     Dimension,
     Unit,
     Use,
@@ -374,6 +375,7 @@ impl Tokenizer {
             m.insert("to", TokenKind::To);
             m.insert("let", TokenKind::Let);
             m.insert("fn", TokenKind::Fn);
+            m.insert("where", TokenKind::Where);
             m.insert("dimension", TokenKind::Dimension);
             m.insert("unit", TokenKind::Unit);
             m.insert("use", TokenKind::Use);

+ 14 - 5
numbat/src/traversal.rs

@@ -1,5 +1,5 @@
 use crate::typechecker::type_scheme::TypeScheme;
-use crate::typed_ast::{Expression, Statement, StructInfo};
+use crate::typed_ast::{DefineVariable, Expression, Statement, StructInfo};
 
 pub trait ForAllTypeSchemes {
     fn for_all_type_schemes(&mut self, f: &mut dyn FnMut(&mut TypeScheme));
@@ -79,11 +79,15 @@ impl ForAllTypeSchemes for Statement {
     fn for_all_type_schemes(&mut self, f: &mut dyn FnMut(&mut TypeScheme)) {
         match self {
             Statement::Expression(expr) => expr.for_all_type_schemes(f),
-            Statement::DefineVariable(_, _, expr, _annotation, type_, _) => {
+            Statement::DefineVariable(DefineVariable(_, _, expr, _annotation, type_, _)) => {
                 expr.for_all_type_schemes(f);
                 f(type_);
             }
-            Statement::DefineFunction(_, _, _, _, body, fn_type, _, _) => {
+            Statement::DefineFunction(_, _, _, _, body, local_variables, fn_type, _, _) => {
+                for local_variable in local_variables {
+                    local_variable.2.for_all_type_schemes(f);
+                    f(&mut local_variable.4);
+                }
                 if let Some(body) = body {
                     body.for_all_type_schemes(f);
                 }
@@ -115,8 +119,13 @@ impl ForAllExpressions for Statement {
     fn for_all_expressions(&self, f: &mut dyn FnMut(&Expression)) {
         match self {
             Statement::Expression(expr) => expr.for_all_expressions(f),
-            Statement::DefineVariable(_, _, expr, _, _, _) => expr.for_all_expressions(f),
-            Statement::DefineFunction(_, _, _, _, body, _, _, _) => {
+            Statement::DefineVariable(DefineVariable(_, _, expr, _, _, _)) => {
+                expr.for_all_expressions(f)
+            }
+            Statement::DefineFunction(_, _, _, _, body, local_variables, _, _, _) => {
+                for local_variable in local_variables {
+                    local_variable.2.for_all_expressions(f);
+                }
                 if let Some(body) = body {
                     body.for_all_expressions(f);
                 }

+ 94 - 80
numbat/src/typechecker/mod.rs

@@ -16,8 +16,8 @@ use std::ops::Deref;
 
 use crate::arithmetic::Exponent;
 use crate::ast::{
-    self, BinaryOperator, ProcedureKind, StringPart, TypeAnnotation, TypeExpression,
-    TypeParameterBound,
+    self, BinaryOperator, DefineVariable, ProcedureKind, StringPart, TypeAnnotation,
+    TypeExpression, TypeParameterBound,
 };
 use crate::dimension::DimensionRegistry;
 use crate::name_resolution::Namespace;
@@ -1064,6 +1064,89 @@ impl TypeChecker {
         })
     }
 
+    fn elaborate_define_variable(
+        &mut self,
+        define_variable: &ast::DefineVariable,
+    ) -> Result<typed_ast::DefineVariable> {
+        let DefineVariable {
+            identifier_span,
+            identifier,
+            expr,
+            type_annotation,
+            decorators,
+        } = define_variable;
+
+        let expr_checked = self.elaborate_expression(expr)?;
+        let type_deduced = expr_checked.get_type();
+
+        if let Some(ref type_annotation) = type_annotation {
+            let type_annotated = self.type_from_annotation(type_annotation)?;
+
+            match (&type_deduced, &type_annotated) {
+                (Type::Dimension(dexpr_deduced), Type::Dimension(dexpr_specified))
+                    if type_deduced.is_closed() && type_annotated.is_closed() =>
+                {
+                    if dexpr_deduced != dexpr_specified {
+                        return Err(TypeCheckError::IncompatibleDimensions(
+                            IncompatibleDimensionsError {
+                                span_operation: *identifier_span,
+                                operation: "variable definition".into(),
+                                span_expected: type_annotation.full_span(),
+                                expected_name: "specified dimension",
+                                expected_dimensions: self.registry.get_derived_entry_names_for(
+                                    &dexpr_specified.to_base_representation(),
+                                ),
+                                expected_type: dexpr_specified.to_base_representation(),
+                                span_actual: expr.full_span(),
+                                actual_name: "   actual dimension",
+                                actual_name_for_fix: "right hand side expression",
+                                actual_dimensions: self.registry.get_derived_entry_names_for(
+                                    &dexpr_deduced.to_base_representation(),
+                                ),
+                                actual_type: dexpr_deduced.to_base_representation(),
+                            },
+                        ));
+                    }
+                }
+                (deduced, annotated) => {
+                    if self
+                        .add_equal_constraint(deduced, annotated)
+                        .is_trivially_violated()
+                    {
+                        return Err(TypeCheckError::IncompatibleTypesInAnnotation(
+                            "definition".into(),
+                            *identifier_span,
+                            annotated.clone(),
+                            type_annotation.full_span(),
+                            deduced.clone(),
+                            expr_checked.full_span(),
+                        ));
+                    }
+                }
+            }
+        }
+
+        for (name, _) in decorator::name_and_aliases(identifier, decorators) {
+            self.env
+                .add(name.clone(), type_deduced.clone(), *identifier_span, false);
+
+            self.value_namespace.add_identifier_allow_override(
+                name.clone(),
+                *identifier_span,
+                "constant".to_owned(),
+            )?;
+        }
+
+        Ok(typed_ast::DefineVariable(
+            identifier.clone(),
+            decorators.clone(),
+            expr_checked,
+            type_annotation.clone(),
+            TypeScheme::concrete(type_deduced),
+            crate::markup::empty(),
+        ))
+    }
+
     fn elaborate_statement(&mut self, ast: &ast::Statement) -> Result<typed_ast::Statement> {
         Ok(match ast {
             ast::Statement::Expression(expr) => {
@@ -1076,85 +1159,9 @@ impl TypeChecker {
                 }
                 typed_ast::Statement::Expression(checked_expr)
             }
-            ast::Statement::DefineVariable {
-                identifier_span,
-                identifier,
-                expr,
-                type_annotation,
-                decorators,
-            } => {
-                let expr_checked = self.elaborate_expression(expr)?;
-                let type_deduced = expr_checked.get_type();
-
-                if let Some(ref type_annotation) = type_annotation {
-                    let type_annotated = self.type_from_annotation(type_annotation)?;
-
-                    match (&type_deduced, &type_annotated) {
-                        (Type::Dimension(dexpr_deduced), Type::Dimension(dexpr_specified))
-                            if type_deduced.is_closed() && type_annotated.is_closed() =>
-                        {
-                            if dexpr_deduced != dexpr_specified {
-                                return Err(TypeCheckError::IncompatibleDimensions(
-                                    IncompatibleDimensionsError {
-                                        span_operation: *identifier_span,
-                                        operation: "variable definition".into(),
-                                        span_expected: type_annotation.full_span(),
-                                        expected_name: "specified dimension",
-                                        expected_dimensions: self
-                                            .registry
-                                            .get_derived_entry_names_for(
-                                                &dexpr_specified.to_base_representation(),
-                                            ),
-                                        expected_type: dexpr_specified.to_base_representation(),
-                                        span_actual: expr.full_span(),
-                                        actual_name: "   actual dimension",
-                                        actual_name_for_fix: "right hand side expression",
-                                        actual_dimensions: self
-                                            .registry
-                                            .get_derived_entry_names_for(
-                                                &dexpr_deduced.to_base_representation(),
-                                            ),
-                                        actual_type: dexpr_deduced.to_base_representation(),
-                                    },
-                                ));
-                            }
-                        }
-                        (deduced, annotated) => {
-                            if self
-                                .add_equal_constraint(deduced, annotated)
-                                .is_trivially_violated()
-                            {
-                                return Err(TypeCheckError::IncompatibleTypesInAnnotation(
-                                    "definition".into(),
-                                    *identifier_span,
-                                    annotated.clone(),
-                                    type_annotation.full_span(),
-                                    deduced.clone(),
-                                    expr_checked.full_span(),
-                                ));
-                            }
-                        }
-                    }
-                }
-
-                for (name, _) in decorator::name_and_aliases(identifier, decorators) {
-                    self.env
-                        .add(name.clone(), type_deduced.clone(), *identifier_span, false);
-
-                    self.value_namespace.add_identifier_allow_override(
-                        name.clone(),
-                        *identifier_span,
-                        "constant".to_owned(),
-                    )?;
-                }
-
+            ast::Statement::DefineVariable(define_variable) => {
                 typed_ast::Statement::DefineVariable(
-                    identifier.clone(),
-                    decorators.clone(),
-                    expr_checked,
-                    type_annotation.clone(),
-                    TypeScheme::concrete(type_deduced),
-                    crate::markup::empty(),
+                    self.elaborate_define_variable(define_variable)?,
                 )
             }
             ast::Statement::DefineBaseUnit(span, unit_name, type_annotation, decorators) => {
@@ -1283,6 +1290,7 @@ impl TypeChecker {
                 type_parameters,
                 parameters,
                 body,
+                local_variables,
                 return_type_annotation,
                 decorators,
             } => {
@@ -1365,6 +1373,11 @@ impl TypeChecker {
                     ));
                 }
 
+                let mut typed_local_variables = vec![];
+                for local_variable in local_variables {
+                    typed_local_variables.push(self.elaborate_define_variable(local_variable)?);
+                }
+
                 let annotated_return_type = return_type_annotation
                     .as_ref()
                     .map(|annotation| typechecker_fn.type_from_annotation(annotation))
@@ -1516,6 +1529,7 @@ impl TypeChecker {
                         })
                         .collect(),
                     body_checked,
+                    typed_local_variables,
                     fn_type,
                     return_type_annotation.clone(),
                     crate::markup::empty(),

+ 7 - 3
numbat/src/typechecker/substitutions.rs

@@ -1,7 +1,7 @@
 use thiserror::Error;
 
 use crate::type_variable::TypeVariable;
-use crate::typed_ast::{DType, DTypeFactor, Expression, StructInfo, Type};
+use crate::typed_ast::{DType, DTypeFactor, DefineVariable, Expression, StructInfo, Type};
 use crate::Statement;
 
 #[derive(Debug, Clone)]
@@ -214,11 +214,15 @@ impl ApplySubstitution for Statement {
     fn apply(&mut self, s: &Substitution) -> Result<(), SubstitutionError> {
         match self {
             Statement::Expression(e) => e.apply(s),
-            Statement::DefineVariable(_, _, e, _annotation, type_, _) => {
+            Statement::DefineVariable(DefineVariable(_, _, e, _annotation, type_, _)) => {
                 e.apply(s)?;
                 type_.apply(s)
             }
-            Statement::DefineFunction(_, _, _, _, body, fn_type, _, _) => {
+            Statement::DefineFunction(_, _, _, _, body, local_variables, fn_type, _, _) => {
+                for local_variable in local_variables {
+                    local_variable.2.apply(s)?;
+                    local_variable.4.apply(s)?;
+                }
                 if let Some(body) = body {
                     body.apply(s)?;
                 }

+ 1 - 1
numbat/src/typechecker/tests/mod.rs

@@ -68,7 +68,7 @@ fn assert_successful_typecheck(input: &str) {
 fn get_inferred_fn_type(input: &str) -> TypeScheme {
     let statement = run_typecheck(input).expect("Input was expected to type-check");
     match statement {
-        Statement::DefineFunction(_, _, _, _, _, fn_type, _, _) => fn_type,
+        Statement::DefineFunction(_, _, _, _, _, _, fn_type, _, _) => fn_type,
         _ => {
             unreachable!();
         }

+ 56 - 11
numbat/src/typed_ast.rs

@@ -567,17 +567,20 @@ impl Expression {
     }
 }
 
+#[derive(Debug, Clone, PartialEq)]
+pub struct DefineVariable(
+    pub String,
+    pub Vec<Decorator>,
+    pub Expression,
+    pub Option<TypeAnnotation>,
+    pub TypeScheme,
+    pub Markup,
+);
+
 #[derive(Debug, Clone, PartialEq)]
 pub enum Statement {
     Expression(Expression),
-    DefineVariable(
-        String,
-        Vec<Decorator>,
-        Expression,
-        Option<TypeAnnotation>,
-        TypeScheme,
-        Markup,
-    ),
+    DefineVariable(DefineVariable),
     DefineFunction(
         String,
         Vec<Decorator>,                            // decorators
@@ -590,6 +593,7 @@ pub enum Statement {
             Markup,                 // readable parameter type
         )>,
         Option<Expression>,     // function body
+        Vec<DefineVariable>,    // local variables
         TypeScheme,             // function type
         Option<TypeAnnotation>, // return type annotation
         Markup,                 // readable return type
@@ -636,7 +640,14 @@ impl Statement {
     pub(crate) fn update_readable_types(&mut self, registry: &DimensionRegistry) {
         match self {
             Statement::Expression(_) => {}
-            Statement::DefineVariable(_, _, _, type_annotation, type_, readable_type) => {
+            Statement::DefineVariable(DefineVariable(
+                _,
+                _,
+                _,
+                type_annotation,
+                type_,
+                readable_type,
+            )) => {
                 *readable_type = Self::create_readable_type(registry, type_, type_annotation);
             }
             Statement::DefineFunction(
@@ -645,6 +656,7 @@ impl Statement {
                 type_parameters,
                 parameters,
                 _,
+                local_variables,
                 fn_type,
                 return_type_annotation,
                 readable_return_type,
@@ -653,6 +665,12 @@ impl Statement {
                     type_parameters.iter().map(|(n, _)| n.clone()).collect(),
                 ));
 
+                for DefineVariable(_, _, _, type_annotation, type_, readable_type) in
+                    local_variables
+                {
+                    *readable_type = Self::create_readable_type(registry, type_, type_annotation);
+                }
+
                 let Type::Fn(parameter_types, return_type) = fn_type.inner else {
                     unreachable!("Expected a function type")
                 };
@@ -894,14 +912,14 @@ pub fn pretty_print_function_signature(
 impl PrettyPrint for Statement {
     fn pretty_print(&self) -> Markup {
         match self {
-            Statement::DefineVariable(
+            Statement::DefineVariable(DefineVariable(
                 identifier,
                 _decs,
                 expr,
                 _annotation,
                 _type,
                 readable_type,
-            ) => {
+            )) => {
                 m::keyword("let")
                     + m::space()
                     + m::identifier(identifier)
@@ -919,6 +937,7 @@ impl PrettyPrint for Statement {
                 type_parameters,
                 parameters,
                 body,
+                local_variables,
                 fn_type,
                 _return_type_annotation,
                 readable_return_type,
@@ -927,6 +946,31 @@ impl PrettyPrint for Statement {
                     type_parameters.iter().map(|(n, _)| n.clone()).collect(),
                 ));
 
+                let mut pretty_local_variables = None;
+                if !local_variables.is_empty() {
+                    let mut plv = m::keyword("where");
+                    for DefineVariable(
+                        identifier,
+                        _decs,
+                        expr,
+                        _annotation,
+                        _type,
+                        readable_type,
+                    ) in local_variables
+                    {
+                        plv += m::space()
+                            + m::identifier(identifier)
+                            + m::operator(":")
+                            + m::space()
+                            + readable_type.clone()
+                            + m::space()
+                            + m::operator("=")
+                            + m::space()
+                            + expr.pretty_print();
+                    }
+                    pretty_local_variables = Some(plv);
+                }
+
                 pretty_print_function_signature(
                     function_name,
                     &fn_type,
@@ -939,6 +983,7 @@ impl PrettyPrint for Statement {
                     .as_ref()
                     .map(|e| m::space() + m::operator("=") + m::space() + e.pretty_print())
                     .unwrap_or_default()
+                    + pretty_local_variables.unwrap_or_default()
             }
             Statement::Expression(expr) => expr.pretty_print(),
             Statement::DefineDimension(identifier, dexprs) if dexprs.is_empty() => {