|
@@ -940,6 +940,31 @@ impl TypeChecker {
|
|
})
|
|
})
|
|
.transpose()?;
|
|
.transpose()?;
|
|
|
|
|
|
|
|
+ let add_function_signature = |tc: &mut TypeChecker, return_type: DType| {
|
|
|
|
+ let parameter_types = typed_parameters
|
|
|
|
+ .iter()
|
|
|
|
+ .map(|(span, _, _, t)| (*span, t.clone()))
|
|
|
|
+ .collect();
|
|
|
|
+ tc.function_signatures.insert(
|
|
|
|
+ function_name.clone(),
|
|
|
|
+ (
|
|
|
|
+ *function_name_span,
|
|
|
|
+ type_parameters.clone(),
|
|
|
|
+ parameter_types,
|
|
|
|
+ is_variadic,
|
|
|
|
+ return_type,
|
|
|
|
+ ),
|
|
|
|
+ );
|
|
|
|
+ };
|
|
|
|
+
|
|
|
|
+ if let Some(ref return_type_specified) = return_type_specified {
|
|
|
|
+ // This is needed for recursive functions. If the return type
|
|
|
|
+ // has been specified, we can already provide a function
|
|
|
|
+ // signature before we check the body of the function. This
|
|
|
|
+ // way, the 'typechecker_fn' can resolve the recursive call.
|
|
|
|
+ add_function_signature(&mut typechecker_fn, return_type_specified.clone());
|
|
|
|
+ }
|
|
|
|
+
|
|
let body_checked = body
|
|
let body_checked = body
|
|
.clone()
|
|
.clone()
|
|
.map(|expr| typechecker_fn.check_expression(&expr))
|
|
.map(|expr| typechecker_fn.check_expression(&expr))
|
|
@@ -986,20 +1011,7 @@ impl TypeChecker {
|
|
})?
|
|
})?
|
|
};
|
|
};
|
|
|
|
|
|
- let parameter_types = typed_parameters
|
|
|
|
- .iter()
|
|
|
|
- .map(|(span, _, _, t)| (*span, t.clone()))
|
|
|
|
- .collect();
|
|
|
|
- self.function_signatures.insert(
|
|
|
|
- function_name.clone(),
|
|
|
|
- (
|
|
|
|
- *function_name_span,
|
|
|
|
- type_parameters.clone(),
|
|
|
|
- parameter_types,
|
|
|
|
- is_variadic,
|
|
|
|
- return_type.clone(),
|
|
|
|
- ),
|
|
|
|
- );
|
|
|
|
|
|
+ add_function_signature(self, return_type.clone());
|
|
|
|
|
|
typed_ast::Statement::DefineFunction(
|
|
typed_ast::Statement::DefineFunction(
|
|
function_name.clone(),
|
|
function_name.clone(),
|
|
@@ -1290,6 +1302,19 @@ mod tests {
|
|
));
|
|
));
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ #[test]
|
|
|
|
+ fn recursive_functions() {
|
|
|
|
+ assert_successful_typecheck("fn f(x: Scalar) -> Scalar = if x < 0 then f(-x) else x");
|
|
|
|
+ assert_successful_typecheck(
|
|
|
|
+ "fn factorial(n: Scalar) -> Scalar = if n < 0 then 1 else factorial(n - 1) * n",
|
|
|
|
+ );
|
|
|
|
+
|
|
|
|
+ assert!(matches!(
|
|
|
|
+ get_typecheck_error("fn f(x: Scalar) -> A = if x < 0 then f(-x) else 2 b"),
|
|
|
|
+ TypeCheckError::IncompatibleTypesInCondition(_, lhs, _, rhs, _) if lhs == Type::Dimension(type_a()) && rhs == Type::Dimension(type_b())
|
|
|
|
+ ));
|
|
|
|
+ }
|
|
|
|
+
|
|
#[test]
|
|
#[test]
|
|
fn generics_basic() {
|
|
fn generics_basic() {
|
|
assert_successful_typecheck(
|
|
assert_successful_typecheck(
|