Procházet zdrojové kódy

Enable printing of boolean values

David Peter před 2 roky
rodič
revize
8483f2a7aa
4 změnil soubory, kde provedl 189 přidání a 122 odebrání
  1. 180 116
      numbat/src/ffi.rs
  2. 1 1
      numbat/src/interpreter.rs
  3. 2 2
      numbat/src/value.rs
  4. 6 3
      numbat/src/vm.rs

+ 180 - 116
numbat/src/ffi.rs

@@ -4,6 +4,7 @@ use std::sync::OnceLock;
 
 use crate::currency::ExchangeRatesCache;
 use crate::interpreter::RuntimeError;
+use crate::value::Value;
 use crate::vm::ExecutionContext;
 use crate::{ast::ProcedureKind, quantity::Quantity};
 
@@ -11,11 +12,11 @@ type ControlFlow = std::ops::ControlFlow<RuntimeError>;
 
 pub(crate) type ArityRange = std::ops::RangeInclusive<usize>;
 
-type BoxedFunction = Box<dyn Fn(&[Quantity]) -> Quantity + Send + Sync>;
+type BoxedFunction = Box<dyn Fn(&[Value]) -> Value + Send + Sync>;
 
 pub(crate) enum Callable {
     Function(BoxedFunction),
-    Procedure(fn(&mut ExecutionContext, &[Quantity]) -> ControlFlow),
+    Procedure(fn(&mut ExecutionContext, &[Value]) -> ControlFlow),
 }
 
 pub(crate) struct ForeignFunction {
@@ -293,7 +294,7 @@ pub(crate) fn functions() -> &'static HashMap<String, ForeignFunction> {
     })
 }
 
