Skip to content

Commit

Permalink
magic patch os.listdir (#695)
Browse files Browse the repository at this point in the history
* magic patch os.listdir

* Update node.py

* remove io patch
  • Loading branch information
PythonFZ authored Aug 16, 2023
1 parent 9d352dd commit 9d85511
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
28 changes: 27 additions & 1 deletion tests/integration/test_fs_patch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import git

import zntrack.examples
Expand All @@ -19,9 +21,33 @@ def test_patch_open(proj_path):

node = zntrack.examples.WriteDVCOuts.from_rev(rev=commit.hexsha)

with node.state.patch_open():
with node.state.magic_patch():
with open(node.outs, "r") as f:
assert f.read() == "Hello World"

with open(node.outs, "r") as f:
assert f.read() == "Lorem Ipsum"

listdir = os.listdir(node.nwd)

with node.state.magic_patch():
os.listdir(node.nwd) == listdir


def test_patch_list(proj_path):
node = zntrack.from_rev(
"HelloWorld",
remote="https://github.com/PythonFZ/ZnTrackExamples.git",
rev="fbb6ada",
)

def func(self, path):
return os.listdir(path)

type(node).list = func

with node.state.magic_patch():
assert "nodes/HelloWorld/random_number.json" in node.list(node.nwd)
assert "nodes/HelloWorld/node-meta.json" in node.list(node.nwd)

assert node.list(node.nwd) == []
15 changes: 12 additions & 3 deletions zntrack/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import functools
import json
import logging
import os
import pathlib
import time
import typing
Expand Down Expand Up @@ -77,13 +78,14 @@ def fs(self) -> dvc.api.DVCFileSystem:
raise dvc.utils.strictyaml.YAMLValidationError

@contextlib.contextmanager
def patch_open(self) -> typing.ContextManager:
def magic_patch(self) -> typing.ContextManager:
"""Patch the open function to use the Node's file system.
Opening a relative path will use the Node's file system.
Opening an absolute path will use the local file system.
"""
original_open = open
original_listdir = os.listdir

def _open(file, *args, **kwargs):
if file == "params.yaml":
Expand All @@ -94,10 +96,17 @@ def _open(file, *args, **kwargs):

return original_open(file, *args, **kwargs)

def _listdir(path, *args, **kwargs):
if not pathlib.Path(path).is_absolute():
return self.fs.listdir(path, detail=False)

return original_listdir(path, *args, **kwargs)

with unittest.mock.patch("builtins.open", _open):
with unittest.mock.patch("__main__.open", _open):
# Jupyter Notebooks replace open with io.open
yield
with unittest.mock.patch("os.listdir", _listdir):
# Jupyter Notebooks replace open with io.open
yield


class _NameDescriptor(zninit.Descriptor):
Expand Down

0 comments on commit 9d85511

Please sign in to comment.