瀏覽代碼

Use 'where' clause in existing Numbat code

David Peter 1 年之前
父節點
當前提交
0e30e77a6f

+ 4 - 2
numbat/modules/core/lists.nbt

@@ -138,5 +138,7 @@ fn split(input: String, separator: String) -> List<String> =
     then []
     else if !str_contains(input, separator)
       then [input]
-      else cons(str_slice(input, 0, str_find(input, separator)),  # TODO: avoid duplication
-                split(str_slice(input, str_find(input, separator) + str_length(separator), str_length(input)), separator))
+      else cons(str_slice(input, 0, idx_separator),
+                split(str_slice(input, idx_separator + str_length(separator), str_length(input)), separator))
+  where
+    idx_separator = str_find(input, separator)

+ 12 - 5
numbat/modules/core/strings.nbt

@@ -25,13 +25,16 @@ fn str_append(a: String, b: String) -> String = "{a}{b}"
 
 @description("Find the first occurrence of a substring in a string")
 fn str_find(haystack: String, needle: String) -> Scalar =
-  if str_length(haystack) == 0
+  if len_haystack == 0
     then -1
     else if str_slice(haystack, 0, str_length(needle)) == needle
       then 0
-      else if str_find(str_slice(haystack, 1, str_length(haystack)), needle) == -1  # TODO: we need local variables!
+      else if str_find(tail_haystack, needle) == -1
         then -1
-        else 1 + str_find(str_slice(haystack, 1, str_length(haystack)), needle)
+        else 1 + str_find(tail_haystack, needle)
+  where
+    len_haystack = str_length(haystack)
+    tail_haystack = str_slice(haystack, 1, len_haystack)
 
 @description("Check if a string contains a substring")
 fn str_contains(haystack: String, needle: String) -> Bool =
@@ -60,12 +63,16 @@ fn _oct_digit(x: Scalar) -> String =
   chr(48 + mod(x, 8))
 
 fn _hex_digit(x: Scalar) -> String =
-    if mod(x, 16) < 10 then chr(48 + mod(x, 16)) else chr(97 + mod(x, 16) - 10)
+  if x_16 < 10 then chr(48 + x_16) else chr(97 + x_16 - 10)
+  where
+    x_16 = mod(x, 16)
 
 fn _digit_in_base(x: Scalar, base: Scalar) -> String =
   if base < 2 || base > 16
     then error("base must be between 2 and 16")
-    else if mod(x, 16) < 10 then chr(48 + mod(x, 16)) else chr(97 + mod(x, 16) - 10)
+    else if x_16 < 10 then chr(48 + x_16) else chr(97 + x_16 - 10)
+  where
+    x_16 = mod(x, 16)
 
 fn _number_in_base(x: Scalar, b: Scalar) -> String =
   if x < 0

+ 11 - 8
numbat/modules/datetime/functions.nbt

@@ -51,15 +51,18 @@ fn _add_years(dt: DateTime, n_years: Scalar) -> DateTime
 
 @description("Adds the given time span to a `DateTime`. This uses leap-year and DST-aware calendar arithmetic with variable-length days, months, and years.")
 fn calendar_add(dt: DateTime, span: Time) -> DateTime =
-   if unit_of(span) == days
+   if span_unit == days
      then _add_days(dt, span / days)
-     else if unit_of(span) == months
-       then _add_months(dt, span / months)
-       else if unit_of(span) == years
-         then _add_years(dt, span / years)
-         else if unit_of(span) == seconds || unit_of(span) == minutes || unit_of(span) == hours
-           then dt + span
-           else error("calendar_add: Unsupported unit: {unit_of(span)}")
+   else if span_unit == months
+     then _add_months(dt, span / months)
+   else if span_unit == years
+     then _add_years(dt, span / years)
+   else if span_unit == seconds || span_unit == minutes || span_unit == hours
+     then dt + span
+   else
+     error("calendar_add: Unsupported unit: {span_unit}")
+  where
+    span_unit = unit_of(span)
 
 @description("Subtract the given time span from a `DateTime`. This uses leap-year and DST-aware calendar arithmetic with variable-length days, months, and years.")
 fn calendar_sub(dt: DateTime, span: Time) -> DateTime =

