Browse Source

Re-implement mean,maximum,minimum using lists

David Peter 1 year ago
parent
commit
476ee2ed44
3 changed files with 47 additions and 9 deletions
  1. 21 0
      examples/prelude_tests.nbt
  2. 4 4
      numbat/modules/core/lists.nbt
  3. 22 5
      numbat/modules/math/functions.nbt

+ 21 - 0
examples/prelude_tests.nbt

@@ -96,3 +96,24 @@ assert((0xa -> hex) == "0xa")
 assert((0xf -> hex) == "0xf")
 assert((0xabc1234567890 -> hex) == "0xabc1234567890")
 assert((-0xc0ffee -> hex) == "-0xc0ffee")
+
+# mean
+
+assert_eq(mean([]), 0)
+assert_eq(mean([1]), 1)
+assert_eq(mean([1, 3]), 2)
+assert_eq(mean([1 m, 300 cm]), 2 m)
+
+# maximum
+
+assert_eq(maximum([1]), 1)
+assert_eq(maximum([1, 3]), 3)
+assert_eq(maximum([3, 1]), 3)
+assert_eq(maximum([100 cm, 3 m]), 3 m)
+
+# minimum
+
+assert_eq(minimum([1]), 1)
+assert_eq(minimum([1, 3]), 1)
+assert_eq(minimum([3, 1]), 1)
+assert_eq(minimum([100 cm, 3 m]), 100 cm)

+ 4 - 4
numbat/modules/core/lists.nbt

@@ -16,10 +16,10 @@ fn cons<A>(x: A, xs: List<A>) -> List<A>
 fn is_empty<A>(xs: List<A>) -> Bool = xs == []
 
 @description("Concatenate two lists")
-fn concat<A>(xs: List<A>, ys: List<A>) -> List<A> =
-  if is_empty(xs)
-    then ys
-    else cons(head(xs), concat(tail(xs), ys))
+fn concat<A>(xs1: List<A>, xs2: List<A>) -> List<A> =
+  if is_empty(xs1)
+    then xs2
+    else cons(head(xs1), concat(tail(xs1), xs2))
 
 @description("Generate a range of integer numbers from `start` to `end` (inclusive)")
 fn range(start: Scalar, end: Scalar) -> List<Scalar> =

+ 22 - 5
numbat/modules/math/functions.nbt

@@ -1,4 +1,5 @@
 use core::scalar
+use core::lists
 use math::constants
 
 ## Basics
@@ -96,11 +97,27 @@ fn gamma(x: Scalar) -> Scalar
 
 ### Statistics
 
-#@name("Arithmetic mean")
-#@url("https://en.wikipedia.org/wiki/Arithmetic_mean")
-#fn mean<D>(xs: D…) -> D
-#fn maximum<D>(xs: D…) -> D
-#fn minimum<D>(xs: D…) -> D
+@name("Arithmetic mean")
+@url("https://en.wikipedia.org/wiki/Arithmetic_mean")
+fn mean<D: Dim>(xs: List<D>) -> D = if is_empty(xs) then 0 else sum(xs) / len(xs)
+
+# TODO: remove these helpers once we support local definitions
+fn _max<D: Dim>(x: D, y: D) -> D = if x > y then x else y
+fn _min<D: Dim>(x: D, y: D) -> D = if x < y then x else y
+
+@name("Maxmimum")
+@description("Get the largest element of a list")
+fn maximum<D: Dim>(xs: List<D>) -> D =
+  if len(xs) == 1
+    then head(xs)
+    else _max(head(xs), maximum(tail(xs)))
+
+@name("Minimum")
+@description("Get the smallest element of a list")
+fn minimum<D: Dim>(xs: List<D>) -> D =
+  if len(xs) == 1
+    then head(xs)
+    else _min(head(xs), minimum(tail(xs)))
 
 ### Geometry