Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Added MMapIndexedCache #4611

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 251 additions & 0 deletions allennlp/common/mmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from functools import lru_cache
import os
import shutil
import struct
import numpy as np
import torch
from allennlp.data.fields import DataArray

dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float,
7: np.double,
8: np.uint16,
}


def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
raise ValueError(dtype)


def index_file_path(prefix_path):
return f"{prefix_path}.idx"


def data_file_path(prefix_path):
return f"{prefix_path}.bin"


def _warmup_mmap_file(path):
with open(path, "rb") as stream:
while stream.read(100 * 1024 * 1024):
pass


class MMapIndexedCache(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b"MMIDIDX\x00\x00"

@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, "wb")

self._file.write(cls._HDR_MAGIC)
self._file.write(struct.pack("<Q", 1))
self._file.write(struct.pack("<B", code(dtype)))

return self

@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []

for size in sizes:
pointers.append(address)
address += size * dtype_size

return pointers

def write(self, sizes):
pointers = self._get_pointers(sizes)

self._file.write(struct.pack("<Q", len(sizes)))

sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order="C"))
del sizes

pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order="C"))
del pointers

def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()

return _Writer()

def __init__(self, path):
with open(path, "rb") as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
version = struct.unpack("<Q", stream.read(8))
assert (1,) == version

(dtype_code,) = struct.unpack("<B", stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize

self._len = struct.unpack("<Q", stream.read(8))[0]
offset = stream.tell()

_warmup_mmap_file(path)

self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
self._sizes = np.frombuffer(
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
)
self._pointers = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._len,
offset=offset + self._sizes.nbytes,
)

def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap

@property
def dtype(self):
return self._dtype

@property
def sizes(self):
return self._sizes

@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]

def __len__(self):
return self._len

def __init__(self, path):
super().__init__()

self._path = None
self._index = None
self._bin_buffer = None

self._do_init(path)

def __getstate__(self):
return self._path

def __setstate__(self, state):
self._do_init(state)

def _do_init(self, path):
self._path = path
self._index = self.Index(index_file_path(self._path))

_warmup_mmap_file(data_file_path(self._path))
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)

def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._index

def __len__(self):
return len(self._index)

@lru_cache(maxsize=8)
def __getitem__(self, i):
ptr, size = self._index[i]
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
if self._index.dtype != np.int64:
np_array = np_array.astype(np.int64)

return torch.from_numpy(np_array)

@property
def sizes(self):
return self._index.sizes

@staticmethod
def exists(path):
return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))


class MMapIndexedCacheBuilder(object):
def __init__(self, out_file, vocab_size=None):
self._data_file = open(out_file, "wb")
self._sizes = []
self._field_names = None
if vocab_size is not None and vocab_size < 65500:
self._dtype = np.uint16
else:
self._dtype = np.int32

def add_instance(self, instance: Instance):
tensor_dict = instance.as_tensor_dict()
flattened_dict = self.flatten_dict(tensor_dict)
if not self._field_names:
self._field_names = list(sorted(flattened_dict.keys()))
assert self._field_names
# TODO: what if some instances have a different set of field names, i.e missing some, for test instances, we don't have supervision.....
OhadRubin marked this conversation as resolved.
Show resolved Hide resolved
# for now we will just write the name of every field next to the data.

for key, value in flattened_dict.items():
self.add_item(key, value)

@classmethod
def flatten_dict(cls, tensor_dict: Dict, prefix=None):
flat_dict = {}
for field_name, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
name = f"{prefix}___{field_name}" if prefix else field_name
flat_dict[name] = value
elif isinstance(value, dict):
flat_dict.update(cls.flatten_dict(value, prefix=field_name))
else:
raise ValueError("You gave me a MetadataField")
return flat_dict

def add_item(self, name, tensor):
np_array = tensor.contiguous().detach().numpy()
np_array_b = np_array.tobytes(order="C")
name_b = name.encode()
self._sizes += [len(name_b), len(np_array_b), np_array.size]
self._data_file.write(name_b)
self._data_file.write(np_array_b)

def merge_file_(self, another_file):
# Concatenate index
index = MMapIndexedCache.Index(index_file_path(another_file))
assert index.dtype == self._dtype

for size in index.sizes:
self._sizes.append(size)

# Concatenate data
with open(data_file_path(another_file), "rb") as f:
shutil.copyfileobj(f, self._data_file)

def finalize(self, index_file):
self._data_file.close()

with MMapIndexedCache.Index.writer(index_file, self._dtype) as index:
index.write(self._sizes)