Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/nledl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ void nle_end(nledl_ctx *);
void nle_set_seed(nledl_ctx *, unsigned long, unsigned long, char);
void nle_get_seed(nledl_ctx *, unsigned long *, unsigned long *, char *);

int nle_save(nledl_ctx *);

#endif /* NLEDL_H */
7 changes: 7 additions & 0 deletions nle/env/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def __init__(
allow_all_modes=False,
spawn_monsters=True,
render_mode="human",
gamesavedir=None,
gameloaddir=None,
):
"""Constructs a new NLE environment.

Expand Down Expand Up @@ -317,6 +319,8 @@ def __init__(
wizard=wizard,
spawn_monsters=spawn_monsters,
scoreprefix=scoreprefix,
gamesavedir=gamesavedir,
gameloaddir=gameloaddir,
)
self._close_nethack = weakref.finalize(self, self.nethack.close)

Expand Down Expand Up @@ -545,6 +549,9 @@ def render(self):

return "\nInvalid render mode: " + mode

def save(self, gamesavedir=None):
return self.nethack.save(gamesavedir=gamesavedir)

def __repr__(self):
return "<%s>" % self.__class__.__name__

Expand Down
64 changes: 46 additions & 18 deletions nle/nethack/nethack.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,20 @@ def _new_dl_linux(vardir):

def _new_dl(vardir):
"""Creates a copied .so file to allow for multiple independent NLE instances"""
if sys.platform == "linux":
return _new_dl_linux(vardir)
# if sys.platform == "linux":
# return _new_dl_linux(vardir)

# MacOS has no memfd_create or O_TMPFILE. Using /dev/fd/{FD} as an argument
# to dlopen doesn't work after unlinking from the file system. So let's copy
# instead and hope vardir gets properly deleted at some point.
dl = tempfile.NamedTemporaryFile(suffix="libnethack.so", dir=vardir)
# dl = tempfile.NamedTemporaryFile(suffix="libnethack.so", dir=vardir)
# shutil.copyfile(DLPATH, dl.name) # Might use fcopyfile.
# return dl, dl.name

dlpath = os.path.join(vardir, "libnethack.so")
dl = open(dlpath, "w")
shutil.copyfile(DLPATH, dl.name) # Might use fcopyfile.
return dl, dl.name
return dl, dlpath


def _close(pynethack, dl, tempdir, warn=True):
Expand Down Expand Up @@ -168,7 +173,11 @@ def __init__(
hackdir=HACKDIR,
spawn_monsters=True,
scoreprefix="",
gamesavedir=None,
gameloaddir=None,
):
self.gamesavedir = gamesavedir
self.gameloaddir = gameloaddir
self._copy = copy

if not os.path.exists(hackdir) or not os.path.exists(
Expand All @@ -182,24 +191,31 @@ def __init__(
self._tempdir = tempfile.TemporaryDirectory(prefix="nle")
self._vardir = self._tempdir.name

# Symlink a nhdat.
os.symlink(os.path.join(hackdir, "nhdat"), os.path.join(self._vardir, "nhdat"))
if self.gameloaddir:
# restore files (save) from directory
shutil.copytree(self.gameloaddir, self._vardir, dirs_exist_ok=True)

# Touch files, so lock_file() in files.c passes.
for fn in ["perm", "record", "logfile"]:
os.close(os.open(os.path.join(self._vardir, fn), os.O_CREAT))
if scoreprefix:
os.close(os.open(scoreprefix + "xlogfile", os.O_CREAT))
self.dlpath = os.path.join(self._vardir, "libnethack.so")
self._dl = open(self.dlpath, "r")
else:
os.close(os.open(os.path.join(self._vardir, "xlogfile"), os.O_CREAT))
# Symlink a nhdat.
os.symlink(os.path.join(hackdir, "nhdat"), os.path.join(self._vardir, "nhdat"))

# Touch files, so lock_file() in files.c passes.
for fn in ["perm", "record", "logfile"]:
os.close(os.open(os.path.join(self._vardir, fn), os.O_CREAT))
if scoreprefix:
os.close(os.open(scoreprefix + "xlogfile", os.O_CREAT))
else:
os.close(os.open(os.path.join(self._vardir, "xlogfile"), os.O_CREAT))

os.mkdir(os.path.join(self._vardir, "save"))
os.mkdir(os.path.join(self._vardir, "save"))

# An assortment of hacks:
# Copy our .so into self._vardir to load several copies of the dl.
# (Or use a memfd_create hack to create a file that gets deleted on
# process exit.)
self._dl, self.dlpath = _new_dl(self._vardir)
# An assortment of hacks:
# Copy our .so into self._vardir to load several copies of the dl.
# (Or use a memfd_create hack to create a file that gets deleted on
# process exit.)
self._dl, self.dlpath = _new_dl(self._vardir)

# Finalize even when the rest of this constructor fails.
self._finalizer = weakref.finalize(self, _close, None, self._dl, self._tempdir)
Expand Down Expand Up @@ -323,3 +339,15 @@ def in_normal_game(self):

def how_done(self):
return self._pynethack.how_done()

def save(self, gamesavedir=None):
if gamesavedir:
savedir = gamesavedir
else:
savedir = self.gamesavedir

assert savedir is not None

success = self._pynethack.save()
shutil.copytree(self._vardir, savedir, dirs_exist_ok=True)
return success
21 changes: 20 additions & 1 deletion nle/tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,26 @@ def test_render_ansi(self, env_name, rollout_len):
assert isinstance(output, str)
assert len(output.replace("\n", "")) == np.prod(nle.env.DUNGEON_SHAPE)


def test_save_and_load(self, env_name, rollout_len):
with tempfile.TemporaryDirectory() as gamesavedir:
env = gym.make(env_name, gamesavedir=gamesavedir)

obs = env.reset()
for _ in range(rollout_len):
action = env.action_space.sample()
obs, _, done, _ = env.step(action)
if done:
obs = env.reset()

env.save()

env = gym.make(env_name, gameloaddir=gamesavedir)
obsload = env.reset()

assert (obsload["blstats"] == obs["blstats"]).all()
assert (obsload["glyphs"] == obs["glyphs"]).all()


class TestGymDynamics:
"""Tests a few game dynamics."""

Expand Down
13 changes: 13 additions & 0 deletions sys/unix/nledl.c
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,16 @@ nle_get_seed(nledl_ctx *nledl, unsigned long *core, unsigned long *disp,
get_seed(nledl->nle_ctx, core, disp, reseed);
}
#endif


int
nle_save(nledl_ctx *nledl)
{
int success;
void *(*dosave0)();

dosave0 = dlsym(nledl->dlhandle, "dosave0");
success = dosave0();

return success;
}
5 changes: 5 additions & 0 deletions win/rl/pynethack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ checked_conversion(py::handle h, const std::vector<ssize_t> &shape)
class Nethack
{
public:
int save() {
return nle_save(nle_);
}

Nethack(std::string dlpath, std::string ttyrec, std::string hackdir,
std::string nethackoptions, bool spawn_monsters,
std::string scoreprefix)
Expand Down Expand Up @@ -404,6 +408,7 @@ PYBIND11_MODULE(_pynethack, m)
.def("get_seeds", &Nethack::get_seeds)
.def("in_normal_game", &Nethack::in_normal_game)
.def("how_done", &Nethack::how_done)
.def("save", &Nethack::save)
.def("set_wizkit", &Nethack::set_wizkit);

py::module mn = m.def_submodule(
Expand Down