diff --git a/rulog_vm/src/environment.rs b/rulog_vm/src/environment.rs index 0110360..e143478 100644 --- a/rulog_vm/src/environment.rs +++ b/rulog_vm/src/environment.rs @@ -1,18 +1,12 @@ +use rulog_core::types::ast::Term; use std::collections::HashMap; -use rulog_core::types::ast::Term; -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, Clone, Default, PartialEq, Eq)] pub struct Environment { pub bindings: HashMap, } impl Environment { - pub fn new() -> Self { - Environment { - bindings: HashMap::new(), - } - } - pub fn bind(&mut self, var: String, term: Term) { self.bindings.insert(var, term); } @@ -20,11 +14,16 @@ impl Environment { pub fn lookup(&self, var: &String) -> Option<&Term> { self.bindings.get(var) } + + pub fn extend(mut self, var: String, term: Term) -> Self { + self.bind(var, term); + self + } } impl FromIterator<(std::string::String, Term)> for Environment { fn from_iter>(iter: T) -> Self { - let mut env = Environment::new(); + let mut env = Environment::default(); for (var, term) in iter { env.bind(var, term); } diff --git a/rulog_vm/src/interpreter.rs b/rulog_vm/src/interpreter.rs index 8f56f0f..6cb8acb 100644 --- a/rulog_vm/src/interpreter.rs +++ b/rulog_vm/src/interpreter.rs @@ -9,43 +9,32 @@ use crate::{ resolver::{QuerySolution, QuerySolver}, types::InterpretingError, }; +pub trait SolutionHandler { + fn handle_solution(&self, solution: Option<&QuerySolution>) -> bool; +} +#[derive(Default)] pub struct Interpreter { clauses: Vec<(Predicate, Vec)>, operator_definitions: HashMap, - - on_solution: Option bool>>, } impl Interpreter { - pub fn new() -> Self { - Interpreter { - clauses: Vec::new(), - operator_definitions: HashMap::new(), - on_solution: None, - } - } - - pub fn on_solution(&mut self, f: F) - where - F: Fn(&QuerySolution) -> bool + 'static, - { - self.on_solution = Some(Box::new(f)); - } - - pub fn eval(&mut self, input: &str) -> Result<(), InterpretingError> { + pub fn eval( + &mut self, + input: &str, + handler: Option<&dyn SolutionHandler>, + ) -> Result<(), InterpretingError> { let program = parse(input).map_err(InterpretingError::ParseError)?; for clause in program.0 { let ret = match clause { Clause::Directive(directive) => self.handle_directive(directive), - Clause::Query(query) => self.handle_query(query), + Clause::Query(query) => self.handle_query(query, handler), Clause::Fact(fact) => self.handle_fact(fact), Clause::Rule(rule_head, rule_body) => self.handle_rule(rule_head, rule_body), }; - if let Err(e) = ret { - return Err(e); - } + ret? } Ok(()) @@ -65,21 +54,27 @@ impl Interpreter { Ok(()) } - fn handle_query(&mut self, query: Query) -> Result<(), InterpretingError> { - log::trace!("handle query resolved: {:?}", query); - let mut query_solver = QuerySolver::new(self.clauses.clone(), query); - if let Some(ref on_solution) = self.on_solution { - while let Some(solution) = query_solver.next() { - if !on_solution(&solution) { - break; - } - } - } else { - for solution in query_solver { - println!("solution: {:?}", solution); + fn handle_query( + &mut self, + query: Query, + handler: Option<&dyn SolutionHandler>, + ) -> Result<(), InterpretingError> { + log::trace!("handle query: {:?}", query); + let handler = handler.unwrap_or(&PrintSolutionHandler); + let query_solver = QuerySolver::new(self.clauses.clone(), query); + + let mut has_solution = false; + for solution in query_solver { + has_solution = true; + if !handler.handle_solution(Some(&solution)) { + break; } } + if !has_solution { + handler.handle_solution(None); + } + Ok(()) } @@ -100,20 +95,69 @@ impl Interpreter { } } +pub struct PrintSolutionHandler; + +impl SolutionHandler for PrintSolutionHandler { + fn handle_solution(&self, solution: Option<&QuerySolution>) -> bool { + println!("solution: {:?}", solution); + true // Continue processing + } +} + #[cfg(test)] mod tests { + use std::cell::RefCell; + + use crate::environment::Environment; + use super::*; + use rulog_core::types::ast::Term; use rulog_test_util::setup_logger; + struct TestSolutionHandler { + expected_solutions: Vec>, + index: RefCell, + } + + impl TestSolutionHandler { + fn new(expected_solutions: Vec>) -> Self { + Self { + expected_solutions, + index: RefCell::new(0), + } + } + } + + impl SolutionHandler for TestSolutionHandler { + fn handle_solution(&self, solution: Option<&QuerySolution>) -> bool { + let size = self.index.borrow().clone(); + if size < self.expected_solutions.len() { + assert_eq!( + solution, + self.expected_solutions[size].as_ref(), + "expected solution: {:?}, actual solution: {:?}", + self.expected_solutions[size], + solution + ); + self.index.replace(size + 1); + true + } else { + false + } + } + } #[test] fn test_parent_true() { setup_logger(); - let mut vm = Interpreter::new(); + let mut vm = Interpreter::default(); let ret = vm.eval( r#" parent(tom, liz). ?- parent(tom, liz). "#, + Some(&TestSolutionHandler::new(vec![Some( + QuerySolution::default(), + )])), ); assert!(ret.is_ok(), "{:?}", ret); } @@ -121,12 +165,17 @@ mod tests { #[test] fn test_parent_false() { setup_logger(); - let mut vm = Interpreter::new(); + let mut vm = Interpreter::default(); let ret = vm.eval( r#" parent(tom, liz). ?- parent(liz, tom). + ?- parent(tom, liz). "#, + Some(&TestSolutionHandler::new(vec![ + None, + Some(QuerySolution::default()), + ])), ); assert!(ret.is_ok(), "{:?}", ret); } @@ -134,12 +183,15 @@ mod tests { #[test] fn test_parent_var() { setup_logger(); - let mut vm = Interpreter::new(); + let mut vm = Interpreter::default(); let ret = vm.eval( r#" parent(tom, liz). ?- parent(X, liz). "#, + Some(&TestSolutionHandler::new(vec![Some(QuerySolution { + env: Environment::default().extend("X".to_string(), Term::Atom("tom".to_string())), + })])), ); assert!(ret.is_ok(), "{:?}", ret); } @@ -147,13 +199,16 @@ mod tests { #[test] fn test_parent_var_multiple() { setup_logger(); - let mut vm = Interpreter::new(); + let mut vm = Interpreter::default(); let ret = vm.eval( r#" parent(tom, liz). parent(tom, bob). ?- parent(X, liz). "#, + Some(&TestSolutionHandler::new(vec![Some(QuerySolution { + env: Environment::default().extend("X".to_string(), Term::Atom("tom".to_string())), + })])), ); assert!(ret.is_ok(), "{:?}", ret); } @@ -161,13 +216,18 @@ mod tests { #[test] fn test_parent_var_multiple_children() { setup_logger(); - let mut vm = Interpreter::new(); + let mut vm = Interpreter::default(); let ret = vm.eval( r#" parent(tom, liz). parent(tom, bob). ?- parent(tom, X). "#, + Some(&TestSolutionHandler::new(vec![Some(QuerySolution { + env: Environment::default() + .extend("X".to_string(), Term::Atom("bob".to_string())) + .extend("X".to_string(), Term::Atom("liz".to_string())), + })])), ); assert!(ret.is_ok(), "{:?}", ret); } diff --git a/rulog_vm/src/resolver.rs b/rulog_vm/src/resolver.rs index 9632b88..18c533f 100644 --- a/rulog_vm/src/resolver.rs +++ b/rulog_vm/src/resolver.rs @@ -2,7 +2,7 @@ use rulog_core::types::ast::{Predicate, Query, Term}; use crate::environment::Environment; -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct QuerySolution { pub env: Environment, } @@ -19,7 +19,7 @@ impl QuerySolver { let initial_state = query .predicates .iter() - .map(|predicate| (predicate.clone(), Environment::new())) + .map(|predicate| (predicate.clone(), Environment::default())) .collect(); QuerySolver { @@ -65,7 +65,7 @@ impl QuerySolver { env: &Environment, clause: &(Predicate, Vec), ) -> Option { - // First, check if the goal can be unified with the head of the clause. + // Check if the goal can be unified with the head of the clause. if goal.name != clause.0.name { return None; } @@ -73,7 +73,7 @@ impl QuerySolver { // Attempt unification of the terms of the goal and the clause. if let Some(new_env) = unify_terms(&goal.terms, &clause.0.terms) { // Compose the new environment with the existing one. - let new_env = compose(&env, &new_env); + let new_env = compose(env, &new_env); // If the clause has a body, we need to expand the state with the new sub-goals. if !clause.1.is_empty() { @@ -96,6 +96,11 @@ impl QuerySolver { return Some(new_env); } + // If goal and clause head match exactly, return the current environment. + if goal == &clause.0 { + return Some(env.clone()); + } + None } } @@ -121,12 +126,59 @@ fn test_query_solver_no_var() { assert_eq!( next_solution, Some(QuerySolution { - env: Environment::new() + env: Environment::default() + }) + ); + assert_eq!(query_solver.next(), None); +} + +#[test] +fn test_query_solver_no_var_true() { + let rules = vec![( + Predicate { + name: "parent".to_string(), + terms: vec![Term::Atom("tom".to_string()), Term::Atom("liz".to_string())], + }, + vec![], + )]; + let query = Query { + predicates: vec![Predicate { + name: "parent".to_string(), + terms: vec![Term::Atom("tom".to_string()), Term::Atom("liz".to_string())], + }], + }; + + let mut query_solver = QuerySolver::new(rules, query); + let next_solution = query_solver.next(); + assert_eq!( + next_solution, + Some(QuerySolution { + env: Environment::default() }) ); assert_eq!(query_solver.next(), None); } +#[test] +fn test_query_solver_no_match() { + let rules = vec![( + Predicate { + name: "parent".to_string(), + terms: vec![Term::Atom("tom".to_string()), Term::Atom("liz".to_string())], + }, + vec![], + )]; + let query = Query { + predicates: vec![Predicate { + name: "parent".to_string(), + terms: vec![Term::Atom("tom".to_string()), Term::Atom("bob".to_string())], + }], + }; + + let mut query_solver = QuerySolver::new(rules, query); + assert_eq!(query_solver.next(), None); +} + #[test] fn test_query_solver_with_var() { /* @@ -220,23 +272,20 @@ fn test_query_solver() { /// Composes two environments. fn compose(env1: &Environment, env2: &Environment) -> Environment { - let mut env = Environment::new(); - for (var, term) in env1.bindings.iter() { - env.bind(var.clone(), apply_env(term, env2)); - } + let mut env = env1.clone(); for (var, term) in env2.bindings.iter() { - env.bind(var.clone(), apply_env(term, env1)); + env.bind(var.clone(), apply_env(term, &env)); } env } #[test] fn test_compose() { - let mut env1 = Environment::new(); + let mut env1 = Environment::default(); env1.bind("X".to_string(), Term::Integer(1)); env1.bind("Y".to_string(), Term::Integer(2)); env1.bind("Z".to_string(), Term::Integer(3)); - let mut env2 = Environment::new(); + let mut env2 = Environment::default(); env2.bind("X".to_string(), Term::Integer(4)); env2.bind("Y".to_string(), Term::Integer(5)); env2.bind("W".to_string(), Term::Integer(6)); @@ -276,7 +325,7 @@ fn apply_env(term: &Term, env: &Environment) -> Term { #[test] fn test_apply_env() { - let mut env = Environment::new(); + let mut env = Environment::default(); env.bind("X".to_string(), Term::Integer(1)); env.bind("Y".to_string(), Term::Integer(2)); env.bind("Z".to_string(), Term::Integer(3)); @@ -303,17 +352,17 @@ fn apply_env_terms(terms: &[Term], env: &Environment) -> Vec { terms.iter().map(|t| apply_env(t, env)).collect() } -fn unify(term1: &Term, term2: &Term) -> Option { - let mut env = Environment::new(); - if unify_helper(term1, term2, &mut env) { - Some(env) - } else { - None - } -} +// fn unify(term1: &Term, term2: &Term) -> Option { +// let mut env = Environment::default(); +// if unify_helper(term1, term2, &mut env) { +// Some(env) +// } else { +// None +// } +// } fn unify_terms(terms1: &[Term], terms2: &[Term]) -> Option { - let mut env = Environment::new(); + let mut env = Environment::default(); if terms1.len() != terms2.len() { return None; } @@ -337,10 +386,10 @@ fn unify_helper(term1: &Term, term2: &Term, env: &mut Environment) -> bool { // if the variable is already bound, unify the bound term with the other term if let Some(binding) = env.lookup(v) { let binding = binding.clone(); - return unify_helper(&binding, t, env); + unify_helper(&binding, t, env) } else { env.bind(v.clone(), t.clone()); - return true; + true } } // if both terms are lists and have the same length, unify the pairs of items @@ -370,7 +419,7 @@ fn unify_helper(term1: &Term, term2: &Term, env: &mut Environment) -> bool { #[test] fn test_unify_helper() { - let mut env = Environment::new(); + let mut env = Environment::default(); assert_eq!( unify_helper( &Term::Structure( @@ -401,3 +450,141 @@ fn test_unify_helper() { .collect() ); } +#[test] +fn test_unify_with_nested_structures() { + let mut env = Environment::default(); + assert_eq!( + unify_helper( + &Term::Structure( + "parent".to_string(), + vec![ + Term::Structure("person".to_string(), vec![Term::Variable("X".to_string())]), + Term::Structure("person".to_string(), vec![Term::Variable("Y".to_string())]) + ] + ), + &Term::Structure( + "parent".to_string(), + vec![ + Term::Structure("person".to_string(), vec![Term::Atom("alice".to_string())]), + Term::Structure("person".to_string(), vec![Term::Atom("bob".to_string())]) + ] + ), + &mut env + ), + true + ); + assert_eq!( + env.bindings, + [ + ("X".to_string(), Term::Atom("alice".to_string())), + ("Y".to_string(), Term::Atom("bob".to_string())) + ] + .iter() + .cloned() + .collect() + ); +} + +#[test] +fn test_unify_with_lists() { + let mut env = Environment::default(); + assert_eq!( + unify_helper( + &Term::List(vec![ + Term::Variable("X".to_string()), + Term::Variable("Y".to_string()) + ]), + &Term::List(vec![Term::Integer(1), Term::Integer(2)]), + &mut env + ), + true + ); + assert_eq!( + env.bindings, + [ + ("X".to_string(), Term::Integer(1)), + ("Y".to_string(), Term::Integer(2)) + ] + .iter() + .cloned() + .collect() + ); +} + +#[test] +fn test_unify_with_recursive_structures() { + let mut env = Environment::default(); + assert_eq!( + unify_helper( + &Term::Structure( + "node".to_string(), + vec![ + Term::Variable("X".to_string()), + Term::Structure( + "node".to_string(), + vec![ + Term::Variable("Y".to_string()), + Term::Variable("Z".to_string()) + ] + ) + ] + ), + &Term::Structure( + "node".to_string(), + vec![ + Term::Integer(1), + Term::Structure("node".to_string(), vec![Term::Integer(2), Term::Integer(3)]) + ] + ), + &mut env + ), + true + ); + assert_eq!( + env.bindings, + [ + ("X".to_string(), Term::Integer(1)), + ("Y".to_string(), Term::Integer(2)), + ("Z".to_string(), Term::Integer(3)) + ] + .iter() + .cloned() + .collect() + ); +} + +#[test] +fn test_unify_with_failure() { + let mut env = Environment::default(); + assert_eq!( + unify_helper( + &Term::Structure("foo".to_string(), vec![Term::Variable("X".to_string())]), + &Term::Structure("bar".to_string(), vec![Term::Integer(1)]), + &mut env + ), + false + ); + assert!(env.bindings.is_empty()); +} + +#[test] +fn test_unify_with_existing_bindings() { + let mut env = Environment::default(); + env.bindings.insert("X".to_string(), Term::Integer(1)); + + assert_eq!( + unify_helper( + &Term::Variable("X".to_string()), + &Term::Integer(1), + &mut env + ), + true + ); + assert_eq!( + env.bindings, + [("X".to_string(), Term::Integer(1))] + .iter() + .cloned() + .collect() + ); +}