diff --git a/hw02/homework2.py b/hw02/homework2.py index a96b1ef..8747bdf 100755 --- a/hw02/homework2.py +++ b/hw02/homework2.py @@ -291,6 +291,46 @@ def substitute (self,id,new_e): self._e2.substitute(id,new_e)) +class ECache (Exp): + def __init__ (self, e): + self._e = e + self._v = None + + def eval (self, prim_dict, fun_dict=FUNC_DICT): + if self._v == None: + self._v = self._e.eval(prim_dict) + return self._v + + def substitute (self, id, new_e): + self._e.substitute(id, new_e) + return self + + +class ELetN (Exp): + # "memoized" local binding + + def __init__ (self, id, e1, e2): + self._id = id + self._e1 = ECache(e1) + self._e2 = e2 + + def __str__ (self): + return "ELetN({},{},{})".format(id, self._e1, self._e2) + + def eval(self, prim_dict, fun_dict=FUNC_DICT): + new_e2 = self._e2.substitute(self._id, self._e1) + return new_e2.eval(prim_dict) + + def substitute (self, id, new_e): + if id == self._id: + return ELetN(self._id, + self._e1.substitute(id,new_e), + self._e2) + return ELetN(self._id, + self._e1.substitute(id,new_e), + self._e2.substitute(id,new_e)) + + class EId (Exp): # identifier diff --git a/hw02/homework2_test.py b/hw02/homework2_test.py index 56c6b93..2e97d5a 100644 --- a/hw02/homework2_test.py +++ b/hw02/homework2_test.py @@ -57,6 +57,18 @@ def test_ELetS(self): ("b",EId("a"))], EPrimCall("-",[EId("a"),EId("b")]))).expand().eval(INITIAL_PRIM_DICT).value, 0) + def test_ELetN(self): + self.assertEqual(ELetN("a",EInteger(10),EId("a")).eval(INITIAL_PRIM_DICT).value, 10) + self.assertEqual(ELetN("a",EInteger(10), + ELetN("b",EInteger(20),EId("a"))).eval(INITIAL_PRIM_DICT).value, 10) + self.assertEqual(ELetN("a",EInteger(10), + ELetN("a",EInteger(20),EId("a"))).eval(INITIAL_PRIM_DICT).value, 20) + self.assertEqual(ELetN("a",EPrimCall("+",[EInteger(10),EInteger(20)]), + ELetN("b",EInteger(20),EId("a"))).eval(INITIAL_PRIM_DICT).value, 30) + self.assertEqual(ELetN("a",EPrimCall("+",[EInteger(10),EInteger(20)]), + ELetN("b",EInteger(20), + EPrimCall("*",[EId("a"),EId("a")]))).eval(INITIAL_PRIM_DICT).value, 900) + def test_EDef(self): EDef("add1", ["x"], EPrimCall("+", [EInteger(1), EId("x")])).eval(INITIAL_PRIM_DICT) EDef("add2", "x", EPrimCall("+", [EInteger(2), EId("x")])).eval(INITIAL_PRIM_DICT)