Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 117 additions & 94 deletions strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

RUN_DEFAULTS_KEY = "strax_defaults"
TEMP_DATA_TYPE_PREFIX = "_temp_"
NOT_ALLOWED_PLUGINS = (strax.LoopPlugin, strax.OverlapWindowPlugin)
NOT_PER_CHUNK_ALLOWED_PLUGINS = (strax.LoopPlugin, strax.OverlapWindowPlugin)

# use tqdm as loaded in utils (from tqdm.notebook when in a jupyter env)
tqdm = strax.utils.tqdm
Expand Down Expand Up @@ -740,17 +740,9 @@ def _context_hash(self):
)
return strax.deterministic_hash(_base_hash_on_config)

def _plugins_are_cached(
self,
targets: ty.Union[ty.Tuple[str], ty.List[str]],
chunk_number: ty.Optional[ty.Dict[str, ty.List[int]]] = None,
) -> bool:
def _plugins_are_cached(self, targets: ty.Union[ty.Tuple[str], ty.List[str]]) -> bool:
"""Check if all the requested targets are in the _fixed_plugin_cache."""
if (
self.context_config["use_per_run_defaults"]
or self._fixed_plugin_cache is None
or chunk_number is not None
):
if self.context_config["use_per_run_defaults"] or self._fixed_plugin_cache is None:
# There is no point in caching if plugins (lineage) can
# change per run or the cache is empty.
return False
Expand All @@ -761,12 +753,8 @@ def _plugins_are_cached(
plugin_cache = self._fixed_plugin_cache[context_hash]
return all([t in plugin_cache for t in targets])

def _plugins_to_cache(
self,
plugins: dict,
chunk_number: ty.Optional[ty.Dict[str, ty.List[int]]] = None,
) -> None:
if self.context_config["use_per_run_defaults"] or chunk_number is not None:
def _plugins_to_cache(self, plugins: dict) -> None:
if self.context_config["use_per_run_defaults"]:
# There is no point in caching if plugins (lineage) can change per run
return
context_hash = self._context_hash()
Expand Down Expand Up @@ -867,9 +855,15 @@ def __get_plugin(
):
"""Get single plugin either from cache or initialize it."""
# Check if plugin for data_type is already cached
if self._plugins_are_cached((data_type,), chunk_number=chunk_number):
if self._plugins_are_cached((data_type,)):
cached_plugins = self.__get_requested_plugins_from_cache(run_id, (data_type,))
target_plugin = cached_plugins[data_type]
if chunk_number is not None:
target_plugin = cached_plugins[data_type].__copy__(True)
self.__assign_chunk_number_to_plugin(target_plugin, chunk_number=chunk_number)
target_plugin.run_id = run_id
target_plugin.fix_dtype()
else:
target_plugin = cached_plugins[data_type]
return target_plugin

if data_type not in self._plugin_class_registry:
Expand All @@ -884,8 +878,7 @@ def __get_plugin(
self._set_plugin_config(plugin, run_id, tolerant=True)

plugin.deps = {
d_depends: self.__get_plugin(run_id, d_depends, chunk_number=chunk_number)
for d_depends in plugin.depends_on
d_depends: self.__get_plugin(run_id, d_depends) for d_depends in plugin.depends_on
}
if plugin.compute_takes_chunk_i:
for k, v in plugin.deps.items():
Expand All @@ -900,7 +893,7 @@ def __get_plugin(
"which is not supported."
)

self.__add_lineage_to_plugin(run_id, plugin, chunk_number=chunk_number)
self.__add_lineage_to_plugin(run_id, plugin)

if not hasattr(plugin, "data_kind") and not plugin.multi_output:
if len(plugin.depends_on):
Expand All @@ -915,11 +908,17 @@ def __get_plugin(
plugin.fix_dtype()

# Add plugin to cache
self._plugins_to_cache(
{data_type: plugin for data_type in plugin.provides}, chunk_number=chunk_number
)
self._plugins_to_cache({data_type: plugin for data_type in plugin.provides})

return plugin
if chunk_number is not None:
target_plugin = plugin.__copy__(True)
self.__assign_chunk_number_to_plugin(target_plugin, chunk_number=chunk_number)
target_plugin.run_id = run_id
target_plugin.fix_dtype()
else:
target_plugin = plugin

return target_plugin

@staticmethod
def _check_chunk_number(chunk_number: ty.List[int]):
Expand All @@ -937,18 +936,13 @@ def _check_chunk_number(chunk_number: ty.List[int]):
f"but got {chunk_number}"
)

def __add_lineage_to_plugin(
self,
run_id,
plugin,
chunk_number: ty.Optional[ty.Dict[str, ty.List[int]]] = None,
):
def __add_lineage_to_plugin(self, run_id, plugin):
"""Adds lineage to plugin in place.

Also adds parent infromation in case of a child plugin.

"""
last_provide = [d_provides for d_provides in plugin.provides][-1]
last_provide = plugin.provides[-1]

if plugin.child_plugin:
# Plugin is a child of another plugin, hence we have to
Expand Down Expand Up @@ -984,68 +978,91 @@ def __add_lineage_to_plugin(
if plugin.takes_config[option].track
}

# Set chunk_number in the lineage
if chunk_number is not None:
for d_depends in plugin.depends_on:
dependencies = self.get_dependencies(d_depends) | {d_depends}
for d in chunk_number.keys():
if d not in dependencies:
continue
if issubclass(plugin.__class__, NOT_ALLOWED_PLUGINS):
raise ValueError(
f"Can not load per-chunk storage from {d} for {plugin.__class__} "
f"because it is subclass of one of {NOT_ALLOWED_PLUGINS}!"
)
if d_depends in chunk_number:
if len(plugin.depends_on) > 1:
for d in plugin.depends_on:
dependencies = self.get_dependencies(d) | {d}
msg = (
f"Can not assign chunk_number for {plugin.__class__} "
"because it has multiple dependencies and one of the "
f"dependencies {d} does not (eventually) depend on {d_depends}."
)
mask = d_depends in dependencies
if not mask:
raise ValueError(msg)
# Make sure other dependencies depend on the same per-chunk data_type
for shortest in [False, True]:
levels = {
_d: self.tree_levels[shortest][_d]["level"]
for _d in dependencies
}
mask &= (
len(
[
k
for k, v in levels.items()
if v == levels.get(d_depends, -1)
]
)
== 1
)
if not mask:
raise ValueError(msg)
configs.setdefault("chunk_number", {})
if d_depends in configs["chunk_number"]:
raise ValueError(
f"Chunk number for {d_depends} is already set in the lineage"
)
self._check_chunk_number(chunk_number[d_depends])
plugin.chunk_number = chunk_number[d_depends]
if plugin.compute_takes_chunk_i and plugin.deps[d_depends].rechunk_on_load:
raise ValueError(
"Can not assign chunk_number for a plugin that takes chunk_i as input "
"when dependency's rechunk_on_load is True."
)
configs["chunk_number"][d_depends] = chunk_number[d_depends]

plugin.lineage = {last_provide: (plugin.__class__.__name__, plugin.version(), configs)}

# This is why the lineage of a plugin contains all its dependencies
for d_depends in plugin.depends_on:
plugin.lineage.update(plugin.deps[d_depends].lineage)

def __assign_chunk_number_to_plugin(
self,
plugin,
chunk_number: ty.Optional[ty.Dict[str, ty.List[int]]] = None,
):
"""Assign chunk_number to plugin in place.

:param plugin: Plugin to which we assign chunk_number
:param chunk_number: Dictionary with data_type as key and chunk_number as value. If None, do
nothing.

"""
if chunk_number is None:
return

if len(set(plugin.depends_on) & set(chunk_number)) > 1 and plugin.compute_takes_chunk_i:
raise ValueError(
"Can not assign chunk_number for a plugin that takes chunk_i as input "
"when multiple dependencies are per-chunk."
)

for d in plugin.depends_on:
if d not in chunk_number:
continue
# This attribute assignment is needed by p.iter
plugin.chunk_number = chunk_number[d]
if plugin.compute_takes_chunk_i and plugin.deps[d].rechunk_on_load:
raise ValueError(
"Can not assign chunk_number for a plugin that takes chunk_i as input "
"when dependency's rechunk_on_load is True."
)

# Iterate over the lineage of the plugin and check if chunk_number
# is needed to be set for the dependencies of the plugin.
for last_provide in plugin.lineage:
p = self.__get_plugin("0", last_provide)
if not (set(p.depends_on) & set(chunk_number)):
continue

if issubclass(p.__class__, NOT_PER_CHUNK_ALLOWED_PLUGINS):
raise ValueError(
f"Can not load per-chunk storage from {chunk_number} for {p.__class__} "
f"because it is subclass of one of {NOT_PER_CHUNK_ALLOWED_PLUGINS}!"
)

# Set chunk_number in the lineage
for d in p.depends_on:
if d not in chunk_number:
continue
self._check_chunk_number(chunk_number[d])
# Make sure that d is the connector of subplots
# For details: https://github.com/AxFoundation/strax/pull/996
if len(p.depends_on) > 1:
for c in p.depends_on:
dependencies = self.get_dependencies(c) | {c}
msg = (
f"Can not assign chunk_number for {p.__class__} "
"because it has multiple dependencies and one of the "
f"dependencies {c} does not (eventually) depend on {d}."
)
mask = d in dependencies
if not mask:
raise ValueError(msg)
# Make sure other dependencies depend on the same per-chunk data_type
for shortest in [False, True]:
levels = {
_d: self.tree_levels[shortest][_d]["level"] for _d in dependencies
}
mask &= (
len([k for k, v in levels.items() if v == levels.get(d, -1)]) == 1
)
if not mask:
raise ValueError(msg)
configs = plugin.lineage[last_provide][2]
configs.setdefault("chunk_number", {})
if d in configs["chunk_number"]:
raise ValueError(f"Chunk number for {d} is already set in the lineage")
configs["chunk_number"][d] = chunk_number[d]

def _per_run_default_allowed_check(self, option_name, option):
"""Check if an option of a registered plugin is allowed."""
per_run_default = option.default_by_run != strax.OMITTED
Expand Down Expand Up @@ -2079,7 +2096,7 @@ def key_for(self, run_id, target, chunk_number=None, combining=False):
:return: strax.DataKey of the target

"""
if self._plugins_are_cached((target,), chunk_number=chunk_number):
if self._plugins_are_cached((target,)):
context_hash = self._context_hash()
if context_hash in self._fixed_plugin_cache:
plugins = self._fixed_plugin_cache[self._context_hash()]
Expand All @@ -2088,12 +2105,18 @@ def key_for(self, run_id, target, chunk_number=None, combining=False):
self.log.warning(
f"Context hash changed to {context_hash} for {self._plugin_class_registry}?"
)
plugins = self._get_plugins((target,), run_id, chunk_number=chunk_number)
plugins = self._get_plugins((target,), run_id)
else:
plugins = self._get_plugins((target,), run_id)

# Prevent modifying the cached plugin
if chunk_number is not None:
plugin = plugins[target].__copy__(True)
self.__assign_chunk_number_to_plugin(plugin, chunk_number=chunk_number)
else:
plugins = self._get_plugins((target,), run_id, chunk_number=chunk_number)
plugin = plugins[target].__copy__(False)

lineage = plugins[target].lineage
return self.get_data_key(run_id, target, lineage, combining=combining)
return self.get_data_key(run_id, target, plugin.lineage, combining=combining)

def get_metadata(self, run_id, target, chunk_number=None, combining=False) -> dict:
"""Return metadata for target for run_id, or raise DataNotAvailable if data is not yet
Expand Down
3 changes: 0 additions & 3 deletions strax/plugins/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,6 @@ def __copy__(self, _deep_copy=False):
plugin_copy.__setattr__(attribute, copy(source_value))
return plugin_copy

def __deepcopy__(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is essentially a bug. __deepcopy__ never works because deepcopy(plugin.deps) will also call __deepcopy__. plugin.deps is a dictionary of plugins.

return self.__copy__(_deep_copy=True)

def __getattr__(self, name):
"""Allow access to config parameters as attributes this allows backwards compatibility in
cases where a descriptor style config depends on a non descriptor style config."""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def test_per_chunk_storage():
# Per-chunk storage not allowed for some plugins
p = type("whatever", (strax.OverlapWindowPlugin,), dict(depends_on="records"))
st.register(p)
with pytest.raises(ValueError):
with pytest.raises(NotImplementedError):
st.make(run_id, "whatever", chunk_number={"records": [0]})


Expand Down