Pārlūkot izejas kodu

Add support for arbitrary types in function parameters

David Peter 2 gadi atpakaļ
vecāks
revīzija
6edef5f311

+ 7 - 11
examples/binomial_coefficient.nbt

@@ -6,17 +6,13 @@
 # TODO: This could really benefit from logical and/or operators
 
 fn binomial_coefficient(n: Scalar, k: Scalar) -> Scalar =
-    if k < 0
-        then 0
-        else if k > n
-            then 0
-            else if k > n - k # Take advantage of symmetry
-                then binomial_coefficient(n, n - k)
-                else if k == 0
-                    then 1
-                    else if n <= 1
-                        then 1
-                        else binomial_coefficient(n - 1, k) + binomial_coefficient(n - 1, k - 1)
+    if or(k < 0, k > n)
+      then 0
+      else if k > n - k # Take advantage of symmetry
+        then binomial_coefficient(n, n - k)
+        else if or(k == 0, n <= 1)
+          then 1
+          else binomial_coefficient(n - 1, k) + binomial_coefficient(n - 1, k - 1)
 
 assert_eq(binomial_coefficient(10, 0), 1)
 assert_eq(binomial_coefficient(10, 1), 10)

+ 2 - 0
numbat/modules/core/booleans.nbt

@@ -0,0 +1,2 @@
+fn and(a: bool, b: bool) = if a then b else false
+fn or(a: bool, b: bool) = if a then true else b

+ 4 - 0
numbat/modules/core/strings.nbt

@@ -0,0 +1,4 @@
+use core::scalar
+
+fn str_length(a: str) -> Scalar
+fn str_concat(a: str, b: str) -> str = "{a}{b}"

+ 2 - 0
numbat/modules/prelude.nbt

@@ -1,6 +1,8 @@
 use core::scalar
 use core::quantities
 use core::dimensions
+use core::booleans
+use core::strings
 
 use math::constants
 use math::functions

+ 1 - 1
numbat/src/ast.rs

@@ -300,7 +300,7 @@ pub enum Statement {
         function_name: String,
         type_parameters: Vec<(Span, String)>,
         /// Parameters, optionally with type annotations. The boolean argument specifies whether or not the parameter is variadic
-        parameters: Vec<(Span, String, Option<DimensionExpression>, bool)>,
+        parameters: Vec<(Span, String, Option<TypeAnnotation>, bool)>,
         /// Function body. If it is absent, the function is implemented via FFI
         body: Option<Expression>,
         return_type_annotation_span: Option<Span>,

+ 16 - 0
numbat/src/ffi.rs

@@ -291,6 +291,15 @@ pub(crate) fn functions() -> &'static HashMap<String, ForeignFunction> {
             );
         }
 
+        m.insert(
+            "str_length".to_string(),
+            ForeignFunction {
+                name: "str_length".into(),
+                arity: 1..=1,
+                callable: Callable::Function(Box::new(str_length)),
+            },
+        );
+
         m
     })
 }
@@ -637,3 +646,10 @@ fn exchange_rate(rate: &'static str) -> BoxedFunction {
         ))
     })
 }
+
+fn str_length(args: &[Value]) -> Value {
+    assert!(args.len() == 1);
+
+    let len = args[0].unsafe_as_string().len();
+    Value::Quantity(Quantity::from_scalar(len as f64))
+}

+ 30 - 23
numbat/src/parser.rs

