浏览代码

Make structs nominally typed

Ben Simms 1 年之前
父节点
当前提交
ff205d167f

+ 6 - 3
book/src/example-numbat_syntax.md

@@ -118,10 +118,13 @@ type(2 m/s)                      # Print the type of an expression
 
 # 9. Structs
 
-${foo: 1A, bar: pi}  # Constructing a record
+# defining a struct
+struct Foo { foo: Current, bar: Scalar }
 
-# Records are structurally typed and field order is not fixed
-let r: ${bar: Scalar, foo: Current} = ${foo: 1A, bar: pi}
+# Constructing a struct
+Foo {foo: 1A, bar: pi}
+
+let r: Foo = Foo {foo: 1A, bar: pi}
 
 r.foo  # Field access is performed with `.field` notation
 ```

+ 11 - 8
book/src/structs.md

@@ -1,21 +1,24 @@
 # Struct types
 
-Numbat supports structurally typed records:
+Numbat supports nominally typed records:
 
 ```nbt
-# A value level struct
-${x: Length, y: Length}
+# Define a struct
+struct Vector {
+  x: Length,
+  y: Length,
+}
+
+let v = Vector {x: 6m, y: 8m}
 
 # A function with a struct as a parameter
-fn euclidian_distance(a: ${x: Length, y: Length}, b: ${x: Length, y: Length}) =
+fn euclidian_distance(a: Vector, b: Vector) =
   sqrt(sqr(a.x - b.x) + sqr(a.y - b.y))
   
 assert_eq(
-  euclidian_distance(${x: 0m, y: 0m}, ${x: 6m, y: 8m}),
+  euclidian_distance(Vector {x: 0m, y: 0m}, v),
   10m)
   
-let r = ${foo: 1}
-
 # Struct fields can be accessed using `.field` notation
