diff --git a/include/nledl.h b/include/nledl.h index b55bcac37..861e37898 100644 --- a/include/nledl.h +++ b/include/nledl.h @@ -28,4 +28,6 @@ void nle_end(nledl_ctx *); void nle_set_seed(nledl_ctx *, unsigned long, unsigned long, char); void nle_get_seed(nledl_ctx *, unsigned long *, unsigned long *, char *); +int nle_save(nledl_ctx *); + #endif /* NLEDL_H */ diff --git a/nle/env/base.py b/nle/env/base.py index bdda5d3d2..9b6276a3a 100644 --- a/nle/env/base.py +++ b/nle/env/base.py @@ -198,6 +198,8 @@ def __init__( allow_all_modes=False, spawn_monsters=True, render_mode="human", + gamesavedir=None, + gameloaddir=None, ): """Constructs a new NLE environment. @@ -317,6 +319,8 @@ def __init__( wizard=wizard, spawn_monsters=spawn_monsters, scoreprefix=scoreprefix, + gamesavedir=gamesavedir, + gameloaddir=gameloaddir, ) self._close_nethack = weakref.finalize(self, self.nethack.close) @@ -545,6 +549,9 @@ def render(self): return "\nInvalid render mode: " + mode + def save(self, gamesavedir=None): + return self.nethack.save(gamesavedir=gamesavedir) + def __repr__(self): return "<%s>" % self.__class__.__name__ diff --git a/nle/nethack/nethack.py b/nle/nethack/nethack.py index d0e2091ed..793a15477 100644 --- a/nle/nethack/nethack.py +++ b/nle/nethack/nethack.py @@ -96,15 +96,20 @@ def _new_dl_linux(vardir): def _new_dl(vardir): """Creates a copied .so file to allow for multiple independent NLE instances""" - if sys.platform == "linux": - return _new_dl_linux(vardir) + # if sys.platform == "linux": + # return _new_dl_linux(vardir) # MacOS has no memfd_create or O_TMPFILE. Using /dev/fd/{FD} as an argument # to dlopen doesn't work after unlinking from the file system. So let's copy # instead and hope vardir gets properly deleted at some point. - dl = tempfile.NamedTemporaryFile(suffix="libnethack.so", dir=vardir) + # dl = tempfile.NamedTemporaryFile(suffix="libnethack.so", dir=vardir) + # shutil.copyfile(DLPATH, dl.name) # Might use fcopyfile. + # return dl, dl.name + + dlpath = os.path.join(vardir, "libnethack.so") + dl = open(dlpath, "w") shutil.copyfile(DLPATH, dl.name) # Might use fcopyfile. - return dl, dl.name + return dl, dlpath def _close(pynethack, dl, tempdir, warn=True): @@ -168,7 +173,11 @@ def __init__( hackdir=HACKDIR, spawn_monsters=True, scoreprefix="", + gamesavedir=None, + gameloaddir=None, ): + self.gamesavedir = gamesavedir + self.gameloaddir = gameloaddir self._copy = copy if not os.path.exists(hackdir) or not os.path.exists( @@ -182,24 +191,31 @@ def __init__( self._tempdir = tempfile.TemporaryDirectory(prefix="nle") self._vardir = self._tempdir.name - # Symlink a nhdat. - os.symlink(os.path.join(hackdir, "nhdat"), os.path.join(self._vardir, "nhdat")) + if self.gameloaddir: + # restore files (save) from directory + shutil.copytree(self.gameloaddir, self._vardir, dirs_exist_ok=True) - # Touch files, so lock_file() in files.c passes. - for fn in ["perm", "record", "logfile"]: - os.close(os.open(os.path.join(self._vardir, fn), os.O_CREAT)) - if scoreprefix: - os.close(os.open(scoreprefix + "xlogfile", os.O_CREAT)) + self.dlpath = os.path.join(self._vardir, "libnethack.so") + self._dl = open(self.dlpath, "r") else: - os.close(os.open(os.path.join(self._vardir, "xlogfile"), os.O_CREAT)) + # Symlink a nhdat. + os.symlink(os.path.join(hackdir, "nhdat"), os.path.join(self._vardir, "nhdat")) + + # Touch files, so lock_file() in files.c passes. + for fn in ["perm", "record", "logfile"]: + os.close(os.open(os.path.join(self._vardir, fn), os.O_CREAT)) + if scoreprefix: + os.close(os.open(scoreprefix + "xlogfile", os.O_CREAT)) + else: + os.close(os.open(os.path.join(self._vardir, "xlogfile"), os.O_CREAT)) - os.mkdir(os.path.join(self._vardir, "save")) + os.mkdir(os.path.join(self._vardir, "save")) - # An assortment of hacks: - # Copy our .so into self._vardir to load several copies of the dl. - # (Or use a memfd_create hack to create a file that gets deleted on - # process exit.) - self._dl, self.dlpath = _new_dl(self._vardir) + # An assortment of hacks: + # Copy our .so into self._vardir to load several copies of the dl. + # (Or use a memfd_create hack to create a file that gets deleted on + # process exit.) + self._dl, self.dlpath = _new_dl(self._vardir) # Finalize even when the rest of this constructor fails. self._finalizer = weakref.finalize(self, _close, None, self._dl, self._tempdir) @@ -323,3 +339,15 @@ def in_normal_game(self): def how_done(self): return self._pynethack.how_done() + + def save(self, gamesavedir=None): + if gamesavedir: + savedir = gamesavedir + else: + savedir = self.gamesavedir + + assert savedir is not None + + success = self._pynethack.save() + shutil.copytree(self._vardir, savedir, dirs_exist_ok=True) + return success \ No newline at end of file diff --git a/nle/tests/test_envs.py b/nle/tests/test_envs.py index b81267ff5..7b3621e42 100644 --- a/nle/tests/test_envs.py +++ b/nle/tests/test_envs.py @@ -340,7 +340,26 @@ def test_render_ansi(self, env_name, rollout_len): assert isinstance(output, str) assert len(output.replace("\n", "")) == np.prod(nle.env.DUNGEON_SHAPE) - + def test_save_and_load(self, env_name, rollout_len): + with tempfile.TemporaryDirectory() as gamesavedir: + env = gym.make(env_name, gamesavedir=gamesavedir) + + obs = env.reset() + for _ in range(rollout_len): + action = env.action_space.sample() + obs, _, done, _ = env.step(action) + if done: + obs = env.reset() + + env.save() + + env = gym.make(env_name, gameloaddir=gamesavedir) + obsload = env.reset() + + assert (obsload["blstats"] == obs["blstats"]).all() + assert (obsload["glyphs"] == obs["glyphs"]).all() + + class TestGymDynamics: """Tests a few game dynamics.""" diff --git a/sys/unix/nledl.c b/sys/unix/nledl.c index 53de50263..acfca21cd 100644 --- a/sys/unix/nledl.c +++ b/sys/unix/nledl.c @@ -150,3 +150,16 @@ nle_get_seed(nledl_ctx *nledl, unsigned long *core, unsigned long *disp, get_seed(nledl->nle_ctx, core, disp, reseed); } #endif + + +int +nle_save(nledl_ctx *nledl) +{ + int success; + void *(*dosave0)(); + + dosave0 = dlsym(nledl->dlhandle, "dosave0"); + success = dosave0(); + + return success; +} \ No newline at end of file diff --git a/win/rl/pynethack.cc b/win/rl/pynethack.cc index 8d045683e..d3779c44d 100644 --- a/win/rl/pynethack.cc +++ b/win/rl/pynethack.cc @@ -101,6 +101,10 @@ checked_conversion(py::handle h, const std::vector &shape) class Nethack { public: + int save() { + return nle_save(nle_); + } + Nethack(std::string dlpath, std::string ttyrec, std::string hackdir, std::string nethackoptions, bool spawn_monsters, std::string scoreprefix) @@ -404,6 +408,7 @@ PYBIND11_MODULE(_pynethack, m) .def("get_seeds", &Nethack::get_seeds) .def("in_normal_game", &Nethack::in_normal_game) .def("how_done", &Nethack::how_done) + .def("save", &Nethack::save) .def("set_wizkit", &Nethack::set_wizkit); py::module mn = m.def_submodule(