Procházet zdrojové kódy

Fix non-dtype lists

David Peter před 1 rokem
rodič
revize
0ae495b3ec

+ 13 - 0
examples/list_tests.nbt

@@ -43,3 +43,16 @@ assert_eq(intersperse(0, [1, 2, 3]), [1, 0, 2, 0, 3])
 
 assert_eq(sum([1, 2, 3, 4, 5]), 15)
 assert_eq(sum([1 m, 200 cm, 3 m]), 6 m)
+
+# Non-dtype lists
+let words = ["hello", "world"]
+assert_eq(head(words), "hello")
+
+fn join(xs: List<String>, sep: String) =
+  if is_empty(xs)
+    then ""
+    else if len(xs) == 1
+      then head(xs)
+      else "{head(xs)}{sep}{join(tail(xs), sep)}"
+
+assert_eq(join(words, " "), "hello world")

+ 14 - 0
numbat/src/typechecker/constraints.rs

@@ -250,6 +250,20 @@ impl Constraint {
                     t.clone(),
                 )))
             }
+            Constraint::Equal(Type::Dimension(dtype_x), t)
+                if dtype_x.deconstruct_as_single_type_variable().is_some() =>
+            {
+                let x = dtype_x.deconstruct_as_single_type_variable().unwrap();
+                debug!(
+                    "  (3) SOLVING: {x} ~ {t} with substitution {x} := {t}",
+                    x = x.unsafe_name(),
+                    t = t
+                );
+                Some(Satisfied::with_substitution(Substitution::single(
+                    x.clone(),
+                    t.clone(),
+                )))
+            }
             Constraint::Equal(t @ Type::Fn(params1, return1), s @ Type::Fn(params2, return2)) => {
                 debug!(
                     "  (4) SOLVING: {t} ~ {s} with new constraints for all parameters and return types",

+ 8 - 1
numbat/src/typechecker/substitutions.rs

@@ -38,7 +38,7 @@ impl Substitution {
 
 #[derive(Debug, Clone, Error, PartialEq, Eq)]
 pub enum SubstitutionError {
-    #[error("Used non-dimension type in a dimension expression: {0}")]
+    #[error("Used non-dimension type '{0}' in a dimension expression")]
     SubstitutedNonDTypeWithinDType(Type),
 }
 
@@ -61,6 +61,13 @@ impl ApplySubstitution for Type {
                 }
                 Ok(())
             }
+            Type::Dimension(dtype) if dtype.deconstruct_as_single_type_variable().is_some() => {
+                let v = dtype.deconstruct_as_single_type_variable().unwrap();
+                if let Some(type_) = s.lookup(&v) {
+                    *self = type_.clone();
+                }
+                Ok(())
+            }
             Type::Dimension(dtype) => dtype.apply(s),
             Type::Boolean => Ok(()),
             Type::String => Ok(()),

+ 3 - 0
numbat/src/typechecker/tests/mod.rs

@@ -31,6 +31,9 @@ const TEST_PRELUDE: &str = "
     fn atan2<T>(x: T, y: T) -> Scalar
 
     fn len<T>(x: List<T>) -> Scalar
+    fn head<T>(x: List<T>) -> T
+
+    fn id<T>(x: T) -> T = x
     ";
 
 fn type_a() -> DType {

+ 3 - 0
numbat/src/typechecker/tests/type_checking.rs

@@ -638,6 +638,9 @@ fn lists() {
 
     assert_successful_typecheck("[[1 a, 2 a], [3 a]]");
 
+    assert_successful_typecheck("[true]");
+    assert_successful_typecheck("head([true, false])");
+
     assert!(matches!(
         get_typecheck_error("[1, a]"),
         TypeCheckError::IncompatibleTypesInList(..)

+ 0 - 1
numbat/src/typechecker/tests/type_inference.rs

@@ -1,4 +1,3 @@
-use ast::scalar;
 use qualified_type::Bounds;
 use qualified_type::QualifiedType;
 use tests::get_typecheck_error;

+ 10 - 0
numbat/src/typed_ast.rs

@@ -94,6 +94,16 @@ impl DType {
         DType::from_factors(&[(DTypeFactor::TPar(name), Exponent::from_integer(1))])
     }
 
+    pub fn deconstruct_as_single_type_variable(&self) -> Option<TypeVariable> {
+        match &self.factors[..] {
+            [(factor, exponent)] if exponent == &Exponent::from_integer(1) => match factor {
+                DTypeFactor::TVar(v) => Some(v.clone()),
+                _ => None,
+            },
+            _ => None,
+        }
+    }
+
     pub fn from_tgen(i: usize) -> DType {
         DType::from_factors(&[(
             DTypeFactor::TVar(TypeVariable::Quantified(i)),