@@ -398,7 +398,7 @@ impl<'a> Parser<'a> {
                     if let Some(param_name) = self.match_exact(TokenKind::Identifier) {
                         let span = self.last().unwrap().span;
                         let param_type_dexpr = if self.match_exact(TokenKind::Colon).is_some() {
-                            Some(self.dimension_expression()?)
+                            Some(self.type_annotation()?)
                         } else {
                             None
                         };
@@ -1953,41 +1953,44 @@ mod tests {
                     (
                         Span::dummy(),
                         "x".into(),
-                        Some(DimensionExpression::Dimension(
-                            Span::dummy(),
-                            "Length".into(),
+                        Some(TypeAnnotation::DimensionExpression(
+                            DimensionExpression::Dimension(Span::dummy(), "Length".into()),
                         )),
                         false,
                     ),
                     (
                         Span::dummy(),
                         "y".into(),
-                        Some(DimensionExpression::Dimension(Span::dummy(), "Time".into())),
+                        Some(TypeAnnotation::DimensionExpression(
+                            DimensionExpression::Dimension(Span::dummy(), "Time".into()),
+                        )),
                         false,
                     ),
                     (
                         Span::dummy(),
                         "z".into(),
-                        Some(DimensionExpression::Multiply(
-                            Span::dummy(),
-                            Box::new(DimensionExpression::Power(
-                                Some(Span::dummy()),
-                                Box::new(DimensionExpression::Dimension(
+                        Some(TypeAnnotation::DimensionExpression(
+                            DimensionExpression::Multiply(
+                                Span::dummy(),
+                                Box::new(DimensionExpression::Power(
+                                    Some(Span::dummy()),
+                                    Box::new(DimensionExpression::Dimension(
+                                        Span::dummy(),
+                                        "Length".into(),
+                                    )),
                                     Span::dummy(),
-                                    "Length".into(),
+                                    Rational::new(3, 1),
                                 )),
-                                Span::dummy(),
-                                Rational::new(3, 1),
-                            )),
-                            Box::new(DimensionExpression::Power(
-                                Some(Span::dummy()),
-                                Box::new(DimensionExpression::Dimension(
+                                Box::new(DimensionExpression::Power(
+                                    Some(Span::dummy()),
+                                    Box::new(DimensionExpression::Dimension(
+                                        Span::dummy(),
+                                        "Time".into(),
+                                    )),
                                     Span::dummy(),
-                                    "Time".into(),
+                                    Rational::new(2, 1),
                                 )),
-                                Span::dummy(),
-                                Rational::new(2, 1),
-                            )),
+                            ),
                         )),
                         false,
                     ),
@@ -2009,7 +2012,9 @@ mod tests {
                 parameters: vec![(
                     Span::dummy(),
                     "x".into(),
-                    Some(DimensionExpression::Dimension(Span::dummy(), "X".into())),
+                    Some(TypeAnnotation::DimensionExpression(
+                        DimensionExpression::Dimension(Span::dummy(), "X".into()),
+                    )),
                     false,
                 )],
                 body: Some(scalar!(1.0)),
@@ -2027,7 +2032,9 @@ mod tests {
                 parameters: vec![(
                     Span::dummy(),
                     "x".into(),
-                    Some(DimensionExpression::Dimension(Span::dummy(), "D".into())),
+                    Some(TypeAnnotation::DimensionExpression(
+                        DimensionExpression::Dimension(Span::dummy(), "D".into()),
+                    )),
                     true,
                 )],
                 body: None,

+ 91 - 94
numbat/src/typechecker.rs

@@ -255,13 +255,6 @@ fn to_rational_exponent(exponent_f64: f64) -> Option<Exponent> {
     Rational::from_f64(exponent_f64)
 }
 
-fn expect_dtype(span: &Span, t: Type) -> Result<DType> {
-    match t {
-        Type::Dimension(dtype) => Ok(dtype),
-        _ => Err(TypeCheckError::ExpectedDimensionType(span.clone(), t)),
-    }
-}
-
 fn dtype(e: &Expression) -> Result<DType> {
     match e.get_type() {
         Type::Dimension(dtype) => Ok(dtype),
@@ -579,8 +572,8 @@ impl TypeChecker {
                     .collect::<Result<Vec<_>>>()?;
                 let argument_types = arguments_checked
                     .iter()
-                    .map(|e| dtype(e))
-                    .collect::<Result<Vec<DType>>>()?;
+                    .map(|e| e.get_type())
+                    .collect::<Vec<Type>>();
 
                 let mut substitutions: Vec<(String, DType)> = vec![];
 
@@ -599,12 +592,8 @@ impl TypeChecker {
                     result_type
                 };
 
-                let mut parameter_types = parameter_types
-                    .into_iter()
-                    .map(|(span, t)| {
-                        expect_dtype(span, t.clone()).map(|dtype| (span.clone(), dtype))
-                    })
-                    .collect::<Result<Vec<(Span, DType)>>>()?;
+                let mut parameter_types = parameter_types.clone();
+
                 if *is_variadic {
                     // For a variadic function, we simply duplicate the parameter type
                     // N times, where N is the number of arguments given.
@@ -618,70 +607,82 @@ impl TypeChecker {
                 for (idx, ((parameter_span, parameter_type), argument_type)) in
                     parameter_types.iter().zip(argument_types).enumerate()
                 {
-                    let mut parameter_type = substitute(&substitutions, parameter_type);
-
-                    let remaining_generic_subtypes: Vec<_> = parameter_type
-                        .iter()
-                        .filter(|BaseRepresentationFactor(name, _)| {
-                            type_parameters.iter().any(|(_, n)| name == n)
-                        })
-                        .collect();
-
-                    if remaining_generic_subtypes.len() > 1 {
-                        return Err(TypeCheckError::MultipleUnresolvedTypeParameters(
-                            *span,
-                            *parameter_span,
-                        ));
-                    }
+                    match (parameter_type, argument_type) {
+                        (Type::Dimension(parameter_type), Type::Dimension(argument_type)) => {
+                            let mut parameter_type = substitute(&substitutions, parameter_type);
+
+                            let remaining_generic_subtypes: Vec<_> = parameter_type
+                                .iter()
+                                .filter(|BaseRepresentationFactor(name, _)| {
+                                    type_parameters.iter().any(|(_, n)| name == n)
+                                })
+                                .collect();
+
+                            if remaining_generic_subtypes.len() > 1 {
+                                return Err(TypeCheckError::MultipleUnresolvedTypeParameters(
+                                    *span,
+                                    *parameter_span,
+                                ));
+                            }
 
-                    if let Some(&generic_subtype_factor) = remaining_generic_subtypes.first() {
-                        let generic_subtype = DType::from_factor(generic_subtype_factor.clone());
-
-                        // The type of the idx-th parameter of the called function has a generic type
-                        // parameter inside. We can now instantiate that generic parameter by solving
-                        // the equation "parameter_type == argument_type" for the generic parameter.
-                        // In order to do this, let's assume `generic_subtype = D^alpha`, then we have
-                        //
-                        //                                parameter_type == argument_type
-                        //    parameter_type / generic_subtype * D^alpha == argument_type
-                        //                                       D^alpha == argument_type / (parameter_type / generic_subtype)
-                        //                                             D == [argument_type / (parameter_type / generic_subtype)]^(1/alpha)
-                        //
-
-                        let alpha = Rational::from_integer(1) / generic_subtype_factor.1;
-                        let d = (argument_type.clone()
-                            / (parameter_type.clone() / generic_subtype))
-                            .power(alpha);
-
-                        // We can now substitute that generic parameter in all subsequent expressions
-                        substitutions.push((generic_subtype_factor.0.clone(), d));
-
-                        parameter_type = substitute(&substitutions, &parameter_type);
-                    }
+                            if let Some(&generic_subtype_factor) =
+                                remaining_generic_subtypes.first()
+                            {
+                                let generic_subtype =
+                                    DType::from_factor(generic_subtype_factor.clone());
+
+                                // The type of the idx-th parameter of the called function has a generic type
+                                // parameter inside. We can now instantiate that generic parameter by solving
+                                // the equation "parameter_type == argument_type" for the generic parameter.
+                                // In order to do this, let's assume `generic_subtype = D^alpha`, then we have
+                                //
+                                //                                parameter_type == argument_type
+                                //    parameter_type / generic_subtype * D^alpha == argument_type
+                                //                                       D^alpha == argument_type / (parameter_type / generic_subtype)
+                                //                                             D == [argument_type / (parameter_type / generic_subtype)]^(1/alpha)
+                                //
+
+                                let alpha = Rational::from_integer(1) / generic_subtype_factor.1;
+                                let d = (argument_type.clone()
+                                    / (parameter_type.clone() / generic_subtype))
+                                    .power(alpha);
+
+                                // We can now substitute that generic parameter in all subsequent expressions
+                                substitutions.push((generic_subtype_factor.0.clone(), d));
+
+                                parameter_type = substitute(&substitutions, &parameter_type);
+                            }
 
-                    if parameter_type != argument_type {
-                        return Err(TypeCheckError::IncompatibleDimensions(
-                            IncompatibleDimensionsError {
-                                span_operation: *span,
-                                operation: format!(
-                                    "argument {num} of function call to '{name}'",
-                                    num = idx + 1,
-                                    name = function_name
-                                ),
-                                span_expected: parameter_types[idx].0,
-                                expected_name: "parameter type",
-                                expected_dimensions: self
-                                    .registry
-                                    .get_derived_entry_names_for(&parameter_type),
-                                expected_type: parameter_type,
-                                span_actual: args[idx].full_span(),
-                                actual_name: " argument type",
-                                actual_dimensions: self
-                                    .registry
-                                    .get_derived_entry_names_for(&argument_type),
-                                actual_type: argument_type,
-                            },
-                        ));
+                            if parameter_type != argument_type {
+                                return Err(TypeCheckError::IncompatibleDimensions(
+                                    IncompatibleDimensionsError {
+                                        span_operation: *span,
+                                        operation: format!(
+                                            "argument {num} of function call to '{name}'",
+                                            num = idx + 1,
+                                            name = function_name
+                                        ),
+                                        span_expected: parameter_types[idx].0,
+                                        expected_name: "parameter type",
+                                        expected_dimensions: self
+                                            .registry
+                                            .get_derived_entry_names_for(&parameter_type),
+                                        expected_type: parameter_type,
+                                        span_actual: args[idx].full_span(),
+                                        actual_name: " argument type",
+                                        actual_dimensions: self
+                                            .registry
+                                            .get_derived_entry_names_for(&argument_type),
+                                        actual_type: argument_type,
+                                    },
+                                ));
+                            }
+                        }
+                        (parameter_type, argument_type) => {
+                            if parameter_type != &argument_type {
+                                todo!()
+                            }
+                        }
                     }
                 }
 
@@ -960,11 +961,8 @@ impl TypeChecker {
                 let mut is_variadic = false;
                 let mut free_type_parameters = vec![];
                 for (parameter_span, parameter, type_annotation, p_is_variadic) in parameters {
-                    let parameter_type = if let Some(type_) = type_annotation {
-                        typechecker_fn
-                            .registry
-                            .get_base_representation(type_)
-                            .map_err(TypeCheckError::RegistryError)?
+                    let parameter_type = if let Some(type_annotation) = type_annotation {
+                        typechecker_fn.type_from_annotation(type_annotation)?
                     } else if is_ffi_function {
                         return Err(TypeCheckError::ForeignFunctionNeedsTypeAnnotations(
                             *function_name_span,
@@ -986,21 +984,20 @@ impl TypeChecker {
                             .add_base_dimension(&free_type_parameter)
                             .expect("we selected a name that is free");
                         type_parameters.push((*parameter_span, free_type_parameter.clone()));
-                        typechecker_fn
-                            .registry
-                            .get_base_representation(&DimensionExpression::Dimension(
-                                *parameter_span,
-                                free_type_parameter,
-                            ))
-                            .map_err(TypeCheckError::RegistryError)?
+                        Type::Dimension(
+                            typechecker_fn
+                                .registry
+                                .get_base_representation(&DimensionExpression::Dimension(
+                                    *parameter_span,
+                                    free_type_parameter,
+                                ))
+                                .map_err(TypeCheckError::RegistryError)?,
+                        )
                     };
 
                     typechecker_fn.identifiers.insert(
                         parameter.clone(),
-                        (
-                            Type::Dimension(parameter_type.clone()),
-                            IdentifierKind::Variable,
-                        ),
+                        (parameter_type.clone(), IdentifierKind::Variable),
                     );
                     typed_parameters.push((
                         *parameter_span,
@@ -1010,7 +1007,7 @@ impl TypeChecker {
                             .as_ref()
                             .map(|d| d.pretty_print())
                             .unwrap_or_else(|| parameter_type.to_readable_type(&self.registry)),
-                        Type::Dimension(parameter_type),
+                        parameter_type,
                     ));
 
                     is_variadic |= p_is_variadic;

+ 8 - 0
numbat/src/value.rs

@@ -23,6 +23,14 @@ impl Value {
             panic!("Expected value to be a bool");
         }
     }
+
+    pub fn unsafe_as_string(&self) -> &str {
+        if let Value::String(s) = self {
+            s
+        } else {
+            panic!("Expected value to be a bool");
+        }
+    }
 }
 
 impl std::fmt::Display for Value {