+ 5 - 3
numbat/modules/math/statistics.nbt

@@ -38,6 +38,8 @@ fn stdev<D: Dim>(xs: List<D>) -> D = sqrt(variance(xs))
 @url("https://en.wikipedia.org/wiki/Median")
 @description("Calculate the median of a list of quantities")
 fn median<D: Dim>(xs: List<D>) -> D =  # TODO: this is extremely inefficient
-  if mod(len(xs), 2) == 1
-    then element_at((len(xs) - 1) / 2, sort(xs))
-    else mean([element_at(len(xs) / 2 - 1, sort(xs)), element_at(len(xs) / 2, sort(xs))])
+  if mod(n, 2) == 1
+    then element_at((n - 1) / 2, sort(xs))
+    else mean([element_at(n / 2 - 1, sort(xs)), element_at(n / 2, sort(xs))])
+  where
+    n = len(xs)

+ 3 - 4
numbat/modules/numerics/diff.nbt

@@ -1,10 +1,9 @@
 use core::quantities
 
-# TODO: Move this to a local definition inside `diff` once we support that
-fn _delta<X: Dim>(x: X) -> X = 1e-10 × unit_of(x)
-
 @name("Numerical differentiation")
 @url("https://en.wikipedia.org/wiki/Numerical_differentiation")
 @description("Compute the numerical derivative of the function $f$ at point $x$ using the central difference method.")
 fn diff<X: Dim, Y: Dim>(f: Fn[(X) -> Y], x: X) -> Y / X =
-  (f(x + _delta(x)) - f(x - _delta(x))) / (2 _delta(x))
+  (f(x + Δx) - f(x - Δx)) / (2 Δx)
+  where
+    Δx = 1e-10 × unit_of(x)

+ 13 - 9
numbat/modules/numerics/solve.nbt

@@ -6,20 +6,24 @@ use core::error
 @description("Find the root of the function $f$ in the interval $[x_1, x_2]$ using the bisection method. The function $f$ must be continuous and $f(x_1) \cdot f(x_2) < 0$.")
 fn root_bisect<X: Dim, Y: Dim>(f: Fn[(X) -> Y], x1: X, x2: X, x_tol: X, y_tol: Y) -> X =
   if abs(x1 - x2) < x_tol
-    then (x1 + x2) / 2
-    else if abs(f((x1 + x2) / 2)) < y_tol
-      then (x1 + x2) / 2
-      else if f((x1 + x2) / 2) * f(x1) < 0
-        then root_bisect(f, x1, (x1 + x2) / 2, x_tol, y_tol)
-        else root_bisect(f, (x1 + x2) / 2, x2, x_tol, y_tol)
-  # TODO: move (x1 + x2) / 2 to a local variable once we support them
+    then x_mean
+    else if abs(f_x_mean) < y_tol
+      then x_mean
+      else if f_x_mean × f(x1) < 0
+        then root_bisect(f, x1, x_mean, x_tol, y_tol)
+        else root_bisect(f, x_mean, x2, x_tol, y_tol)
+  where
+    x_mean = (x1 + x2) / 2
+    f_x_mean = f(x_mean)
 
 fn _root_newton_helper<X: Dim, Y: Dim>(f: Fn[(X) -> Y], f_prime: Fn[(X) -> Y / X], x0: X, y_tol: Y, max_iterations: Scalar) -> X =
   if max_iterations <= 0
     then error("root_newton: Maximum number of iterations reached. Try another initial guess?")
-    else if abs(f(x0)) < y_tol
+    else if abs(f_x0) < y_tol
       then x0
-      else _root_newton_helper(f, f_prime, x0 - f(x0) / f_prime(x0), y_tol, max_iterations - 1)
+      else _root_newton_helper(f, f_prime, x0 - f_x0 / f_prime(x0), y_tol, max_iterations - 1)
+  where
+    f_x0 = f(x0)
 
 @name("Newton's method")
 @url("https://en.wikipedia.org/wiki/Newton%27s_method")