Kaynağa Gözat

Fix pretty printing for some generic types

closes #524
closes #545
David Peter 1 yıl önce
ebeveyn
işleme
7a677456a4

+ 2 - 2
numbat/src/bytecode_interpreter.rs

@@ -417,7 +417,7 @@ impl BytecodeInterpreter {
                             readable_type: annotation
                                 .as_ref()
                                 .map(|a| a.pretty_print())
-                                .unwrap_or(type_.to_readable_type(dimension_registry)),
+                                .unwrap_or(type_.to_readable_type(dimension_registry, false)),
                             aliases,
                             name: decorator::name(decorators),
                             canonical_name: decorator::get_canonical_unit_name(
@@ -473,7 +473,7 @@ impl BytecodeInterpreter {
                         readable_type: annotation
                             .as_ref()
                             .map(|a| a.pretty_print())
-                            .unwrap_or(type_.to_readable_type(dimension_registry)),
+                            .unwrap_or(type_.to_readable_type(dimension_registry, false)),
                         aliases,
                         name: decorator::name(decorators),
                         canonical_name: decorator::get_canonical_unit_name(unit_name, decorators),

+ 1 - 1
numbat/src/interpreter/mod.rs

@@ -93,7 +93,7 @@ impl InterpreterResult {
                             if type_.is_scalar() {
                                 None
                             } else {
-                                let ty = type_.to_readable_type(registry);
+                                let ty = type_.to_readable_type(registry, true);
                                 Some(m::dimmed("    [") + ty + m::dimmed("]"))
                             }
                         })

+ 3 - 1
numbat/src/typechecker/mod.rs

@@ -1813,7 +1813,9 @@ impl TypeChecker {
         if let Some((span, type_of_hole)) = elaborated_statement.find_typed_hole()? {
             return Err(TypeCheckError::TypedHoleInStatement(
                 span,
-                type_of_hole.to_readable_type(&self.registry).to_string(),
+                type_of_hole
+                    .to_readable_type(&self.registry, true)
+                    .to_string(),
                 elaborated_statement.pretty_print().to_string(),
                 self.env
                     .iter_relevant_matches()

+ 21 - 9
numbat/src/typechecker/type_scheme.rs

@@ -109,22 +109,34 @@ impl TypeScheme {
     pub(crate) fn to_readable_type(
         &self,
         registry: &crate::dimension::DimensionRegistry,
+        with_quantifiers: bool,
     ) -> crate::markup::Markup {
         let (instantiated_type, type_parameters) = self.instantiate_for_printing(None);
 
         let mut markup = m::empty();
-        for type_parameter in &type_parameters {
-            markup += m::keyword("forall");
-            markup += m::space();
-            markup += m::type_identifier(type_parameter.unsafe_name());
 
-            if instantiated_type.bounds.is_dtype_bound(type_parameter) {
-                markup += m::operator(":");
+        if with_quantifiers {
+            let has_type_parameters = !type_parameters.is_empty();
+
+            if has_type_parameters {
+                markup += m::keyword("forall");
+            }
+
+            for type_parameter in &type_parameters {
+                markup += m::space();
+                markup += m::type_identifier(type_parameter.unsafe_name());
+
+                if instantiated_type.bounds.is_dtype_bound(type_parameter) {
+                    markup += m::operator(":");
+                    markup += m::space();
+                    markup += m::type_identifier("Dim");
+                }
+                markup += m::operator(".");
+            }
+
+            if has_type_parameters {
                 markup += m::space();
-                markup += m::type_identifier("Dim");
             }
-            markup += m::operator(".");
-            markup += m::space();
         }
 
         markup + instantiated_type.inner.to_readable_type(registry)

+ 9 - 4
numbat/src/typed_ast.rs

@@ -630,11 +630,12 @@ impl Statement {
         registry: &DimensionRegistry,
         type_: &TypeScheme,
         annotation: &Option<TypeAnnotation>,
+        with_quantifiers: bool,
     ) -> Markup {
         if let Some(annotation) = annotation {
             annotation.pretty_print()
         } else {
-            type_.to_readable_type(registry)
+            type_.to_readable_type(registry, with_quantifiers)
         }
     }
 
@@ -649,7 +650,7 @@ impl Statement {
                 type_,
                 readable_type,
             )) => {
-                *readable_type = Self::create_readable_type(registry, type_, type_annotation);
+                *readable_type = Self::create_readable_type(registry, type_, type_annotation, true);
             }
             Statement::DefineFunction(
                 _,
@@ -669,7 +670,8 @@ impl Statement {
                 for DefineVariable(_, _, _, type_annotation, type_, readable_type) in
                     local_variables
                 {
-                    *readable_type = Self::create_readable_type(registry, type_, type_annotation);
+                    *readable_type =
+                        Self::create_readable_type(registry, type_, type_annotation, false);
                 }
 
                 let Type::Fn(parameter_types, return_type) = fn_type.inner else {
@@ -680,6 +682,7 @@ impl Statement {
                     registry,
                     &TypeScheme::concrete(*return_type),
                     return_type_annotation,
+                    false,
                 );
 
                 for ((_, _, type_annotation, readable_parameter_type), parameter_type) in
@@ -689,13 +692,15 @@ impl Statement {
                         registry,
                         &TypeScheme::concrete(parameter_type.clone()),
                         type_annotation,
+                        false,
                     );
                 }
             }
             Statement::DefineDimension(_, _) => {}
             Statement::DefineBaseUnit(_, _, _, _) => {}
             Statement::DefineDerivedUnit(_, _, _, type_annotation, type_, readable_type) => {
-                *readable_type = Self::create_readable_type(registry, type_, type_annotation);
+                *readable_type =
+                    Self::create_readable_type(registry, type_, type_annotation, false);
             }
             Statement::ProcedureCall(_, _) => {}
             Statement::DefineStruct(_) => {}

+ 11 - 0
numbat/tests/interpreter.rs

@@ -818,6 +818,11 @@ fn test_statement_pretty_printing() {
 
     expect_pretty_print("fn f(x) = 2 x", "fn f<A: Dim>(x: A) -> A = 2 x");
 
+    expect_pretty_print(
+        "fn f(x, y) = x * y",
+        "fn f<A: Dim, B: Dim>(x: A, y: B) -> A × B = x × y",
+    );
+
     // Partially annotated functions
     expect_pretty_print(
         "fn f() -> Length * Frequency = c",
@@ -832,6 +837,12 @@ fn test_statement_pretty_printing() {
     expect_pretty_print("fn f(x) -> Length = 2 x", "fn f(x: Length) -> Length = 2 x");
 
     expect_pretty_print("fn f<Z>(z: Z) = z", "fn f<Z>(z: Z) -> Z = z");
+
+    // Functions with local variables
+    expect_pretty_print(
+        "fn f(x) = y where y = x",
+        "fn f<A>(x: A) -> A = y\n  where y: A = x",
+    );
 }
 #[cfg(test)]
 mod tests {