Forráskód Böngészése

Add parsing of type parameter bounds

David Peter 1 éve
szülő
commit
ccf17e1081

+ 7 - 2
numbat/src/ast.rs

@@ -386,6 +386,11 @@ pub enum ProcedureKind {
     Type,
 }
 
+#[derive(Debug, Clone, PartialEq)]
+pub enum TypeParameterBound {
+    Dim,
+}
+
 #[derive(Debug, Clone, PartialEq)]
 pub enum Statement {
     Expression(Expression),
@@ -399,7 +404,7 @@ pub enum Statement {
     DefineFunction {
         function_name_span: Span,
         function_name: String,
-        type_parameters: Vec<(Span, String)>,
+        type_parameters: Vec<(Span, String, Option<TypeParameterBound>)>,
         /// Parameters, optionally with type annotations.
         parameters: Vec<(Span, String, Option<TypeAnnotation>)>,
         /// Function body. If it is absent, the function is implemented via FFI
@@ -598,7 +603,7 @@ impl ReplaceSpans for Statement {
                 function_name: function_name.clone(),
                 type_parameters: type_parameters
                     .iter()
-                    .map(|(_, name)| (Span::dummy(), name.clone()))
+                    .map(|(_, name, bound)| (Span::dummy(), name.clone(), bound.clone()))
                     .collect(),
                 parameters: parameters
                     .iter()

+ 57 - 3
numbat/src/parser.rs

@@ -64,7 +64,7 @@
 use crate::arithmetic::{Exponent, Rational};
 use crate::ast::{
     BinaryOperator, Expression, ProcedureKind, Statement, StringPart, TypeAnnotation,
-    TypeExpression, UnaryOperator,
+    TypeExpression, TypeParameterBound, UnaryOperator,
 };
 use crate::decorator::{self, Decorator};
 use crate::number::Number;
@@ -220,6 +220,12 @@ pub enum ParseErrorKind {
 
     #[error("Expected ',' or ']' in list expression")]
     ExpectedCommaOrRightBracketInList,
+
+    #[error("Unknown bound '{0}' in type parameter definition")]
+    UnknownBound(String),
+
+    #[error("Expected bound in type parameter definition")]
+    ExpectedBoundInTypeParameterDefinition,
 }
 
 #[derive(Debug, Clone, Error)]
@@ -431,8 +437,36 @@ impl<'a> Parser<'a> {
                 if self.match_exact(TokenKind::LessThan).is_some() {
                     while self.match_exact(TokenKind::GreaterThan).is_none() {
                         if let Some(type_parameter_name) = self.match_exact(TokenKind::Identifier) {
+                            let bound = if self.match_exact(TokenKind::Colon).is_some() {
+                                match self.match_exact(TokenKind::Identifier) {
+                                    Some(token) if token.lexeme == "Dim" => {
+                                        Some(TypeParameterBound::Dim)
+                                    }
+                                    Some(token) => {
+                                        return Err(ParseError {
+                                            kind: ParseErrorKind::UnknownBound(
+                                                token.lexeme.clone(),
+                                            ),
+                                            span: token.span,
+                                        });
+                                    }
+                                    None => {
+                                        return Err(ParseError {
+                                            kind: ParseErrorKind::ExpectedBoundInTypeParameterDefinition,
+                                            span: self.peek().span,
+                                        });
+                                    }
+                                }
+                            } else {
+                                None
+                            };
+
                             let span = self.last().unwrap().span;
-                            type_parameters.push((span, type_parameter_name.lexeme.to_string()));
+                            type_parameters.push((
+                                span,
+                                type_parameter_name.lexeme.to_string(),
+                                bound,
+                            ));
 
                             if self.match_exact(TokenKind::Comma).is_none()
                                 && self.peek().kind != TokenKind::GreaterThan
@@ -2418,7 +2452,27 @@ mod tests {
             Statement::DefineFunction {
                 function_name_span: Span::dummy(),
                 function_name: "foo".into(),
-                type_parameters: vec![(Span::dummy(), "X".into())],
+                type_parameters: vec![(Span::dummy(), "X".into(), None)],
+                parameters: vec![(
+                    Span::dummy(),
+                    "x".into(),
+                    Some(TypeAnnotation::TypeExpression(
+                        TypeExpression::TypeIdentifier(Span::dummy(), "X".into()),
+                    )),
+                )],
+                body: Some(scalar!(1.0)),
+                return_type_annotation_span: None,
+                return_type_annotation: None,
+                decorators: vec![],
+            },
+        );
+
+        parse_as(
+            &["fn foo<X: Dim>(x: X) = 1"],
+            Statement::DefineFunction {
+                function_name_span: Span::dummy(),
+                function_name: "foo".into(),
+                type_parameters: vec![(Span::dummy(), "X".into(), Some(TypeParameterBound::Dim))],
                 parameters: vec![(
                     Span::dummy(),
                     "x".into(),

+ 2 - 1
numbat/src/typechecker/environment.rs

@@ -1,3 +1,4 @@
+use crate::ast::TypeParameterBound;
 use crate::span::Span;
 use crate::type_variable::TypeVariable;
 use crate::Type;
@@ -12,7 +13,7 @@ type Identifier = String; // TODO ?
 #[derive(Clone, Debug)]
 pub struct FunctionSignature {
     pub definition_span: Span,
-    pub type_parameters: Vec<(Span, String)>,
+    pub type_parameters: Vec<(Span, String, Option<TypeParameterBound>)>,
     pub parameters: Vec<(Span, String)>,
     pub fn_type: TypeScheme,
 }

+ 2 - 2
numbat/src/typechecker/mod.rs

@@ -1269,7 +1269,7 @@ impl TypeChecker {
                 let mut typechecker_fn = self.clone(); // TODO: is this even needed?
                 let is_ffi_function = body.is_none();
 
-                for (span, type_parameter) in type_parameters {
+                for (span, type_parameter, _bound) in type_parameters {
                     if typechecker_fn.type_namespace.has_identifier(type_parameter) {
                         return Err(TypeCheckError::TypeParameterNameClash(
                             *span,
@@ -1467,7 +1467,7 @@ impl TypeChecker {
                     decorators.clone(),
                     type_parameters
                         .iter()
-                        .map(|(_, name)| name.clone())
+                        .map(|(_, name, bound)| (name.clone(), bound.clone()))
                         .collect(),
                     typed_parameters
                         .iter()

+ 3 - 3
numbat/src/typed_ast.rs

@@ -2,8 +2,8 @@ use indexmap::IndexMap;
 use itertools::Itertools;
 
 use crate::arithmetic::Exponent;
-use crate::ast::ProcedureKind;
 pub use crate::ast::{BinaryOperator, TypeExpression, UnaryOperator};
+use crate::ast::{ProcedureKind, TypeParameterBound};
 use crate::dimension::DimensionRegistry;
 use crate::traversal::ForAllTypeSchemes;
 use crate::type_variable::TypeVariable;
@@ -556,8 +556,8 @@ pub enum Statement {
     DefineVariable(String, Vec<Decorator>, Expression, TypeScheme),
     DefineFunction(
         String,
-        Vec<Decorator>, // decorators
-        Vec<String>,    // type parameters
+        Vec<Decorator>,                            // decorators
+        Vec<(String, Option<TypeParameterBound>)>, // type parameters
         Vec<(
             // parameters:
             Span,   // span of the parameter