-fn print(ctx: &mut ExecutionContext, args: &[Quantity]) -> ControlFlow {
+fn print(ctx: &mut ExecutionContext, args: &[Value]) -> ControlFlow {
     assert!(args.len() == 1);
 
     (ctx.print_fn)(&format!("{}\n", args[0]));
@@ -301,16 +302,16 @@ fn print(ctx: &mut ExecutionContext, args: &[Quantity]) -> ControlFlow {
     ControlFlow::Continue(())
 }
 
-fn assert_eq(_: &mut ExecutionContext, args: &[Quantity]) -> ControlFlow {
+fn assert_eq(_: &mut ExecutionContext, args: &[Value]) -> ControlFlow {
     assert!(args.len() == 2 || args.len() == 3);
 
+    let lhs = args[0].unsafe_as_quantity();
+    let rhs = args[1].unsafe_as_quantity();
+
     if args.len() == 2 {
-        let error = ControlFlow::Break(RuntimeError::AssertEq2Failed(
-            args[0].clone(),
-            args[1].clone(),
-        ));
-        if let Ok(args1_converted) = args[1].convert_to(args[0].unit()) {
-            if args[0] == args1_converted {
+        let error = ControlFlow::Break(RuntimeError::AssertEq2Failed(lhs.clone(), rhs.clone()));
+        if let Ok(args1_converted) = rhs.convert_to(lhs.unit()) {
+            if lhs == &args1_converted {
                 ControlFlow::Continue(())
             } else {
                 error
@@ -319,21 +320,20 @@ fn assert_eq(_: &mut ExecutionContext, args: &[Quantity]) -> ControlFlow {
             error
         }
     } else {
-        let result = &args[0] - &args[1];
+        let result = lhs - rhs;
+        let eps = args[2].unsafe_as_quantity();
 
         match result {
-            Ok(diff) => match diff.convert_to(args[2].unit()) {
+            Ok(diff) => match diff.convert_to(eps.unit()) {
                 Err(e) => ControlFlow::Break(RuntimeError::QuantityError(e)),
                 Ok(diff_converted) => {
-                    if diff_converted.unsafe_value().to_f64().abs()
-                        < args[2].unsafe_value().to_f64()
-                    {
+                    if diff_converted.unsafe_value().to_f64().abs() < eps.unsafe_value().to_f64() {
                         ControlFlow::Continue(())
                     } else {
                         ControlFlow::Break(RuntimeError::AssertEq3Failed(
-                            args[0].clone(),
-                            args[1].clone(),
-                            args[2].clone(),
+                            lhs.clone(),
+                            rhs.clone(),
+                            eps.clone(),
                         ))
                     }
                 }
@@ -343,225 +343,289 @@ fn assert_eq(_: &mut ExecutionContext, args: &[Quantity]) -> ControlFlow {
     }
 }
 
-fn unit_of(args: &[Quantity]) -> Quantity {
+fn unit_of(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    Quantity::new_f64(1.0, args[0].unit().clone())
+    Value::Quantity(Quantity::new_f64(
+        1.0,
+        args[0].unsafe_as_quantity().unit().clone(),
+    ))
 }
 
-fn abs(args: &[Quantity]) -> Quantity {
+fn abs(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let value = args[0].unsafe_value().to_f64();
-    Quantity::new_f64(value.abs(), args[0].unit().clone())
+    let arg = args[0].unsafe_as_quantity();
+
+    let value = arg.unsafe_value().to_f64();
+    Value::Quantity(Quantity::new_f64(value.abs(), arg.unit().clone()))
 }
 
-fn round(args: &[Quantity]) -> Quantity {
+fn round(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let value = args[0].unsafe_value().to_f64();
-    Quantity::new_f64(value.round(), args[0].unit().clone())
+    let arg = args[0].unsafe_as_quantity();
+
+    let value = arg.unsafe_value().to_f64();
+    Value::Quantity(Quantity::new_f64(value.round(), arg.unit().clone()))
 }
 
-fn floor(args: &[Quantity]) -> Quantity {
+fn floor(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let value = args[0].unsafe_value().to_f64();
-    Quantity::new_f64(value.floor(), args[0].unit().clone())
+    let arg = args[0].unsafe_as_quantity();
+
+    let value = arg.unsafe_value().to_f64();
+    Value::Quantity(Quantity::new_f64(value.floor(), arg.unit().clone()))
 }
 
-fn ceil(args: &[Quantity]) -> Quantity {
+fn ceil(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let value = args[0].unsafe_value().to_f64();
-    Quantity::new_f64(value.ceil(), args[0].unit().clone())
+    let arg = args[0].unsafe_as_quantity();
+
+    let value = arg.unsafe_value().to_f64();
+    Value::Quantity(Quantity::new_f64(value.ceil(), arg.unit().clone()))
 }
 
-fn sin(args: &[Quantity]) -> Quantity {
+fn sin(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let input = args[0].as_scalar().unwrap().to_f64();
-    Quantity::from_scalar(input.sin())
+    let arg = args[0].unsafe_as_quantity();
+
+    let input = arg.as_scalar().unwrap().to_f64();
+    Value::Quantity(Quantity::from_scalar(input.sin()))
 }
 
-fn cos(args: &[Quantity]) -> Quantity {
+fn cos(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let input = args[0].as_scalar().unwrap().to_f64();
-    Quantity::from_scalar(input.cos())
+    let arg = args[0].unsafe_as_quantity();
+
+    let input = arg.as_scalar().unwrap().to_f64();
+    Value::Quantity(Quantity::from_scalar(input.cos()))
 }
 
-fn tan(args: &[Quantity]) -> Quantity {
+fn tan(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let input = args[0].as_scalar().unwrap().to_f64();
-    Quantity::from_scalar(input.tan())
+    let arg = args[0].unsafe_as_quantity();
+
+    let input = arg.as_scalar().unwrap().to_f64();
+    Value::Quantity(Quantity::from_scalar(input.tan()))
 }
 
-fn asin(args: &[Quantity]) -> Quantity {
+fn asin(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let input = args[0].as_scalar().unwrap().to_f64();
-    Quantity::from_scalar(input.asin())
+    let arg = args[0].unsafe_as_quantity();
+
+    let input = arg.as_scalar().unwrap().to_f64();
+    Value::Quantity(Quantity::from_scalar(input.asin()))
 }
 
-fn acos(args: &[Quantity]) -> Quantity {
+fn acos(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let input = args[0].as_scalar().unwrap().to_f64();
-    Quantity::from_scalar(input.acos())
+    let arg = args[0].unsafe_as_quantity();
+
+    let input = arg.as_scalar().unwrap().to_f64();
+    Value::Quantity(Quantity::from_scalar(input.acos()))
 }
 
-fn atan(args: &[Quantity]) -> Quantity {
+fn atan(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let input = args[0].as_scalar().unwrap().to_f64();
-    Quantity::from_scalar(input.atan())
+    let arg = args[0].unsafe_as_quantity();
+
+    let input = arg.as_scalar().unwrap().to_f64();
+    Value::Quantity(Quantity::from_scalar(input.atan()))
 }
 
-fn atan2(args: &[Quantity]) -> Quantity {
+fn atan2(args: &[Value]) -> Value {
     assert!(args.len() == 2);
 
-    let input0 = args[0].unsafe_value().to_f64();
-    let input1 = args[1]
-        .convert_to(args[0].unit())
-        .unwrap()
-        .unsafe_value()
-        .to_f64();
-    Quantity::from_scalar(input0.atan2(input1))
+    let y = args[0].unsafe_as_quantity();
+    let x = args[1].unsafe_as_quantity();
+
+    let input0 = y.unsafe_value().to_f64();
+    let input1 = x.convert_to(y.unit()).unwrap().unsafe_value().to_f64();
+    Value::Quantity(Quantity::from_scalar(input0.atan2(input1)))
 }
 
-fn sinh(args: &[Quantity]) -> Quantity {
+fn sinh(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let input = args[0].as_scalar().unwrap().to_f64();
-    Quantity::from_scalar(input.sinh())
+    let arg = args[0].unsafe_as_quantity();
+
+    let input = arg.as_scalar().unwrap().to_f64();
+    Value::Quantity(Quantity::from_scalar(input.sinh()))
 }
 
-fn cosh(args: &[Quantity]) -> Quantity {
+fn cosh(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let input = args[0].as_scalar().unwrap().to_f64();
-    Quantity::from_scalar(input.cosh())
+    let arg = args[0].unsafe_as_quantity();
+
+    let input = arg.as_scalar().unwrap().to_f64();
+    Value::Quantity(Quantity::from_scalar(input.cosh()))
 }
 
-fn tanh(args: &[Quantity]) -> Quantity {
+fn tanh(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let input = args[0].as_scalar().unwrap().to_f64();
-    Quantity::from_scalar(input.tanh())
+    let arg = args[0].unsafe_as_quantity();
+
+    let input = arg.as_scalar().unwrap().to_f64();
+    Value::Quantity(Quantity::from_scalar(input.tanh()))
 }
 
-fn asinh(args: &[Quantity]) -> Quantity {
+fn asinh(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let input = args[0].as_scalar().unwrap().to_f64();
-    Quantity::from_scalar(input.asinh())
+    let arg = args[0].unsafe_as_quantity();
+
+    let input = arg.as_scalar().unwrap().to_f64();
+    Value::Quantity(Quantity::from_scalar(input.asinh()))
 }
 
-fn acosh(args: &[Quantity]) -> Quantity {
+fn acosh(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let input = args[0].as_scalar().unwrap().to_f64();
-    Quantity::from_scalar(input.acosh())
+    let arg = args[0].unsafe_as_quantity();
+
+    let input = arg.as_scalar().unwrap().to_f64();
+    Value::Quantity(Quantity::from_scalar(input.acosh()))
 }
 
-fn atanh(args: &[Quantity]) -> Quantity {
+fn atanh(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let input = args[0].as_scalar().unwrap().to_f64();
-    Quantity::from_scalar(input.atanh())
+    let arg = args[0].unsafe_as_quantity();
+
+    let input = arg.as_scalar().unwrap().to_f64();
+    Value::Quantity(Quantity::from_scalar(input.atanh()))
 }
 
-fn mod_(args: &[Quantity]) -> Quantity {
+fn mod_(args: &[Value]) -> Value {
     assert!(args.len() == 2);
 
-    let input0 = args[0].unsafe_value().to_f64();
-    let input1 = args[1]
-        .convert_to(args[0].unit())
-        .unwrap()
-        .unsafe_value()
-        .to_f64();
-    Quantity::new_f64(input0.rem_euclid(input1), args[0].unit().clone())
+    let x = args[0].unsafe_as_quantity();
+    let y = args[1].unsafe_as_quantity();
+
+    let input0 = x.unsafe_value().to_f64();
+    let input1 = y.convert_to(x.unit()).unwrap().unsafe_value().to_f64();
+    Value::Quantity(Quantity::new_f64(
+        input0.rem_euclid(input1),
+        x.unit().clone(),
+    ))
 }
 
-fn exp(args: &[Quantity]) -> Quantity {
+fn exp(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let input = args[0].as_scalar().unwrap().to_f64();
-    Quantity::from_scalar(input.exp())
+    let arg = args[0].unsafe_as_quantity();
+
+    let input = arg.as_scalar().unwrap().to_f64();
+    Value::Quantity(Quantity::from_scalar(input.exp()))
 }
 
-fn ln(args: &[Quantity]) -> Quantity {
+fn ln(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let input = args[0].as_scalar().unwrap().to_f64();
-    Quantity::from_scalar(input.ln())
+    let arg = args[0].unsafe_as_quantity();
+
+    let input = arg.as_scalar().unwrap().to_f64();
+    Value::Quantity(Quantity::from_scalar(input.ln()))
 }
 
-fn log10(args: &[Quantity]) -> Quantity {
+fn log10(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let input = args[0].as_scalar().unwrap().to_f64();
-    Quantity::from_scalar(input.log10())
+    let arg = args[0].unsafe_as_quantity();
+
+    let input = arg.as_scalar().unwrap().to_f64();
+    Value::Quantity(Quantity::from_scalar(input.log10()))
 }
 
-fn log2(args: &[Quantity]) -> Quantity {
+fn log2(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let input = args[0].as_scalar().unwrap().to_f64();
-    Quantity::from_scalar(input.log2())
+    let arg = args[0].unsafe_as_quantity();
+
+    let input = arg.as_scalar().unwrap().to_f64();
+    Value::Quantity(Quantity::from_scalar(input.log2()))
 }
 
-fn gamma(args: &[Quantity]) -> Quantity {
+fn gamma(args: &[Value]) -> Value {
     assert!(args.len() == 1);
 
-    let input = args[0].as_scalar().unwrap().to_f64();
-    Quantity::from_scalar(crate::gamma::gamma(input))
+    let arg = args[0].unsafe_as_quantity();
+
+    let input = arg.as_scalar().unwrap().to_f64();
+    Value::Quantity(Quantity::from_scalar(crate::gamma::gamma(input)))
 }
 
-fn mean(args: &[Quantity]) -> Quantity {
+fn mean(args: &[Value]) -> Value {
     assert!(!args.is_empty());
 
-    let output_unit = args[0].unit();
-    Quantity::new_f64(
+    let output_unit = args[0].unsafe_as_quantity().unit();
+    Value::Quantity(Quantity::new_f64(
         args.iter()
-            .map(|q| q.convert_to(output_unit).unwrap().unsafe_value().to_f64())
+            .map(|q| {
+                q.unsafe_as_quantity()
+                    .convert_to(output_unit)
+                    .unwrap()
+                    .unsafe_value()
+                    .to_f64()
+            })
             .sum::<f64>()
             / (args.len() as f64),
         output_unit.clone(),
-    )
+    ))
 }
 
-fn maximum(args: &[Quantity]) -> Quantity {
+fn maximum(args: &[Value]) -> Value {
     assert!(!args.is_empty());
 
-    let output_unit = args[0].unit();
-    Quantity::new(
+    let output_unit = args[0].unsafe_as_quantity().unit();
+    Value::Quantity(Quantity::new(
         args.iter()
-            .map(|q| *q.convert_to(output_unit).unwrap().unsafe_value())
+            .map(|q| {
+                *q.unsafe_as_quantity()
+                    .convert_to(output_unit)
+                    .unwrap()
+                    .unsafe_value()
+            })
             .max_by(|l, r| l.partial_cmp(r).unwrap())
             .unwrap(),
         output_unit.clone(),
-    )
+    ))
 }
 
-fn minimum(args: &[Quantity]) -> Quantity {
+fn minimum(args: &[Value]) -> Value {
     assert!(!args.is_empty());
 
-    let output_unit = args[0].unit();
-    Quantity::new(
+    let output_unit = args[0].unsafe_as_quantity().unit();
+    Value::Quantity(Quantity::new(
         args.iter()
-            .map(|q| *q.convert_to(output_unit).unwrap().unsafe_value())
+            .map(|q| {
+                *q.unsafe_as_quantity()
+                    .convert_to(output_unit)
+                    .unwrap()
+                    .unsafe_value()
+            })
             .min_by(|l, r| l.partial_cmp(r).unwrap())
             .unwrap(),
         output_unit.clone(),
-    )
+    ))
 }
 
 fn exchange_rate(rate: &'static str) -> BoxedFunction {
-    Box::new(|_args: &[Quantity]| -> Quantity {
+    Box::new(|_args: &[Value]| -> Value {
         let exchange_rates = ExchangeRatesCache::new();
-        Quantity::from_scalar(exchange_rates.get_rate(rate).unwrap_or(f64::NAN))
+        Value::Quantity(Quantity::from_scalar(
+            exchange_rates.get_rate(rate).unwrap_or(f64::NAN),
+        ))
     })
 }

+ 1 - 1
numbat/src/interpreter.rs

@@ -134,7 +134,7 @@ mod tests {
     fn assert_evaluates_to(input: &str, expected: Quantity) {
         if let InterpreterResult::Value(actual) = get_interpreter_result(input).unwrap() {
             let actual = actual.unsafe_as_quantity();
-            assert_eq!(actual, expected);
+            assert_eq!(actual, &expected);
         } else {
             panic!();
         }

+ 2 - 2
numbat/src/value.rs

@@ -7,9 +7,9 @@ pub enum Value {
 }
 
 impl Value {
-    pub fn unsafe_as_quantity(self) -> Quantity {
+    pub fn unsafe_as_quantity(&self) -> &Quantity {
         if let Value::Quantity(q) = self {
-            q
+            &q
         } else {
             panic!("Expected value to be a quantity");
         }

+ 6 - 3
numbat/src/vm.rs

@@ -460,7 +460,10 @@ impl Vm {
     }
 
     fn pop_quantity(&mut self) -> Quantity {
-        self.pop().unsafe_as_quantity()
+        match self.pop() {
+            Value::Quantity(q) => q,
+            _ => panic!("Expected quantity to be on the top of the stack"),
+        }
     }
 
     fn pop_bool(&mut self) -> bool {
@@ -637,14 +640,14 @@ impl Vm {
 
                     let mut args = vec![];
                     for _ in 0..num_args {
-                        args.push(self.pop_quantity());
+                        args.push(self.pop());
                     }
                     args.reverse(); // TODO: use a deque?
 
                     match &self.ffi_callables[function_idx].callable {
                         Callable::Function(function) => {
                             let result = (function)(&args[..]);
-                            self.push_quantity(result);
+                            self.push(result);
                         }
                         Callable::Procedure(procedure) => {
                             let result = (procedure)(ctx, &args[..]);