Skip to content

Commit

Permalink
Conversion to sliced checkpoints now supports PathManager (facebookre…
Browse files Browse the repository at this point in the history
…search#387)

Summary:
Pull Request resolved: facebookresearch#387

The sliced were not created at the right place because of os.path.abspath

Reviewed By: iseessel

Differential Revision: D30109789

fbshipit-source-id: c332fbf5f5c52241a537bd1188e3268a2f5cb966
  • Loading branch information
QuentinDuval authored and facebook-github-bot committed Aug 4, 2021
1 parent 3660267 commit 5a622f4
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
8 changes: 4 additions & 4 deletions vissl/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from fvcore.common.file_io import PathManager
from vissl.config import AttrDict
from vissl.utils.env import get_machine_local_and_dist_rank
from vissl.utils.io import create_file_symlink, makedir
from vissl.utils.io import abspath, create_file_symlink, makedir
from vissl.utils.layer_memory_tracking import null_context


Expand Down Expand Up @@ -353,10 +353,10 @@ def save_slice(cls, checkpoint_path: str, param_path: str, param) -> str:
- return the created file name
"""
checkpoint_sub_folder = os.path.splitext(checkpoint_path)[0] + "_layers"
os.makedirs(checkpoint_sub_folder, exist_ok=True)
makedir(checkpoint_sub_folder)
hash_name = hashlib.sha1(param_path.encode()).hexdigest()
file_path = os.path.join(f"{checkpoint_sub_folder}", f"{hash_name}.torch")
file_path = os.path.abspath(file_path)
file_path = os.path.join(checkpoint_sub_folder, f"{hash_name}.torch")
file_path = abspath(file_path)
checkpoint_slice = {"type": CheckpointItemType.slice.name, "weight": param}
with PathManager.open(file_path, "wb") as f:
torch.save(checkpoint_slice, f)
Expand Down
12 changes: 12 additions & 0 deletions vissl/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,18 @@ def load_file(filename, mmap_mode=None):
return data


def abspath(resource_path: str):
"""
Make a path absolute, but take into account prefixes like
"http://" or "manifold://"
"""
regex = re.compile(r"^\w+://")
if regex.match(resource_path) is None:
return os.path.abspath(resource_path)
else:
return resource_path


def makedir(dir_path):
"""
Create the directory if it does not exist.
Expand Down

0 comments on commit 5a622f4

Please sign in to comment.