浏览代码

Fix `unit_of`-related problems

Introduces some alternative helpers to `unit_of`
David Peter 9 月之前
父节点
当前提交
e8bac1feeb

+ 5 - 5
book/src/list-functions-math.md

@@ -886,24 +886,24 @@ Compute the numerical derivative of the function \\( f \\) at point \\( x \\) us
 More information [here](https://en.wikipedia.org/wiki/Numerical_differentiation).
 
 ```nbt
-fn diff<X: Dim, Y: Dim>(f: Fn[(X) -> Y], x: X) -> Y / X
+fn diff<X: Dim, Y: Dim>(f: Fn[(X) -> Y], x: X, Δx: X) -> Y / X
 ```
 
 <details>
 <summary>Examples</summary>
 
 Compute the derivative of \\( f(x) = x² -x -1 \\) at \\( x=1 \\).
-<pre><div class="buttons"><button class="fa fa-play play-button" title="Run this code" aria-label="Run this code"  onclick=" window.open('https://numbat.dev/?q=use%20numerics%3A%3Adiff%0Afn%20polynomial%28x%29%20%3D%20x%C2%B2%20%2D%20x%20%2D%201%0Adiff%28polynomial%2C%201%29')""></button></div><code class="language-nbt hljs numbat">use numerics::diff
+<pre><div class="buttons"><button class="fa fa-play play-button" title="Run this code" aria-label="Run this code"  onclick=" window.open('https://numbat.dev/?q=use%20numerics%3A%3Adiff%0Afn%20polynomial%28x%29%20%3D%20x%C2%B2%20%2D%20x%20%2D%201%0Adiff%28polynomial%2C%201%2C%201e%2D10%29')""></button></div><code class="language-nbt hljs numbat">use numerics::diff
 fn polynomial(x) = x² - x - 1
-diff(polynomial, 1)
+diff(polynomial, 1, 1e-10)
 
     = 1.0
 </code></pre>
 
 Compute the free fall velocity after \\( t=2 s \\).
-<pre><div class="buttons"><button class="fa fa-play play-button" title="Run this code" aria-label="Run this code"  onclick=" window.open('https://numbat.dev/?q=use%20numerics%3A%3Adiff%0Afn%20distance%28t%29%20%3D%200%2E5%20g0%20t%C2%B2%0Afn%20velocity%28t%29%20%3D%20diff%28distance%2C%20t%29%0Avelocity%282%20s%29')""></button></div><code class="language-nbt hljs numbat">use numerics::diff
+<pre><div class="buttons"><button class="fa fa-play play-button" title="Run this code" aria-label="Run this code"  onclick=" window.open('https://numbat.dev/?q=use%20numerics%3A%3Adiff%0Afn%20distance%28t%29%20%3D%200%2E5%20g0%20t%C2%B2%0Afn%20velocity%28t%29%20%3D%20diff%28distance%2C%20t%2C%201e%2D10%20s%29%0Avelocity%282%20s%29')""></button></div><code class="language-nbt hljs numbat">use numerics::diff
 fn distance(t) = 0.5 g0 t²
-fn velocity(t) = diff(distance, t)
+fn velocity(t) = diff(distance, t, 1e-10 s)
 velocity(2 s)
 
     = 19.6133 m/s    [Velocity]

+ 71 - 0
book/src/list-functions-other.md

@@ -189,6 +189,77 @@ fn unit_of<T: Dim>(x: T) -> T
 
 </details>
 
+### `has_unit`
+Returns true if `quantity` has the same unit as `unit_query`, or if `quantity` evaluates to zero.
+
+```nbt
+fn has_unit<T: Dim>(quantity: T, unit_query: T) -> Bool
+```
+
+<details>
+<summary>Examples</summary>
+
+<pre><div class="buttons"><button class="fa fa-play play-button" title="Run this code" aria-label="Run this code"  onclick=" window.open('https://numbat.dev/?q=has%5Funit%2820%20km%2Fh%2C%20km%2Fh%29')""></button></div><code class="language-nbt hljs numbat">has_unit(20 km/h, km/h)
+
+    = true    [Bool]
+</code></pre>
+
+<pre><div class="buttons"><button class="fa fa-play play-button" title="Run this code" aria-label="Run this code"  onclick=" window.open('https://numbat.dev/?q=has%5Funit%2820%20km%2Fh%2C%20m%2Fs%29')""></button></div><code class="language-nbt hljs numbat">has_unit(20 km/h, m/s)
+
+    = false    [Bool]
+</code></pre>
+
+</details>
+
+### `is_dimensionless`
+Returns true if `quantity` is dimensionless, or if `quantity` is zero.
+
+```nbt
+fn is_dimensionless<T: Dim>(quantity: T) -> Bool
+```
+
+<details>
+<summary>Examples</summary>
+
+<pre><div class="buttons"><button class="fa fa-play play-button" title="Run this code" aria-label="Run this code"  onclick=" window.open('https://numbat.dev/?q=is%5Fdimensionless%2810%29')""></button></div><code class="language-nbt hljs numbat">is_dimensionless(10)
+
+    = true    [Bool]
+</code></pre>
+
+<pre><div class="buttons"><button class="fa fa-play play-button" title="Run this code" aria-label="Run this code"  onclick=" window.open('https://numbat.dev/?q=is%5Fdimensionless%2810%20km%2Fh%29')""></button></div><code class="language-nbt hljs numbat">is_dimensionless(10 km/h)
+
+    = false    [Bool]
+</code></pre>
+
+</details>
+
+### `unit_name`
+Returns a string representation of the unit of `quantity`. Returns an empty string if `quantity` is dimensionless.
+
+```nbt
+fn unit_name<T: Dim>(quantity: T) -> String
+```
+
+<details>
+<summary>Examples</summary>
+
+<pre><div class="buttons"><button class="fa fa-play play-button" title="Run this code" aria-label="Run this code"  onclick=" window.open('https://numbat.dev/?q=unit%5Fname%2820%29')""></button></div><code class="language-nbt hljs numbat">unit_name(20)
+
+    = ""    [String]
+</code></pre>
+
+<pre><div class="buttons"><button class="fa fa-play play-button" title="Run this code" aria-label="Run this code"  onclick=" window.open('https://numbat.dev/?q=unit%5Fname%2820%20m%5E2%29')""></button></div><code class="language-nbt hljs numbat">unit_name(20 m^2)
+
+    = "m²"    [String]
+</code></pre>
+
+<pre><div class="buttons"><button class="fa fa-play play-button" title="Run this code" aria-label="Run this code"  onclick=" window.open('https://numbat.dev/?q=unit%5Fname%2820%20km%2Fh%29')""></button></div><code class="language-nbt hljs numbat">unit_name(20 km/h)
+
+    = "km/h"    [String]
+</code></pre>
+
+</details>
+
 ## Chemical elements
 
 Defined in: `chemistry::elements`

+ 35 - 8
examples/tests/core.nbt

@@ -1,3 +1,16 @@
+# value_of
+
+assert_eq(value_of(0),            0)
+
+assert_eq(value_of(1),            1)
+assert_eq(value_of(1.2345),       1.2345)
+
+assert_eq(value_of(1 m),          1)
+assert_eq(value_of(1.2345 m),     1.2345)
+
+assert_eq(value_of(1 m^2/s),      1)
+assert_eq(value_of(1.2345 m^2/s), 1.2345)
+
 # unit_of
 
 assert_eq(unit_of(1),            1)
@@ -9,18 +22,32 @@ assert_eq(unit_of(1.2345 m),     m)
 assert_eq(unit_of(1 m^2/s),      m^2/s)
 assert_eq(unit_of(1.2345 m^2/s), m^2/s)
 
-# value_of
+# has_unit
 
-assert_eq(value_of(0),            0)
+assert(has_unit(1 m, m))
+assert(has_unit(2 m, m))
 
-assert_eq(value_of(1),            1)
-assert_eq(value_of(1.2345),       1.2345)
+assert(!has_unit(1 m, cm))
+assert(!has_unit(1 m, km))
+assert(!has_unit(1 m, ft))
 
-assert_eq(value_of(1 m),          1)
-assert_eq(value_of(1.2345 m),     1.2345)
+assert(has_unit(0, m))
+assert(has_unit(0, cm))
+assert(has_unit(0, s))
 
-assert_eq(value_of(1 m^2/s),      1)
-assert_eq(value_of(1.2345 m^2/s), 1.2345)
+# is_dimensionless
+
+assert(is_dimensionless(0))
+assert(is_dimensionless(1))
+assert(!is_dimensionless(1 m))
+assert(!is_dimensionless(1 m/s))
+
+# unit_name
+
+assert_eq(unit_name(0), "")
+assert_eq(unit_name(1), "")
+assert_eq(unit_name(1 m), "m")
+assert_eq(unit_name(1 m^2/s), "m²/s")
 
 # round, round_in
 

+ 5 - 6
examples/tests/numerics.nbt

@@ -19,18 +19,17 @@ assert_eq(fixed_point(f_sqrt3, 1, 1e-10), sqrt(3), 1e-10)
 
 # Differentiation
 
-assert_eq(diff(log, 2.0), 0.5, 1e-5)
+assert_eq(diff(log, 2.0, 1e-10), 0.5, 1e-5)
 
-# TODO: Sadly, the following is not possible at the moment. See https://github.com/sharkdp/numbat/issues/521 for details
-# assert_eq(diff(sin, 0.0), 1.0, 1e-5)
+assert_eq(diff(sin, 0.0, 1e-10), 1.0, 1e-5)
 
-assert_eq(diff(sqrt, 1.0), 0.5, 1e-5)
+assert_eq(diff(sqrt, 1.0, 1e-10), 0.5, 1e-5)
 
 fn f2(x: Scalar) -> Scalar = x² + 4 x + 1
 
-assert_eq(diff(f2, 2.0), 8.0, 1e-5)
+assert_eq(diff(f2, 2.0, 1e-10), 8.0, 1e-5)
 
 fn dist(t: Time) -> Length = 0.5 g0 t^2
-fn velocity(t: Time) -> Velocity = diff(dist, t)
+fn velocity(t: Time) -> Velocity = diff(dist, t, 1e-10 s)
 
 assert_eq(velocity(2.0 s), 2.0 s × g0, 1e-3 m/s)

+ 16 - 0
numbat/modules/core/quantities.nbt

@@ -9,3 +9,19 @@ fn value_of<T: Dim>(x: T) -> Scalar
 @example("unit_of(20 km/h)")
 fn unit_of<T: Dim>(x: T) -> T = if x_value == 0 then error("Invalid argument: cannot call `unit_of` on a value that evaluates to 0") else x / value_of(x)
     where x_value = value_of(x)
+
+@description("Returns true if `quantity` has the same unit as `unit_query`, or if `quantity` evaluates to zero.")
+@example("has_unit(20 km/h, km/h)")
+@example("has_unit(20 km/h, m/s)")
+fn has_unit<T: Dim>(quantity: T, unit_query: T) -> Bool
+
+@description("Returns true if `quantity` is dimensionless, or if `quantity` is zero.")
+@example("is_dimensionless(10)")
+@example("is_dimensionless(10 km/h)")
+fn is_dimensionless<T: Dim>(quantity: T) -> Bool
+
+@description("Returns a string representation of the unit of `quantity`. Returns an empty string if `quantity` is dimensionless.")
+@example("unit_name(20)")
+@example("unit_name(20 m^2)")
+@example("unit_name(20 km/h)")
+fn unit_name<T: Dim>(quantity: T) -> String

+ 7 - 7
numbat/modules/datetime/functions.nbt

@@ -62,18 +62,18 @@ fn _add_years(dt: DateTime, n_years: Scalar) -> DateTime
 @description("Adds the given time span to a `DateTime`. This uses leap-year and DST-aware calendar arithmetic with variable-length days, months, and years.")
 @example("calendar_add(datetime(\"2022-07-20 21:52 +0200\"), 2 years)")
 fn calendar_add(dt: DateTime, span: Time) -> DateTime =
-   if span_unit == days
+   if span == 0
+     then dt
+   else if has_unit(span, days)
      then _add_days(dt, span / days)
-   else if span_unit == months
+   else if has_unit(span, months)
      then _add_months(dt, span / months)
-   else if span_unit == years
+   else if has_unit(span, years)
      then _add_years(dt, span / years)
-   else if span_unit == seconds || span_unit == minutes || span_unit == hours
+   else if has_unit(span, seconds) || has_unit(span, minutes) || has_unit(span, hours)
      then dt + span
    else
-     error("calendar_add: Unsupported unit: {span_unit}")
-  where
-    span_unit = unit_of(span)
+     error("calendar_add: Unsupported unit for `span`")
 
 @description("Subtract the given time span from a `DateTime`. This uses leap-year and DST-aware calendar arithmetic with variable-length days, months, and years.")
 @example("calendar_sub(datetime(\"2022-07-20 21:52 +0200\"), 3 years)")

+ 3 - 5
numbat/modules/numerics/diff.nbt

@@ -3,9 +3,7 @@ use core::quantities
 @name("Numerical differentiation")
 @url("https://en.wikipedia.org/wiki/Numerical_differentiation")
 @description("Compute the numerical derivative of the function $f$ at point $x$ using the central difference method.")
-@example("fn polynomial(x) = x² - x - 1\ndiff(polynomial, 1)", "Compute the derivative of $f(x) = x² -x -1$ at $x=1$.")
-@example("fn distance(t) = 0.5 g0 t²\nfn velocity(t) = diff(distance, t)\nvelocity(2 s)", "Compute the free fall velocity after $t=2 s$.")
-fn diff<X: Dim, Y: Dim>(f: Fn[(X) -> Y], x: X) -> Y / X =
+@example("fn polynomial(x) = x² - x - 1\ndiff(polynomial, 1, 1e-10)", "Compute the derivative of $f(x) = x² -x -1$ at $x=1$.")
+@example("fn distance(t) = 0.5 g0 t²\nfn velocity(t) = diff(distance, t, 1e-10 s)\nvelocity(2 s)", "Compute the free fall velocity after $t=2 s$.")
+fn diff<X: Dim, Y: Dim>(f: Fn[(X) -> Y], x: X, Δx: X) -> Y / X =
   (f(x + Δx) - f(x - Δx)) / 2 Δx
-  where
-    Δx = 1e-10 × unit_of(x)

+ 1 - 1
numbat/modules/plot/bar_chart.nbt

@@ -15,7 +15,7 @@ fn _default_label(n: Scalar) -> String = "{n}"
 fn bar_chart<A: Dim>(values: List<A>) -> BarChart =
   BarChart {
     value_label: "",
-    value_unit: if _is_scalar(head(values)) then "" else _unit_name(head(values)),
+    value_unit: unit_name(head(values)),
     values: map(value_of, values),
     x_labels: map(_default_label, range(1, len(values))),
   }

+ 0 - 4
numbat/modules/plot/common.nbt

@@ -1,10 +1,6 @@
 use core::quantities
 use core::strings
 
-fn _unit_name<A: Dim>(x: A) -> String = str_replace("1 ", "", "{unit_of(x)}")
-
-fn _is_scalar<A: Dim>(x: A) -> Bool = "{unit_of(x)}" == "1"
-
 # TODO: this function is overly generic, but we don't have bounded
 # polymorphism yet.
 fn show<Plot>(plot: Plot) -> String

+ 3 - 3
numbat/modules/plot/line_plot.nbt

@@ -15,10 +15,10 @@ let _num_points_for_line_plot = 2000
 fn line_plot<A: Dim, B: Dim>(f: Fn[(A) -> B], x_start: A, x_end: A) -> LinePlot =
   LinePlot {
     x_label: "",
-    x_unit: if _is_scalar(x_end) then "" else _unit_name(x_end),
+    x_unit: unit_name(x_end),
     y_label: "",
-    y_unit: if _is_scalar(f(x_end)) then "" else _unit_name(f(x_end)),
-    xs: linspace(x_start / unit_of(x_start), x_end / unit_of(x_end), _num_points_for_line_plot),
+    y_unit: unit_name(f(x_end)),
+    xs: linspace(value_of(x_start), value_of(x_end), _num_points_for_line_plot),
     ys: map(value_of, map(f, linspace(x_start, x_end, _num_points_for_line_plot))),
   }
 

+ 22 - 0
numbat/src/ffi/functions.rs

@@ -39,6 +39,9 @@ pub(crate) fn functions() -> &'static HashMap<String, ForeignFunction> {
         // Core
         insert_function!(error, 1..=1);
         insert_function!(value_of, 1..=1);
+        insert_function!(has_unit, 2..=2);
+        insert_function!(is_dimensionless, 1..=1);
+        insert_function!(unit_name, 1..=1);
 
         // Math
         insert_function!("mod", mod_, 2..=2);
@@ -126,3 +129,22 @@ fn value_of(mut args: Args) -> Result<Value> {
 
     return_scalar!(quantity.unsafe_value().to_f64())
 }
+
+fn has_unit(mut args: Args) -> Result<Value> {
+    let quantity = quantity_arg!(args);
+    let unit_query = quantity_arg!(args);
+
+    return_boolean!(quantity.is_zero() || quantity.unit() == unit_query.unit())
+}
+
+fn is_dimensionless(mut args: Args) -> Result<Value> {
+    let quantity = quantity_arg!(args);
+
+    return_boolean!(quantity.unit().is_scalar())
+}
+
+fn unit_name(mut args: Args) -> Result<Value> {
+    let quantity = quantity_arg!(args);
+
+    return_string!(from = &quantity.unit().to_string())
+}