Browse Source

Add comparison of strings, bools

David Peter 2 years ago
parent
commit
bee36b7523

+ 7 - 0
numbat/modules/core/strings.nbt

@@ -6,6 +6,13 @@ fn str_slice(s: str, start: Scalar, end: Scalar) -> str
 
 
 fn str_append(a: str, b: str) -> str = "{a}{b}"
 fn str_append(a: str, b: str) -> str = "{a}{b}"
 
 
+fn str_contains(haystack: str, needle: str) -> bool =
+  if str_length(haystack) == 0
+    then false
+    else if str_slice(haystack, 0, str_length(needle)) == needle
+      then true
+      else str_contains(str_slice(haystack, 1, str_length(haystack)), needle)
+
 fn str_repeat(a: str, n: Scalar) -> str =
 fn str_repeat(a: str, n: Scalar) -> str =
   if n > 0
   if n > 0
     then str_append(a, str_repeat(a, n - 1))
     then str_append(a, str_repeat(a, n - 1))

+ 17 - 0
numbat/src/diagnostic.rs

@@ -226,6 +226,23 @@ impl ErrorDiagnostic for TypeCheckError {
                     "Incompatible types in 'then' and 'else' branches of conditional",
                     "Incompatible types in 'then' and 'else' branches of conditional",
                 ),
                 ),
             ]),
             ]),
+            TypeCheckError::IncompatibleTypesInComparison(
+                op_span,
+                lhs_type,
+                lhs_span,
+                rhs_type,
+                rhs_span,
+            ) => d.with_labels(vec![
+                lhs_span
+                    .diagnostic_label(LabelStyle::Secondary)
+                    .with_message(lhs_type.to_string()),
+                rhs_span
+                    .diagnostic_label(LabelStyle::Secondary)
+                    .with_message(rhs_type.to_string()),
+                op_span
+                    .diagnostic_label(LabelStyle::Primary)
+                    .with_message("Incompatible types comparison operator"),
+            ]),
             TypeCheckError::IncompatibleTypeInAssert(procedure_span, type_, type_span) => d
             TypeCheckError::IncompatibleTypeInAssert(procedure_span, type_, type_span) => d
                 .with_labels(vec![
                 .with_labels(vec![
                     type_span
                     type_span

+ 23 - 3
numbat/src/typechecker.rs

@@ -250,6 +250,9 @@ pub enum TypeCheckError {
 
 
     #[error("Incompatible types in {0}")]
     #[error("Incompatible types in {0}")]
     IncompatibleTypesInAnnotation(String, Span, Type, Span, Type, Span),
     IncompatibleTypesInAnnotation(String, Span, Type, Span, Type, Span),
+
+    #[error("Incompatible types in comparison operator")]
+    IncompatibleTypesInComparison(Span, Type, Span, Type, Span),
 }
 }
 
 
 type Result<T> = std::result::Result<T, TypeCheckError>;
 type Result<T> = std::result::Result<T, TypeCheckError>;
@@ -525,12 +528,29 @@ impl TypeChecker {
                     typed_ast::BinaryOperator::LessThan
                     typed_ast::BinaryOperator::LessThan
                     | typed_ast::BinaryOperator::GreaterThan
                     | typed_ast::BinaryOperator::GreaterThan
                     | typed_ast::BinaryOperator::LessOrEqual
                     | typed_ast::BinaryOperator::LessOrEqual
-                    | typed_ast::BinaryOperator::GreaterOrEqual
-                    | typed_ast::BinaryOperator::Equal
-                    | typed_ast::BinaryOperator::NotEqual => {
+                    | typed_ast::BinaryOperator::GreaterOrEqual => {
                         let _ = get_type_and_assert_equality()?;
                         let _ = get_type_and_assert_equality()?;
                         Type::Boolean
                         Type::Boolean
                     }
                     }
+                    typed_ast::BinaryOperator::Equal | typed_ast::BinaryOperator::NotEqual => {
+                        let lhs_type = lhs_checked.get_type();
+                        let rhs_type = rhs_checked.get_type();
+                        if lhs_type.is_dtype() || rhs_type.is_dtype() {
+                            let _ = get_type_and_assert_equality()?;
+                        } else {
+                            if lhs_type != rhs_type {
+                                return Err(TypeCheckError::IncompatibleTypesInComparison(
+                                    span_op.unwrap(),
+                                    lhs_type,
+                                    lhs.full_span(),
+                                    rhs_type,
+                                    rhs.full_span(),
+                                ));
+                            }
+                        }
+
+                        Type::Boolean
+                    }
                 };
                 };
 
 
                 typed_ast::Expression::BinaryOperator(
                 typed_ast::Expression::BinaryOperator(

+ 4 - 0
numbat/src/typed_ast.rs

@@ -78,6 +78,10 @@ impl Type {
     pub fn scalar() -> Type {
     pub fn scalar() -> Type {
         Type::Dimension(DType::unity())
         Type::Dimension(DType::unity())
     }
     }
+
+    pub fn is_dtype(&self) -> bool {
+        matches!(self, Type::Dimension(..))
+    }
 }
 }
 
 
 #[derive(Debug, Clone, PartialEq)]
 #[derive(Debug, Clone, PartialEq)]

+ 11 - 7
numbat/src/vm.rs

@@ -595,12 +595,7 @@ impl Vm {
                     };
                     };
                     self.push_quantity(result.map_err(RuntimeError::QuantityError)?);
                     self.push_quantity(result.map_err(RuntimeError::QuantityError)?);
                 }
                 }
-                op @ (Op::LessThan
-                | Op::GreaterThan
-                | Op::LessOrEqual
-                | Op::GreatorOrEqual
-                | Op::Equal
-                | Op::NotEqual) => {
+                op @ (Op::LessThan | Op::GreaterThan | Op::LessOrEqual | Op::GreatorOrEqual) => {
                     let rhs = self.pop_quantity();
                     let rhs = self.pop_quantity();
                     let lhs = self.pop_quantity();
                     let lhs = self.pop_quantity();
 
 
@@ -609,11 +604,20 @@ impl Vm {
                         Op::GreaterThan => lhs > rhs,
                         Op::GreaterThan => lhs > rhs,
                         Op::LessOrEqual => lhs <= rhs,
                         Op::LessOrEqual => lhs <= rhs,
                         Op::GreatorOrEqual => lhs >= rhs,
                         Op::GreatorOrEqual => lhs >= rhs,
+                        _ => unreachable!(),
+                    };
+
+                    self.push(Value::Boolean(result));
+                }
+                op @ (Op::Equal | Op::NotEqual) => {
+                    let rhs = self.pop();
+                    let lhs = self.pop();
+
+                    let result = match op {
                         Op::Equal => lhs == rhs,
                         Op::Equal => lhs == rhs,
                         Op::NotEqual => lhs != rhs,
                         Op::NotEqual => lhs != rhs,
                         _ => unreachable!(),
                         _ => unreachable!(),
                     };
                     };
-
                     self.push(Value::Boolean(result));
                     self.push(Value::Boolean(result));
                 }
                 }
                 Op::Negate => {
                 Op::Negate => {