소스 검색

Allow non-annotated parameter types, see #29

David Peter 2 년 전
부모
커밋
53aa753949

+ 1 - 0
examples/typecheck_error/foreign_function_without_parameter_type.nbt

@@ -0,0 +1 @@
+fn f(x: Scalar, y) -> Scalar

+ 1 - 0
examples/typecheck_error/foreign_function_without_return_type.nbt

@@ -0,0 +1 @@
+fn f(x: Scalar)

+ 0 - 1
examples/typecheck_error/parameter_types_can_not_be_deduced.nbt

@@ -1 +0,0 @@
-fn f(x, y) -> Scalar = x * y

+ 3 - 7
numbat/src/diagnostic.rs

@@ -200,17 +200,13 @@ impl ErrorDiagnostic for TypeCheckError {
                         .diagnostic_label(LabelStyle::Primary)
                         .with_message(inner_error),
                 ]),
-            TypeCheckError::ForeignFunctionNeedsReturnTypeAnnotation(span, _) => {
-                d.with_labels(vec![span
+            TypeCheckError::ForeignFunctionNeedsTypeAnnotations(span, _) => d
+                .with_labels(vec![span
                     .diagnostic_label(LabelStyle::Primary)
-                    .with_message(inner_error)])
-            }
+                    .with_message(inner_error)]),
             TypeCheckError::UnknownForeignFunction(span, _) => d.with_labels(vec![span
                 .diagnostic_label(LabelStyle::Primary)
                 .with_message(inner_error)]),
-            TypeCheckError::ParameterTypesCanNotBeDeduced(span) => d.with_labels(vec![span
-                .diagnostic_label(LabelStyle::Primary)
-                .with_message(inner_error)]),
         }
     }
 }

+ 9 - 0
numbat/src/parser.rs

