浏览代码

Add new dtype helper

David Peter 1 年之前
父节点
当前提交
4bb29f6d82
共有 1 个文件被更改,包括 21 次插入46 次删除
  1. 21 46
      numbat/src/typechecker/mod.rs

+ 21 - 46
numbat/src/typechecker/mod.rs

@@ -74,6 +74,18 @@ impl TypeChecker {
         self.constraints.add(Constraint::IsDType(type_.clone()))
     }
 
+    fn enforce_dtype(&mut self, type_: &Type, span: Span) -> Result<()> {
+        if self
+            .constraints
+            .add(Constraint::IsDType(type_.clone()))
+            .is_trivially_violated()
+        {
+            return Err(TypeCheckError::ExpectedDimensionType(span, type_.clone()));
+        }
+
+        Ok(())
+    }
+
     fn type_from_annotation(&self, annotation: &TypeAnnotation) -> Result<Type> {
         match annotation {
             TypeAnnotation::TypeExpression(dexpr) => {
@@ -337,12 +349,7 @@ impl TypeChecker {
                         }
                     }
                     ast::UnaryOperator::Negate => {
-                        if self.add_dtype_constraint(&type_).is_trivially_violated() {
-                            return Err(TypeCheckError::ExpectedDimensionType(
-                                expr.full_span(),
-                                type_,
-                            ));
-                        }
+                        self.enforce_dtype(&type_, expr.full_span())?;
                     }
                     ast::UnaryOperator::LogicalNeg => {
                         if self
@@ -524,8 +531,8 @@ impl TypeChecker {
                             let type_lhs = lhs_checked.get_type();
                             let type_rhs = rhs_checked.get_type();
 
-                            self.add_dtype_constraint(&type_lhs).ok(); // TODO: here we can fail immediately, if this constraint is trivially violated
-                            self.add_dtype_constraint(&type_rhs).ok();
+                            self.enforce_dtype(&type_lhs, lhs_checked.full_span())?;
+                            self.enforce_dtype(&type_rhs, rhs_checked.full_span())?;
 
                             // We first introduce a fresh type variable for the result
                             let tv_result = self.name_generator.fresh_type_variable();
@@ -591,46 +598,11 @@ impl TypeChecker {
                             type_result
                         }
                         typed_ast::BinaryOperator::Power => {
-                            // let exponent_type = dtype(&rhs_checked)?;
-                            // if !exponent_type.is_scalar() {
-                            //     return Err(TypeCheckError::NonScalarExponent(
-                            //         rhs.full_span(),
-                            //         Type::Dimension(exponent_type), // TODO
-                            //     ));
-                            // }
-
-                            // let base_type = dtype(&lhs_checked)?;
-                            // if base_type.is_scalar() {
-                            //     // Skip evaluating the exponent if the lhs is a scalar. This allows
-                            //     // for arbitrary (decimal) exponents, if the base is a scalar.
-
-                            //     Type::Dimension(base_type)
-                            // } else {
-                            //     let exponent = evaluate_const_expr(&rhs_checked)?;
-                            //     Type::Dimension(base_type.power(exponent))
-                            // }
-
                             let type_base_inferred = lhs_type;
                             let type_exponent_inferred = rhs_type;
 
-                            if self
-                                .add_dtype_constraint(&type_base_inferred)
-                                .is_trivially_violated()
-                            {
-                                return Err(TypeCheckError::ExpectedDimensionType(
-                                    lhs.full_span(),
-                                    type_base_inferred,
-                                ));
-                            }
-                            if self
-                                .add_dtype_constraint(&type_exponent_inferred)
-                                .is_trivially_violated()
-                            {
-                                return Err(TypeCheckError::ExpectedDimensionType(
-                                    rhs.full_span(),
-                                    type_exponent_inferred,
-                                ));
-                            }
+                            self.enforce_dtype(&type_base_inferred, lhs.full_span())?;
+                            self.enforce_dtype(&type_exponent_inferred, rhs.full_span())?;
 
                             match type_base_inferred {
                                 Type::Dimension(base_dtype) if base_dtype.is_scalar() => {
@@ -1577,7 +1549,10 @@ impl TypeChecker {
                         // no argument type checks required, everything can be printed
                     }
                     ProcedureKind::Assert => {
-                        if checked_args[0].get_type() != Type::Boolean {
+                        if self
+                            .add_equal_constraint(&checked_args[0].get_type(), &Type::Boolean)
+                            .is_trivially_violated()
+                        {
                             return Err(TypeCheckError::IncompatibleTypeInAssert(
                                 *span,
                                 checked_args[0].get_type(),