Browse Source

Consolidated duplicated code in `TypeChecker::elaborate_define_variable` and the `ast::Statement::DefineDerivedUnit` branch of `TypeChecker::elaborate_statement`

Robert Bennett 1 year ago
parent
commit
da09448743
1 changed files with 74 additions and 72 deletions
  1. 74 72
      numbat/src/typechecker/mod.rs

+ 74 - 72
numbat/src/typechecker/mod.rs

@@ -62,6 +62,18 @@ pub struct TypeChecker {
     constraints: ConstraintSet,
 }
 
+struct ElaborationDefinitionArgs<'a> {
+    identifier_span: Span,
+    expr: &'a ast::Expression,
+    type_annotation_span: Option<Span>,
+    type_annotation: Option<&'a TypeAnnotation>,
+    operation: &'a str,
+    expected_name: &'static str,
+    actual_name: &'static str,
+    actual_name_for_fix: &'static str,
+    elaboration_kind: &'a str,
+}
+
 impl TypeChecker {
     fn fresh_type_variable(&mut self) -> Type {
         Type::TVar(self.name_generator.fresh_type_variable())
@@ -1068,24 +1080,31 @@ impl TypeChecker {
         })
     }
 
-    fn elaborate_define_variable(
+    fn _elaborate_inner(
         &mut self,
-        define_variable: &ast::DefineVariable,
-    ) -> Result<typed_ast::DefineVariable> {
-        let DefineVariable {
+        definition: ElaborationDefinitionArgs,
+    ) -> Result<(typed_ast::Expression, typed_ast::Type)> {
+        let ElaborationDefinitionArgs {
             identifier_span,
-            identifier,
             expr,
+            type_annotation_span,
             type_annotation,
-            decorators,
-        } = define_variable;
+            operation,
+            expected_name,
+            actual_name,
+            actual_name_for_fix,
+            elaboration_kind,
+        } = definition;
 
         let expr_checked = self.elaborate_expression(expr)?;
         let type_deduced = expr_checked.get_type();
 
-        if let Some(ref type_annotation) = type_annotation {
+        if let Some(type_annotation) = type_annotation {
             let type_annotated = self.type_from_annotation(type_annotation)?;
 
+            let type_annotation_span =
+                type_annotation_span.unwrap_or_else(|| type_annotation.full_span());
+
             match (&type_deduced, &type_annotated) {
                 (Type::Dimension(dexpr_deduced), Type::Dimension(dexpr_specified))
                     if type_deduced.is_closed() && type_annotated.is_closed() =>
@@ -1093,17 +1112,17 @@ impl TypeChecker {
                     if dexpr_deduced != dexpr_specified {
                         return Err(TypeCheckError::IncompatibleDimensions(
                             IncompatibleDimensionsError {
-                                span_operation: *identifier_span,
-                                operation: "variable definition".into(),
-                                span_expected: type_annotation.full_span(),
-                                expected_name: "specified dimension",
+                                span_operation: identifier_span,
+                                operation: operation.into(),
+                                span_expected: type_annotation_span,
+                                expected_name,
                                 expected_dimensions: self.registry.get_derived_entry_names_for(
                                     &dexpr_specified.to_base_representation(),
                                 ),
                                 expected_type: dexpr_specified.to_base_representation(),
                                 span_actual: expr.full_span(),
-                                actual_name: "   actual dimension",
-                                actual_name_for_fix: "right hand side expression",
+                                actual_name,
+                                actual_name_for_fix,
                                 actual_dimensions: self.registry.get_derived_entry_names_for(
                                     &dexpr_deduced.to_base_representation(),
                                 ),
@@ -1118,8 +1137,8 @@ impl TypeChecker {
                         .is_trivially_violated()
                     {
                         return Err(TypeCheckError::IncompatibleTypesInAnnotation(
-                            "definition".into(),
-                            *identifier_span,
+                            elaboration_kind.into(),
+                            identifier_span,
                             annotated.clone(),
                             type_annotation.full_span(),
                             deduced.clone(),
@@ -1130,6 +1149,33 @@ impl TypeChecker {
             }
         }
 
+        Ok((expr_checked, type_deduced))
+    }
+
+    fn elaborate_define_variable(
+        &mut self,
+        define_variable: &ast::DefineVariable,
+    ) -> Result<typed_ast::DefineVariable> {
+        let DefineVariable {
+            identifier_span,
+            identifier,
+            expr,
+            type_annotation,
+            decorators,
+        } = define_variable;
+
+        let (expr_checked, type_deduced) = self._elaborate_inner(ElaborationDefinitionArgs {
+            identifier_span: *identifier_span,
+            expr,
+            type_annotation_span: None,
+            type_annotation: type_annotation.as_ref(),
+            operation: "variable definition",
+            expected_name: "specified dimension",
+            actual_name: "   actual dimension",
+            actual_name_for_fix: "right hand side expression",
+            elaboration_kind: "definition",
+        })?;
+
         for (name, _) in decorator::name_and_aliases(identifier, decorators) {
             self.env
                 .add(name.clone(), type_deduced.clone(), *identifier_span, false);
@@ -1218,62 +1264,18 @@ impl TypeChecker {
                 type_annotation,
                 decorators,
             } => {
-                // TODO: this is the *exact same code* that we have above for
-                // variable definitions => deduplicate this somehow
-
-                let expr_checked = self.elaborate_expression(expr)?;
-                let type_deduced = expr_checked.get_type();
-
-                if let Some(ref type_annotation) = type_annotation {
-                    let type_annotated = self.type_from_annotation(type_annotation)?;
-
-                    match (&type_deduced, &type_annotated) {
-                        (Type::Dimension(dexpr_deduced), Type::Dimension(dexpr_specified))
-                            if type_deduced.is_closed() && type_annotated.is_closed() =>
-                        {
-                            if dexpr_deduced != dexpr_specified {
-                                return Err(TypeCheckError::IncompatibleDimensions(
-                                    IncompatibleDimensionsError {
-                                        span_operation: *identifier_span,
-                                        operation: "unit definition".into(),
-                                        span_expected: type_annotation_span.unwrap(),
-                                        expected_name: "specified dimension",
-                                        expected_dimensions: self
-                                            .registry
-                                            .get_derived_entry_names_for(
-                                                &dexpr_specified.to_base_representation(),
-                                            ),
-                                        expected_type: dexpr_specified.to_base_representation(),
-                                        span_actual: expr.full_span(),
-                                        actual_name: "   actual dimension",
-                                        actual_name_for_fix: "right hand side expression",
-                                        actual_dimensions: self
-                                            .registry
-                                            .get_derived_entry_names_for(
-                                                &dexpr_deduced.to_base_representation(),
-                                            ),
-                                        actual_type: dexpr_deduced.to_base_representation(),
-                                    },
-                                ));
-                            }
-                        }
-                        (deduced, annotated) => {
-                            if self
-                                .add_equal_constraint(deduced, annotated)
-                                .is_trivially_violated()
-                            {
-                                return Err(TypeCheckError::IncompatibleTypesInAnnotation(
-                                    "unit definition".into(),
-                                    *identifier_span,
-                                    annotated.clone(),
-                                    type_annotation.full_span(),
-                                    deduced.clone(),
-                                    expr_checked.full_span(),
-                                ));
-                            }
-                        }
-                    }
-                }
+                let (expr_checked, type_deduced) =
+                    self._elaborate_inner(ElaborationDefinitionArgs {
+                        identifier_span: *identifier_span,
+                        expr,
+                        type_annotation_span: type_annotation_span.as_ref().copied(),
+                        type_annotation: type_annotation.as_ref(),
+                        operation: "unit definition",
+                        expected_name: "specified dimension",
+                        actual_name: "   actual dimension",
+                        actual_name_for_fix: "right hand side expression",
+                        elaboration_kind: "unit definition",
+                    })?;
 
                 for (name, _) in decorator::name_and_aliases(identifier, decorators) {
                     self.env