-let x = r.foo
+let x = v.x
 ```

+ 6 - 3
examples/numbat_syntax.nbt

@@ -113,9 +113,12 @@ type(2 m/s)                      # Print the type of an expression
 
 # 9. Structs
 
-${foo: 1A, bar: pi}  # Constructing a record
+# defining a struct
+struct Foo { foo: Current, bar: Scalar }
 
-# Records are structurally typed and field order is not fixed
-let r: ${bar: Scalar, foo: Current} = ${foo: 1A, bar: pi}
+# Constructing a struct
+Foo {foo: 1A, bar: pi}
+
+let r: Foo = Foo {foo: 1A, bar: pi}
 
 r.foo  # Field access is performed with `.field` notation

+ 77 - 70
numbat/src/ast.rs

@@ -85,7 +85,12 @@ pub enum Expression {
     Boolean(Span, bool),
     String(Span, Vec<StringPart>),
     Condition(Span, Box<Expression>, Box<Expression>, Box<Expression>),
-    MakeStruct(Span, Vec<(Span, String, Expression)>),
+    MakeStruct {
+        full_span: Span,
+        ident_span: Span,
+        name: String,
+        fields: Vec<(Span, String, Expression)>,
+    },
     AccessStruct(Span, Span, Box<Expression>, String),
 }
 
@@ -118,7 +123,7 @@ impl Expression {
                 span_if.extend(&then_expr.full_span())
             }
             Expression::String(span, _) => *span,
-            Expression::MakeStruct(full_span, _) => *full_span,
+            Expression::MakeStruct { full_span, .. } => *full_span,
             Expression::AccessStruct(full_span, _ident_span, _, _) => *full_span,
         }
     }
@@ -204,13 +209,15 @@ macro_rules! conditional {
 
 #[cfg(test)]
 macro_rules! struct_ {
-    ( $( $field:ident : $val:expr ),* ) => {{
-        crate::ast::Expression::MakeStruct(
-            Span::dummy(),
-            vec![
+    ( $name:ident, $( $field:ident : $val:expr ),* ) => {{
+        crate::ast::Expression::MakeStruct {
+            full_span: Span::dummy(),
+            ident_span: Span::dummy(),
+            name: stringify!($name).to_owned(),
+            fields: vec![
                 $((Span::dummy(), stringify!($field).to_owned(), $val)),*
             ]
-        )
+        }
     }};
 }
 
@@ -236,24 +243,22 @@ pub(crate) use struct_;
 #[derive(Debug, Clone, PartialEq)]
 pub enum TypeAnnotation {
     Never(Span),
-    DimensionExpression(DimensionExpression),
+    TypeExpression(TypeExpression),
     Bool(Span),
     String(Span),
     DateTime(Span),
     Fn(Span, Vec<TypeAnnotation>, Box<TypeAnnotation>),
-    Struct(Span, Vec<(Span, String, TypeAnnotation)>),
 }
 
 impl TypeAnnotation {
     pub fn full_span(&self) -> Span {
         match self {
             TypeAnnotation::Never(span) => *span,
-            TypeAnnotation::DimensionExpression(d) => d.full_span(),
+            TypeAnnotation::TypeExpression(d) => d.full_span(),
             TypeAnnotation::Bool(span) => *span,
             TypeAnnotation::String(span) => *span,
             TypeAnnotation::DateTime(span) => *span,
             TypeAnnotation::Fn(span, _, _) => *span,
-            TypeAnnotation::Struct(span, _) => *span,
         }
     }
 }
@@ -262,7 +267,7 @@ impl PrettyPrint for TypeAnnotation {
     fn pretty_print(&self) -> Markup {
         match self {
             TypeAnnotation::Never(_) => m::type_identifier("!"),
-            TypeAnnotation::DimensionExpression(d) => d.pretty_print(),
+            TypeAnnotation::TypeExpression(d) => d.pretty_print(),
             TypeAnnotation::Bool(_) => m::type_identifier("Bool"),
             TypeAnnotation::String(_) => m::type_identifier("String"),
             TypeAnnotation::DateTime(_) => m::type_identifier("DateTime"),
@@ -281,48 +286,37 @@ impl PrettyPrint for TypeAnnotation {
                     + return_type.pretty_print()
                     + m::operator("]")
             }
-            TypeAnnotation::Struct(_, fields) => {
-                m::operator("${")
-                    + Itertools::intersperse(
-                        fields.iter().map(|(_, n, t)| {
-                            m::identifier(n) + m::operator(":") + m::space() + t.pretty_print()
-                        }),
-                        m::operator(",") + m::space(),
-                    )
-                    .sum()
-                    + m::operator("}")
-            }
         }
     }
 }
 
 #[derive(Debug, Clone, PartialEq)]
 
-pub enum DimensionExpression {
+pub enum TypeExpression {
     Unity(Span),
-    Dimension(Span, String),
-    Multiply(Span, Box<DimensionExpression>, Box<DimensionExpression>),
-    Divide(Span, Box<DimensionExpression>, Box<DimensionExpression>),
+    TypeIdentifier(Span, String),
+    Multiply(Span, Box<TypeExpression>, Box<TypeExpression>),
+    Divide(Span, Box<TypeExpression>, Box<TypeExpression>),
     Power(
         Option<Span>, // operator span, not available for unicode exponents
-        Box<DimensionExpression>,
+        Box<TypeExpression>,
         Span, // span for the exponent
         Exponent,
     ),
 }
 
-impl DimensionExpression {
+impl TypeExpression {
     pub fn full_span(&self) -> Span {
         match self {
-            DimensionExpression::Unity(s) => *s,
-            DimensionExpression::Dimension(s, _) => *s,
-            DimensionExpression::Multiply(span_op, lhs, rhs) => {
+            TypeExpression::Unity(s) => *s,
+            TypeExpression::TypeIdentifier(s, _) => *s,
+            TypeExpression::Multiply(span_op, lhs, rhs) => {
                 span_op.extend(&lhs.full_span()).extend(&rhs.full_span())
             }
-            DimensionExpression::Divide(span_op, lhs, rhs) => {
+            TypeExpression::Divide(span_op, lhs, rhs) => {
                 span_op.extend(&lhs.full_span()).extend(&rhs.full_span())
             }
-            DimensionExpression::Power(span_op, lhs, span_exponent, _exp) => match span_op {
+            TypeExpression::Power(span_op, lhs, span_exponent, _exp) => match span_op {
                 Some(span_op) => span_op.extend(&lhs.full_span()).extend(span_exponent),
                 None => lhs.full_span().extend(span_exponent),
             },
@@ -330,29 +324,29 @@ impl DimensionExpression {
     }
 }
 
-fn with_parens(dexpr: &DimensionExpression) -> Markup {
+fn with_parens(dexpr: &TypeExpression) -> Markup {
     match dexpr {
-        expr @ (DimensionExpression::Unity(..)
-        | DimensionExpression::Dimension(..)
-        | DimensionExpression::Power(..)) => expr.pretty_print(),
-        expr @ (DimensionExpression::Multiply(..) | DimensionExpression::Divide(..)) => {
+        expr @ (TypeExpression::Unity(..)
+        | TypeExpression::TypeIdentifier(..)
+        | TypeExpression::Power(..)) => expr.pretty_print(),
+        expr @ (TypeExpression::Multiply(..) | TypeExpression::Divide(..)) => {
             m::operator("(") + expr.pretty_print() + m::operator(")")
         }
     }
 }
 
-impl PrettyPrint for DimensionExpression {
+impl PrettyPrint for TypeExpression {
     fn pretty_print(&self) -> Markup {
         match self {
-            DimensionExpression::Unity(_) => m::type_identifier("1"),
-            DimensionExpression::Dimension(_, ident) => m::type_identifier(ident),
-            DimensionExpression::Multiply(_, lhs, rhs) => {
+            TypeExpression::Unity(_) => m::type_identifier("1"),
+            TypeExpression::TypeIdentifier(_, ident) => m::type_identifier(ident),
+            TypeExpression::Multiply(_, lhs, rhs) => {
                 lhs.pretty_print() + m::space() + m::operator("×") + m::space() + rhs.pretty_print()
             }
-            DimensionExpression::Divide(_, lhs, rhs) => {
+            TypeExpression::Divide(_, lhs, rhs) => {
                 lhs.pretty_print() + m::space() + m::operator("/") + m::space() + with_parens(rhs)
             }
-            DimensionExpression::Power(_, lhs, _, exp) => {
+            TypeExpression::Power(_, lhs, _, exp) => {
                 with_parens(lhs)
                     + m::operator("^")
                     + if exp.is_positive() {
@@ -395,18 +389,23 @@ pub enum Statement {
         /// Optional annotated return type
         return_type_annotation: Option<TypeAnnotation>,
     },
-    DefineDimension(String, Vec<DimensionExpression>),
-    DefineBaseUnit(Span, String, Option<DimensionExpression>, Vec<Decorator>),
+    DefineDimension(Span, String, Vec<TypeExpression>),
+    DefineBaseUnit(Span, String, Option<TypeExpression>, Vec<Decorator>),
     DefineDerivedUnit {
         identifier_span: Span,
         identifier: String,
         expr: Expression,
         type_annotation_span: Option<Span>,
-        type_annotation: Option<DimensionExpression>,
+        type_annotation: Option<TypeExpression>,
         decorators: Vec<Decorator>,
     },
     ProcedureCall(Span, ProcedureKind, Vec<Expression>),
     ModuleImport(Span, ModulePath),
+    DefineStruct {
+        struct_name_span: Span,
+        struct_name: String,
+        fields: Vec<(Span, String, TypeAnnotation)>,
+    },
 }
 
 #[cfg(test)]
@@ -419,45 +418,36 @@ impl ReplaceSpans for TypeAnnotation {
     fn replace_spans(&self) -> Self {
         match self {
             TypeAnnotation::Never(_) => TypeAnnotation::Never(Span::dummy()),
-            TypeAnnotation::DimensionExpression(d) => {
-                TypeAnnotation::DimensionExpression(d.replace_spans())
-            }
+            TypeAnnotation::TypeExpression(d) => TypeAnnotation::TypeExpression(d.replace_spans()),
             TypeAnnotation::Bool(_) => TypeAnnotation::Bool(Span::dummy()),
             TypeAnnotation::String(_) => TypeAnnotation::String(Span::dummy()),
             TypeAnnotation::DateTime(_) => TypeAnnotation::DateTime(Span::dummy()),
             TypeAnnotation::Fn(_, pt, rt) => {
                 TypeAnnotation::Fn(Span::dummy(), pt.clone(), rt.clone())
             }
-            TypeAnnotation::Struct(_, fields) => TypeAnnotation::Struct(
-                Span::dummy(),
-                fields
-                    .iter()
-                    .map(|(_, n, v)| (Span::dummy(), n.clone(), v.replace_spans()))
-                    .collect(),
-            ),
         }
     }
 }
 
 #[cfg(test)]
-impl ReplaceSpans for DimensionExpression {
+impl ReplaceSpans for TypeExpression {
     fn replace_spans(&self) -> Self {
         match self {
-            DimensionExpression::Unity(_) => DimensionExpression::Unity(Span::dummy()),
-            DimensionExpression::Dimension(_, d) => {
-                DimensionExpression::Dimension(Span::dummy(), d.clone())
+            TypeExpression::Unity(_) => TypeExpression::Unity(Span::dummy()),
+            TypeExpression::TypeIdentifier(_, d) => {
+                TypeExpression::TypeIdentifier(Span::dummy(), d.clone())
             }
-            DimensionExpression::Multiply(_, lhs, rhs) => DimensionExpression::Multiply(
+            TypeExpression::Multiply(_, lhs, rhs) => TypeExpression::Multiply(
                 Span::dummy(),
                 Box::new(lhs.replace_spans()),
                 Box::new(rhs.replace_spans()),
             ),
-            DimensionExpression::Divide(_, lhs, rhs) => DimensionExpression::Divide(
+            TypeExpression::Divide(_, lhs, rhs) => TypeExpression::Divide(
                 Span::dummy(),
                 Box::new(lhs.replace_spans()),
                 Box::new(rhs.replace_spans()),
             ),
-            DimensionExpression::Power(span_op, lhs, _, exp) => DimensionExpression::Power(
+            TypeExpression::Power(span_op, lhs, _, exp) => TypeExpression::Power(
                 span_op.map(|_| Span::dummy()),
                 Box::new(lhs.replace_spans()),
                 Span::dummy(),
@@ -531,13 +521,15 @@ impl ReplaceSpans for Expression {
                 Span::dummy(),
                 parts.iter().map(|p| p.replace_spans()).collect(),
             ),
-            Expression::MakeStruct(_, fields) => Expression::MakeStruct(
-                Span::dummy(),
-                fields
+            Expression::MakeStruct { name, fields, .. } => Expression::MakeStruct {
+                full_span: Span::dummy(),
+                ident_span: Span::dummy(),
+                name: name.clone(),
+                fields: fields
                     .iter()
                     .map(|(_, n, v)| (Span::dummy(), n.clone(), v.replace_spans()))
                     .collect(),
-            ),
+            },
             Expression::AccessStruct(_, _, expr, attr) => Expression::AccessStruct(
                 Span::dummy(),
                 Span::dummy(),
@@ -596,7 +588,8 @@ impl ReplaceSpans for Statement {
                 return_type_annotation_span: return_type_span.map(|_| Span::dummy()),
                 return_type_annotation: return_type_annotation.as_ref().map(|t| t.replace_spans()),
             },
-            Statement::DefineDimension(name, dexprs) => Statement::DefineDimension(
+            Statement::DefineDimension(_, name, dexprs) => Statement::DefineDimension(
+                Span::dummy(),
                 name.clone(),
                 dexprs.iter().map(|t| t.replace_spans()).collect(),
             ),
@@ -629,6 +622,20 @@ impl ReplaceSpans for Statement {
             Statement::ModuleImport(_, module_path) => {
                 Statement::ModuleImport(Span::dummy(), module_path.clone())
             }
+            Statement::DefineStruct {
+                struct_name,
+                fields,
+                ..
+            } => Statement::DefineStruct {
+                struct_name_span: Span::dummy(),
+                struct_name: struct_name.clone(),
+                fields: fields
+                    .into_iter()
+                    .map(|(_span, name, type_)| {
+                        (Span::dummy(), name.clone(), type_.replace_spans())
+                    })
+                    .collect(),
+            },
         }
     }
 }

+ 14 - 23
numbat/src/bytecode_interpreter.rs

@@ -1,6 +1,5 @@
 use std::collections::HashMap;
 
-use indexmap::IndexMap;
 use itertools::Itertools;
 
 use crate::ast::ProcedureKind;
@@ -168,39 +167,28 @@ impl BytecodeInterpreter {
                     self.vm.add_op2(Op::Call, idx, args.len() as u16); // TODO: check overflow
                 }
             }
-            Expression::MakeStruct(_span, exprs, type_) => {
-                // here we discard any type information about the struct, so we must have a consistent ordering of fields
-                //
-                // to do that, sort the fields lexographically
+            Expression::MakeStruct(_span, exprs, struct_info) => {
+                // structs must be consistently ordered in the VM, so we reorder
+                // the field values so that they are evaluated in the order the
+                // struct fields are defined.
+
                 let sorted_exprs = exprs
                     .iter()
-                    .sorted_by_key(|(n, _)| n)
-                    .cloned()
-                    .collect::<IndexMap<_, _>>();
+                    .sorted_by_key(|(n, _)| struct_info.fields.get_index_of(n).unwrap());
 
-                for (_, expr) in sorted_exprs.iter().rev() {
+                for (_, expr) in sorted_exprs.rev() {
                     self.compile_expression_with_simplify(expr)?;
                 }
 
-                let field_meta = type_
-                    .iter()
-                    .map(|(n, _)| (n.clone(), sorted_exprs.get_index_of(n).unwrap()))
-                    .collect();
-
-                let field_meta_idx = self.vm.add_struct_fields(field_meta);
+                let struct_info_idx = self.vm.get_structinfo_idx(&struct_info.name).unwrap() as u16;
 
                 self.vm
-                    .add_op2(Op::BuildStruct, field_meta_idx, exprs.len() as u16);
+                    .add_op2(Op::BuildStruct, struct_info_idx, exprs.len() as u16);
             }
-            Expression::AccessStruct(_span, _full_span, expr, attr, struct_type, _result_type) => {
+            Expression::AccessStruct(_span, _full_span, expr, attr, struct_info, _result_type) => {
                 self.compile_expression_with_simplify(expr)?;
 
-                let (idx, _) = struct_type
-                    .iter()
-                    .map(|(n, _)| n)
-                    .sorted()
-                    .find_position(|n| *n == attr)
-                    .unwrap();
+                let idx = struct_info.fields.get_index_of(attr).unwrap();
 
                 self.vm.add_op1(Op::DestructureStruct, idx as u16);
             }
@@ -478,6 +466,9 @@ impl BytecodeInterpreter {
                 self.vm
                     .add_op2(Op::FFICallProcedure, idx, args.len() as u16); // TODO: check overflow
             }
+            Statement::DefineStruct(struct_info) => {
+                self.vm.add_struct_info(struct_info);
+            }
         }
 
         Ok(())

+ 67 - 18
numbat/src/diagnostic.rs

@@ -46,6 +46,7 @@ impl ErrorDiagnostic for NameResolutionError {
         match self {
             NameResolutionError::IdentifierClash {
                 conflicting_identifier: _,
+                original_item_type,
                 conflict_span,
                 original_span,
             } => vec![Diagnostic::error()
@@ -53,7 +54,11 @@ impl ErrorDiagnostic for NameResolutionError {
                 .with_labels(vec![
                     original_span
                         .diagnostic_label(LabelStyle::Secondary)
-                        .with_message("Previously defined here"),
+                        .with_message(if let Some(t) = original_item_type.as_ref() {
+                            format!("Previously defined {t} here")
+                        } else {
+                            "Previously defined here".to_owned()
+                        }),
                     conflict_span
                         .diagnostic_label(LabelStyle::Primary)
                         .with_message("identifier is already in use"),
@@ -323,23 +328,6 @@ impl ErrorDiagnostic for TypeCheckError {
                         .with_notes(vec![inner_error])
                 }
             }
-            TypeCheckError::NameAlreadyUsedBy(_, definition_span, previous_definition_span) => {
-                let mut labels = vec![];
-
-                if let Some(span) = previous_definition_span {
-                    labels.push(
-                        span.diagnostic_label(LabelStyle::Secondary)
-                            .with_message("Previously defined here"),
-                    );
-                }
-
-                labels.push(
-                    definition_span
-                        .diagnostic_label(LabelStyle::Primary)
-                        .with_message(inner_error),
-                );
-                d.with_labels(labels)
-            }
             TypeCheckError::NoDimensionlessBaseUnit(span, unit_name) => d
                 .with_labels(vec![span
                     .diagnostic_label(LabelStyle::Primary)
@@ -416,6 +404,67 @@ impl ErrorDiagnostic for TypeCheckError {
                         .diagnostic_label(LabelStyle::Secondary)
                         .with_message(type_.to_string()),
                 ]),
+            TypeCheckError::IncompatibleTypesForStructField(
+                expected_field_span,
+                _expected_type,
+                expr_span,
+                _found_type,
+            ) => d.with_labels(vec![
+                expr_span
+                    .diagnostic_label(LabelStyle::Primary)
+                    .with_message(inner_error),
+                expected_field_span
+                    .diagnostic_label(LabelStyle::Secondary)
+                    .with_message("Defined here"),
+            ]),
+            TypeCheckError::UnknownStruct(span, _name) => d.with_labels(vec![span
+                .diagnostic_label(LabelStyle::Primary)
+                .with_message(inner_error)]),
+            TypeCheckError::UnknownFieldOfStruct(field_span, defn_span, _, _) => {
+                d.with_labels(vec![
+                    field_span
+                        .diagnostic_label(LabelStyle::Primary)
+                        .with_message(inner_error),
+                    defn_span
+                        .diagnostic_label(LabelStyle::Secondary)
+                        .with_message("Struct defined here"),
+                ])
+            }
+            TypeCheckError::DuplicateFieldInStructDefinition(
+                this_field_span,
+                that_field_span,
+                _attr_name,
+            ) => d.with_labels(vec![
+                this_field_span
+                    .diagnostic_label(LabelStyle::Primary)
+                    .with_message(inner_error),
+                that_field_span
+                    .diagnostic_label(LabelStyle::Secondary)
+                    .with_message("Already defined here"),
+            ]),
+            TypeCheckError::MissingFieldsFromStructConstruction(
+                construction_span,
+                defn_span,
+                missing,
+            ) => d
+                .with_labels(vec![
+                    construction_span
+                        .diagnostic_label(LabelStyle::Primary)
+                        .with_message(inner_error),
+                    defn_span
+                        .diagnostic_label(LabelStyle::Secondary)
+                        .with_message("Struct defined here"),
+                ])
+                .with_notes(vec!["Missing fields: ".to_owned()])
+                .with_notes(
+                    missing
+                        .iter()
+                        .map(|(n, t)| n.to_owned() + ": " + &t.to_string())
+                        .collect(),
+                ),
+            TypeCheckError::NameResolutionError(inner) => {
+                return inner.diagnostics();
+            }
         };
         vec![d]
     }

+ 8 - 8
numbat/src/dimension.rs

@@ -1,5 +1,5 @@
 use crate::arithmetic::Power;
-use crate::ast::DimensionExpression;
+use crate::ast::TypeExpression;
 use crate::registry::{BaseRepresentation, Registry, Result};
 
 #[derive(Default, Clone)]
@@ -10,27 +10,27 @@ pub struct DimensionRegistry {
 impl DimensionRegistry {
     pub fn get_base_representation(
         &self,
-        expression: &DimensionExpression,
+        expression: &TypeExpression,
     ) -> Result<BaseRepresentation> {
         match expression {
-            DimensionExpression::Unity(_) => Ok(BaseRepresentation::unity()),
-            DimensionExpression::Dimension(_, name) => self
+            TypeExpression::Unity(_) => Ok(BaseRepresentation::unity()),
+            TypeExpression::TypeIdentifier(_, name) => self
                 .registry
                 .get_base_representation_for_name(name)
                 .map(|r| r.0),
-            DimensionExpression::Multiply(_, lhs, rhs) => {
+            TypeExpression::Multiply(_, lhs, rhs) => {
                 let lhs = self.get_base_representation(lhs)?;
                 let rhs = self.get_base_representation(rhs)?;
 
                 Ok(lhs * rhs)
             }
-            DimensionExpression::Divide(_, lhs, rhs) => {
+            TypeExpression::Divide(_, lhs, rhs) => {
                 let lhs = self.get_base_representation(lhs)?;
                 let rhs = self.get_base_representation(rhs)?;
 
                 Ok(lhs / rhs)
             }
-            DimensionExpression::Power(_, expr, _, outer_exponent) => {
+            TypeExpression::Power(_, expr, _, outer_exponent) => {
                 Ok(self.get_base_representation(expr)?.power(*outer_exponent))
             }
         }
@@ -62,7 +62,7 @@ impl DimensionRegistry {
     pub fn add_derived_dimension(
         &mut self,
         name: &str,
-        expression: &DimensionExpression,
+        expression: &TypeExpression,
     ) -> Result<BaseRepresentation> {
         let base_representation = self.get_base_representation(expression)?;
         self.registry

+ 60 - 2
numbat/src/name_resolution.rs

@@ -1,18 +1,76 @@
+use std::collections::HashMap;
+
 use thiserror::Error;
 
 use crate::span::Span;
 
 pub const LAST_RESULT_IDENTIFIERS: &[&str] = &["ans", "_"];
 
-#[derive(Debug, Clone, Error)]
+#[derive(Debug, Clone, Error, PartialEq, Eq)]
 pub enum NameResolutionError {
-    #[error("Identifier is already in use: '{conflicting_identifier}'.")]
+    #[error("Identifier is already in use{}: '{conflicting_identifier}'.",
+            if let Some(t) = .original_item_type { format!(" by the {t}") } else { "".to_owned() })]
     IdentifierClash {
         conflicting_identifier: String,
         conflict_span: Span,
         original_span: Span,
+        original_item_type: Option<String>,
     },
 
     #[error("Reserved identifier")]
     ReservedIdentifier(Span),
 }
+
+#[derive(Debug, Clone, Default)]
+pub struct Namespace {
+    seen: HashMap<String, (String, Span)>,
+}
+
+impl Namespace {
+    pub fn add_allow_override(
+        &mut self,
+        name: String,
+        span: Span,
+        item_type: String,
+    ) -> Result<(), NameResolutionError> {
+        self.add_inner(name, span, item_type, true)
+    }
+
+    pub fn add(
+        &mut self,
+        name: String,
+        span: Span,
+        item_type: String,
+    ) -> Result<(), NameResolutionError> {
+        self.add_inner(name, span, item_type, false)
+    }
+
+    fn add_inner(
+        &mut self,
+        name: String,
+        span: Span,
+        item_type: String,
+        allow_override: bool,
+    ) -> Result<(), NameResolutionError> {
+        if let Some((original_item_type, original_span)) = self.seen.get(&name) {
+            if original_span == &span {
+                return Ok(());
+            }
+
+            if allow_override && original_item_type == &item_type {
+                return Ok(());
+            }
+
+            return Err(NameResolutionError::IdentifierClash {
+                conflicting_identifier: name,
+                conflict_span: span,
+                original_span: *original_span,
+                original_item_type: Some(original_item_type.clone()),
+            });
+        }
+
+        self.seen.insert(name, (item_type, span));
+
+        Ok(())
+    }
+}

+ 168 - 146
numbat/src/parser.rs

@@ -2,9 +2,10 @@
 //!
 //! Grammar:
 //! ```txt
