Jelajahi Sumber

Fix assert/assert_eq

David Peter 1 tahun lalu
induk
melakukan
e92120e7ba
2 mengubah file dengan 22 tambahan dan 6 penghapusan
  1. 12 3
      examples/numerical_diff.nbt
  2. 10 3
      numbat/src/typechecker/mod.rs

+ 12 - 3
examples/numerical_diff.nbt

@@ -3,9 +3,18 @@ let eps = 1e-10
 fn diff<X, Y>(f: Fn[(X) -> Y], x: X) -> Y / X =
   (f(x + eps · unit_of(x)) - f(x)) / (eps · unit_of(x))
 
-# assert_eq(diff(log, 2.0), 0.5, 1e-5)
-# assert_eq(diff(sin, 0.0), 1.0, 1e-5)
+assert_eq(diff(log, 2.0), 0.5, 1e-5)
+assert_eq(diff(sin, 0.0), 1.0, 1e-5)
+
+assert_eq(diff(sqrt, 1.0), 0.5, 1e-5)
 
 fn f(x: Scalar) -> Scalar = x * x + 4 * x + 1
 
-# assert_eq(diff(f, 2.0), 8.0, 1e-5)
+assert_eq(diff(f, 2.0), 8.0, 1e-5)
+
+fn dist(t: Time) -> Length = 0.5 g0 t^2
+fn velocity(t: Time) -> Velocity = diff(dist, t)
+
+assert_eq(velocity(2.0 s), 2.0 s × g0, 1e-3 m/s)
+
+diff

+ 10 - 3
numbat/src/typechecker/mod.rs

@@ -1561,10 +1561,17 @@ impl TypeChecker {
                         }
                     }
                     ProcedureKind::AssertEq => {
-                        let type_first = dtype(&checked_args[0])?;
+                        let type_first = &checked_args[0].get_type();
+                        self.enforce_dtype(type_first, checked_args[0].full_span())?;
+
                         for arg in &checked_args[1..] {
-                            let type_arg = dtype(arg)?;
-                            if type_arg != type_first {
+                            let type_arg = arg.get_type();
+                            self.enforce_dtype(&type_arg, arg.full_span())?;
+
+                            if self
+                                .add_equal_constraint(type_first, &type_arg)
+                                .is_trivially_violated()
+                            {
                                 return Err(TypeCheckError::IncompatibleTypesInAssertEq(
                                     *span,
                                     checked_args[0].get_type(),