@@ -389,6 +389,12 @@ impl<'a> Parser<'a> {
             }
         } else if self.match_exact(TokenKind::Dimension).is_some() {
             if let Some(identifier) = self.match_exact(TokenKind::Identifier) {
+                if identifier.lexeme.starts_with("__") {
+                    todo!(
+                        "Parse error: double-underscore type names are reserved for internal use"
+                    );
+                }
+
                 if self.match_exact(TokenKind::Equal).is_some() {
                     self.skip_empty_lines();
                     let mut dexprs = vec![self.dimension_expression()?];
@@ -949,6 +955,9 @@ impl<'a> Parser<'a> {
             self.peek().span,
         ));
         if let Some(token) = self.match_exact(TokenKind::Identifier) {
+            if token.lexeme.starts_with("__") {
+                todo!("Parse error: double-underscore type names are reserved for internal use");
+            }
             let span = self.last().unwrap().span;
             Ok(DimensionExpression::Dimension(span, token.lexeme.clone()))
         } else if let Some(number) = self.match_exact(TokenKind::Number) {

+ 46 - 20
numbat/src/typechecker.rs

@@ -9,6 +9,7 @@ use crate::span::Span;
 use crate::typed_ast::{self, Type};
 use crate::{ast, decorator, ffi};
 
+use ast::DimensionExpression;
 use num_traits::{FromPrimitive, Zero};
 use thiserror::Error;
 
@@ -65,11 +66,8 @@ pub enum TypeCheckError {
     #[error("Multiple unresolved generic parameters in a single function parameter type are not (yet) supported. Consider reordering the function parameters")]
     MultipleUnresolvedTypeParameters(Span, Span),
 
-    #[error("Parameter types can not (yet) be deduced, they have to be specified manually: f(x: Length, y: Time) -> …")]
-    ParameterTypesCanNotBeDeduced(Span),
-
-    #[error("Foreign function definition (without body) '{1}' needs a return type annotation.")]
-    ForeignFunctionNeedsReturnTypeAnnotation(Span, String),
+    #[error("Foreign function definition (without body) '{1}' needs parameter and return type annotations.")]
+    ForeignFunctionNeedsTypeAnnotations(Span, String),
 
     #[error("Unknown foreign function (without body) '{1}'")]
     UnknownForeignFunction(Span, String),
@@ -534,9 +532,11 @@ impl TypeChecker {
                 return_type_annotation,
             } => {
                 let mut typechecker_fn = self.clone();
+                let is_ffi_function = body.is_none();
+                let mut type_parameters = type_parameters.clone();
 
-                for (span, type_parameter) in type_parameters {
-                    match typechecker_fn.registry.add_base_dimension(type_parameter) {
+                for (span, type_parameter) in &type_parameters {
+                    match typechecker_fn.registry.add_base_dimension(&type_parameter) {
                         Err(RegistryError::EntryExists(name)) => {
                             return Err(TypeCheckError::TypeParameterNameClash(*span, name))
                         }
@@ -547,15 +547,37 @@ impl TypeChecker {
 
                 let mut typed_parameters = vec![];
                 let mut is_variadic = false;
-                for (parameter_span, parameter, optional_dexpr, p_is_variadic) in parameters {
-                    let parameter_type = typechecker_fn
-                        .registry
-                        .get_base_representation(&optional_dexpr.clone().ok_or(
-                            TypeCheckError::ParameterTypesCanNotBeDeduced(*parameter_span),
-                        )?)
-                        // TODO: add type inference, see https://github.com/sharkdp/numbat/issues/29
-                        // TODO: once we add type inference, make sure that annotations are required for foreign functions
-                        .map_err(TypeCheckError::RegistryError)?;
+                let mut free_type_parameters = vec![];
+                for (parameter_span, parameter, type_annotation, p_is_variadic) in parameters {
+                    let parameter_type = if let Some(type_) = type_annotation {
+                        typechecker_fn
+                            .registry
+                            .get_base_representation(&type_)
+                            .map_err(TypeCheckError::RegistryError)?
+                    } else if is_ffi_function {
+                        return Err(TypeCheckError::ForeignFunctionNeedsTypeAnnotations(
+                            *function_name_span,
+                            function_name.clone(),
+                        ));
+                    } else {
+                        let free_type_parameter =
+                            format!("__T{num}", num = free_type_parameters.len());
+                        free_type_parameters.push((parameter.clone(), free_type_parameter.clone()));
+
+                        typechecker_fn
+                            .registry
+                            .add_base_dimension(&free_type_parameter)
+                            .expect("double-underscore identifiers are only used internally");
+                        type_parameters.push((parameter_span.clone(), free_type_parameter.clone()));
+                        typechecker_fn
+                            .registry
+                            .get_base_representation(&DimensionExpression::Dimension(
+                                parameter_span.clone(),
+                                free_type_parameter,
+                            ))
+                            .map_err(TypeCheckError::RegistryError)?
+                    };
+
                     typechecker_fn
                         .identifiers
                         .insert(parameter.clone(), parameter_type.clone());
@@ -569,12 +591,16 @@ impl TypeChecker {
                     is_variadic |= p_is_variadic;
                 }
 
+                if free_type_parameters.len() > 0 {
+                    // TODO: Perform type inference
+                }
+
                 let return_type_specified = return_type_annotation
                     .clone()
-                    .map(|ref return_type_dexpr| {
+                    .map(|ref annotation| {
                         typechecker_fn
                             .registry
-                            .get_base_representation(return_type_dexpr)
+                            .get_base_representation(annotation)
                             .map_err(TypeCheckError::RegistryError)
                     })
                     .transpose()?;
@@ -610,7 +636,7 @@ impl TypeChecker {
                     }
 
                     return_type_specified.ok_or_else(|| {
-                        TypeCheckError::ForeignFunctionNeedsReturnTypeAnnotation(
+                        TypeCheckError::ForeignFunctionNeedsTypeAnnotations(
                             *function_name_span,
                             function_name.clone(),
                         )
@@ -1041,7 +1067,7 @@ mod tests {
     fn foreign_function_with_missing_return_type() {
         assert!(matches!(
             get_typecheck_error("fn sin(x: Scalar)"),
-            TypeCheckError::ForeignFunctionNeedsReturnTypeAnnotation(_, name) if name == "sin"
+            TypeCheckError::ForeignFunctionNeedsTypeAnnotations(_, name) if name == "sin"
         ));
     }