-//! statement       ::=   variable_decl | function_decl | dimension_decl | unit_decl | module_import | procedure_call | expression
+//! statement       ::=   variable_decl | struct_decl | function_decl | dimension_decl | unit_decl | module_import | procedure_call | expression
 //!
 //! variable_decl   ::=   "let" identifier ( ":" type_annotation ) ? "=" expression
+//! struct_decl     ::=   "struct" identifier "{" ( identifier ":" type_annotation "," )* ( identifier ":" type_annotation "," ? ) ? "}"
 //! function_decl   ::=   "fn" identifier ( fn_decl_generic ) ? fn_decl_param ( "->" type_annotation ) ? ( "=" expression ) ?
 //! fn_decl_generic ::=   "<" ( identifier "," ) * identifier ">"
 //! fn_decl_param   ::=   "(" ( identifier ( ":" type_annotation ) ? "," )* ( identifier ( ":" type_annotation ) ) ? ")"
@@ -15,8 +16,7 @@
 //!
 //! decorator       ::=   "@" ( "metric_prefixes" | "binary_prefixes" | ( "aliases(" list_of_aliases ")" ) )
 //!
-//! type_annotation ::=   "Bool" | "String" | dimension_expr | struct_type
-//! struct_type     ::=   "${" ( identifier ":" type_annotation "," )* ( identifier ":" type_annotation "," ? ) ? "}"
+//! type_annotation ::=   "Bool" | "String" | dimension_expr
 //! dimension_expr  ::=   dim_factor
 //! dim_factor      ::=   dim_power ( (multiply | divide) dim_power ) *
 //! dim_power       ::=   dim_primary ( power dim_exponent | unicode_exponent ) ?
@@ -41,8 +41,8 @@
 //! unicode_power   ::=   call ( "⁻" ? ( "¹" | "²" | "³" | "⁴" | "⁵" | "⁶" | "⁷" | "⁸" | "⁹" ) ) ?
 //! call            ::=   primary ( ( "(" arguments? ")" ) | "." identifier ) *
 //! arguments       ::=   expression ( "," expression ) *
-//! primary         ::=   boolean | string | hex_number | oct_number | bin_number | number | identifier | "(" expression ")" | struct_expr
-//! struct_expr     ::=   "${" ( identifier ":" type_annotation "," )* ( identifier ":" expression "," ? ) ? "}"
+//! primary         ::=   boolean | string | hex_number | oct_number | bin_number | number | identifier ( struct_expr ? ) | "(" expression ")" | struct_expr
+//! struct_expr     ::=   "{" ( identifier ":" type_annotation "," )* ( identifier ":" expression "," ? ) ? "}"
 //!
 //! number          ::=   [0-9][0-9_]*("." ([0-9][0-9_]*)?)?([eE][+-]?[0-9][0-9_]*)?
 //! hex_number      ::=   "0x" [0-9a-fA-F]*
@@ -62,8 +62,8 @@
 
 use crate::arithmetic::{Exponent, Rational};
 use crate::ast::{
-    BinaryOperator, DimensionExpression, Expression, ProcedureKind, Statement, StringPart,
-    TypeAnnotation, UnaryOperator,
+    BinaryOperator, Expression, ProcedureKind, Statement, StringPart, TypeAnnotation,
+    TypeExpression, UnaryOperator,
 };
 use crate::decorator::{self, Decorator};
 use crate::number::Number;
@@ -211,6 +211,9 @@ pub enum ParseErrorKind {
 
     #[error("Expected {0} in function type")]
     ExpectedTokenInFunctionType(&'static str),
+
+    #[error("Expected '{{' after struct name")]
+    ExpectedLeftCurlyAfterStructName,
 }
 
 #[derive(Debug, Clone, Error)]
@@ -556,11 +559,13 @@ impl<'a> Parser<'a> {
                     }
 
                     Ok(Statement::DefineDimension(
+                        identifier.span,
                         identifier.lexeme.clone(),
                         dexprs,
                     ))
                 } else {
                     Ok(Statement::DefineDimension(
+                        identifier.span,
                         identifier.lexeme.clone(),
                         vec![],
                     ))
@@ -714,6 +719,52 @@ impl<'a> Parser<'a> {
                     span: self.peek().span,
                 })
             }
