Browse Source

Added consuming `DType::into_factors`, removing need for mutable access to factors
Changed `from_factors` to take an owned `Vec`

Robert Bennett 1 year ago
parent
commit
3a25ac736e
3 changed files with 22 additions and 30 deletions
  1. 1 4
      numbat/src/typechecker/constraints.rs
  2. 6 5
      numbat/src/typechecker/mod.rs
  3. 15 21
      numbat/src/typed_ast.rs

+ 1 - 4
numbat/src/typechecker/constraints.rs

@@ -297,10 +297,7 @@ impl Constraint {
             Constraint::EqualScalar(dtype) => match dtype.split_first_factor() {
                 Some(((DTypeFactor::TVar(tv), k), rest)) => {
                     let result = DType::from_factors(
-                        &rest
-                            .iter()
-                            .map(|(f, j)| (f.clone(), -j / k))
-                            .collect::<Vec<_>>(),
+                        rest.iter().map(|(f, j)| (f.clone(), -j / k)).collect(),
                     );
                     Some(Satisfied::with_substitution(Substitution::single(
                         tv.clone(),

+ 6 - 5
numbat/src/typechecker/mod.rs

@@ -118,14 +118,15 @@ impl TypeChecker {
                     }
                 }
 
-                let mut dtype: DType = self
+                let mut factors = self
                     .registry
                     .get_base_representation(dexpr)
-                    .map(|br| br.into())
-                    .map_err(TypeCheckError::RegistryError)?;
+                    .map(DType::from)
+                    .map_err(TypeCheckError::RegistryError)?
+                    .into_factors();
 
                 // Replace BaseDimension("D") with TVar("D") for all type parameters
-                for (factor, _) in dtype.factors_mut() {
+                for (factor, _) in &mut factors {
                     *factor = match factor {
                         DTypeFactor::BaseDimension(ref n)
                             if self
@@ -140,7 +141,7 @@ impl TypeChecker {
                     }
                 }
 
-                Ok(Type::Dimension(dtype))
+                Ok(Type::Dimension(DType::from_factors(factors)))
             }
             TypeAnnotation::Bool(_) => Ok(Type::Boolean),
             TypeAnnotation::String(_) => Ok(Type::String),

+ 15 - 21
numbat/src/typed_ast.rs

@@ -45,28 +45,22 @@ pub struct DType {
 }
 
 impl DType {
-    pub fn new(factors: Vec<DtypeFactorPower>) -> Self {
-        Self { factors }
-    }
-
     pub fn factors(&self) -> &[DtypeFactorPower] {
         &self.factors
     }
 
-    pub fn factors_mut(&mut self) -> &mut [DtypeFactorPower] {
-        &mut self.factors
+    pub fn into_factors(self) -> Vec<DtypeFactorPower> {
+        self.factors
     }
 
-    pub fn from_factors(factors: &[DtypeFactorPower]) -> DType {
-        let mut dtype = DType {
-            factors: factors.into(),
-        };
+    pub fn from_factors(factors: Vec<DtypeFactorPower>) -> DType {
+        let mut dtype = DType { factors };
         dtype.canonicalize();
         dtype
     }
 
     pub fn scalar() -> DType {
-        DType::from_factors(&[])
+        DType::from_factors(vec![])
     }
 
     pub fn is_scalar(&self) -> bool {
@@ -105,11 +99,11 @@ impl DType {
     }
 
     pub fn from_type_variable(v: TypeVariable) -> DType {
-        DType::from_factors(&[(DTypeFactor::TVar(v), Exponent::from_integer(1))])
+        DType::from_factors(vec![(DTypeFactor::TVar(v), Exponent::from_integer(1))])
     }
 
     pub fn from_type_parameter(name: String) -> DType {
-        DType::from_factors(&[(DTypeFactor::TPar(name), Exponent::from_integer(1))])
+        DType::from_factors(vec![(DTypeFactor::TPar(name), Exponent::from_integer(1))])
     }
 
     pub fn deconstruct_as_single_type_variable(&self) -> Option<TypeVariable> {
@@ -122,14 +116,14 @@ impl DType {
     }
 
     pub fn from_tgen(i: usize) -> DType {
-        DType::from_factors(&[(
+        DType::from_factors(vec![(
             DTypeFactor::TVar(TypeVariable::Quantified(i)),
             Exponent::from_integer(1),
         )])
     }
 
     pub fn base_dimension(name: &str) -> DType {
-        DType::from_factors(&[(
+        DType::from_factors(vec![(
             DTypeFactor::BaseDimension(name.into()),
             Exponent::from_integer(1),
         )])
@@ -170,16 +164,16 @@ impl DType {
     pub fn multiply(&self, other: &DType) -> DType {
         let mut factors = self.factors.clone();
         factors.extend(other.factors.clone());
-        DType::from_factors(&factors)
+        DType::from_factors(factors)
     }
 
     pub fn power(&self, n: Exponent) -> DType {
-        let factors: Vec<_> = self
+        let factors = self
             .factors
             .iter()
             .map(|(f, m)| (f.clone(), n * m))
             .collect();
-        DType::from_factors(&factors)
+        DType::from_factors(factors)
     }
 
     pub fn inverse(&self) -> DType {
@@ -233,7 +227,7 @@ impl DType {
                 }
             }
         }
-        Self::from_factors(&factors)
+        Self::from_factors(factors)
     }
 
     pub fn to_base_representation(&self) -> BaseRepresentation {
@@ -272,11 +266,11 @@ impl std::fmt::Display for DType {
 
 impl From<BaseRepresentation> for DType {
     fn from(base_representation: BaseRepresentation) -> Self {
-        let factors: Vec<_> = base_representation
+        let factors = base_representation
             .into_iter()
             .map(|BaseRepresentationFactor(name, exp)| (DTypeFactor::BaseDimension(name), exp))
             .collect();
-        DType::from_factors(&factors)
+        DType::from_factors(factors)
     }
 }