Browse Source

Add better code-source handling

David Peter 2 years ago
parent
commit
ae2e50d6ff

+ 52 - 30
numbat-cli/src/main.rs

@@ -8,7 +8,7 @@ use highlighter::NumbatHighlighter;
 
 use numbat::markup;
 use numbat::pretty_print::PrettyPrint;
-use numbat::resolver::FileSystemImporter;
+use numbat::resolver::{CodeSource, FileSystemImporter};
 use numbat::{Context, ExitStatus, InterpreterResult, NumbatError, ParseError};
 
 use anyhow::{bail, Context as AnyhowContext, Result};
@@ -86,7 +86,6 @@ struct NumbatHelper {
 struct Cli {
     args: Args,
     context: Arc<Mutex<Context>>,
-    current_filename: Option<PathBuf>,
 }
 
 impl Cli {
@@ -102,7 +101,6 @@ impl Cli {
         Self {
             context: Arc::new(Mutex::new(context)),
             args,
-            current_filename: None,
         }
     }
 
@@ -114,12 +112,16 @@ impl Cli {
             let modules_path = Self::get_modules_path();
             let prelude_path = modules_path.join("prelude.nbt");
 
-            self.current_filename = Some(prelude_path.clone());
             let prelude_code = fs::read_to_string(&prelude_path).context(format!(
                 "Error while reading prelude from '{}'",
                 prelude_path.to_string_lossy()
             ))?;
-            let result = self.parse_and_evaluate(&prelude_code, ExecutionMode::Normal, false);
+            let result = self.parse_and_evaluate(
+                &prelude_code,
+                CodeSource::File(prelude_path),
+                ExecutionMode::Normal,
+                false,
+            );
             if result.is_break() {
                 bail!("Interpreter error in Prelude code")
             }
@@ -128,29 +130,39 @@ impl Cli {
         if load_init {
             let user_init_path = Self::get_config_path().join("init.nbt");
 
-            self.current_filename = Some(user_init_path.clone());
             if let Ok(user_init_code) = fs::read_to_string(&user_init_path) {
-                let result = self.parse_and_evaluate(&user_init_code, ExecutionMode::Normal, false);
+                let result = self.parse_and_evaluate(
+                    &user_init_code,
+                    CodeSource::File(user_init_path),
+                    ExecutionMode::Normal,
+                    false,
+                );
                 if result.is_break() {
                     bail!("Interpreter error in user initialization code")
                 }
             }
         }
 
-        let code: Option<String> = if let Some(ref path) = self.args.file {
-            self.current_filename = Some(path.clone());
-            Some(fs::read_to_string(path).context(format!(
-                "Could not load source file '{}'",
-                path.to_string_lossy()
-            ))?)
-        } else {
-            self.current_filename = None;
-            self.args.expression.clone()
-        };
+        let (code, code_source): (Option<String>, CodeSource) =
+            if let Some(ref path) = self.args.file {
+                (
+                    Some(fs::read_to_string(path).context(format!(
+                        "Could not load source file '{}'",
+                        path.to_string_lossy()
+                    ))?),
+                    CodeSource::File(path.clone()),
+                )
+            } else {
+                (self.args.expression.clone(), CodeSource::Text)
+            };
 
         if let Some(code) = code {
-            let result =
-                self.parse_and_evaluate(&code, ExecutionMode::Normal, self.args.pretty_print);
+            let result = self.parse_and_evaluate(
+                &code,
+                code_source,
+                ExecutionMode::Normal,
+                self.args.pretty_print,
+            );
 
             match result {
                 std::ops::ControlFlow::Continue(()) => Ok(()),
@@ -250,6 +262,7 @@ impl Cli {
                             _ => {
                                 let result = self.parse_and_evaluate(
                                     &line,
+                                    CodeSource::Text,
                                     ExecutionMode::Interactive,
                                     self.args.pretty_print,
                                 );
@@ -282,10 +295,11 @@ impl Cli {
     fn parse_and_evaluate(
         &mut self,
         input: &str,
+        code_source: CodeSource,
         execution_mode: ExecutionMode,
         pretty_print: bool,
     ) -> ControlFlow {
-        let result = { self.context.lock().unwrap().interpret(input) };
+        let result = { self.context.lock().unwrap().interpret(input, code_source) };
 
         match result {
             Ok((statements, interpreter_result)) => {
@@ -315,17 +329,25 @@ impl Cli {
                     InterpreterResult::Exit(exit_status) => ControlFlow::Break(exit_status),
                 }
             }
-            Err(NumbatError::ParseError(ref e @ ParseError { ref span, .. })) => {
-                let line = input.lines().nth(span.line - 1).unwrap();
-
-                let filename = self
-                    .current_filename
-                    .as_deref()
-                    .map(|p| p.to_string_lossy())
-                    .unwrap_or_else(|| "<input>".into());
+            Err(NumbatError::ParseError {
+                inner: ref e @ ParseError { ref span, .. },
+                code_source,
+            }) => {
+                let line = input.lines().nth(span.line - 1).unwrap(); // TODO
+
+                let code_source_text = match code_source {
+                    CodeSource::Text => "<input>".to_string(),
+                    CodeSource::File(path) => format!("File {}", path.to_string_lossy()),
+                    CodeSource::Module(module_path, path) => format!(
+                        "Module '{module_path}', File {path}",
+                        module_path = itertools::join(module_path.0.iter(), "::"),
+                        path = path
+                            .map(|p| p.to_string_lossy().to_string())
+                            .unwrap_or("?".into()),
+                    ),
+                };
                 eprintln!(
-                    "File {filename}:{line_number}:{position}",
-                    filename = filename,
+                    "{code_source_text}:{line_number}:{position}",
                     line_number = span.line,
                     position = span.position
                 );

+ 15 - 5
numbat/src/lib.rs

@@ -29,6 +29,7 @@ use bytecode_interpreter::BytecodeInterpreter;
 use interpreter::{Interpreter, RuntimeError};
 use name_resolution::NameResolutionError;
 use prefix_transformer::Transformer;
+use resolver::CodeSource;
 use resolver::ModuleImporter;
 use resolver::NullImporter;
 use resolver::Resolver;
@@ -43,8 +44,11 @@ pub use parser::ParseError;
 
 #[derive(Debug, Error)]
 pub enum NumbatError {
-    #[error("{0}")]
-    ParseError(ParseError),
+    #[error("{inner}")]
+    ParseError {
+        inner: ParseError,
+        code_source: CodeSource,
+    },
     #[error("{0}")]
     ResolverError(ResolverError),
     #[error("{0}")]
@@ -98,11 +102,17 @@ impl Context {
         &self.prefix_transformer.dimension_names
     }
 
-    pub fn interpret(&mut self, code: &str) -> Result<(Vec<Statement>, InterpreterResult)> {
+    pub fn interpret(
+        &mut self,
+        code: &str,
+        code_source: CodeSource,
+    ) -> Result<(Vec<Statement>, InterpreterResult)> {
         let resolver = Resolver::new(self.module_importer.as_ref());
 
-        let statements = resolver.resolve(code).map_err(|e| match e {
-            ResolverError::ParseError(e) => NumbatError::ParseError(e),
+        let statements = resolver.resolve(code, code_source).map_err(|e| match e {
+            ResolverError::ParseError { inner, code_source } => {
+                NumbatError::ParseError { inner, code_source }
+            }
             e => NumbatError::ResolverError(e),
         })?;
 

+ 35 - 17
numbat/src/resolver.rs

@@ -16,13 +16,28 @@ impl std::fmt::Display for ModulePath {
     }
 }
 
+#[derive(Debug, Clone)]
+pub enum CodeSource {
+    /// User input from the command line or a REPL
+    Text,
+
+    /// A file that has been read in
+    File(PathBuf),
+
+    /// A module that has been imported
+    Module(ModulePath, Option<PathBuf>),
+}
+
 #[derive(Error, Debug)]
 pub enum ResolverError {
     #[error("Unknown module '{0}'.")]
     UnknownModule(ModulePath),
 
-    #[error("{0}")]
-    ParseError(ParseError),
+    #[error("{inner}")]
+    ParseError {
+        inner: ParseError,
+        code_source: CodeSource,
+    },
 }
 
 type Result<T> = std::result::Result<T, ResolverError>;
@@ -36,8 +51,8 @@ impl<'a> Resolver<'a> {
         Self { importer }
     }
 
-    fn parse(&self, code: &str) -> Result<Vec<Statement>> {
-        parse(code).map_err(ResolverError::ParseError)
+    fn parse(&self, code: &str, code_source: CodeSource) -> Result<Vec<Statement>> {
+        parse(code).map_err(|inner| ResolverError::ParseError { inner, code_source })
     }
 
     fn inlining_pass(&self, program: &[Statement]) -> Result<(Vec<Statement>, bool)> {
@@ -47,8 +62,11 @@ impl<'a> Resolver<'a> {
         for statement in program {
             match statement {
                 Statement::ModuleImport(module_path) => {
-                    if let Some(code) = self.importer.import(module_path) {
-                        for statement in parse(&code).map_err(ResolverError::ParseError)? {
+                    if let Some((code, filesystem_path)) = self.importer.import(module_path) {
+                        for statement in self.parse(
+                            &code,
+                            CodeSource::Module(module_path.clone(), filesystem_path),
+                        )? {
                             new_program.push(statement);
                         }
                         performed_imports = true;
@@ -63,10 +81,10 @@ impl<'a> Resolver<'a> {
         Ok((new_program, performed_imports))
     }
 
-    pub fn resolve(&self, code: &str) -> Result<Vec<Statement>> {
+    pub fn resolve(&self, code: &str, code_source: CodeSource) -> Result<Vec<Statement>> {
         // TODO: handle cyclic dependencies & infinite loops
 
-        let mut statements = self.parse(code)?;
+        let mut statements = self.parse(code, code_source)?;
 
         loop {
             let result = self.inlining_pass(&statements)?;
@@ -79,7 +97,7 @@ impl<'a> Resolver<'a> {
 }
 
 pub trait ModuleImporter {
-    fn import(&self, path: &ModulePath) -> Option<String>;
+    fn import(&self, path: &ModulePath) -> Option<(String, Option<PathBuf>)>;
 }
 
 pub struct NullImporter {}
@@ -91,7 +109,7 @@ impl NullImporter {
 }
 
 impl ModuleImporter for NullImporter {
-    fn import(&self, _: &ModulePath) -> Option<String> {
+    fn import(&self, _: &ModulePath) -> Option<(String, Option<PathBuf>)> {
         None
     }
 }
@@ -111,7 +129,7 @@ impl FileSystemImporter {
 }
 
 impl ModuleImporter for FileSystemImporter {
-    fn import(&self, module_path: &ModulePath) -> Option<String> {
+    fn import(&self, module_path: &ModulePath) -> Option<(String, Option<PathBuf>)> {
         for path in &self.root_paths {
             let mut path = path.clone();
             for part in &module_path.0 {
@@ -120,8 +138,8 @@ impl ModuleImporter for FileSystemImporter {
 
             path.set_extension("nbt");
 
-            if let Ok(code) = fs::read_to_string(path) {
-                return Some(code);
+            if let Ok(code) = fs::read_to_string(&path) {
+                return Some((code, Some(path.to_owned())));
             }
         }
 
@@ -141,10 +159,10 @@ mod tests {
     struct TestImporter {}
 
     impl ModuleImporter for TestImporter {
-        fn import(&self, path: &ModulePath) -> Option<String> {
+        fn import(&self, path: &ModulePath) -> Option<(String, Option<PathBuf>)> {
             match path {
-                ModulePath(p) if p == &["foo", "bar"] => Some("use foo::baz".into()),
-                ModulePath(p) if p == &["foo", "baz"] => Some("let a = 1".into()),
+                ModulePath(p) if p == &["foo", "bar"] => Some(("use foo::baz".into(), None)),
+                ModulePath(p) if p == &["foo", "baz"] => Some(("let a = 1".into(), None)),
                 _ => None,
             }
         }
@@ -160,7 +178,7 @@ mod tests {
         let importer = TestImporter {};
 
         let resolver = Resolver::new(&importer);
-        let program_inlined = resolver.resolve(program).unwrap();
+        let program_inlined = resolver.resolve(program, CodeSource::Text).unwrap();
 
         assert_eq!(
             &program_inlined,

+ 8 - 2
numbat/tests/common.rs

@@ -1,6 +1,9 @@
 use std::path::Path;
 
-use numbat::{resolver::FileSystemImporter, Context};
+use numbat::{
+    resolver::{CodeSource, FileSystemImporter},
+    Context,
+};
 
 pub fn get_test_context() -> Context {
     let module_path = Path::new("../modules");
@@ -11,7 +14,10 @@ pub fn get_test_context() -> Context {
     let mut context = Context::new(importer);
 
     assert!(context
-        .interpret(&std::fs::read_to_string(module_path.join("prelude.nbt")).unwrap())
+        .interpret(
+            &std::fs::read_to_string(module_path.join("prelude.nbt")).unwrap(),
+            CodeSource::Text
+        )
         .expect("Error while running prelude")
         .1
         .is_success());

+ 6 - 5
numbat/tests/interpreter.rs

@@ -3,10 +3,11 @@ mod common;
 use common::get_test_context;
 
 use numbat::markup::{Formatter, PlainTextFormatter};
+use numbat::resolver::CodeSource;
 use numbat::{pretty_print::PrettyPrint, Context, InterpreterResult};
 
 fn expect_output_with_context(ctx: &mut Context, code: &str, expected_output: &str) {
-    if let InterpreterResult::Quantity(q) = ctx.interpret(code).unwrap().1 {
+    if let InterpreterResult::Quantity(q) = ctx.interpret(code, CodeSource::Text).unwrap().1 {
         let fmt = PlainTextFormatter {};
 
         let actual_output = fmt.format(&q.pretty_print(), false);
@@ -23,7 +24,7 @@ fn expect_output(code: &str, expected_output: &str) {
 
 fn expect_failure(code: &str, msg_part: &str) {
     let mut ctx = get_test_context();
-    if let Err(e) = ctx.interpret(code) {
+    if let Err(e) = ctx.interpret(code, CodeSource::Text) {
         let error_message = e.to_string();
         assert!(error_message.contains(msg_part));
     } else {
@@ -62,7 +63,7 @@ fn test_conversions() {
 fn test_implicit_conversion() {
     let mut ctx = get_test_context();
 
-    let _ = ctx.interpret("let x = 5 m").unwrap();
+    let _ = ctx.interpret("let x = 5 m", CodeSource::Text).unwrap();
 
     expect_output_with_context(&mut ctx, "x", "5 m");
     expect_output_with_context(&mut ctx, "2x", "10 m");
@@ -152,10 +153,10 @@ fn test_other_functions() {
 fn test_last_result_identifier() {
     let mut ctx = get_test_context();
 
-    let _ = ctx.interpret("2 + 3").unwrap();
+    let _ = ctx.interpret("2 + 3", CodeSource::Text).unwrap();
     expect_output_with_context(&mut ctx, "ans", "5");
 
-    let _ = ctx.interpret("1 + 2").unwrap();
+    let _ = ctx.interpret("1 + 2", CodeSource::Text).unwrap();
     expect_output_with_context(&mut ctx, "_", "3");
 }
 

+ 7 - 6
numbat/tests/prelude_and_examples.rs

@@ -2,13 +2,14 @@ mod common;
 
 use common::get_test_context;
 
+use numbat::resolver::CodeSource;
 use numbat::{InterpreterResult, NumbatError};
 
 use std::ffi::OsStr;
 use std::fs;
 
 fn assert_typechecks_and_runs(code: &str) {
-    let result = get_test_context().interpret(code);
+    let result = get_test_context().interpret(code, CodeSource::Text);
     assert!(result.is_ok());
     assert!(matches!(
         result.unwrap().1,
@@ -18,28 +19,28 @@ fn assert_typechecks_and_runs(code: &str) {
 
 fn assert_parse_error(code: &str) {
     assert!(matches!(
-        get_test_context().interpret(code),
-        Err(NumbatError::ParseError(_))
+        get_test_context().interpret(code, CodeSource::Text),
+        Err(NumbatError::ParseError { .. })
     ));
 }
 
 fn assert_name_resolution_error(code: &str) {
     assert!(matches!(
-        get_test_context().interpret(code),
+        get_test_context().interpret(code, CodeSource::Text),
         Err(NumbatError::NameResolutionError(_))
     ));
 }
 
 fn assert_typecheck_error(code: &str) {
     assert!(matches!(
-        get_test_context().interpret(code),
+        get_test_context().interpret(code, CodeSource::Text),
         Err(NumbatError::TypeCheckError(_))
     ));
 }
 
 fn assert_runtime_error(code: &str) {
     assert!(matches!(
-        get_test_context().interpret(code),
+        get_test_context().interpret(code, CodeSource::Text),
         Err(NumbatError::RuntimeError(_))
     ));
 }