+        } else if self.match_exact(TokenKind::Struct).is_some() {
+            let name = self.identifier()?;
+            let name_span = self.last().unwrap().span;
+
+            if self.match_exact(TokenKind::LeftCurly).is_none() {
+                return Err(ParseError {
+                    kind: ParseErrorKind::ExpectedLeftCurlyAfterStructName,
+                    span: self.peek().span,
+                });
+            }
+
+            let mut fields = vec![];
+            while self.match_exact(TokenKind::RightCurly).is_none() {
+                let Some(field_name) = self.match_exact(TokenKind::Identifier) else {
+                    return Err(ParseError {
+                        kind: ParseErrorKind::ExpectedFieldNameInStruct,
+                        span: self.peek().span,
+                    });
+                };
+
+                if self.match_exact(TokenKind::Colon).is_none() {
+                    return Err(ParseError {
+                        kind: ParseErrorKind::ExpectedColonAfterFieldName,
+                        span: self.peek().span,
+                    });
+                }
+
+                let attr_type = self.type_annotation()?;
+
+                let has_comma = self.match_exact(TokenKind::Comma).is_some();
+                self.match_exact(TokenKind::Newline);
+                if !has_comma && self.peek().kind != TokenKind::RightCurly {
+                    return Err(ParseError {
+                        kind: ParseErrorKind::ExpectedCommaOrRightCurlyInStructFieldList,
+                        span: self.peek().span,
+                    });
+                }
+
+                fields.push((field_name.span, field_name.lexeme.to_owned(), attr_type));
+            }
+
+            Ok(Statement::DefineStruct {
+                struct_name_span: name_span,
+                struct_name: name,
+                fields,
+            })
         } else if self.match_any(PROCEDURES).is_some() {
             let span = self.last().unwrap().span;
             let procedure_kind = match self.last().unwrap().kind {
@@ -1154,45 +1205,51 @@ impl<'a> Parser<'a> {
         } else if let Some(_) = self.match_exact(TokenKind::Inf) {
             let span = self.last().unwrap().span;
             Ok(Expression::Scalar(span, Number::from_f64(f64::INFINITY)))
-        } else if self.match_exact(TokenKind::DollarLeftCurly).is_some() {
-            self.match_exact(TokenKind::Newline);
+        } else if let Some(identifier) = self.match_exact(TokenKind::Identifier) {
             let span = self.last().unwrap().span;
 
-            let mut fields = vec![];
-            while self.match_exact(TokenKind::RightCurly).is_none() {
-                let Some(field_name) = self.match_exact(TokenKind::Identifier) else {
-                    return Err(ParseError {
-                        kind: ParseErrorKind::ExpectedFieldNameInStruct,
-                        span: self.peek().span,
-                    });
-                };
+            if self.match_exact(TokenKind::LeftCurly).is_some() {
+                self.match_exact(TokenKind::Newline);
+                let mut fields = vec![];
+                while self.match_exact(TokenKind::RightCurly).is_none() {
+                    let Some(field_name) = self.match_exact(TokenKind::Identifier) else {
+                        return Err(ParseError {
+                            kind: ParseErrorKind::ExpectedFieldNameInStruct,
+                            span: self.peek().span,
+                        });
+                    };
 
-                if self.match_exact(TokenKind::Colon).is_none() {
-                    return Err(ParseError {
-                        kind: ParseErrorKind::ExpectedColonAfterFieldName,
-                        span: self.peek().span,
-                    });
-                }
+                    if self.match_exact(TokenKind::Colon).is_none() {
+                        return Err(ParseError {
+                            kind: ParseErrorKind::ExpectedColonAfterFieldName,
+                            span: self.peek().span,
+                        });
+                    }
 
-                let expr = self.expression()?;
+                    let expr = self.expression()?;
 
-                let has_comma = self.match_exact(TokenKind::Comma).is_some();
-                self.match_exact(TokenKind::Newline);
-                if !has_comma && self.peek().kind != TokenKind::RightCurly {
-                    return Err(ParseError {
-                        kind: ParseErrorKind::ExpectedCommaOrRightCurlyInStructFieldList,
-                        span: self.peek().span,
-                    });
+                    let has_comma = self.match_exact(TokenKind::Comma).is_some();
+                    self.match_exact(TokenKind::Newline);
+                    if !has_comma && self.peek().kind != TokenKind::RightCurly {
+                        return Err(ParseError {
+                            kind: ParseErrorKind::ExpectedCommaOrRightCurlyInStructFieldList,
+                            span: self.peek().span,
+                        });
+                    }
+
+                    fields.push((field_name.span, field_name.lexeme.to_owned(), expr));
                 }
 
-                fields.push((field_name.span, field_name.lexeme.to_owned(), expr));
-            }
+                let full_span = span.extend(&self.last().unwrap().span);
 
-            let span = span.extend(&self.last().unwrap().span);
+                return Ok(Expression::MakeStruct {
+                    full_span,
+                    ident_span: span,
+                    name: identifier.lexeme.clone(),
+                    fields,
+                });
+            }
 
-            Ok(Expression::MakeStruct(span, fields))
-        } else if let Some(identifier) = self.match_exact(TokenKind::Identifier) {
-            let span = self.last().unwrap().span;
             Ok(Expression::Identifier(span, identifier.lexeme.clone()))
         } else if let Some(inner) = self.match_any(&[TokenKind::True, TokenKind::False]) {
             Ok(Expression::Boolean(
@@ -1305,44 +1362,6 @@ impl<'a> Parser<'a> {
             Ok(TypeAnnotation::String(token.span))
         } else if let Some(token) = self.match_exact(TokenKind::DateTime) {
             Ok(TypeAnnotation::DateTime(token.span))
-        } else if self.match_exact(TokenKind::DollarLeftCurly).is_some() {
-            let span = self.last().unwrap().span;
-
-            self.match_exact(TokenKind::Newline);
-
-            let mut fields = vec![];
-            while self.match_exact(TokenKind::RightCurly).is_none() {
-                let Some(field_name) = self.match_exact(TokenKind::Identifier) else {
-                    return Err(ParseError {
-                        kind: ParseErrorKind::ExpectedFieldNameInStruct,
-                        span: self.peek().span,
-                    });
-                };
-
-                if self.match_exact(TokenKind::Colon).is_none() {
-                    return Err(ParseError {
-                        kind: ParseErrorKind::ExpectedColonAfterFieldName,
-                        span: self.peek().span,
-                    });
-                }
-
-                let attr_type = self.type_annotation()?;
-
-                let has_comma = self.match_exact(TokenKind::Comma).is_some();
-                self.match_exact(TokenKind::Newline);
-                if !has_comma && self.peek().kind != TokenKind::RightCurly {
-                    return Err(ParseError {
-                        kind: ParseErrorKind::ExpectedCommaOrRightCurlyInStructFieldList,
-                        span: self.peek().span,
-                    });
-                }
-
-                fields.push((field_name.span, field_name.lexeme.to_owned(), attr_type));
-            }
-
-            let span = span.extend(&self.last().unwrap().span);
-
-            Ok(TypeAnnotation::Struct(span, fields))
         } else if self.match_exact(TokenKind::CapitalFn).is_some() {
             let span = self.last().unwrap().span;
             if self.match_exact(TokenKind::LeftBracket).is_none() {
@@ -1393,39 +1412,37 @@ impl<'a> Parser<'a> {
 
             Ok(TypeAnnotation::Fn(span, params, Box::new(return_type)))
         } else {
-            Ok(TypeAnnotation::DimensionExpression(
-                self.dimension_expression()?,
-            ))
+            Ok(TypeAnnotation::TypeExpression(self.dimension_expression()?))
         }
     }
 
-    fn dimension_expression(&mut self) -> Result<DimensionExpression> {
+    fn dimension_expression(&mut self) -> Result<TypeExpression> {
         self.dimension_factor()
     }
 
-    fn dimension_factor(&mut self) -> Result<DimensionExpression> {
+    fn dimension_factor(&mut self) -> Result<TypeExpression> {
         let mut expr = self.dimension_power()?;
         while let Some(operator_token) = self.match_any(&[TokenKind::Multiply, TokenKind::Divide]) {
             let span = self.last().unwrap().span;
             let rhs = self.dimension_power()?;
 
             expr = if operator_token.kind == TokenKind::Multiply {
-                DimensionExpression::Multiply(span, Box::new(expr), Box::new(rhs))
+                TypeExpression::Multiply(span, Box::new(expr), Box::new(rhs))
             } else {
-                DimensionExpression::Divide(span, Box::new(expr), Box::new(rhs))
+                TypeExpression::Divide(span, Box::new(expr), Box::new(rhs))
             };
         }
         Ok(expr)
     }
 
-    fn dimension_power(&mut self) -> Result<DimensionExpression> {
+    fn dimension_power(&mut self) -> Result<TypeExpression> {
         let expr = self.dimension_primary()?;
 
         if self.match_exact(TokenKind::Power).is_some() {
             let span = self.last().unwrap().span;
             let (span_exponent, exponent) = self.dimension_exponent()?;
 
-            Ok(DimensionExpression::Power(
+            Ok(TypeExpression::Power(
                 Some(span),
                 Box::new(expr),
                 span_exponent,
@@ -1435,7 +1452,7 @@ impl<'a> Parser<'a> {
             let span_exponent = self.last().unwrap().span;
             let exp = Self::unicode_exponent_to_int(exponent.lexeme.as_str());
 
-            Ok(DimensionExpression::Power(
+            Ok(TypeExpression::Power(
                 None,
                 Box::new(expr),
                 span_exponent,
@@ -1508,7 +1525,7 @@ impl<'a> Parser<'a> {
         }
     }
 
-    fn dimension_primary(&mut self) -> Result<DimensionExpression> {
+    fn dimension_primary(&mut self) -> Result<TypeExpression> {
         let e = Err(ParseError::new(
             ParseErrorKind::ExpectedDimensionPrimary,
             self.peek().span,
@@ -1521,13 +1538,13 @@ impl<'a> Parser<'a> {
                 ));
             }
             let span = self.last().unwrap().span;
-            Ok(DimensionExpression::Dimension(span, token.lexeme.clone()))
+            Ok(TypeExpression::TypeIdentifier(span, token.lexeme.clone()))
         } else if let Some(number) = self.match_exact(TokenKind::Number) {
             let span = self.last().unwrap().span;
             if number.lexeme != "1" {
                 e
             } else {
-                Ok(DimensionExpression::Unity(span))
+                Ok(TypeExpression::Unity(span))
             }
         } else if self.match_exact(TokenKind::LeftParen).is_some() {
             let dexpr = self.dimension_expression()?;
@@ -1607,7 +1624,7 @@ pub fn parse(input: &str, code_source_id: usize) -> ParseResult {
 }
 
 #[cfg(test)]
-pub fn parse_dexpr(input: &str) -> DimensionExpression {
+pub fn parse_dexpr(input: &str) -> TypeExpression {
     let tokens = crate::tokenizer::tokenize(input, 0).expect("No tokenizer errors in tests");
     let mut parser = crate::parser::Parser::new(&tokens);
     let expr = parser
@@ -2053,8 +2070,8 @@ mod tests {
                 identifier_span: Span::dummy(),
                 identifier: "x".into(),
                 expr: binop!(scalar!(1.0), Mul, identifier!("meter")),
-                type_annotation: Some(TypeAnnotation::DimensionExpression(
-                    DimensionExpression::Dimension(Span::dummy(), "Length".into()),
+                type_annotation: Some(TypeAnnotation::TypeExpression(
+                    TypeExpression::TypeIdentifier(Span::dummy(), "Length".into()),
                 )),
                 decorators: Vec::new(),
             },
@@ -2067,8 +2084,8 @@ mod tests {
                 identifier_span: Span::dummy(),
                 identifier: "x".into(),
                 expr: binop!(scalar!(1.0), Mul, identifier!("meter")),
-                type_annotation: Some(TypeAnnotation::DimensionExpression(
-                    DimensionExpression::Dimension(Span::dummy(), "Length".into()),
+                type_annotation: Some(TypeAnnotation::TypeExpression(
+                    TypeExpression::TypeIdentifier(Span::dummy(), "Length".into()),
                 )),
                 decorators: vec![
                     decorator::Decorator::Name("myvar".into()),
@@ -2104,7 +2121,7 @@ mod tests {
     fn dimension_definition() {
         parse_as(
             &["dimension px"],
-            Statement::DefineDimension("px".into(), vec![]),
+            Statement::DefineDimension(Span::dummy(), "px".into(), vec![]),
         );
 
         parse_as(
@@ -2114,14 +2131,15 @@ mod tests {
                 "dimension Area =\n  Length × Length",
             ],
             Statement::DefineDimension(
+                Span::dummy(),
                 "Area".into(),
-                vec![DimensionExpression::Multiply(
+                vec![TypeExpression::Multiply(
                     Span::dummy(),
-                    Box::new(DimensionExpression::Dimension(
+                    Box::new(TypeExpression::TypeIdentifier(
                         Span::dummy(),
                         "Length".into(),
                     )),
-                    Box::new(DimensionExpression::Dimension(
+                    Box::new(TypeExpression::TypeIdentifier(
                         Span::dummy(),
                         "Length".into(),
                     )),
@@ -2132,14 +2150,15 @@ mod tests {
         parse_as(
             &["dimension Velocity = Length / Time"],
             Statement::DefineDimension(
+                Span::dummy(),
                 "Velocity".into(),
-                vec![DimensionExpression::Divide(
+                vec![TypeExpression::Divide(
                     Span::dummy(),
-                    Box::new(DimensionExpression::Dimension(
+                    Box::new(TypeExpression::TypeIdentifier(
                         Span::dummy(),
                         "Length".into(),
                     )),
-                    Box::new(DimensionExpression::Dimension(Span::dummy(), "Time".into())),
+                    Box::new(TypeExpression::TypeIdentifier(Span::dummy(), "Time".into())),
                 )],
             ),
         );
@@ -2147,10 +2166,11 @@ mod tests {
         parse_as(
             &["dimension Area = Length^2"],
             Statement::DefineDimension(
+                Span::dummy(),
                 "Area".into(),
-                vec![DimensionExpression::Power(
+                vec![TypeExpression::Power(
                     Some(Span::dummy()),
-                    Box::new(DimensionExpression::Dimension(
+                    Box::new(TypeExpression::TypeIdentifier(
                         Span::dummy(),
                         "Length".into(),
                     )),
@@ -2163,15 +2183,16 @@ mod tests {
         parse_as(
             &["dimension Energy = Mass * Length^2 / Time^2"],
             Statement::DefineDimension(
+                Span::dummy(),
                 "Energy".into(),
-                vec![DimensionExpression::Divide(
+                vec![TypeExpression::Divide(
                     Span::dummy(),
-                    Box::new(DimensionExpression::Multiply(
+                    Box::new(TypeExpression::Multiply(
                         Span::dummy(),
-                        Box::new(DimensionExpression::Dimension(Span::dummy(), "Mass".into())),
-                        Box::new(DimensionExpression::Power(
+                        Box::new(TypeExpression::TypeIdentifier(Span::dummy(), "Mass".into())),
+                        Box::new(TypeExpression::Power(
                             Some(Span::dummy()),
-                            Box::new(DimensionExpression::Dimension(
+                            Box::new(TypeExpression::TypeIdentifier(
                                 Span::dummy(),
                                 "Length".into(),
                             )),
@@ -2179,9 +2200,9 @@ mod tests {
                             Rational::from_integer(2),
                         )),
                     )),
-                    Box::new(DimensionExpression::Power(
+                    Box::new(TypeExpression::Power(
                         Some(Span::dummy()),
-                        Box::new(DimensionExpression::Dimension(Span::dummy(), "Time".into())),
+                        Box::new(TypeExpression::TypeIdentifier(Span::dummy(), "Time".into())),
                         Span::dummy(),
                         Rational::from_integer(2),
                     )),
@@ -2192,10 +2213,11 @@ mod tests {
         parse_as(
             &["dimension X = Length^(12345/67890)"],
             Statement::DefineDimension(
+                Span::dummy(),
                 "X".into(),
-                vec![DimensionExpression::Power(
+                vec![TypeExpression::Power(
                     Some(Span::dummy()),
-                    Box::new(DimensionExpression::Dimension(
+                    Box::new(TypeExpression::TypeIdentifier(
                         Span::dummy(),
                         "Length".into(),
                     )),
@@ -2236,8 +2258,8 @@ mod tests {
                 parameters: vec![],
                 body: Some(scalar!(1.0)),
                 return_type_annotation_span: Some(Span::dummy()),
-                return_type_annotation: Some(TypeAnnotation::DimensionExpression(
-                    DimensionExpression::Dimension(Span::dummy(), "Scalar".into()),
+                return_type_annotation: Some(TypeAnnotation::TypeExpression(
+                    TypeExpression::TypeIdentifier(Span::dummy(), "Scalar".into()),
                 )),
             },
         );
@@ -2282,52 +2304,50 @@ mod tests {
                     (
                         Span::dummy(),
                         "x".into(),
-                        Some(TypeAnnotation::DimensionExpression(
-                            DimensionExpression::Dimension(Span::dummy(), "Length".into()),
+                        Some(TypeAnnotation::TypeExpression(
+                            TypeExpression::TypeIdentifier(Span::dummy(), "Length".into()),
                         )),
                         false,
                     ),
                     (
                         Span::dummy(),
                         "y".into(),
-                        Some(TypeAnnotation::DimensionExpression(
-                            DimensionExpression::Dimension(Span::dummy(), "Time".into()),
+                        Some(TypeAnnotation::TypeExpression(
+                            TypeExpression::TypeIdentifier(Span::dummy(), "Time".into()),
                         )),
                         false,
                     ),
                     (
                         Span::dummy(),
                         "z".into(),
-                        Some(TypeAnnotation::DimensionExpression(
-                            DimensionExpression::Multiply(
-                                Span::dummy(),
-                                Box::new(DimensionExpression::Power(
-                                    Some(Span::dummy()),
-                                    Box::new(DimensionExpression::Dimension(
-                                        Span::dummy(),
-                                        "Length".into(),
-                                    )),
+                        Some(TypeAnnotation::TypeExpression(TypeExpression::Multiply(
+                            Span::dummy(),
+                            Box::new(TypeExpression::Power(
+                                Some(Span::dummy()),
+                                Box::new(TypeExpression::TypeIdentifier(
                                     Span::dummy(),
-                                    Rational::new(3, 1),
+                                    "Length".into(),
                                 )),
-                                Box::new(DimensionExpression::Power(
-                                    Some(Span::dummy()),
-                                    Box::new(DimensionExpression::Dimension(
-                                        Span::dummy(),
-                                        "Time".into(),
-                                    )),
+                                Span::dummy(),
+                                Rational::new(3, 1),
+                            )),
+                            Box::new(TypeExpression::Power(
+                                Some(Span::dummy()),
+                                Box::new(TypeExpression::TypeIdentifier(
                                     Span::dummy(),
-                                    Rational::new(2, 1),
+                                    "Time".into(),
                                 )),
-                            ),
-                        )),
+                                Span::dummy(),
+                                Rational::new(2, 1),
+                            )),
+                        ))),
                         false,
                     ),
                 ],
                 body: Some(scalar!(1.0)),
                 return_type_annotation_span: Some(Span::dummy()),
-                return_type_annotation: Some(TypeAnnotation::DimensionExpression(
-                    DimensionExpression::Dimension(Span::dummy(), "Scalar".into()),
+                return_type_annotation: Some(TypeAnnotation::TypeExpression(
+                    TypeExpression::TypeIdentifier(Span::dummy(), "Scalar".into()),
                 )),
             },
         );
@@ -2341,8 +2361,8 @@ mod tests {
                 parameters: vec![(
                     Span::dummy(),
                     "x".into(),
-                    Some(TypeAnnotation::DimensionExpression(
-                        DimensionExpression::Dimension(Span::dummy(), "X".into()),
+                    Some(TypeAnnotation::TypeExpression(
+                        TypeExpression::TypeIdentifier(Span::dummy(), "X".into()),
                     )),
                     false,
                 )],
@@ -2361,15 +2381,15 @@ mod tests {
                 parameters: vec![(
                     Span::dummy(),
                     "x".into(),
-                    Some(TypeAnnotation::DimensionExpression(
-                        DimensionExpression::Dimension(Span::dummy(), "D".into()),
+                    Some(TypeAnnotation::TypeExpression(
+                        TypeExpression::TypeIdentifier(Span::dummy(), "D".into()),
                     )),
                     true,
                 )],
                 body: None,
                 return_type_annotation_span: Some(Span::dummy()),
-                return_type_annotation: Some(TypeAnnotation::DimensionExpression(
-                    DimensionExpression::Dimension(Span::dummy(), "D".into()),
+                return_type_annotation: Some(TypeAnnotation::TypeExpression(
+                    TypeExpression::TypeIdentifier(Span::dummy(), "D".into()),
                 )),
             },
         );
@@ -2693,19 +2713,21 @@ mod tests {
     #[test]
     fn structs() {
         parse_as_expression(
-            &["${foo: 1, bar: 2}"],
+            &["Foo {foo: 1, bar: 2}"],
             struct_! {
+                Foo,
                 foo: scalar!(1.0),
                 bar: scalar!(2.0)
             },
         );
 
         parse_as_expression(
-            &["${foo: 1, bar: 2}.foo"],
+            &["Foo {foo: 1, bar: 2}.foo"],
             Expression::AccessStruct(
                 Span::dummy(),
                 Span::dummy(),
                 Box::new(struct_! {
+                    Foo,
                     foo: scalar!(1.0),
                     bar: scalar!(2.0)
                 }),

+ 1 - 0
numbat/src/prefix_parser.rs

@@ -143,6 +143,7 @@ impl PrefixParser {
     ) -> NameResolutionError {
         NameResolutionError::IdentifierClash {
             conflicting_identifier: name.to_string(),
+            original_item_type: None,
             conflict_span,
             original_span,
         }

+ 23 - 6
numbat/src/prefix_transformer.rs

@@ -97,12 +97,20 @@ impl Transformer {
                     })
                     .collect(),
             ),
-            Expression::MakeStruct(span, args) => Expression::MakeStruct(
-                span,
-                args.into_iter()
+            Expression::MakeStruct {
+                full_span,
+                ident_span,
+                name,
+                fields,
+            } => Expression::MakeStruct {
+                full_span,
+                ident_span,
+                name,
+                fields: fields
+                    .into_iter()
                     .map(|(span, attr, arg)| (span, attr, self.transform_expression(arg)))
                     .collect(),
-            ),
+            },
             Expression::AccessStruct(full_span, ident_span, expr, attr) => {
                 Expression::AccessStruct(
                     full_span,
@@ -228,9 +236,18 @@ impl Transformer {
                     return_type_annotation,
                 }
             }
-            Statement::DefineDimension(name, dexprs) => {
+            Statement::DefineStruct {
+                struct_name_span,
+                struct_name,
+                fields,
+            } => Statement::DefineStruct {
+                struct_name_span,
+                struct_name,
+                fields,
+            },
+            Statement::DefineDimension(name_span, name, dexprs) => {
                 self.dimension_names.push(name.clone());
-                Statement::DefineDimension(name, dexprs)
+                Statement::DefineDimension(name_span, name, dexprs)
             }
             Statement::ProcedureCall(span, procedure, args) => Statement::ProcedureCall(
                 span,

+ 4 - 2
numbat/src/tokenizer.rs

@@ -54,7 +54,7 @@ pub enum TokenKind {
     LeftBracket,
     RightBracket,
 
-    DollarLeftCurly, // ${
+    LeftCurly,
     RightCurly,
 
     // Operators and special signs
@@ -90,6 +90,7 @@ pub enum TokenKind {
     Dimension,
     Unit,
     Use,
+    Struct,
 
     To,
 
@@ -349,6 +350,7 @@ impl Tokenizer {
             m.insert("to", TokenKind::To);
             m.insert("let", TokenKind::Let);
             m.insert("fn", TokenKind::Fn);
+            m.insert("struct", TokenKind::Struct);
             m.insert("dimension", TokenKind::Dimension);
             m.insert("unit", TokenKind::Unit);
             m.insert("use", TokenKind::Use);
@@ -403,7 +405,7 @@ impl Tokenizer {
             ')' => TokenKind::RightParen,
             '[' => TokenKind::LeftBracket,
             ']' => TokenKind::RightBracket,
-            '$' if self.match_char('{') => TokenKind::DollarLeftCurly,
+            '{' if !self.interpolation_state.is_inside() => TokenKind::LeftCurly,
             '}' if !self.interpolation_state.is_inside() => TokenKind::RightCurly,
             '≤' => TokenKind::LessOrEqual,
             '<' if self.match_char('=') => TokenKind::LessOrEqual,

+ 231 - 103
numbat/src/typechecker.rs

@@ -4,7 +4,6 @@ use std::{
     fmt,
 };
 
-use crate::typed_ast::{self, Type};
 use crate::{
     arithmetic::{pretty_exponent, Exponent, Power, Rational},
     ast::ProcedureKind,
@@ -17,9 +16,14 @@ use crate::{
 };
 use crate::{dimension::DimensionRegistry, typed_ast::DType};
 use crate::{ffi::ArityRange, typed_ast::Expression};
+use crate::{
+    name_resolution::Namespace,
+    typed_ast::{self, StructInfo, Type},
+    NameResolutionError,
+};
 use crate::{name_resolution::LAST_RESULT_IDENTIFIERS, pretty_print::PrettyPrint};
 
-use ast::{BinaryOperator, DimensionExpression};
+use ast::{BinaryOperator, TypeExpression};
 use itertools::Itertools;
 use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive, Zero};
 use thiserror::Error;
@@ -283,8 +287,8 @@ pub enum TypeCheckError {
     #[error("Incompatible types in function call: expected '{1}', got '{3}' instead")]
     IncompatibleTypesInFunctionCall(Option<Span>, Type, Span, Type),
 
-    #[error("This name is already used by {0}")]
-    NameAlreadyUsedBy(&'static str, Span, Option<Span>),
+    #[error("Incompatible types for struct field: expected '{1}', got '{3}' instead")]
+    IncompatibleTypesForStructField(Span, Type, Span, Type),
 
     #[error("Missing a definition for dimension {1}")]
     MissingDimension(Span, String),
@@ -298,6 +302,15 @@ pub enum TypeCheckError {
     #[error("Base units can not be dimensionless.")]
     NoDimensionlessBaseUnit(Span, String),
 
+    #[error("Unknown struct '{1}")]
+    UnknownStruct(Span, String),
+
+    #[error("Unknown field {2} of struct {3}")]
+    UnknownFieldOfStruct(Span, Span, String, String),
+
+    #[error("Duplicate field {2} in struct definition")]
+    DuplicateFieldInStructDefinition(Span, Span, String),
+
     #[error("Duplicate field {2} in struct construction")]
     DuplicateFieldInStructConstruction(Span, Span, String),
 
@@ -306,6 +319,12 @@ pub enum TypeCheckError {
 
     #[error("Accessing unknown field {2} of struct {3}")]
     AccessingUnknownFieldOfStruct(Span, Span, String, Type),
+
+    #[error("Missing fields from struct construction")]
+    MissingFieldsFromStructConstruction(Span, Span, Vec<(String, Type)>),
+
+    #[error(transparent)]
+    NameResolutionError(#[from] NameResolutionError),
 }
 
 type Result<T> = std::result::Result<T, TypeCheckError>;
@@ -441,14 +460,17 @@ struct FunctionSignature {
     parameter_types: Vec<(Span, Type)>,
     is_variadic: bool,
     return_type: Type,
-    is_foreign: bool,
 }
 
 #[derive(Clone, Default)]
 pub struct TypeChecker {
     identifiers: HashMap<String, (Type, Option<Span>)>,
     function_signatures: HashMap<String, FunctionSignature>,
+    structs: HashMap<String, StructInfo>,
     registry: DimensionRegistry,
+
+    type_namespace: Namespace,
+    value_namespace: Namespace,
 }
 
 impl TypeChecker {
@@ -521,7 +543,6 @@ impl TypeChecker {
             parameter_types,
             is_variadic,
             return_type,
-            is_foreign: _,
         } = signature;
 
         let arity_range = if *is_variadic {
@@ -683,17 +704,23 @@ impl TypeChecker {
         ) -> Type {
             match t {
                 Type::Dimension(d) => Type::Dimension(substitute(&substitutions, d)),
-                Type::Struct(fields) => Type::Struct(
-                    fields
+                Type::Struct(StructInfo {
+                    definition_span,
+                    name,
+                    fields,
+                }) => Type::Struct(StructInfo {
+                    definition_span: *definition_span,
+                    name: name.clone(),
+                    fields: fields
                         .into_iter()
-                        .map(|(n, t)| {
+                        .map(|(n, (s, t))| {
                             (
                                 n.to_owned(),
-                                apply_substitutions(t, substitute, substitutions),
+                                (s.clone(), apply_substitutions(t, substitute, substitutions)),
                             )
                         })
                         .collect(),
-                ),
+                }),
                 type_ => type_.clone(),
             }
         }
@@ -1128,13 +1155,21 @@ impl TypeChecker {
                     Box::new(else_),
                 )
             }
-            ast::Expression::MakeStruct(span, fields) => {
+            ast::Expression::MakeStruct {
+                full_span,
+                ident_span,
+                name,
+                fields,
+            } => {
                 let fields_checked = fields
                     .iter()
                     .map(|(_, n, v)| Ok((n.to_string(), self.check_expression(v)?)))
                     .collect::<Result<Vec<_>>>()?;
 
-                let mut field_types = indexmap::IndexMap::new();
+                let Some(struct_info) = self.structs.get(name) else {
+                    return Err(TypeCheckError::UnknownStruct(*ident_span, name.clone()));
+                };
+
                 let mut seen_fields = HashMap::new();
 
                 for ((field, expr), span) in
@@ -1147,18 +1182,52 @@ impl TypeChecker {
                             field.to_string(),
                         ));
                     }
-                    field_types.insert(field.to_string(), expr.get_type());
+
+                    let Some((expected_field_span, expected_type)) = struct_info.fields.get(field)
+                    else {
+                        return Err(TypeCheckError::UnknownFieldOfStruct(
+                            *span,
+                            struct_info.definition_span,
+                            field.clone(),
+                            struct_info.name.clone(),
+                        ));
+                    };
+
+                    let found_type = &expr.get_type();
+                    if !found_type.is_subtype_of(expected_type) {
+                        return Err(TypeCheckError::IncompatibleTypesForStructField(
+                            *expected_field_span,
+                            expected_type.clone(),
+                            expr.full_span(),
+                            found_type.clone(),
+                        ));
+                    }
+
                     seen_fields.insert(field, *span);
                 }
 
-                typed_ast::Expression::MakeStruct(*span, fields_checked, field_types)
+                let missing_fields = {
+                    let mut fields = struct_info.fields.clone();
+                    fields.retain(|f, _| !seen_fields.contains_key(f));
+                    fields.into_iter().map(|(n, (_, t))| (n, t)).collect_vec()
+                };
+
+                if !missing_fields.is_empty() {
+                    return Err(TypeCheckError::MissingFieldsFromStructConstruction(
+                        *full_span,
+                        struct_info.definition_span,
+                        missing_fields,
+                    ));
+                }
+
+                typed_ast::Expression::MakeStruct(*full_span, fields_checked, struct_info.clone())
             }
             ast::Expression::AccessStruct(full_span, ident_span, expr, attr) => {
                 let expr_checked = self.check_expression(expr)?;
 
                 let type_ = expr_checked.get_type();
 
-                let Type::Struct(fields) = type_.clone() else {
+                let Type::Struct(struct_info) = type_.clone() else {
                     return Err(TypeCheckError::AccessingFieldOfNonStruct(
                         *ident_span,
                         expr.full_span(),
@@ -1167,7 +1236,7 @@ impl TypeChecker {
                     ));
                 };
 
-                let Some((_, ret_ty)) = fields.iter().find(|(n, _)| *n == attr) else {
+                let Some((_, ret_ty)) = struct_info.fields.get(attr) else {
                     return Err(TypeCheckError::AccessingUnknownFieldOfStruct(
                         *ident_span,
                         expr.full_span(),
@@ -1183,7 +1252,7 @@ impl TypeChecker {
                     *full_span,
                     Box::new(expr_checked),
                     attr.to_owned(),
-                    fields,
+                    struct_info,
                     ret_ty,
                 )
             }
@@ -1207,16 +1276,6 @@ impl TypeChecker {
                 type_annotation,
                 decorators,
             } => {
-                // Make sure that identifier does not clash with a function name. We do not
-                // check for clashes with unit names, as this is handled by the prefix parser.
-                if let Some(signature) = self.function_signatures.get(identifier) {
-                    return Err(TypeCheckError::NameAlreadyUsedBy(
-                        "a function",
-                        *identifier_span,
-                        Some(signature.definition_span),
-                    ));
-                }
-
                 let expr_checked = self.check_expression(expr)?;
                 let type_deduced = expr_checked.get_type();
 
@@ -1265,6 +1324,12 @@ impl TypeChecker {
                 for (name, _) in decorator::name_and_aliases(identifier, decorators) {
                     self.identifiers
                         .insert(name.clone(), (type_deduced.clone(), Some(*identifier_span)));
+
+                    self.value_namespace.add_allow_override(
+                        name.clone(),
+                        *identifier_span,
+                        "constant".to_owned(),
+                    )?;
                 }
 
                 typed_ast::Statement::DefineVariable(
@@ -1387,24 +1452,18 @@ impl TypeChecker {
                 return_type_annotation_span,
                 return_type_annotation,
             } => {
-                // Make sure that function name does not clash with an identifier. We do not
-                // check for clashes with unit names, as this is handled by the prefix parser.
-                if let Some((_, span)) = self.identifiers.get(function_name) {
-                    return Err(TypeCheckError::NameAlreadyUsedBy(
-                        "a constant",
+                if body.is_none() {
+                    self.value_namespace.add(
+                        function_name.clone(),
                         *function_name_span,
-                        *span,
-                    ));
-                }
-
-                if let Some(signature) = self.function_signatures.get(function_name) {
-                    if signature.is_foreign {
-                        return Err(TypeCheckError::NameAlreadyUsedBy(
-                            "a foreign function",
-                            *function_name_span,
-                            Some(signature.definition_span),
-                        ));
-                    }
+                        "foreign function".to_owned(),
+                    )?;
+                } else {
+                    self.value_namespace.add_allow_override(
+                        function_name.clone(),
+                        *function_name_span,
+                        "function".to_owned(),
+                    )?;
                 }
 
                 let mut typechecker_fn = self.clone();
@@ -1451,7 +1510,7 @@ impl TypeChecker {
                         Type::Dimension(
                             typechecker_fn
                                 .registry
-                                .get_base_representation(&DimensionExpression::Dimension(
+                                .get_base_representation(&TypeExpression::TypeIdentifier(
                                     *parameter_span,
                                     free_type_parameter,
                                 ))
@@ -1486,35 +1545,29 @@ impl TypeChecker {
                     .map(|annotation| typechecker_fn.type_from_annotation(annotation))
                     .transpose()?;
 
-                let add_function_signature =
-                    |tc: &mut TypeChecker, return_type: Type, is_foreign: bool| {
-                        let parameter_types = typed_parameters
-                            .iter()
-                            .map(|(span, _, _, _, t)| (*span, t.clone()))
-                            .collect();
-                        tc.function_signatures.insert(
-                            function_name.clone(),
-                            FunctionSignature {
-                                definition_span: *function_name_span,
-                                type_parameters: type_parameters.clone(),
-                                parameter_types,
-                                is_variadic,
-                                return_type,
-                                is_foreign,
-                            },
-                        );
-                    };
+                let add_function_signature = |tc: &mut TypeChecker, return_type: Type| {
+                    let parameter_types = typed_parameters
+                        .iter()
+                        .map(|(span, _, _, _, t)| (*span, t.clone()))
+                        .collect();
+                    tc.function_signatures.insert(
+                        function_name.clone(),
+                        FunctionSignature {
+                            definition_span: *function_name_span,
+                            type_parameters: type_parameters.clone(),
+                            parameter_types,
+                            is_variadic,
+                            return_type,
+                        },
+                    );
+                };
 
                 if let Some(ref return_type_specified) = return_type_specified {
                     // This is needed for recursive functions. If the return type
                     // has been specified, we can already provide a function
                     // signature before we check the body of the function. This
                     // way, the 'typechecker_fn' can resolve the recursive call.
-                    add_function_signature(
-                        &mut typechecker_fn,
-                        return_type_specified.clone(),
-                        body.is_none(),
-                    );
+                    add_function_signature(&mut typechecker_fn, return_type_specified.clone());
                 }
 
                 let body_checked = body
@@ -1588,7 +1641,7 @@ impl TypeChecker {
                     })?
                 };
 
-                add_function_signature(self, return_type.clone(), body.is_none());
+                add_function_signature(self, return_type.clone());
 
                 typed_ast::Statement::DefineFunction(
                     function_name.clone(),
@@ -1605,7 +1658,10 @@ impl TypeChecker {
                     return_type,
                 )
             }
-            ast::Statement::DefineDimension(name, dexprs) => {
+            ast::Statement::DefineDimension(name_span, name, dexprs) => {
+                self.type_namespace
+                    .add(name.clone(), *name_span, "dimension".to_owned())?;
+
                 if let Some(dexpr) = dexprs.first() {
                     self.registry
                         .add_derived_dimension(name, dexpr)
@@ -1713,6 +1769,46 @@ impl TypeChecker {
             ast::Statement::ModuleImport(_, _) => {
                 unreachable!("Modules should have been inlined by now")
             }
+            ast::Statement::DefineStruct {
+                struct_name_span,
+                struct_name,
+                fields,
+            } => {
+                self.type_namespace.add(
+                    struct_name.clone(),
+                    *struct_name_span,
+                    "struct".to_owned(),
+                )?;
+
+                let mut seen_fields = HashMap::new();
+
+                for (span, field, _) in fields {
+                    if let Some(other_span) = seen_fields.get(field) {
+                        return Err(TypeCheckError::DuplicateFieldInStructDefinition(
+                            *span,
+                            *other_span,
+                            field.to_string(),
+                        ));
+                    }
+
+                    seen_fields.insert(field, *span);
+                }
+
+                let struct_info = StructInfo {
+                    definition_span: *struct_name_span,
+                    name: struct_name.clone(),
+                    fields: fields
+                        .iter()
+                        .map(|(span, name, type_)| {
+                            Ok((name.clone(), (*span, self.type_from_annotation(type_)?)))
+                        })
+                        .collect::<Result<_>>()?,
+                };
+                self.structs
+                    .insert(struct_name.clone(), struct_info.clone());
+
+                typed_ast::Statement::DefineStruct(struct_info)
+            }
         })
     }
 
@@ -1735,11 +1831,21 @@ impl TypeChecker {
     fn type_from_annotation(&self, annotation: &TypeAnnotation) -> Result<Type> {
         match annotation {
             TypeAnnotation::Never(_) => Ok(Type::Never),
-            TypeAnnotation::DimensionExpression(dexpr) => self
-                .registry
-                .get_base_representation(dexpr)
-                .map(Type::Dimension)
-                .map_err(TypeCheckError::RegistryError),
+            TypeAnnotation::TypeExpression(dexpr) => {
+                if let TypeExpression::TypeIdentifier(_, name) = dexpr {
+                    if let Some(info) = self.structs.get(name) {
+                        // if we see a struct name here, it's safe to assume it
+                        // isn't accidentally clashing with a dimension, we
+                        // check that earlier.
+                        return Ok(Type::Struct(info.clone()));
+                    }
+                }
+
+                self.registry
+                    .get_base_representation(dexpr)
+                    .map(Type::Dimension)
+                    .map_err(TypeCheckError::RegistryError)
+            }
             TypeAnnotation::Bool(_) => Ok(Type::Boolean),
             TypeAnnotation::String(_) => Ok(Type::String),
             TypeAnnotation::DateTime(_) => Ok(Type::DateTime),
@@ -1750,12 +1856,6 @@ impl TypeChecker {
                     .collect::<Result<Vec<_>>>()?,
                 Box::new(self.type_from_annotation(return_type)?),
             )),
-            TypeAnnotation::Struct(_, fields) => Ok(Type::Struct(
-                fields
-                    .iter()
-                    .map(|(_, n, t)| Ok((n.clone(), self.type_from_annotation(t)?)))
-                    .collect::<Result<indexmap::IndexMap<_, _>>>()?,
-            )),
         }
     }
 }
@@ -1785,6 +1885,8 @@ mod tests {
     fn returns_never() -> ! = error(\"…\")
     fn takes_never_returns_a(x: !) -> A = a
 
+    struct SomeStruct { a: A, b: B }
+
     let callable = takes_a_returns_b
     ";
 
@@ -2014,24 +2116,24 @@ mod tests {
         ));
     }
 
-    #[test]
-    fn generics_with_records() {
-        assert_successful_typecheck(
-            "
-            fn f<D>(x: D) = ${foo: x}
-            f(2)
-            f(2 a).foo == 2 a
-            ",
-        );
-
-        assert_successful_typecheck(
-            "
-            fn f<D>(x: D) -> ${foo: D} = ${foo: x}
-            f(2)
-            f(2 a).foo == 2 a
-            ",
-        );
-    }
+    // #[test]
+    // fn generics_with_records() {
+    //     assert_successful_typecheck(
+    //         "
+    //         fn f<D>(x: D) = ${foo: x}
+    //         f(2)
+    //         f(2 a).foo == 2 a
+    //         ",
+    //     );
+
+    //     assert_successful_typecheck(
+    //         "
+    //         fn f<D>(x: D) -> ${foo: D} = ${foo: x}
+    //         f(2)
+    //         f(2 a).foo == 2 a
+    //         ",
+    //     );
+    // }
 
     #[test]
     fn generics_multiple_unresolved_type_parameters() {
@@ -2455,14 +2557,35 @@ mod tests {
         ));
     }
 
+    #[test]
     fn struct_errors() {
         assert!(matches!(
-            get_typecheck_error("${foo: 1, foo: 2}"),
-            TypeCheckError::DuplicateFieldInStructConstruction(_, _, field) if field == "foo"
+            get_typecheck_error("SomeStruct {a: 1, b: 1b}"),
+            TypeCheckError::IncompatibleTypesForStructField(..)
         ));
 
         assert!(matches!(
-            get_typecheck_error("${}.foo"),
+            get_typecheck_error("NotAStruct {}"),
+            TypeCheckError::UnknownStruct(_, name) if name == "NotAStruct"
+        ));
+
+        assert!(matches!(
+            get_typecheck_error("SomeStruct {not_a_field: 1}"),
+            TypeCheckError::UnknownFieldOfStruct(_, _, field, _) if field == "not_a_field"
+        ));
+
+        assert!(matches!(
+            get_typecheck_error("struct Foo { foo: A, foo: A }"),
+            TypeCheckError::DuplicateFieldInStructDefinition(_, _, field) if field == "foo"
+        ));
+
+        assert!(matches!(
+            get_typecheck_error("SomeStruct {a: 1a, a: 1a, b: 2b}"),
+            TypeCheckError::DuplicateFieldInStructConstruction(_, _, field) if field == "a"
+        ));
+
+        assert!(matches!(
+            get_typecheck_error("SomeStruct {a: 1a, b: 1b}.foo"),
             TypeCheckError::AccessingUnknownFieldOfStruct(_, _, field, _) if field == "foo"
         ));
 
@@ -2470,5 +2593,10 @@ mod tests {
             get_typecheck_error("(1).foo"),
             TypeCheckError::AccessingFieldOfNonStruct(_, _, field, _) if field == "foo"
         ));
+
+        assert!(matches!(
+            get_typecheck_error("SomeStruct {}"),
+            TypeCheckError::MissingFieldsFromStructConstruction(..)
+        ));
     }
 }

+ 45 - 26
numbat/src/typed_ast.rs

@@ -1,8 +1,9 @@
+use indexmap::IndexMap;
 use itertools::Itertools;
 
 use crate::arithmetic::{Exponent, Rational};
 use crate::ast::ProcedureKind;
-pub use crate::ast::{BinaryOperator, DimensionExpression, UnaryOperator};
+pub use crate::ast::{BinaryOperator, TypeExpression, UnaryOperator};
 use crate::dimension::DimensionRegistry;
 use crate::{
     decorator::Decorator, markup::Markup, number::Number, prefix::Prefix,
@@ -52,6 +53,13 @@ impl DType {
     }
 }
 
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct StructInfo {
+    pub definition_span: Span,
+    pub name: String,
+    pub fields: IndexMap<String, (Span, Type)>,
+}
+
 #[derive(Debug, Clone, PartialEq, Eq)]
 pub enum Type {
     Never,
@@ -60,7 +68,7 @@ pub enum Type {
     String,
     DateTime,
     Fn(Vec<Type>, Box<Type>),
-    Struct(indexmap::IndexMap<String, Type>),
+    Struct(StructInfo),
 }
 
 impl std::fmt::Display for Type {
@@ -78,13 +86,13 @@ impl std::fmt::Display for Type {
                     ps = param_types.iter().map(|p| p.to_string()).join(", ")
                 )
             }
-            Type::Struct(members) => {
+            Type::Struct(StructInfo { name, fields, .. }) => {
                 write!(
                     f,
-                    "${{{}}}",
-                    members
+                    "{name} {{{}}}",
+                    fields
                         .iter()
-                        .map(|(n, t)| n.to_string() + ": " + &t.to_string())
+                        .map(|(n, (_, t))| n.to_string() + ": " + &t.to_string())
                         .join(", ")
                 )
             }
@@ -115,10 +123,12 @@ impl PrettyPrint for Type {
                     + return_type.pretty_print()
                     + m::operator("]")
             }
-            Type::Struct(fields) => {
-                m::operator("${")
+            Type::Struct(StructInfo { name, fields, .. }) => {
+                m::type_identifier(name)
+                    + m::space()
+                    + m::operator("{")
                     + Itertools::intersperse(
-                        fields.iter().map(|(n, t)| {
+                        fields.iter().map(|(n, (_, t))| {
                             m::type_identifier(n) + m::operator(":") + m::space() + t.pretty_print()
                         }),
                         m::operator(",") + m::space(),
@@ -232,19 +242,8 @@ pub enum Expression {
     Boolean(Span, bool),
     Condition(Span, Box<Expression>, Box<Expression>, Box<Expression>),
     String(Span, Vec<StringPart>),
-    MakeStruct(
-        Span,
-        Vec<(String, Expression)>,
-        indexmap::IndexMap<String, Type>,
-    ),
-    AccessStruct(
-        Span,
-        Span,
-        Box<Expression>,
-        String,
-        indexmap::IndexMap<String, Type>,
-        Type,
-    ),
+    MakeStruct(Span, Vec<(String, Expression)>, StructInfo),
+    AccessStruct(Span, Span, Box<Expression>, String, StructInfo, Type),
 }
 
 impl Expression {
@@ -300,10 +299,11 @@ pub enum Statement {
         Markup,             // readable return type
         Type,               // return type
     ),
-    DefineDimension(String, Vec<DimensionExpression>),
+    DefineDimension(String, Vec<TypeExpression>),
     DefineBaseUnit(String, Vec<Decorator>, Markup, Type),
     DefineDerivedUnit(String, Expression, Vec<Decorator>, Markup, Type),
     ProcedureCall(crate::ast::ProcedureKind, Vec<Expression>),
+    DefineStruct(StructInfo),
 }
 
 impl Statement {
@@ -524,6 +524,21 @@ impl PrettyPrint for Statement {
                     .sum()
                     + m::operator(")")
             }
+            Statement::DefineStruct(StructInfo { name, fields, .. }) => {
+                m::keyword("struct")
+                    + m::space()
+                    + m::type_identifier(name.clone())
+                    + m::space()
+                    + m::operator("{")
+                    + Itertools::intersperse(
+                        fields.iter().map(|(n, (_, t))| {
+                            m::identifier(n) + m::operator(":") + m::space() + t.pretty_print()
+                        }),
+                        m::operator(",") + m::space(),
+                    )
+                    .sum()
+                    + m::operator("}")
+            }
         }
     }
 }
@@ -721,8 +736,10 @@ impl PrettyPrint for Expression {
                     + m::space()
                     + with_parens(else_)
             }
-            MakeStruct(_, exprs, _type) => {
-                m::operator("${")
+            MakeStruct(_, exprs, struct_info) => {
+                m::type_identifier(struct_info.name.clone())
+                    + m::space()
+                    + m::operator("{")
                     + itertools::Itertools::intersperse(
                         exprs.iter().map(|(n, e)| {
                             m::identifier(n) + m::operator(":") + m::space() + e.pretty_print()
@@ -787,6 +804,8 @@ mod tests {
                  @metric_prefixes
                  unit points
 
+                 struct Foo {{foo: Length, bar: Time}}
+
                  let a = 1
                  let b = 1
                  let c = 1
@@ -903,7 +922,7 @@ mod tests {
         roundtrip_check("-3!");
         roundtrip_check("(-3)!");
         roundtrip_check("megapoints");
-        roundtrip_check("${foo: 1 meter, bar: 1 second}");
+        roundtrip_check("Foo {foo: 1 meter, bar: 1 second}");
     }
 
     #[test]

+ 15 - 12
numbat/src/value.rs

@@ -2,7 +2,7 @@ use std::sync::Arc;
 
 use itertools::Itertools;
 
-use crate::{pretty_print::PrettyPrint, quantity::Quantity};
+use crate::{pretty_print::PrettyPrint, quantity::Quantity, typed_ast::StructInfo};
 
 #[derive(Debug, Clone, PartialEq, Eq)]
 pub enum FunctionReference {
@@ -33,7 +33,7 @@ pub enum Value {
     DateTime(chrono::DateTime<chrono::FixedOffset>),
     FunctionReference(FunctionReference),
     FormatSpecifiers(Option<String>),
-    Struct(Arc<[(String, usize)]>, Vec<Value>),
+    Struct(Arc<StructInfo>, Vec<Value>),
 }
 
 impl Value {
@@ -92,12 +92,15 @@ impl std::fmt::Display for Value {
             Value::DateTime(dt) => write!(f, "datetime(\"{}\")", dt),
             Value::FunctionReference(r) => write!(f, "{}", r),
             Value::FormatSpecifiers(_) => write!(f, "<format specfiers>"),
-            Value::Struct(field_meta, values) => write!(
+            Value::Struct(struct_info, values) => write!(
                 f,
-                "${{ {} }}",
-                field_meta
-                    .iter()
-                    .map(|(name, idx)| name.to_owned() + ": " + &values[*idx].to_string())
+                "{} {{ {} }}",
+                struct_info.name,
+                struct_info
+                    .fields
+                    .keys()
+                    .zip(values)
+                    .map(|(name, value)| name.to_owned() + ": " + &value.to_string())
                     .join(", ")
             ),
         }
@@ -114,12 +117,12 @@ impl PrettyPrint for Value {
             Value::FunctionReference(r) => crate::markup::string(r.to_string()),
             Value::FormatSpecifiers(Some(s)) => crate::markup::string(s),
             Value::FormatSpecifiers(None) => crate::markup::empty(),
-            Value::Struct(field_meta, values) => {
-                crate::markup::operator("${")
+            Value::Struct(struct_info, values) => {
+                crate::markup::type_identifier(struct_info.name.clone())
+                    + crate::markup::space()
+                    + crate::markup::operator("{")
                     + itertools::Itertools::intersperse(
-                        field_meta.iter().map(|(name, idx)| {
-                            let val = &values[*idx];
-
+                        struct_info.fields.keys().zip(values).map(|(name, val)| {
                             crate::markup::identifier(name)
                                 + crate::markup::operator(":")
                                 + crate::markup::space()

+ 21 - 14
numbat/src/vm.rs

@@ -2,8 +2,9 @@ use std::collections::HashMap;
 use std::sync::Arc;
 use std::{cmp::Ordering, fmt::Display};
 
-use indexmap::IndexSet;
+use indexmap::IndexMap;
 
+use crate::typed_ast::StructInfo;
 use crate::{
     ffi::{self, ArityRange, Callable, ForeignFunction},
     interpreter::{InterpreterResult, PrintFunction, Result, RuntimeError},
@@ -279,8 +280,8 @@ pub struct Vm {
     /// Constants are numbers like '1.4' or a [Unit] like 'meter'.
     pub constants: Vec<Constant>,
 
-    /// struct field metadata, used so we can recover struct field layout at runtime
-    struct_fields: IndexSet<Arc<[(String, usize)]>>,
+    /// struct metadata, used so we can display struct fields at runtime
+    struct_infos: IndexMap<String, Arc<StructInfo>>,
 
     /// Unit prefixes in use
     prefixes: Vec<Prefix>,
@@ -318,7 +319,7 @@ impl Vm {
             bytecode: vec![("<main>".into(), vec![])],
             current_chunk_index: 0,
             constants: vec![],
-            struct_fields: IndexSet::new(),
+            struct_infos: IndexMap::new(),
             prefixes: vec![],
             strings: vec![],
             unit_information: vec![],
@@ -380,10 +381,16 @@ impl Vm {
         (self.constants.len() - 1) as u16 // TODO: this can overflow, see above
     }
 
-    pub fn add_struct_fields(&mut self, fields: Vec<(String, usize)>) -> u16 {
-        let (idx, _) = self.struct_fields.insert_full(fields.into());
+    pub fn add_struct_info(&mut self, struct_info: &StructInfo) -> usize {
+        let e = self.struct_infos.entry(struct_info.name.clone());
+        let idx = e.index();
+        e.or_insert_with(|| Arc::new(struct_info.clone()));
 
-        idx as u16
+        idx
+    }
+
+    pub fn get_structinfo_idx(&self, name: &str) -> Option<usize> {
+        self.struct_infos.get_index_of(name)
     }
 
     pub fn add_prefix(&mut self, prefix: Prefix) -> u16 {
@@ -972,12 +979,12 @@ impl Vm {
                     }
                 }
                 Op::BuildStruct => {
-                    let meta_idx = self.read_u16();
-                    let fields_meta = self
-                        .struct_fields
-                        .get_index(meta_idx as usize)
-                        .expect("Missing struct metadata {meta_idx}")
-                        .clone();
+                    let info_idx = self.read_u16();
+                    let (_, struct_info) = self
+                        .struct_infos
+                        .get_index(info_idx as usize)
+                        .expect("Missing struct metadata");
+                    let struct_info = Arc::clone(struct_info);
                     let num_args = self.read_u16();
 
                     let mut content = Vec::with_capacity(num_args as usize);
@@ -986,7 +993,7 @@ impl Vm {
                         content.push(self.pop());
                     }
 
-                    self.stack.push(Value::Struct(fields_meta, content));
+                    self.stack.push(Value::Struct(struct_info, content));
                 }
                 Op::DestructureStruct => {
                     let field_idx = self.read_u16();

+ 14 - 5
numbat/tests/interpreter.rs

@@ -52,9 +52,12 @@ fn expect_failure(code: &str, msg_part: &str) {
     if let Err(e) = ctx.interpret(code, CodeSource::Internal) {
         let error_message = e.to_string();
         println!("{}", error_message);
-        assert!(error_message.contains(msg_part));
+        assert!(
+            error_message.contains(msg_part),
+            "Expected {msg_part} but got {error_message}"
+        );
     } else {
-        panic!();
+        panic!("Expected an error but but instead {code} did not fail");
     }
 }
 
@@ -497,11 +500,17 @@ fn test_name_clash_errors() {
 fn test_type_check_errors() {
     expect_failure("foo", "Unknown identifier 'foo'");
 
-    expect_failure("let sin=2", "This name is already used by a function");
-    expect_failure("fn pi() = 1", "This name is already used by a constant");
+    expect_failure(
+        "let sin=2",
+        "Identifier is already in use by the foreign function: 'sin'",
+    );
+    expect_failure(
+        "fn pi() = 1",
+        "Identifier is already in use by the constant: 'pi'",
+    );
     expect_failure(
         "fn sin(x)=0",
-        "This name is already used by a foreign function",
+        "Identifier is already in use by the foreign function: 'sin'",
     );
 }
 

+ 1 - 1
numbat/tests/prelude_and_examples.rs

@@ -12,7 +12,7 @@ use crate::common::get_test_context_without_prelude;
 
 fn assert_runs(code: &str) {
     let result = get_test_context().interpret(code, CodeSource::Internal);
-    assert!(result.is_ok());
+    assert!(result.is_ok(), "Failed with: {result:#?}");
     assert!(matches!(
         result.unwrap().1,
         InterpreterResult::Value(_) | InterpreterResult::Continue