Browse Source

Remove Never, initial constraints

David Peter 1 year ago
parent
commit
6264040d6c

+ 1 - 1
numbat/modules/core/error.nbt

@@ -1,2 +1,2 @@
 @description("Throw a user-defined error.")
-fn error(message: String) -> !
+fn error(message: String) -> Scalar

+ 0 - 4
numbat/src/ast.rs

@@ -256,7 +256,6 @@ pub(crate) use struct_;
 
 #[derive(Debug, Clone, PartialEq)]
 pub enum TypeAnnotation {
-    Never(Span),
     TypeExpression(TypeExpression),
     Bool(Span),
     String(Span),
@@ -268,7 +267,6 @@ pub enum TypeAnnotation {
 impl TypeAnnotation {
     pub fn full_span(&self) -> Span {
         match self {
-            TypeAnnotation::Never(span) => *span,
             TypeAnnotation::TypeExpression(d) => d.full_span(),
             TypeAnnotation::Bool(span) => *span,
             TypeAnnotation::String(span) => *span,
@@ -282,7 +280,6 @@ impl TypeAnnotation {
 impl PrettyPrint for TypeAnnotation {
     fn pretty_print(&self) -> Markup {
         match self {
-            TypeAnnotation::Never(_) => m::type_identifier("!"),
             TypeAnnotation::TypeExpression(d) => d.pretty_print(),
             TypeAnnotation::Bool(_) => m::type_identifier("Bool"),
             TypeAnnotation::String(_) => m::type_identifier("String"),
@@ -440,7 +437,6 @@ pub trait ReplaceSpans {
 impl ReplaceSpans for TypeAnnotation {
     fn replace_spans(&self) -> Self {
         match self {
-            TypeAnnotation::Never(_) => TypeAnnotation::Never(Span::dummy()),
             TypeAnnotation::TypeExpression(d) => TypeAnnotation::TypeExpression(d.replace_spans()),
             TypeAnnotation::Bool(_) => TypeAnnotation::Bool(Span::dummy()),
             TypeAnnotation::String(_) => TypeAnnotation::String(Span::dummy()),

+ 6 - 8
numbat/src/bytecode_interpreter.rs

@@ -47,7 +47,7 @@ pub struct BytecodeInterpreter {
 impl BytecodeInterpreter {
     fn compile_expression(&mut self, expr: &Expression) -> Result<()> {
         match expr {
-            Expression::Scalar(_span, n) => {
+            Expression::Scalar(_span, n, _type) => {
                 let index = self.vm.add_constant(Constant::Scalar(n.to_f64()));
                 self.vm.add_op1(Op::LoadConstant, index);
             }
@@ -295,7 +295,7 @@ impl BytecodeInterpreter {
                 self.compile_expression_with_simplify(expr)?;
                 self.vm.add_op(Op::Return);
             }
-            Statement::DefineVariable(identifier, decorators, expr, _type_annotation, _type) => {
+            Statement::DefineVariable(identifier, decorators, expr, _type) => {
                 let current_depth = self.current_depth();
 
                 // For variables, we ignore the prefix info and only use the names
@@ -326,7 +326,6 @@ impl BytecodeInterpreter {
                 _type_parameters,
                 parameters,
                 Some(expr),
-                _return_type_annotation,
                 _return_type,
             ) => {
                 self.vm.begin_function(name);
@@ -357,7 +356,6 @@ impl BytecodeInterpreter {
                 _type_parameters,
                 parameters,
                 None,
-                _return_type_annotation,
                 _return_type,
             ) => {
                 // Declaring a foreign function does not generate any bytecode. But we register
@@ -380,7 +378,7 @@ impl BytecodeInterpreter {
                 // Declaring a dimension is like introducing a new type. The information
                 // is only relevant for the type checker. Nothing happens at run time.
             }
-            Statement::DefineBaseUnit(unit_name, decorators, readable_type, type_) => {
+            Statement::DefineBaseUnit(unit_name, decorators, type_) => {
                 let aliases = decorator::name_and_aliases(unit_name, decorators)
                     .map(|(name, ap)| (name.clone(), ap))
                     .collect();
@@ -391,7 +389,7 @@ impl BytecodeInterpreter {
                         unit_name,
                         UnitMetadata {
                             type_: type_.clone(),
-                            readable_type: readable_type.clone(),
+                            readable_type: crate::markup::empty(), // TODO: type_.to_readable_type(registry)
                             aliases,
                             name: decorator::name(decorators),
                             canonical_name: decorator::get_canonical_unit_name(
@@ -414,7 +412,7 @@ impl BytecodeInterpreter {
                         .insert(name.into(), constant_idx);
                 }
             }
-            Statement::DefineDerivedUnit(unit_name, expr, decorators, readable_type, type_) => {
+            Statement::DefineDerivedUnit(unit_name, expr, decorators, type_) => {
                 let aliases = decorator::name_and_aliases(unit_name, decorators)
                     .map(|(name, ap)| (name.clone(), ap))
                     .collect();
@@ -437,7 +435,7 @@ impl BytecodeInterpreter {
                     ),
                     UnitMetadata {
                         type_: type_.clone(),
-                        readable_type: readable_type.clone(),
+                        readable_type: crate::markup::empty(), // TODO
                         aliases,
                         name: decorator::name(decorators),
                         canonical_name: decorator::get_canonical_unit_name(unit_name, decorators),

+ 4 - 2
numbat/src/diagnostic.rs

@@ -259,7 +259,7 @@ impl ErrorDiagnostic for TypeCheckError {
                     .with_message(rhs_type.to_string()),
                 op_span
                     .diagnostic_label(LabelStyle::Primary)
-                    .with_message("Incompatible types comparison operator"),
+                    .with_message("Incompatible types in comparison operator"),
             ]),
             TypeCheckError::IncompatibleTypeInAssert(procedure_span, type_, type_span) => d
                 .with_labels(vec![
@@ -479,7 +479,9 @@ impl ErrorDiagnostic for TypeCheckError {
             TypeCheckError::NameResolutionError(inner) => {
                 return inner.diagnostics();
             }
-            TypeCheckError::ConstraintSolverError(_) => d.with_message(inner_error),
+            TypeCheckError::ConstraintSolverError(_) | TypeCheckError::SubstitutionError(_) => {
+                d.with_message(inner_error) // TODO
+            }
         };
         vec![d]
     }

+ 1 - 3
numbat/src/parser.rs

@@ -1421,9 +1421,7 @@ impl<'a> Parser<'a> {
     }
 
     fn type_annotation(&mut self) -> Result<TypeAnnotation> {
-        if let Some(token) = self.match_exact(TokenKind::ExclamationMark) {
-            Ok(TypeAnnotation::Never(token.span))
-        } else if let Some(token) = self.match_exact(TokenKind::Bool) {
+        if let Some(token) = self.match_exact(TokenKind::Bool) {
             Ok(TypeAnnotation::Bool(token.span))
         } else if let Some(token) = self.match_exact(TokenKind::String) {
             Ok(TypeAnnotation::String(token.span))

+ 21 - 9
numbat/src/quantity.rs

@@ -59,7 +59,7 @@ impl Quantity {
     }
 
     pub fn convert_to(&self, target_unit: &Unit) -> Result<Quantity> {
-        if &self.unit == target_unit {
+        if &self.unit == target_unit || self.unsafe_value().to_f64().is_zero() {
             Ok(Quantity::new(self.value, target_unit.clone()))
         } else {
             // Remove common unit factors to reduce unnecessary conversion procedures
@@ -246,10 +246,16 @@ impl std::ops::Add for &Quantity {
     type Output = Result<Quantity>;
 
     fn add(self, rhs: Self) -> Self::Output {
-        Ok(Quantity {
-            value: self.value + rhs.convert_to(&self.unit)?.value,
-            unit: self.unit.clone(),
-        })
+        if self.is_zero() {
+            Ok(rhs.clone())
+        } else if rhs.is_zero() {
+            Ok(self.clone())
+        } else {
+            Ok(Quantity {
+                value: self.value + rhs.convert_to(&self.unit)?.value,
+                unit: self.unit.clone(),
+            })
+        }
     }
 }
 
@@ -257,10 +263,16 @@ impl std::ops::Sub for &Quantity {
     type Output = Result<Quantity>;
 
     fn sub(self, rhs: Self) -> Self::Output {
-        Ok(Quantity {
-            value: self.value - rhs.convert_to(&self.unit)?.value,
-            unit: self.unit.clone(),
-        })
+        if self.is_zero() {
+            Ok(-rhs.clone())
+        } else if rhs.is_zero() {
+            Ok(self.clone())
+        } else {
+            Ok(Quantity {
+                value: self.value - rhs.convert_to(&self.unit)?.value,
+                unit: self.unit.clone(),
+            })
+        }
     }
 }
 

+ 121 - 128
numbat/src/typechecker/constraints.rs

@@ -171,137 +171,131 @@ impl ApplySubstitution for ConstraintSet {
 pub enum Constraint {
     Equal(Type, Type),
     IsDType(Type),
-    EqualScalar(DType),
+    // EqualScalar(DType),
 }
 
 impl Constraint {
     /// Try to solve a constraint. Returns `None` if the constaint can not (yet) be solved.
     fn try_satisfy(&self) -> Option<Satisfied> {
         match self {
-            _ => None, // …
-
-                       // Constraint::Equal(t1, t2) if t1 == t2 => {
-                       //     println!(
-                       //         "  (1) SOLVING: {} ~ {} trivially ",
-                       //         t1.pretty_print(),
-                       //         t2.pretty_print()
-                       //     );
-                       //     Some(Satisfied::trivially())
-                       // }
-                       // Constraint::Equal(Type::TVar(x), t) if !t.contains(x) => {
-                       //     println!(
-                       //         "  (2) SOLVING: {x} ~ {t} with substitution {x} := {t}",
-                       //         x = x,
-                       //         t = t.pretty_print()
-                       //     );
-                       //     Some(Satisfied::with_substitution(Substitution::single(
-                       //         x.clone(),
-                       //         t.clone(),
-                       //     )))
-                       // }
-                       // Constraint::Equal(s, Type::TVar(x)) if !s.contains(x) => {
-                       //     println!(
-                       //         "  (3) SOLVING: {s} ~ {x} with substitution {x} := {s}",
-                       //         s = s.pretty_print(),
-                       //         x = x
-                       //     );
-                       //     Some(Satisfied::with_substitution(Substitution::single(
-                       //         x.clone(),
-                       //         s.clone(),
-                       //     )))
-                       // }
-                       // Constraint::Equal(t @ Type::TArr(s1, s2), s @ Type::TArr(t1, t2)) => {
-                       //     println!(
-                       //         "  (4) SOLVING: {t} ~ {s} with new constraints {s1} ~ {t1} and {s2} ~ {t2}",
-                       //         t = t.pretty_print(),
-                       //         s = s.pretty_print(),
-                       //         s1 = s1.pretty_print(),
-                       //         s2 = s2.pretty_print(),
-                       //         t1 = t1.pretty_print(),
-                       //         t2 = t2.pretty_print()
-                       //     );
-                       //     Some(Satisfied::with_new_constraints(vec![
-                       //         Constraint::Equal(s1.as_ref().clone(), t1.as_ref().clone()),
-                       //         Constraint::Equal(s2.as_ref().clone(), t2.as_ref().clone()),
-                       //     ]))
-                       // }
-                       // Constraint::Equal(s @ Type::List(s1), t @ Type::List(t1)) => {
-                       //     println!(
-                       //         "  (5) SOLVING: {s} ~ {t} with new constraint {s1} ~ {t1}",
-                       //         s = s.pretty_print(),
-                       //         t = t.pretty_print(),
-                       //         s1 = s1.pretty_print(),
-                       //         t1 = t1.pretty_print()
-                       //     );
-                       //     Some(Satisfied::with_new_constraints(vec![Constraint::Equal(
-                       //         s1.as_ref().clone(),
-                       //         t1.as_ref().clone(),
-                       //     )]))
-                       // }
-                       // Constraint::Equal(Type::TVar(tv), Type::DType(d))
-                       // | Constraint::Equal(Type::DType(d), Type::TVar(tv)) => {
-                       //     println!(
-                       //         "  (6) SOLVING: {tv} ~ {d} by lifting the type variable to a DType",
-                       //         tv = tv,
-                       //         d = d.pretty_print()
-                       //     );
-
-                       //     Some(Satisfied::with_new_constraints(vec![Constraint::Equal(
-                       //         Type::DType(DType::from_type_variable(tv.clone())),
-                       //         Type::DType(d.clone()),
-                       //     )]))
-                       // }
-                       // Constraint::Equal(Type::DType(d1), Type::DType(d2)) => {
-                       //     let d_result = d1.divide(d2);
-                       //     println!(
-                       //         "  (7) SOLVING: {} ~ {} with new constraint d_result = Scalar",
-                       //         d1.pretty_print(),
-                       //         d2.pretty_print()
-                       //     );
-                       //     Some(Satisfied::with_new_constraints(vec![
-                       //         Constraint::EqualScalar(d_result),
-                       //     ]))
-                       // }
-                       // Constraint::Equal(_, _) => None,
-                       // Constraint::IsDType(Type::DType(inner)) => {
-                       //     let new_constraints = inner
-                       //         .type_variables()
-                       //         .iter()
-                       //         .map(|tv| Constraint::IsDType(Type::TVar(tv.clone())))
-                       //         .collect();
-                       //     println!(
-                       //         "  (8) SOLVING: {} : DType through new constraints: {:?}",
-                       //         inner.pretty_print(),
-                       //         new_constraints
-                       //     );
-                       //     Some(Satisfied::with_new_constraints(new_constraints))
-                       // }
-                       // Constraint::IsDType(_) => None,
-                       // Constraint::EqualScalar(d) if d == &DType::scalar() => {
-                       //     println!("  (9) SOLVING: Scalar = Scalar trivially");
-                       //     Some(Satisfied::trivially())
-                       // }
-                       // Constraint::EqualScalar(dtype) => match dtype.split_first_factor() {
-                       //     Some(((DTypeFactor::TVar(tv), k), rest)) => {
-                       //         let result = DType::from_factors(
-                       //             &rest
-                       //                 .iter()
-                       //                 .map(|(f, j)| (f.clone(), -j / k))
-                       //                 .collect::<Vec<_>>(),
-                       //         );
-                       //         println!(
-                       //             "  (10) SOLVING: {dtype} = Scalar with substitution {tv} := {result}",
-                       //             dtype = dtype.pretty_print(),
-                       //             tv = tv,
-                       //             result = result.pretty_print()
-                       //         );
-                       //         Some(Satisfied::with_substitution(Substitution::single(
-                       //             tv.clone(),
-                       //             Type::DType(result),
-                       //         )))
-                       //     }
-                       //     _ => None,
-                       // },
+            Constraint::Equal(t1, t2) if t1 == t2 => {
+                println!("  (1) SOLVING: {} ~ {} trivially ", t1, t2);
+                Some(Satisfied::trivially())
+            }
+            Constraint::Equal(Type::TVar(x), t) if !t.contains(x) => {
+                println!(
+                    "  (2) SOLVING: {x} ~ {t} with substitution {x} := {t}",
+                    x = x.name(),
+                    t = t
+                );
+                Some(Satisfied::with_substitution(Substitution::single(
+                    x.clone(),
+                    t.clone(),
+                )))
+            }
+            Constraint::Equal(s, Type::TVar(x)) if !s.contains(x) => {
+                println!(
+                    "  (3) SOLVING: {s} ~ {x} with substitution {x} := {s}",
+                    s = s,
+                    x = x.name()
+                );
+                Some(Satisfied::with_substitution(Substitution::single(
+                    x.clone(),
+                    s.clone(),
+                )))
+            }
+            // Constraint::Equal(t @ Type::TArr(s1, s2), s @ Type::TArr(t1, t2)) => {
+            //     println!(
+            //         "  (4) SOLVING: {t} ~ {s} with new constraints {s1} ~ {t1} and {s2} ~ {t2}",
+            //         t = t,
+            //         s = s,
+            //         s1 = s1,
+            //         s2 = s2,
+            //         t1 = t1,
+            //         t2 = t2
+            //     );
+            //     Some(Satisfied::with_new_constraints(vec![
+            //         Constraint::Equal(s1.as_ref().clone(), t1.as_ref().clone()),
+            //         Constraint::Equal(s2.as_ref().clone(), t2.as_ref().clone()),
+            //     ]))
+            // }
+            Constraint::Equal(s @ Type::List(s1), t @ Type::List(t1)) => {
+                println!(
+                    "  (5) SOLVING: {s} ~ {t} with new constraint {s1} ~ {t1}",
+                    s = s,
+                    t = t,
+                    s1 = s1,
+                    t1 = t1
+                );
+                Some(Satisfied::with_new_constraints(vec![Constraint::Equal(
+                    s1.as_ref().clone(),
+                    t1.as_ref().clone(),
+                )]))
+            }
+            // Constraint::Equal(Type::TVar(tv), Type::Dimension(d))
+            // | Constraint::Equal(Type::Dimension(d), Type::TVar(tv)) => {
+            //     println!(
+            //         "  (6) SOLVING: {tv} ~ {d} by lifting the type variable to a DType",
+            //         tv = tv.name(),
+            //         d = d
+            //     );
+
+            //     Some(Satisfied::with_new_constraints(vec![Constraint::Equal(
+            //         Type::Dimension(DType::from_type_variable(tv.clone())),
+            //         Type::Dimension(d.clone()),
+            //     )]))
+            // }
+            // Constraint::Equal(Type::Dimension(d1), Type::Dimension(d2)) => {
+            //     let d_result = d1.divide(d2);
+            //     println!(
+            //         "  (7) SOLVING: {} ~ {} with new constraint d_result = Scalar",
+            //         d1.pretty_print(),
+            //         d2.pretty_print()
+            //     );
+            //     Some(Satisfied::with_new_constraints(vec![
+            //         Constraint::EqualScalar(d_result),
+            //     ]))
+            // }
+            Constraint::Equal(_, _) => None,
+            // Constraint::IsDType(Type::Dimension(inner)) => {
+            //     let new_constraints = inner
+            //         .type_variables()
+            //         .iter()
+            //         .map(|tv| Constraint::IsDType(Type::TVar(tv.clone())))
+            //         .collect();
+            //     println!(
+            //         "  (8) SOLVING: {} : DType through new constraints: {:?}",
+            //         inner.pretty_print(),
+            //         new_constraints
+            //     );
+            //     Some(Satisfied::with_new_constraints(new_constraints))
+            // }
+            // Constraint::IsDType(_) => None,
+            // Constraint::EqualScalar(d) if d == &DType::scalar() => {
+            //     println!("  (9) SOLVING: Scalar = Scalar trivially");
+            //     Some(Satisfied::trivially())
+            // }
+            // Constraint::EqualScalar(dtype) => match dtype.split_first_factor() {
+            //     Some(((DTypeFactor::TVar(tv), k), rest)) => {
+            //         let result = DType::from_factors(
+            //             &rest
+            //                 .iter()
+            //                 .map(|(f, j)| (f.clone(), -j / k))
+            //                 .collect::<Vec<_>>(),
+            //         );
+            //         println!(
+            //             "  (10) SOLVING: {dtype} = Scalar with substitution {tv} := {result}",
+            //             dtype = dtype.pretty_print(),
+            //             tv = tv,
+            //             result = result.pretty_print()
+            //         );
+            //         Some(Satisfied::with_substitution(Substitution::single(
+            //             tv.clone(),
+            //             Type::Dimension(result),
+            //         )))
+            //     }
+            _ => None,
+            // },
         }
     }
 
@@ -311,7 +305,7 @@ impl Constraint {
                 format!("  {} ~ {}", t1, t2)
             }
             Constraint::IsDType(t) => format!("  {}: DType", t),
-            Constraint::EqualScalar(d) => format!("  {} = Scalar", d),
+            // Constraint::EqualScalar(d) => format!("  {} = Scalar", d),
         }
     }
 
@@ -333,8 +327,7 @@ impl ApplySubstitution for Constraint {
             }
             Constraint::IsDType(t) => {
                 t.apply(substitution)?;
-            }
-            Constraint::EqualScalar(d) => d.apply(substitution)?,
+            } // Constraint::EqualScalar(d) => d.apply(substitution)?,
         }
         Ok(())
     }

+ 4 - 0
numbat/src/typechecker/error.rs

@@ -7,6 +7,7 @@ use crate::{NameResolutionError, Type};
 use thiserror::Error;
 
 use super::constraints::ConstraintSolverError;
+use super::substitutions::SubstitutionError;
 use super::IncompatibleDimensionsError;
 
 #[derive(Debug, Clone, Error, PartialEq, Eq)]
@@ -136,6 +137,9 @@ pub enum TypeCheckError {
 
     #[error(transparent)]
     ConstraintSolverError(#[from] ConstraintSolverError),
+
+    #[error(transparent)]
+    SubstitutionError(#[from] SubstitutionError),
 }
 
 pub type Result<T> = std::result::Result<T, TypeCheckError>;

+ 155 - 175
numbat/src/typechecker/mod.rs

@@ -23,12 +23,14 @@ use crate::span::Span;
 use crate::typed_ast::{self, DType, Expression, StructInfo, Type};
 use crate::{decorator, ffi, suggestion};
 
-use constraints::ConstraintSet;
+use constraints::{Constraint, ConstraintSet};
 use itertools::Itertools;
+use name_generator::NameGenerator;
 use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive, Zero};
 
 pub use error::{Result, TypeCheckError};
 pub use incompatible_dimensions::IncompatibleDimensionsError;
+use substitutions::ApplySubstitution;
 
 fn to_rational_exponent(exponent_f64: f64) -> Option<Exponent> {
     Rational::from_f64(exponent_f64)
@@ -46,7 +48,7 @@ fn dtype(e: &Expression) -> Result<DType> {
 /// need to know not just the *type* but also the *value* of the exponent.
 fn evaluate_const_expr(expr: &typed_ast::Expression) -> Result<Exponent> {
     match expr {
-        typed_ast::Expression::Scalar(span, n) => {
+        typed_ast::Expression::Scalar(span, n, _type) => {
             Ok(to_rational_exponent(n.to_f64())
                 .ok_or(TypeCheckError::NonRationalExponent(*span))?)
         }
@@ -183,13 +185,17 @@ pub struct TypeChecker {
     type_namespace: Namespace,
     value_namespace: Namespace,
 
+    name_generator: NameGenerator,
     constraints: ConstraintSet,
 }
 
 impl TypeChecker {
+    fn fresh_type_variable(&mut self) -> crate::type_variable::TypeVariable {
+        self.name_generator.fresh_type_variable()
+    }
+
     fn type_from_annotation(&self, annotation: &TypeAnnotation) -> Result<Type> {
         match annotation {
-            TypeAnnotation::Never(_) => Ok(Type::Never),
             TypeAnnotation::TypeExpression(dexpr) => {
                 if let TypeExpression::TypeIdentifier(_, name) = dexpr {
                     if let Some(info) = self.structs.get(name) {
@@ -433,13 +439,13 @@ impl TypeChecker {
                     }
                 }
                 (parameter_type, argument_type) => {
-                    if !argument_type.is_subtype_of(parameter_type) {
-                        return Err(TypeCheckError::IncompatibleTypesInFunctionCall(
-                            Some(*parameter_span),
-                            parameter_type.clone(),
-                            arguments[idx].full_span(),
-                            argument_type.clone(),
-                        ));
+                    if &argument_type != parameter_type {
+                        // return Err(TypeCheckError::IncompatibleTypesInFunctionCall(
+                        //     Some(*parameter_span),
+                        //     parameter_type.clone(),
+                        //     arguments[idx].full_span(),
+                        //     argument_type.clone(),
+                        // ));
                     }
                 }
             }
@@ -513,9 +519,15 @@ impl TypeChecker {
         ))
     }
 
-    fn elaborate_expression(&self, ast: &ast::Expression) -> Result<typed_ast::Expression> {
+    fn elaborate_expression(&mut self, ast: &ast::Expression) -> Result<typed_ast::Expression> {
         Ok(match ast {
-            ast::Expression::Scalar(span, n) => typed_ast::Expression::Scalar(*span, *n),
+            ast::Expression::Scalar(span, n) if n.to_f64().is_zero() => {
+                let tv = self.fresh_type_variable();
+                typed_ast::Expression::Scalar(*span, *n, Type::TVar(tv))
+            }
+            ast::Expression::Scalar(span, n) => {
+                typed_ast::Expression::Scalar(*span, *n, Type::scalar())
+            }
             ast::Expression::Identifier(span, name) => {
                 let type_ = self.identifier_type(*span, name)?.clone();
 
@@ -536,7 +548,6 @@ impl TypeChecker {
                 let checked_expr = self.elaborate_expression(expr)?;
                 let type_ = checked_expr.get_type();
                 match (&type_, op) {
-                    (Type::Never, _) => {}
                     (Type::Dimension(dtype), ast::UnaryOperator::Factorial) => {
                         if !dtype.is_scalar() {
                             return Err(TypeCheckError::NonScalarFactorialArgument(
@@ -572,23 +583,7 @@ impl TypeChecker {
                 let lhs_type = lhs_checked.get_type();
                 let rhs_type = rhs_checked.get_type();
 
-                if lhs_type.is_never() {
-                    return Ok(typed_ast::Expression::BinaryOperator(
-                        *span_op,
-                        *op,
-                        Box::new(lhs_checked),
-                        Box::new(rhs_checked),
-                        rhs_type,
-                    ));
-                } else if rhs_type.is_never() {
-                    return Ok(typed_ast::Expression::BinaryOperator(
-                        *span_op,
-                        *op,
-                        Box::new(lhs_checked),
-                        Box::new(rhs_checked),
-                        lhs_type,
-                    ));
-                } else if rhs_type.is_fn_type() && op == &BinaryOperator::ConvertTo {
+                if rhs_type.is_fn_type() && op == &BinaryOperator::ConvertTo {
                     let (parameter_types, return_type) = match rhs_type {
                         Type::Fn(p, r) => (p, r),
                         _ => unreachable!(),
@@ -604,13 +599,13 @@ impl TypeChecker {
                         });
                     }
 
-                    if !parameter_types[0].is_subtype_of(&lhs_type) {
-                        return Err(TypeCheckError::IncompatibleTypesInFunctionCall(
-                            None,
-                            parameter_types[0].clone(),
-                            lhs.full_span(),
-                            lhs_type,
-                        ));
+                    if parameter_types[0] != lhs_type {
+                        // return Err(TypeCheckError::IncompatibleTypesInFunctionCall(
+                        //     None,
+                        //     parameter_types[0].clone(),
+                        //     lhs.full_span(),
+                        //     lhs_type,
+                        // ));
                     }
 
                     typed_ast::Expression::CallableCall(
@@ -675,58 +670,64 @@ impl TypeChecker {
                         ));
                     }
                 } else {
-                    let get_type_and_assert_equality = || {
-                        let lhs_type = dtype(&lhs_checked)?;
-                        let rhs_type = dtype(&rhs_checked)?;
-                        if lhs_type != rhs_type {
-                            let full_span = ast::Expression::BinaryOperator {
-                                op: *op,
-                                lhs: lhs.clone(),
-                                rhs: rhs.clone(),
-                                span_op: *span_op,
-                            }
-                            .full_span();
-                            Err(TypeCheckError::IncompatibleDimensions(
-                                IncompatibleDimensionsError {
-                                    span_operation: span_op.unwrap_or(full_span),
-                                    operation: match op {
-                                        typed_ast::BinaryOperator::Add => "addition".into(),
-                                        typed_ast::BinaryOperator::Sub => "subtraction".into(),
-                                        typed_ast::BinaryOperator::Mul => "multiplication".into(),
-                                        typed_ast::BinaryOperator::Div => "division".into(),
-                                        typed_ast::BinaryOperator::Power => "exponentiation".into(),
-                                        typed_ast::BinaryOperator::ConvertTo => {
-                                            "unit conversion".into()
-                                        }
-                                        typed_ast::BinaryOperator::LessThan
-                                        | typed_ast::BinaryOperator::GreaterThan
-                                        | typed_ast::BinaryOperator::LessOrEqual
-                                        | typed_ast::BinaryOperator::GreaterOrEqual
-                                        | typed_ast::BinaryOperator::Equal
-                                        | typed_ast::BinaryOperator::NotEqual => {
-                                            "comparison".into()
-                                        }
-                                        typed_ast::BinaryOperator::LogicalAnd => "and".into(),
-                                        typed_ast::BinaryOperator::LogicalOr => "or".into(),
-                                    },
-                                    span_expected: lhs.full_span(),
-                                    expected_name: " left hand side",
-                                    expected_dimensions: self
-                                        .registry
-                                        .get_derived_entry_names_for(&lhs_type),
-                                    expected_type: lhs_type,
-                                    span_actual: rhs.full_span(),
-                                    actual_name: "right hand side",
-                                    actual_name_for_fix: "expression on the right hand side",
-                                    actual_dimensions: self
-                                        .registry
-                                        .get_derived_entry_names_for(&rhs_type),
-                                    actual_type: rhs_type,
-                                },
-                            ))
-                        } else {
-                            Ok(Type::Dimension(lhs_type))
-                        }
+                    let mut get_type_and_assert_equality = || -> Result<Type> {
+                        let lhs_type = lhs_checked.get_type();
+                        let rhs_type = rhs_checked.get_type();
+
+                        self.constraints
+                            .add(Constraint::Equal(lhs_type.clone(), rhs_type));
+
+                        Ok(lhs_type)
+
+                        // if lhs_type != rhs_type {
+                        //     let full_span = ast::Expression::BinaryOperator {
+                        //         op: *op,
+                        //         lhs: lhs.clone(),
+                        //         rhs: rhs.clone(),
+                        //         span_op: *span_op,
+                        //     }
+                        //     .full_span();
+                        //     Err(TypeCheckError::IncompatibleDimensions(
+                        //         IncompatibleDimensionsError {
+                        //             span_operation: span_op.unwrap_or(full_span),
+                        //             operation: match op {
+                        //                 typed_ast::BinaryOperator::Add => "addition".into(),
+                        //                 typed_ast::BinaryOperator::Sub => "subtraction".into(),
+                        //                 typed_ast::BinaryOperator::Mul => "multiplication".into(),
+                        //                 typed_ast::BinaryOperator::Div => "division".into(),
+                        //                 typed_ast::BinaryOperator::Power => "exponentiation".into(),
+                        //                 typed_ast::BinaryOperator::ConvertTo => {
+                        //                     "unit conversion".into()
+                        //                 }
+                        //                 typed_ast::BinaryOperator::LessThan
+                        //                 | typed_ast::BinaryOperator::GreaterThan
+                        //                 | typed_ast::BinaryOperator::LessOrEqual
+                        //                 | typed_ast::BinaryOperator::GreaterOrEqual
+                        //                 | typed_ast::BinaryOperator::Equal
+                        //                 | typed_ast::BinaryOperator::NotEqual => {
+                        //                     "comparison".into()
+                        //                 }
+                        //                 typed_ast::BinaryOperator::LogicalAnd => "and".into(),
+                        //                 typed_ast::BinaryOperator::LogicalOr => "or".into(),
+                        //             },
+                        //             span_expected: lhs.full_span(),
+                        //             expected_name: " left hand side",
+                        //             expected_dimensions: self
+                        //                 .registry
+                        //                 .get_derived_entry_names_for(&lhs_type),
+                        //             expected_type: lhs_type,
+                        //             span_actual: rhs.full_span(),
+                        //             actual_name: "right hand side",
+                        //             actual_name_for_fix: "expression on the right hand side",
+                        //             actual_dimensions: self
+                        //                 .registry
+                        //                 .get_derived_entry_names_for(&rhs_type),
+                        //             actual_type: rhs_type,
+                        //         },
+                        //     ))
+                        // } else {
+
+                        // }
                     };
 
                     let type_ = match op {
@@ -773,13 +774,13 @@ impl TypeChecker {
                                 || lhs_type.is_fn_type()
                                 || rhs_type.is_fn_type()
                             {
-                                return Err(TypeCheckError::IncompatibleTypesInComparison(
-                                    span_op.unwrap(),
-                                    lhs_type,
-                                    lhs.full_span(),
-                                    rhs_type,
-                                    rhs.full_span(),
-                                ));
+                                // return Err(TypeCheckError::IncompatibleTypesInComparison(
+                                //     span_op.unwrap(),
+                                //     lhs_type,
+                                //     lhs.full_span(),
+                                //     rhs_type,
+                                //     rhs.full_span(),
+                                // ));
                             }
 
                             Type::Boolean
@@ -850,13 +851,13 @@ impl TypeChecker {
                             for (param_type, arg_checked) in
                                 parameters_types.iter().zip(&arguments_checked)
                             {
-                                if !arg_checked.get_type().is_subtype_of(param_type) {
-                                    return Err(TypeCheckError::IncompatibleTypesInFunctionCall(
-                                        None,
-                                        param_type.clone(),
-                                        arg_checked.full_span(),
-                                        arg_checked.get_type(),
-                                    ));
+                                if &arg_checked.get_type() != param_type {
+                                    // return Err(TypeCheckError::IncompatibleTypesInFunctionCall(
+                                    //     None,
+                                    //     param_type.clone(),
+                                    //     arg_checked.full_span(),
+                                    //     arg_checked.get_type(),
+                                    // ));
                                 }
                             }
 
@@ -896,9 +897,9 @@ impl TypeChecker {
             ),
             ast::Expression::Condition(span, condition, then, else_) => {
                 let condition = self.elaborate_expression(condition)?;
-                if condition.get_type() != Type::Boolean {
-                    return Err(TypeCheckError::ExpectedBool(condition.full_span()));
-                }
+                // if condition.get_type() != Type::Boolean {
+                //     return Err(TypeCheckError::ExpectedBool(condition.full_span()));
+                // }
 
                 let then = self.elaborate_expression(then)?;
                 let else_ = self.elaborate_expression(else_)?;
@@ -906,24 +907,15 @@ impl TypeChecker {
                 let then_type = then.get_type();
                 let else_type = else_.get_type();
 
-                if then_type.is_never() || else_type.is_never() {
-                    // This case is fine. We use the type of the *other* branch in those cases.
-                    // For example:
-                    //
-                    //   if <some precondition>
-                    //     then X
-                    //     else error("please make sure <some precondition> is met")
-                    //
-                    // Here, we simply use the type of `X` as the type of the whole expression.
-                } else if then_type != else_type {
-                    return Err(TypeCheckError::IncompatibleTypesInCondition(
-                        *span,
-                        then_type,
-                        then.full_span(),
-                        else_type,
-                        else_.full_span(),
-                    ));
-                }
+                // if then_type != else_type {
+                //     return Err(TypeCheckError::IncompatibleTypesInCondition(
+                //         *span,
+                //         then_type,
+                //         then.full_span(),
+                //         else_type,
+                //         else_.full_span(),
+                //     ));
+                // }
 
                 typed_ast::Expression::Condition(
                     *span,
@@ -971,7 +963,7 @@ impl TypeChecker {
                     };
 
                     let found_type = &expr.get_type();
-                    if !found_type.is_subtype_of(expected_type) {
+                    if found_type != expected_type {
                         return Err(TypeCheckError::IncompatibleTypesForStructField(
                             *expected_field_span,
                             expected_type.clone(),
@@ -1047,20 +1039,21 @@ impl TypeChecker {
                     elements_checked.iter().map(|e| e.get_type()).collect();
 
                 let element_type = if element_types.is_empty() {
-                    todo!()
+                    let tv = self.fresh_type_variable();
+                    Type::List(Box::new(Type::TVar(tv)))
                 } else {
                     let type_of_first_element = element_types[0].clone();
                     for (subsequent_element, type_of_subsequent_element) in
                         elements_checked.iter().zip(element_types.iter()).skip(1)
                     {
-                        if type_of_first_element != *type_of_subsequent_element {
-                            return Err(TypeCheckError::IncompatibleTypesInList(
-                                elements_checked[0].full_span(),
-                                type_of_first_element.clone(),
-                                subsequent_element.full_span(),
-                                type_of_subsequent_element.clone(),
-                            ));
-                        }
+                        // if type_of_first_element != *type_of_subsequent_element {
+                        //     return Err(TypeCheckError::IncompatibleTypesInList(
+                        //         elements_checked[0].full_span(),
+                        //         type_of_first_element.clone(),
+                        //         subsequent_element.full_span(),
+                        //         type_of_subsequent_element.clone(),
+                        //     ));
+                        // }
                     }
                     type_of_first_element
                 };
@@ -1118,15 +1111,15 @@ impl TypeChecker {
                             }
                         }
                         (deduced, annotated) => {
-                            if !deduced.is_subtype_of(&annotated) {
-                                return Err(TypeCheckError::IncompatibleTypesInAnnotation(
-                                    "definition".into(),
-                                    *identifier_span,
-                                    annotated,
-                                    type_annotation.full_span(),
-                                    deduced.clone(),
-                                    expr_checked.full_span(),
-                                ));
+                            if deduced != &annotated {
+                                // return Err(TypeCheckError::IncompatibleTypesInAnnotation(
+                                //     "definition".into(),
+                                //     *identifier_span,
+                                //     annotated,
+                                //     type_annotation.full_span(),
+                                //     deduced.clone(),
+                                //     expr_checked.full_span(),
+                                // ));
                             }
                         }
                     }
@@ -1147,10 +1140,6 @@ impl TypeChecker {
                     identifier.clone(),
                     decorators.clone(),
                     expr_checked,
-                    type_annotation
-                        .as_ref()
-                        .map(|d| d.pretty_print())
-                        .unwrap_or_else(|| type_deduced.to_readable_type(&self.registry)),
                     type_deduced,
                 )
             }
@@ -1187,10 +1176,6 @@ impl TypeChecker {
                 typed_ast::Statement::DefineBaseUnit(
                     unit_name.clone(),
                     decorators.clone(),
-                    type_annotation
-                        .as_ref()
-                        .map(|d| d.pretty_print())
-                        .unwrap_or_else(|| type_specified.to_readable_type(&self.registry)),
                     Type::Dimension(type_specified),
                 )
             }
@@ -1247,10 +1232,6 @@ impl TypeChecker {
                     identifier.clone(),
                     expr_checked,
                     decorators.clone(),
-                    type_annotation
-                        .as_ref()
-                        .map(|d| d.pretty_print())
-                        .unwrap_or_else(|| type_deduced.to_readable_type(&self.registry)),
                     Type::Dimension(type_deduced),
                 )
             }
@@ -1345,10 +1326,6 @@ impl TypeChecker {
                         *parameter_span,
                         parameter.clone(),
                         *p_is_variadic,
-                        type_annotation
-                            .as_ref()
-                            .map(|d| d.pretty_print())
-                            .unwrap_or_else(|| parameter_type.to_readable_type(&self.registry)),
                         parameter_type,
                     ));
 
@@ -1367,7 +1344,7 @@ impl TypeChecker {
                 let add_function_signature = |tc: &mut TypeChecker, return_type: Type| {
                     let parameter_types = typed_parameters
                         .iter()
-                        .map(|(span, name, _, _, t)| (*span, name.clone(), t.clone()))
+                        .map(|(span, name, _, t)| (*span, name.clone(), t.clone()))
                         .collect();
                     tc.functions.insert(
                         function_name.clone(),
@@ -1435,16 +1412,16 @@ impl TypeChecker {
                                 Type::Dimension(dtype_deduced)
                             }
                             (type_deduced, type_specified) => {
-                                if !type_deduced.is_subtype_of(&type_specified) {
-                                    return Err(TypeCheckError::IncompatibleTypesInAnnotation(
-                                        "function definition".into(),
-                                        *function_name_span,
-                                        type_specified,
-                                        return_type_annotation_span.unwrap(),
-                                        type_deduced,
-                                        body.as_ref().map(|b| b.full_span()).unwrap(),
-                                    ));
-                                }
+                                // if type_deduced != type_specified {
+                                //     return Err(TypeCheckError::IncompatibleTypesInAnnotation(
+                                //         "function definition".into(),
+                                //         *function_name_span,
+                                //         type_specified,
+                                //         return_type_annotation_span.unwrap(),
+                                //         type_deduced,
+                                //         body.as_ref().map(|b| b.full_span()).unwrap(),
+                                //     ));
+                                // }
                                 type_specified
                             }
                         }
@@ -1478,10 +1455,6 @@ impl TypeChecker {
                         .collect(),
                     typed_parameters,
                     body_checked,
-                    return_type_annotation
-                        .as_ref()
-                        .map(|d| d.pretty_print())
-                        .unwrap_or_else(|| return_type.to_readable_type(&self.registry)),
                     return_type,
                 )
             }
@@ -1656,10 +1629,17 @@ impl TypeChecker {
         }
 
         // Solve constraints
-        self.constraints
+        let (substitution, dtype_variables) = self
+            .constraints
             .solve()
             .map_err(TypeCheckError::ConstraintSolverError)?;
 
+        for statement in &mut statements_elaborated {
+            statement
+                .apply(&substitution)
+                .map_err(TypeCheckError::SubstitutionError)?;
+        }
+
         Ok(statements_elaborated)
     }
 

+ 1 - 1
numbat/src/typechecker/name_generator.rs

@@ -1,6 +1,6 @@
 use crate::type_variable::TypeVariable;
 
-#[derive(Clone)]
+#[derive(Clone, Default)]
 pub struct NameGenerator {
     counter: u64,
 }

+ 139 - 5
numbat/src/typechecker/substitutions.rs

@@ -1,7 +1,8 @@
 use thiserror::Error;
 
 use crate::type_variable::TypeVariable;
-use crate::typed_ast::{DType, Type};
+use crate::typed_ast::{DType, Expression, StructInfo, Type};
+use crate::Statement;
 
 #[derive(Debug, Clone)]
 pub struct Substitution(pub Vec<(TypeVariable, Type)>);
@@ -46,13 +47,146 @@ pub trait ApplySubstitution {
 }
 
 impl ApplySubstitution for Type {
-    fn apply(&mut self, _substitution: &Substitution) -> Result<(), SubstitutionError> {
-        Ok(()) // TODO
+    fn apply(&mut self, s: &Substitution) -> Result<(), SubstitutionError> {
+        match self {
+            Type::TVar(v) => {
+                if let Some(type_) = s.lookup(v) {
+                    *self = type_.clone();
+                }
+                Ok(())
+            }
+            Type::Dimension(dtype) => dtype.apply(s),
+            Type::Boolean => Ok(()),
+            Type::String => Ok(()),
+            Type::DateTime => Ok(()),
+            Type::Fn(param_types, return_type) => {
+                for param_type in param_types {
+                    param_type.apply(s)?;
+                }
+                return_type.apply(s)
+            }
+            Type::Struct(info) => {
+                for (_, field_type) in info.fields.values_mut() {
+                    field_type.apply(s)?;
+                }
+                Ok(())
+            }
+            Type::List(element_type) => element_type.apply(s),
+        }
     }
 }
 
 impl ApplySubstitution for DType {
-    fn apply(&mut self, _substitution: &Substitution) -> Result<(), SubstitutionError> {
-        Ok(()) // TODO
+    fn apply(&mut self, s: &Substitution) -> Result<(), SubstitutionError> {
+        // TODO
+        Ok(())
+    }
+}
+
+impl ApplySubstitution for StructInfo {
+    fn apply(&mut self, s: &Substitution) -> Result<(), SubstitutionError> {
+        for (_, field_type) in self.fields.values_mut() {
+            field_type.apply(s)?;
+        }
+        Ok(())
+    }
+}
+
+impl ApplySubstitution for Expression {
+    fn apply(&mut self, s: &Substitution) -> Result<(), SubstitutionError> {
+        match self {
+            Expression::Scalar(_, _, type_) => type_.apply(s),
+            Expression::Identifier(_, _, type_) => type_.apply(s),
+            Expression::UnitIdentifier(_, _, _, _, type_) => type_.apply(s),
+            Expression::UnaryOperator(_, _, expr, type_) => {
+                expr.apply(s)?;
+                type_.apply(s)
+            }
+            Expression::BinaryOperator(_, _, lhs, rhs, type_) => {
+                lhs.apply(s)?;
+                rhs.apply(s)?;
+                type_.apply(s)
+            }
+            Expression::BinaryOperatorForDate(_, _, lhs, rhs, type_) => {
+                lhs.apply(s)?;
+                rhs.apply(s)?;
+                type_.apply(s)
+            }
+            Expression::FunctionCall(_, _, _, arguments, return_type) => {
+                for arg in arguments {
+                    arg.apply(s)?;
+                }
+                return_type.apply(s)
+            }
+            Expression::CallableCall(_, callable, arguments, return_type) => {
+                callable.apply(s)?;
+                for arg in arguments {
+                    arg.apply(s)?;
+                }
+                return_type.apply(s)
+            }
+            Expression::Boolean(_, _) => Ok(()),
+            Expression::Condition(_, if_, then_, else_) => {
+                if_.apply(s)?;
+                then_.apply(s)?;
+                else_.apply(s)
+            }
+            Expression::String(_, _) => Ok(()),
+            Expression::InstantiateStruct(_, initializers, info) => {
+                for (_, expr) in initializers {
+                    expr.apply(s)?;
+                }
+                info.apply(s)
+            }
+            Expression::AccessField(_, _, instance, _, info, type_) => {
+                instance.apply(s)?;
+                info.apply(s)?;
+                type_.apply(s)
+            }
+            Expression::List(_, elements, element_type) => {
+                for element in elements {
+                    element.apply(s)?;
+                }
+                element_type.apply(s)
+            }
+        }
+    }
+}
+
+impl ApplySubstitution for Statement {
+    fn apply(&mut self, s: &Substitution) -> Result<(), SubstitutionError> {
+        match self {
+            Statement::Expression(e) => e.apply(s),
+            Statement::DefineVariable(_, _, e, type_) => {
+                e.apply(s)?;
+                type_.apply(s)
+            }
+            Statement::DefineFunction(_, _, _, parameters, body, return_type) => {
+                for (_, _, _, parameter_type) in parameters {
+                    parameter_type.apply(s)?;
+                }
+                if let Some(body) = body {
+                    body.apply(s)?;
+                }
+                return_type.apply(s)
+            }
+            Statement::DefineDimension(_, _) => Ok(()),
+            Statement::DefineBaseUnit(_, _, type_) => type_.apply(s),
+            Statement::DefineDerivedUnit(_, e, _, type_) => {
+                e.apply(s)?;
+                type_.apply(s)
+            }
+            Statement::ProcedureCall(_, args) => {
+                for arg in args {
+                    arg.apply(s)?;
+                }
+                Ok(())
+            }
+            Statement::DefineStruct(info) => {
+                info.apply(s)?;
+
+                Ok(())
+            }
+        }
     }
 }

+ 25 - 46
numbat/src/typed_ast.rs

@@ -66,7 +66,6 @@ pub struct StructInfo {
 #[derive(Debug, Clone, PartialEq, Eq)]
 pub enum Type {
     TVar(TypeVariable),
-    Never,
     Dimension(DType),
     Boolean,
     String,
@@ -80,7 +79,6 @@ impl std::fmt::Display for Type {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         match self {
             Type::TVar(v) => write!(f, "{}", v.name()),
-            Type::Never => write!(f, "!"),
             Type::Dimension(d) => d.fmt(f),
             Type::Boolean => write!(f, "Bool"),
             Type::String => write!(f, "String"),
@@ -111,7 +109,6 @@ impl PrettyPrint for Type {
     fn pretty_print(&self) -> Markup {
         match self {
             Type::TVar(v) => m::type_identifier(&v.name()),
-            Type::Never => m::type_identifier("!"),
             Type::Dimension(d) => d.pretty_print(),
             Type::Boolean => m::type_identifier("Bool"),
             Type::String => m::type_identifier("String"),
@@ -154,10 +151,6 @@ impl Type {
         Type::Dimension(DType::unity())
     }
 
-    pub fn is_never(&self) -> bool {
-        matches!(self, Type::Never)
-    }
-
     pub fn is_dtype(&self) -> bool {
         matches!(self, Type::Dimension(..))
     }
@@ -166,15 +159,6 @@ impl Type {
         matches!(self, Type::Fn(..))
     }
 
-    pub fn is_subtype_of(&self, other: &Type) -> bool {
-        match (self, other) {
-            (Type::Never, _) => true,
-            (_, Type::Never) => false,
-            (Type::List(el1), Type::List(el2)) => el1.is_subtype_of(el2),
-            (t1, t2) => t1 == t2,
-        }
-    }
-
     pub(crate) fn type_variables(&self) -> Vec<TypeVariable> {
         todo!()
     }
@@ -182,6 +166,10 @@ impl Type {
     pub(crate) fn instantiate(&self, type_variables: &[TypeVariable]) -> Type {
         todo!()
     }
+
+    pub(crate) fn contains(&self, x: &TypeVariable) -> bool {
+        false // TODO!
+    }
 }
 
 #[derive(Debug, Clone, PartialEq)]
@@ -225,7 +213,7 @@ impl PrettyPrint for &Vec<StringPart> {
 
 #[derive(Debug, Clone, PartialEq)]
 pub enum Expression {
-    Scalar(Span, Number),
+    Scalar(Span, Number, Type),
     Identifier(Span, String, Type),
     UnitIdentifier(Span, Prefix, String, String, Type),
     UnaryOperator(Span, UnaryOperator, Box<Expression>, Type),
@@ -296,7 +284,7 @@ impl Expression {
 #[derive(Debug, Clone, PartialEq)]
 pub enum Statement {
     Expression(Expression),
-    DefineVariable(String, Vec<Decorator>, Expression, Markup, Type),
+    DefineVariable(String, Vec<Decorator>, Expression, Type),
     DefineFunction(
         String,
         Vec<Decorator>, // decorators
@@ -306,16 +294,14 @@ pub enum Statement {
             Span,   // span of the parameter
             String, // parameter name
             bool,   // whether or not it is variadic
-            Markup, // readable parameter type
             Type,   // parameter type
         )>,
         Option<Expression>, // function body
-        Markup,             // readable return type
         Type,               // return type
     ),
     DefineDimension(String, Vec<TypeExpression>),
-    DefineBaseUnit(String, Vec<Decorator>, Markup, Type),
-    DefineDerivedUnit(String, Expression, Vec<Decorator>, Markup, Type),
+    DefineBaseUnit(String, Vec<Decorator>, Type),
+    DefineDerivedUnit(String, Expression, Vec<Decorator>, Type),
     ProcedureCall(crate::ast::ProcedureKind, Vec<Expression>),
     DefineStruct(StructInfo),
 }
@@ -333,7 +319,7 @@ impl Statement {
 impl Expression {
     pub fn get_type(&self) -> Type {
         match self {
-            Expression::Scalar(_, _) => Type::Dimension(DType::unity()),
+            Expression::Scalar(_, _, type_) => type_.clone(),
             Expression::Identifier(_, _, type_) => type_.clone(),
             Expression::UnitIdentifier(_, _, _, _, _type) => _type.clone(),
             Expression::UnaryOperator(_, _, _, type_) => type_.clone(),
@@ -342,13 +328,7 @@ impl Expression {
             Expression::FunctionCall(_, _, _, _, type_) => type_.clone(),
             Expression::CallableCall(_, _, _, type_) => type_.clone(),
             Expression::Boolean(_, _) => Type::Boolean,
-            Expression::Condition(_, _, then_, else_) => {
-                if then_.get_type().is_never() {
-                    else_.get_type()
-                } else {
-                    then_.get_type()
-                }
-            }
+            Expression::Condition(_, _, then_, _) => then_.get_type(),
             Expression::String(_, _) => Type::String,
             Expression::InstantiateStruct(_, _, type_) => Type::Struct(type_.clone()),
             Expression::AccessField(_, _, _, _, _, type_) => type_.clone(),
@@ -424,13 +404,13 @@ fn decorator_markup(decorators: &Vec<Decorator>) -> Markup {
 impl PrettyPrint for Statement {
     fn pretty_print(&self) -> Markup {
         match self {
-            Statement::DefineVariable(identifier, _decs, expr, readable_type, _type) => {
+            Statement::DefineVariable(identifier, _decs, expr, type_) => {
                 m::keyword("let")
                     + m::space()
                     + m::identifier(identifier)
                     + m::operator(":")
                     + m::space()
-                    + readable_type.clone()
+                    + type_.pretty_print()
                     + m::space()
                     + m::operator("=")
                     + m::space()
@@ -442,8 +422,7 @@ impl PrettyPrint for Statement {
                 type_parameters,
                 parameters,
                 body,
-                readable_return_type,
-                _return_type,
+                return_type,
             ) => {
                 let markup_type_parameters = if type_parameters.is_empty() {
                     m::empty()
@@ -460,11 +439,11 @@ impl PrettyPrint for Statement {
                 let markup_parameters = Itertools::intersperse(
                     parameters
                         .iter()
-                        .map(|(_span, name, is_variadic, readable_type, _type)| {
+                        .map(|(_span, name, is_variadic, parameter_type)| {
                             m::identifier(name)
                                 + m::operator(":")
                                 + m::space()
-                                + readable_type.clone()
+                                + parameter_type.pretty_print()
                                 + if *is_variadic {
                                     m::operator("…")
                                 } else {
@@ -476,7 +455,7 @@ impl PrettyPrint for Statement {
                 .sum();
 
                 let markup_return_type =
-                    m::space() + m::operator("->") + m::space() + readable_return_type.clone();
+                    m::space() + m::operator("->") + m::space() + return_type.pretty_print();
 
                 m::keyword("fn")
                     + m::space()
@@ -508,23 +487,23 @@ impl PrettyPrint for Statement {
                     )
                     .sum()
             }
-            Statement::DefineBaseUnit(identifier, decorators, readable_type, _type) => {
+            Statement::DefineBaseUnit(identifier, decorators, type_) => {
                 decorator_markup(decorators)
                     + m::keyword("unit")
                     + m::space()
                     + m::unit(identifier)
                     + m::operator(":")
                     + m::space()
-                    + readable_type.clone()
+                    + type_.pretty_print()
             }
-            Statement::DefineDerivedUnit(identifier, expr, decorators, readable_type, _type) => {
+            Statement::DefineDerivedUnit(identifier, expr, decorators, type_) => {
                 decorator_markup(decorators)
                     + m::keyword("unit")
                     + m::space()
                     + m::unit(identifier)
                     + m::operator(":")
                     + m::space()
-                    + readable_type.clone()
+                    + type_.pretty_print()
                     + m::space()
                     + m::operator("=")
                     + m::space()
@@ -618,7 +597,7 @@ fn pretty_print_binop(op: &BinaryOperator, lhs: &Expression, rhs: &Expression) -
         }
         BinaryOperator::Mul => match (lhs, rhs) {
             (
-                Expression::Scalar(_, s),
+                Expression::Scalar(_, s, _type_scalar),
                 Expression::UnitIdentifier(_, prefix, _name, full_name, _type),
             ) => {
                 // Fuse multiplication of a scalar and a unit to a quantity
@@ -626,7 +605,7 @@ fn pretty_print_binop(op: &BinaryOperator, lhs: &Expression, rhs: &Expression) -
                     + m::space()
                     + m::unit(format!("{}{}", prefix.as_string_long(), full_name))
             }
-            (Expression::Scalar(_, s), Expression::Identifier(_, name, _type)) => {
+            (Expression::Scalar(_, s, _), Expression::Identifier(_, name, _type)) => {
                 // Fuse multiplication of a scalar and identifier
                 pretty_scalar(*s) + m::space() + m::identifier(name)
             }
@@ -702,10 +681,10 @@ fn pretty_print_binop(op: &BinaryOperator, lhs: &Expression, rhs: &Expression) -
 
             add_parens_if_needed(lhs) + op.pretty_print() + add_parens_if_needed(rhs)
         }
-        BinaryOperator::Power if matches!(rhs, Expression::Scalar(_, n) if n.to_f64() == 2.0) => {
+        BinaryOperator::Power if matches!(rhs, Expression::Scalar(_, n, _type) if n.to_f64() == 2.0) => {
             with_parens(lhs) + m::operator("²")
         }
-        BinaryOperator::Power if matches!(rhs, Expression::Scalar(_, n) if n.to_f64() == 3.0) => {
+        BinaryOperator::Power if matches!(rhs, Expression::Scalar(_, n, _type) if n.to_f64() == 3.0) => {
             with_parens(lhs) + m::operator("³")
         }
         _ => with_parens(lhs) + op.pretty_print() + with_parens(rhs),
@@ -717,7 +696,7 @@ impl PrettyPrint for Expression {
         use Expression::*;
 
         match self {
-            Scalar(_, n) => pretty_scalar(*n),
+            Scalar(_, n, _) => pretty_scalar(*n),
             Identifier(_, name, _type) => m::identifier(name),
             UnitIdentifier(_, prefix, _name, full_name, _type) => {
                 m::unit(format!("{}{}", prefix.as_string_long(), full_name))