Browse Source

Add 'assert(…)' procedure, closes #188

David Peter 2 years ago
parent
commit
81956076bd

+ 1 - 1
assets/numbat.sublime-syntax

@@ -7,7 +7,7 @@ file_extensions:
 scope: source.nbt
 contexts:
   main:
-    - match: \b(per|to|let|fn|dimension|unit|use|long|short|both|none|print|assert_eq|type|if|then|else|true|false|bool|str)\b
+    - match: \b(per|to|let|fn|dimension|unit|use|long|short|both|none|print|assert|assert_eq|type|if|then|else|true|false|bool|str)\b
       scope: keyword.control.nbt
     - match: '#(.*)'
       scope: comment.line.nbt

+ 1 - 1
assets/numbat.vim

@@ -5,7 +5,7 @@ if exists("b:current_syntax")
 endif
 
 " Numbat Keywords
-syn keyword numbatKeywords per to let fn dimension unit use long short both none print assert_eq type if then else true false bool str
+syn keyword numbatKeywords per to let fn dimension unit use long short both none print assert assert_eq type if then else true false bool str
 highlight default link numbatKeywords Keyword
 
 " Physical dimensions (every capitalized word)

+ 1 - 1
book/numbat.js

@@ -4,7 +4,7 @@ hljs.registerLanguage('numbat', function(hljs) {
     aliases: ['nbt'],
     case_insensitive: false,
     keywords: {
-      keyword: 'per to let fn dimension unit use long short both none print assert_eq type if then else true false bool str',
+      keyword: 'per to let fn dimension unit use long short both none print assert assert_eq type if then else true false bool str',
     },
     contains: [
       hljs.HASH_COMMENT_MODE,

+ 7 - 0
book/src/procedures.md

@@ -48,6 +48,13 @@ assert_eq(alpha, 1 / 137, 1e-4)
 assert_eq(3.3 ft, 1 m, 1 cm)
 ```
 
+There is also a plain `assert` procedure that can test any boolean condition. For example:
+
+```nbt
+assert(1 yard < 1 meter)
+assert(π != 3)
+```
+
 A runtime error is thrown if an assertion fails. Otherwise, nothing happens.
 
 ## Debugging

+ 11 - 7
examples/binomial_coefficient.nbt

@@ -6,13 +6,17 @@
 # TODO: This could really benefit from logical and/or operators
 
 fn binomial_coefficient(n: Scalar, k: Scalar) -> Scalar =
-    if or(k < 0, k > n)
-      then 0
-      else if k > n - k # Take advantage of symmetry
-        then binomial_coefficient(n, n - k)
-        else if or(k == 0, n <= 1)
-          then 1
-          else binomial_coefficient(n - 1, k) + binomial_coefficient(n - 1, k - 1)
+    if k < 0
+        then 0
+        else if k > n
+            then 0
+            else if k > n - k # Take advantage of symmetry
+                then binomial_coefficient(n, n - k)
+                else if k == 0
+                    then 1
+                    else if n <= 1
+                        then 1
+                        else binomial_coefficient(n - 1, k) + binomial_coefficient(n - 1, k - 1)
 
 assert_eq(binomial_coefficient(10, 0), 1)
 assert_eq(binomial_coefficient(10, 1), 10)

+ 16 - 0
examples/booleans.nbt

@@ -0,0 +1,16 @@
+fn not(a: bool) = if a then false else true
+fn and(a: bool, b: bool) = if a then b else false
+fn or(a: bool, b: bool) = if a then true else b
+
+assert(not(false))
+assert(not(not(true)))
+
+assert(and(true, true))
+assert(not(and(true, false)))
+assert(not(and(false, true)))
+assert(not(and(false, false)))
+
+assert(or(true, true))
+assert(or(true, false))
+assert(or(false, true))
+assert(not(or(false, false)))

+ 0 - 2
numbat/modules/core/booleans.nbt

@@ -1,2 +0,0 @@
-fn and(a: bool, b: bool) = if a then b else false
-fn or(a: bool, b: bool) = if a then true else b

+ 0 - 1
numbat/modules/prelude.nbt

@@ -1,7 +1,6 @@
 use core::scalar
 use core::quantities
 use core::dimensions
-use core::booleans
 use core::strings
 
 use math::constants

+ 1 - 0
numbat/src/ast.rs

@@ -282,6 +282,7 @@ impl PrettyPrint for DimensionExpression {
 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
 pub enum ProcedureKind {
     Print,
+    Assert,
     AssertEq,
     Type,
 }

+ 9 - 0
numbat/src/diagnostic.rs

@@ -226,6 +226,15 @@ impl ErrorDiagnostic for TypeCheckError {
                     "Incompatible types in 'then' and 'else' branches of conditional",
                 ),
             ]),
+            TypeCheckError::IncompatibleTypeInAssert(procedure_span, type_, type_span) => d
+                .with_labels(vec![
+                    type_span
+                        .diagnostic_label(LabelStyle::Secondary)
+                        .with_message(type_.to_string()),
+                    procedure_span
+                        .diagnostic_label(LabelStyle::Primary)
+                        .with_message("Non-boolean type in 'assert' call"),
+                ]),
             TypeCheckError::IncompatibleTypesInAssertEq(
                 procedure_span,
                 first_type,

+ 18 - 0
numbat/src/ffi.rs

@@ -41,6 +41,14 @@ pub(crate) fn procedures() -> &'static HashMap<ProcedureKind, ForeignFunction> {
                 callable: Callable::Procedure(print),
             },
         );
+        m.insert(
+            ProcedureKind::Assert,
+            ForeignFunction {
+                name: "assert".into(),
+                arity: 1..=1,
+                callable: Callable::Procedure(assert),
+            },
+        );
         m.insert(
             ProcedureKind::AssertEq,
             ForeignFunction {
@@ -319,6 +327,16 @@ fn print(ctx: &mut ExecutionContext, args: &[Value]) -> ControlFlow {
     ControlFlow::Continue(())
 }
 
+fn assert(_: &mut ExecutionContext, args: &[Value]) -> ControlFlow {
+    assert!(args.len() == 1);
+
+    if args[0].unsafe_as_bool() {
+        ControlFlow::Continue(())
+    } else {
+        ControlFlow::Break(RuntimeError::AssertFailed)
+    }
+}
+
 fn assert_eq(_: &mut ExecutionContext, args: &[Value]) -> ControlFlow {
     assert!(args.len() == 2 || args.len() == 3);
 

+ 2 - 0
numbat/src/interpreter.rs

@@ -21,6 +21,8 @@ pub enum RuntimeError {
     UnitRegistryError(UnitRegistryError), // TODO: can this even be triggered?
     #[error("{0}")]
     QuantityError(QuantityError),
+    #[error("Assertion failed")]
+    AssertFailed,
     #[error(
         "Assertion failed because the following two quantities are not the same:\n  {0}\n  {1}"
     )]

+ 3 - 1
numbat/src/parser.rs

@@ -11,7 +11,7 @@
 //! dimension_decl  ::=   "dimension" identifier ( "=" dimension_expr ) *
 //! unit_decl       ::=   decorator * "unit" ( ":" dimension_expr ) ? ( "=" expression ) ?
 //! module_import   ::=   "use" ident ( "::" ident) *
-//! procedure_call  ::=   ( "print" | "assert_eq" | "type" ) "(" arguments? ")"
+//! procedure_call  ::=   ( "print" | "assert" | "assert_eq" | "type" ) "(" arguments? ")"
 //!
 //! decorator       ::=   "@" ( "metric_prefixes" | "binary_prefixes" | ( "aliases(" list_of_aliases ")" ) )
 //!
@@ -207,6 +207,7 @@ type Result<T> = std::result::Result<T, ParseError>;
 
 static PROCEDURES: &[TokenKind] = &[
     TokenKind::ProcedurePrint,
+    TokenKind::ProcedureAssert,
     TokenKind::ProcedureAssertEq,
     TokenKind::ProcedureType,
 ];
@@ -632,6 +633,7 @@ impl<'a> Parser<'a> {
             let span = self.last().unwrap().span;
             let procedure_kind = match self.last().unwrap().kind {
                 TokenKind::ProcedurePrint => ProcedureKind::Print,
+                TokenKind::ProcedureAssert => ProcedureKind::Assert,
                 TokenKind::ProcedureAssertEq => ProcedureKind::AssertEq,
                 TokenKind::ProcedureType => ProcedureKind::Type,
                 _ => unreachable!(),

+ 2 - 0
numbat/src/tokenizer.rs

@@ -99,6 +99,7 @@ pub enum TokenKind {
 
     // Procedure calls
     ProcedurePrint,
+    ProcedureAssert,
     ProcedureAssertEq,
     ProcedureType,
 
@@ -330,6 +331,7 @@ impl Tokenizer {
             m.insert("both", TokenKind::Both);
             m.insert("none", TokenKind::None);
             m.insert("print", TokenKind::ProcedurePrint);
+            m.insert("assert", TokenKind::ProcedureAssert);
             m.insert("assert_eq", TokenKind::ProcedureAssertEq);
             m.insert("type", TokenKind::ProcedureType);
             m.insert("bool", TokenKind::Bool);

+ 12 - 0
numbat/src/typechecker.rs

@@ -242,6 +242,9 @@ pub enum TypeCheckError {
     #[error("Incompatible types in condition")]
     IncompatibleTypesInCondition(Span, Type, Span, Type, Span),
 
+    #[error("Argument types in assert call must be boolean")]
+    IncompatibleTypeInAssert(Span, Type, Span),
+
     #[error("Argument types in assert_eq calls must match")]
     IncompatibleTypesInAssertEq(Span, Type, Span, Type, Span),
 
@@ -1208,6 +1211,15 @@ impl TypeChecker {
                     ProcedureKind::Print => {
                         // no argument type checks required, everything can be printed
                     }
+                    ProcedureKind::Assert => {
+                        if checked_args[0].get_type() != Type::Boolean {
+                            return Err(TypeCheckError::IncompatibleTypeInAssert(
+                                *span,
+                                checked_args[0].get_type(),
+                                checked_args[0].full_span(),
+                            ));
+                        }
+                    }
                     ProcedureKind::AssertEq => {
                         let type_first = dtype(&checked_args[0])?;
                         for arg in &checked_args[1..] {

+ 1 - 0
numbat/src/typed_ast.rs

@@ -350,6 +350,7 @@ impl PrettyPrint for Statement {
             Statement::ProcedureCall(kind, args) => {
                 let identifier = match kind {
                     ProcedureKind::Print => "print",
+                    ProcedureKind::Assert => "assert",
                     ProcedureKind::AssertEq => "assert_eq",
                     ProcedureKind::Type => "type",
                 };

+ 2 - 2
numbat/src/value.rs

@@ -16,9 +16,9 @@ impl Value {
         }
     }
 
-    pub fn unsafe_as_bool(self) -> bool {
+    pub fn unsafe_as_bool(&self) -> bool {
         if let Value::Boolean(b) = self {
-            b
+            *b
         } else {
             panic!("Expected value to be a bool");
         }

+ 1 - 1
vscode-extension/syntaxes/numbat.tmLanguage.json

@@ -32,7 +32,7 @@
             "patterns": [
                 {
                     "name": "keyword.control.numbat",
-                    "match": "\\b(per|to|let|fn|dimension|unit|use|long|short|both|none|print|assert_eq|type|if|then|else|true|false|bool|str)\\b"
+                    "match": "\\b(per|to|let|fn|dimension|unit|use|long|short|both|none|print|assert|assert_eq|type|if|then|else|true|false|bool|str)\\b"
                 }
             ]
         },