Browse Source

assert_eq for all types, add more list functions

David Peter 1 year ago
parent
commit
26d3282b1c

+ 28 - 13
examples/list_tests.nbt

@@ -1,25 +1,40 @@
 let xs = [1, 2, 3]
 
-assert(len([]) == 0)
-assert(len(xs) == 3)
+assert_eq(len([]), 0)
+assert_eq(len(xs), 3)
 
-assert(head(xs) == 1)
-assert(tail(xs) == [2, 3])
+assert_eq(head(xs), 1)
+assert_eq(head(tail(xs)), 2)
 
-assert(sequence(0) == [])
-assert(sequence(5) == [0, 1, 2, 3, 4])
+assert_eq(tail(xs), [2, 3])
 
-fn const_5(x) = 5
-assert(generate(3, const_5) == [5, 5, 5])
+assert_eq(cons(0, xs), [0, 1, 2, 3])
 
+assert(is_empty([]))
+assert(!is_empty(xs))
+
+assert_eq(concat([], []), [])
+assert_eq(concat([], xs), xs)
+assert_eq(concat(xs, []), xs)
+assert_eq(concat(xs, xs), [1, 2, 3, 1, 2, 3])
+
+assert_eq(range(0, 0), [0])
+assert_eq(range(0, 5), [0, 1, 2, 3, 4, 5])
+
+assert_eq(reverse([]), [])
+assert_eq(reverse(xs), [3, 2, 1])
 
 fn inc(x) = x + 1
-assert(map(inc, xs) == [2, 3, 4])
+assert_eq(map(inc, xs), [2, 3, 4])
+# fn gen_range(n) = range(1, n)
+# assert_eq(map(gen_range, xs), [[1], [1, 2], [1, 2, 3]])
+# fn to_string(x) = "{x}"
+# assert_eq(map(to_string, xs), ["1", "2", "3"])
 
-assert(reverse([]) == [])
-assert(reverse(xs) == [3, 2, 1])
+fn is_even(x) = mod(x, 2) == 0
+assert_eq(filter(is_even, range(1, 10)), [2, 4, 6, 8, 10])
 
-assert(sum([1, 2, 3, 4, 5]) == 15)
+assert_eq(sum([1, 2, 3, 4, 5]), 15)
 
 fn mul(x, y) = x * y
-assert(foldl(mul, 1, [1, 2, 3, 4, 5]) == 120)
+assert_eq(foldl(mul, 1, [1, 2, 3, 4, 5]), 120)

+ 36 - 12
numbat/modules/core/lists.nbt

@@ -1,41 +1,65 @@
 use core::scalar
 
+@description("Get the length of a list")
 fn len<A>(xs: List<A>) -> Scalar
+
+@description("Get the first element of a list. Yields a runtime error if the list is empty.")
 fn head<A>(xs: List<A>) -> A
+
+@description("Get everything but the first element of a list. Yields a runtime error if the list is empty.")
 fn tail<A>(xs: List<A>) -> List<A>
+
+@description("Prepend an element to a list")
 fn cons<A>(x: A, xs: List<A>) -> List<A>
 
+@description("Check if a list is empty")
 fn is_empty<A>(xs: List<A>) -> Bool = xs == []
 
-fn generate<A>(n: Scalar, f: Fn[() -> A]) -> List<A> =
-  if n == 0
-    then []
-    else cons(f(), generate(n - 1, f))
-
-fn map<A, B>(f: Fn[(A) -> B], xs: List<A>) -> List<B> =
+@description("Concatenate two lists")
+fn concat<A>(xs: List<A>, ys: List<A>) -> List<A> =
   if is_empty(xs)
+    then ys
+    else cons(head(xs), concat(tail(xs), ys))
+
+@description("Generate a range of integer numbers from `start` to `end` (inclusive)")
+fn range(start: Scalar, end: Scalar) -> List<Scalar> =
+  if start > end
     then []
-    else cons(f(head(xs)), map(f, tail(xs)))
+    else cons(start, range(start + 1, end))
 
+@description("Append an element to the end of a list")
 fn cons_end<A>(xs: List<A>, x: A) -> List<A> =
   if is_empty(xs)
     then [x]
     else cons(head(xs), cons_end(tail(xs), x))
 
+@description("Reverse the order of a list")
 fn reverse<A>(xs: List<A>) -> List<A> =
   if is_empty(xs)
     then []
     else cons_end(reverse(tail(xs)), head(xs))
 
-fn sequence(n: Scalar) -> List<Scalar> =
-  if n == 0
+@description("Generate a new list by applying a function to each element of the input list")
+fn map<A, B>(f: Fn[(A) -> B], xs: List<A>) -> List<B> =
+  if is_empty(xs)
     then []
-    else cons_end(sequence(n - 1), n - 1)
+    else cons(f(head(xs)), map(f, tail(xs)))
 
