Jelajahi Sumber

Support for recursive functions

Turing completeness!

closes #32
David Peter 2 tahun lalu
induk
melakukan
b75bdb9c83
2 mengubah file dengan 49 tambahan dan 14 penghapusan
  1. 10 0
      examples/factorial.nbt
  2. 39 14
      numbat/src/typechecker.rs

+ 10 - 0
examples/factorial.nbt

@@ -0,0 +1,10 @@
+# Naive factorial implementation to showcase recursive
+# functions and conditionals.
+
+fn factorial(n: Scalar) -> Scalar =
+  if n < 1
+    then 1
+    else n × factorial(n - 1)
+
+# Compare result with the builtin factorial operator
+assert_eq(factorial(10), 10!)

+ 39 - 14
numbat/src/typechecker.rs

@@ -940,6 +940,31 @@ impl TypeChecker {
                     })
                     .transpose()?;
 
+                let add_function_signature = |tc: &mut TypeChecker, return_type: DType| {
+                    let parameter_types = typed_parameters
+                        .iter()
+                        .map(|(span, _, _, t)| (*span, t.clone()))
+                        .collect();
+                    tc.function_signatures.insert(
+                        function_name.clone(),
+                        (
+                            *function_name_span,
+                            type_parameters.clone(),
+                            parameter_types,
+                            is_variadic,
+                            return_type,
+                        ),
+                    );
+                };
+
+                if let Some(ref return_type_specified) = return_type_specified {
+                    // This is needed for recursive functions. If the return type
+                    // has been specified, we can already provide a function
+                    // signature before we check the body of the function. This
+                    // way, the 'typechecker_fn' can resolve the recursive call.
+                    add_function_signature(&mut typechecker_fn, return_type_specified.clone());
+                }
+
                 let body_checked = body
                     .clone()
                     .map(|expr| typechecker_fn.check_expression(&expr))
@@ -986,20 +1011,7 @@ impl TypeChecker {
                     })?
                 };
 
-                let parameter_types = typed_parameters
-                    .iter()
-                    .map(|(span, _, _, t)| (*span, t.clone()))
-                    .collect();
-                self.function_signatures.insert(
-                    function_name.clone(),
-                    (
-                        *function_name_span,
-                        type_parameters.clone(),
-                        parameter_types,
-                        is_variadic,
-                        return_type.clone(),
-                    ),
-                );
+                add_function_signature(self, return_type.clone());
 
                 typed_ast::Statement::DefineFunction(
                     function_name.clone(),
@@ -1290,6 +1302,19 @@ mod tests {
         ));
     }
 
+    #[test]
+    fn recursive_functions() {
+        assert_successful_typecheck("fn f(x: Scalar) -> Scalar = if x < 0 then f(-x) else x");
+        assert_successful_typecheck(
+            "fn factorial(n: Scalar) -> Scalar = if n < 0 then 1 else factorial(n - 1) * n",
+        );
+
+        assert!(matches!(
+            get_typecheck_error("fn f(x: Scalar) -> A = if x < 0 then f(-x) else 2 b"),
+            TypeCheckError::IncompatibleTypesInCondition(_, lhs, _, rhs, _) if lhs == Type::Dimension(type_a()) && rhs == Type::Dimension(type_b())
+        ));
+    }
+
     #[test]
     fn generics_basic() {
         assert_successful_typecheck(