+@description("Filter a list by a predicate")
+fn filter<A>(p: Fn[(A) -> Bool], xs: List<A>) -> List<A> =
+  if is_empty(xs)
+    then []
+    else if p(head(xs))
+      then cons(head(xs), filter(p, tail(xs)))
+      else filter(p, tail(xs))
+
+@description("Fold a function over a list")
 fn foldl<A, B>(f: Fn[(A, B) -> A], acc: A, xs: List<B>) -> A =
   if is_empty(xs)
     then acc
     else foldl(f, f(acc, head(xs)), tail(xs))
 
-fn add(x, y) = x + y # TODO
-fn sum<A>(xs: List<A>) -> A = foldl(add, 0, xs)
+fn _add(x, y) = x + y # TODO: replace this with a local function once we support them
+
+@description("Sum all elements of a list")
+fn sum<A>(xs: List<A>) -> A = foldl(_add, 0, xs)

+ 21 - 7
numbat/src/ffi.rs

@@ -511,21 +511,35 @@ fn assert(_: &mut ExecutionContext, args: &[Value]) -> ControlFlow {
 fn assert_eq(_: &mut ExecutionContext, args: &[Value]) -> ControlFlow {
     assert!(args.len() == 2 || args.len() == 3);
 
-    let lhs = args[0].unsafe_as_quantity();
-    let rhs = args[1].unsafe_as_quantity();
-
     if args.len() == 2 {
+        let lhs = &args[0];
+        let rhs = &args[1];
+
         let error = ControlFlow::Break(RuntimeError::AssertEq2Failed(lhs.clone(), rhs.clone()));
-        if let Ok(args1_converted) = rhs.convert_to(lhs.unit()) {
-            if lhs == &args1_converted {
-                ControlFlow::Continue(())
+
+        if lhs.is_quantity() {
+            let lhs = lhs.unsafe_as_quantity();
+            let rhs = rhs.unsafe_as_quantity();
+
+            if let Ok(args1_converted) = rhs.convert_to(lhs.unit()) {
+                if lhs == &args1_converted {
+                    ControlFlow::Continue(())
+                } else {
+                    error
+                }
             } else {
                 error
             }
         } else {
-            error
+            if lhs == rhs {
+                ControlFlow::Continue(())
+            } else {
+                error
+            }
         }
     } else {
+        let lhs = args[0].unsafe_as_quantity();
+        let rhs = args[1].unsafe_as_quantity();
         let result = lhs - rhs;
         let eps = args[2].unsafe_as_quantity();
 

+ 2 - 4
numbat/src/interpreter.rs

@@ -27,10 +27,8 @@ pub enum RuntimeError {
     QuantityError(QuantityError),
     #[error("Assertion failed")]
     AssertFailed,
-    #[error(
-        "Assertion failed because the following two quantities are not the same:\n  {0}\n  {1}"
-    )]
-    AssertEq2Failed(Quantity, Quantity),
+    #[error("Assertion failed because the following two values are not the same:\n  {0}\n  {1}")]
+    AssertEq2Failed(Value, Value),
     #[error("Assertion failed because the following two quantities differ by more than {2}:\n  {0}\n  {1}")]
     AssertEq3Failed(Quantity, Quantity, Quantity),
     #[error("Could not load exchange rates from European Central Bank.")]

+ 3 - 1
numbat/src/typechecker/constraints.rs

@@ -226,7 +226,9 @@ impl Constraint {
                 Type::Dimension(_) => TrivialResultion::Satisfied,
                 _ => TrivialResultion::Violated,
             },
-            Constraint::IsDType(_) => TrivialResultion::Unknown,
+            Constraint::IsDType(Type::Dimension(_)) => TrivialResultion::Satisfied,
+            Constraint::IsDType(Type::TVar(_)) => TrivialResultion::Unknown,
+            Constraint::IsDType(_) => TrivialResultion::Violated,
             Constraint::EqualScalar(d) if d.is_scalar() => TrivialResultion::Satisfied,
             Constraint::EqualScalar(d) if d.type_variables().is_empty() => {
                 TrivialResultion::Violated

+ 9 - 2
numbat/src/typechecker/mod.rs

@@ -1558,12 +1558,19 @@ impl TypeChecker {
                         }
                     }
                     ProcedureKind::AssertEq => {
+                        // The three-argument version of assert_eq requires dtypes as inputs:
+                        let needs_dtypes = checked_args.len() == 3;
+
                         let type_first = &checked_args[0].get_type();
-                        self.enforce_dtype(type_first, checked_args[0].full_span())?;
+                        if needs_dtypes {
+                            self.enforce_dtype(type_first, checked_args[0].full_span())?;
+                        }
 
                         for arg in &checked_args[1..] {
                             let type_arg = arg.get_type();
-                            self.enforce_dtype(&type_arg, arg.full_span())?;
+                            if needs_dtypes {
+                                self.enforce_dtype(&type_arg, arg.full_span())?;
+                            }
 
                             if self
                                 .add_equal_constraint(type_first, &type_arg)

+ 4 - 0
numbat/src/value.rs

@@ -100,6 +100,10 @@ impl Value {
             panic!("Expected value to be a list");
         }
     }
+
+    pub(crate) fn is_quantity(&self) -> bool {
+        matches!(self, Value::Quantity(_))
+    }
 }
 
 impl std::fmt::Display for Value {