diff --git a/strax/context.py b/strax/context.py index 40b92a7b7..08217055e 100644 --- a/strax/context.py +++ b/strax/context.py @@ -24,7 +24,7 @@ RUN_DEFAULTS_KEY = "strax_defaults" TEMP_DATA_TYPE_PREFIX = "_temp_" -NOT_ALLOWED_PLUGINS = (strax.LoopPlugin, strax.OverlapWindowPlugin) +NOT_PER_CHUNK_ALLOWED_PLUGINS = (strax.LoopPlugin, strax.OverlapWindowPlugin) # use tqdm as loaded in utils (from tqdm.notebook when in a jupyter env) tqdm = strax.utils.tqdm @@ -740,17 +740,9 @@ def _context_hash(self): ) return strax.deterministic_hash(_base_hash_on_config) - def _plugins_are_cached( - self, - targets: ty.Union[ty.Tuple[str], ty.List[str]], - chunk_number: ty.Optional[ty.Dict[str, ty.List[int]]] = None, - ) -> bool: + def _plugins_are_cached(self, targets: ty.Union[ty.Tuple[str], ty.List[str]]) -> bool: """Check if all the requested targets are in the _fixed_plugin_cache.""" - if ( - self.context_config["use_per_run_defaults"] - or self._fixed_plugin_cache is None - or chunk_number is not None - ): + if self.context_config["use_per_run_defaults"] or self._fixed_plugin_cache is None: # There is no point in caching if plugins (lineage) can # change per run or the cache is empty. return False @@ -761,12 +753,8 @@ def _plugins_are_cached( plugin_cache = self._fixed_plugin_cache[context_hash] return all([t in plugin_cache for t in targets]) - def _plugins_to_cache( - self, - plugins: dict, - chunk_number: ty.Optional[ty.Dict[str, ty.List[int]]] = None, - ) -> None: - if self.context_config["use_per_run_defaults"] or chunk_number is not None: + def _plugins_to_cache(self, plugins: dict) -> None: + if self.context_config["use_per_run_defaults"]: # There is no point in caching if plugins (lineage) can change per run return context_hash = self._context_hash() @@ -867,9 +855,15 @@ def __get_plugin( ): """Get single plugin either from cache or initialize it.""" # Check if plugin for data_type is already cached - if self._plugins_are_cached((data_type,), chunk_number=chunk_number): + if self._plugins_are_cached((data_type,)): cached_plugins = self.__get_requested_plugins_from_cache(run_id, (data_type,)) - target_plugin = cached_plugins[data_type] + if chunk_number is not None: + target_plugin = cached_plugins[data_type].__copy__(True) + self.__assign_chunk_number_to_plugin(target_plugin, chunk_number=chunk_number) + target_plugin.run_id = run_id + target_plugin.fix_dtype() + else: + target_plugin = cached_plugins[data_type] return target_plugin if data_type not in self._plugin_class_registry: @@ -884,8 +878,7 @@ def __get_plugin( self._set_plugin_config(plugin, run_id, tolerant=True) plugin.deps = { - d_depends: self.__get_plugin(run_id, d_depends, chunk_number=chunk_number) - for d_depends in plugin.depends_on + d_depends: self.__get_plugin(run_id, d_depends) for d_depends in plugin.depends_on } if plugin.compute_takes_chunk_i: for k, v in plugin.deps.items(): @@ -900,7 +893,7 @@ def __get_plugin( "which is not supported." ) - self.__add_lineage_to_plugin(run_id, plugin, chunk_number=chunk_number) + self.__add_lineage_to_plugin(run_id, plugin) if not hasattr(plugin, "data_kind") and not plugin.multi_output: if len(plugin.depends_on): @@ -915,11 +908,17 @@ def __get_plugin( plugin.fix_dtype() # Add plugin to cache - self._plugins_to_cache( - {data_type: plugin for data_type in plugin.provides}, chunk_number=chunk_number - ) + self._plugins_to_cache({data_type: plugin for data_type in plugin.provides}) - return plugin + if chunk_number is not None: + target_plugin = plugin.__copy__(True) + self.__assign_chunk_number_to_plugin(target_plugin, chunk_number=chunk_number) + target_plugin.run_id = run_id + target_plugin.fix_dtype() + else: + target_plugin = plugin + + return target_plugin @staticmethod def _check_chunk_number(chunk_number: ty.List[int]): @@ -937,18 +936,13 @@ def _check_chunk_number(chunk_number: ty.List[int]): f"but got {chunk_number}" ) - def __add_lineage_to_plugin( - self, - run_id, - plugin, - chunk_number: ty.Optional[ty.Dict[str, ty.List[int]]] = None, - ): + def __add_lineage_to_plugin(self, run_id, plugin): """Adds lineage to plugin in place. Also adds parent infromation in case of a child plugin. """ - last_provide = [d_provides for d_provides in plugin.provides][-1] + last_provide = plugin.provides[-1] if plugin.child_plugin: # Plugin is a child of another plugin, hence we have to @@ -984,68 +978,91 @@ def __add_lineage_to_plugin( if plugin.takes_config[option].track } - # Set chunk_number in the lineage - if chunk_number is not None: - for d_depends in plugin.depends_on: - dependencies = self.get_dependencies(d_depends) | {d_depends} - for d in chunk_number.keys(): - if d not in dependencies: - continue - if issubclass(plugin.__class__, NOT_ALLOWED_PLUGINS): - raise ValueError( - f"Can not load per-chunk storage from {d} for {plugin.__class__} " - f"because it is subclass of one of {NOT_ALLOWED_PLUGINS}!" - ) - if d_depends in chunk_number: - if len(plugin.depends_on) > 1: - for d in plugin.depends_on: - dependencies = self.get_dependencies(d) | {d} - msg = ( - f"Can not assign chunk_number for {plugin.__class__} " - "because it has multiple dependencies and one of the " - f"dependencies {d} does not (eventually) depend on {d_depends}." - ) - mask = d_depends in dependencies - if not mask: - raise ValueError(msg) - # Make sure other dependencies depend on the same per-chunk data_type - for shortest in [False, True]: - levels = { - _d: self.tree_levels[shortest][_d]["level"] - for _d in dependencies - } - mask &= ( - len( - [ - k - for k, v in levels.items() - if v == levels.get(d_depends, -1) - ] - ) - == 1 - ) - if not mask: - raise ValueError(msg) - configs.setdefault("chunk_number", {}) - if d_depends in configs["chunk_number"]: - raise ValueError( - f"Chunk number for {d_depends} is already set in the lineage" - ) - self._check_chunk_number(chunk_number[d_depends]) - plugin.chunk_number = chunk_number[d_depends] - if plugin.compute_takes_chunk_i and plugin.deps[d_depends].rechunk_on_load: - raise ValueError( - "Can not assign chunk_number for a plugin that takes chunk_i as input " - "when dependency's rechunk_on_load is True." - ) - configs["chunk_number"][d_depends] = chunk_number[d_depends] - plugin.lineage = {last_provide: (plugin.__class__.__name__, plugin.version(), configs)} # This is why the lineage of a plugin contains all its dependencies for d_depends in plugin.depends_on: plugin.lineage.update(plugin.deps[d_depends].lineage) + def __assign_chunk_number_to_plugin( + self, + plugin, + chunk_number: ty.Optional[ty.Dict[str, ty.List[int]]] = None, + ): + """Assign chunk_number to plugin in place. + + :param plugin: Plugin to which we assign chunk_number + :param chunk_number: Dictionary with data_type as key and chunk_number as value. If None, do + nothing. + + """ + if chunk_number is None: + return + + if len(set(plugin.depends_on) & set(chunk_number)) > 1 and plugin.compute_takes_chunk_i: + raise ValueError( + "Can not assign chunk_number for a plugin that takes chunk_i as input " + "when multiple dependencies are per-chunk." + ) + + for d in plugin.depends_on: + if d not in chunk_number: + continue + # This attribute assignment is needed by p.iter + plugin.chunk_number = chunk_number[d] + if plugin.compute_takes_chunk_i and plugin.deps[d].rechunk_on_load: + raise ValueError( + "Can not assign chunk_number for a plugin that takes chunk_i as input " + "when dependency's rechunk_on_load is True." + ) + + # Iterate over the lineage of the plugin and check if chunk_number + # is needed to be set for the dependencies of the plugin. + for last_provide in plugin.lineage: + p = self.__get_plugin("0", last_provide) + if not (set(p.depends_on) & set(chunk_number)): + continue + + if issubclass(p.__class__, NOT_PER_CHUNK_ALLOWED_PLUGINS): + raise ValueError( + f"Can not load per-chunk storage from {chunk_number} for {p.__class__} " + f"because it is subclass of one of {NOT_PER_CHUNK_ALLOWED_PLUGINS}!" + ) + + # Set chunk_number in the lineage + for d in p.depends_on: + if d not in chunk_number: + continue + self._check_chunk_number(chunk_number[d]) + # Make sure that d is the connector of subplots + # For details: https://github.com/AxFoundation/strax/pull/996 + if len(p.depends_on) > 1: + for c in p.depends_on: + dependencies = self.get_dependencies(c) | {c} + msg = ( + f"Can not assign chunk_number for {p.__class__} " + "because it has multiple dependencies and one of the " + f"dependencies {c} does not (eventually) depend on {d}." + ) + mask = d in dependencies + if not mask: + raise ValueError(msg) + # Make sure other dependencies depend on the same per-chunk data_type + for shortest in [False, True]: + levels = { + _d: self.tree_levels[shortest][_d]["level"] for _d in dependencies + } + mask &= ( + len([k for k, v in levels.items() if v == levels.get(d, -1)]) == 1 + ) + if not mask: + raise ValueError(msg) + configs = plugin.lineage[last_provide][2] + configs.setdefault("chunk_number", {}) + if d in configs["chunk_number"]: + raise ValueError(f"Chunk number for {d} is already set in the lineage") + configs["chunk_number"][d] = chunk_number[d] + def _per_run_default_allowed_check(self, option_name, option): """Check if an option of a registered plugin is allowed.""" per_run_default = option.default_by_run != strax.OMITTED @@ -2079,7 +2096,7 @@ def key_for(self, run_id, target, chunk_number=None, combining=False): :return: strax.DataKey of the target """ - if self._plugins_are_cached((target,), chunk_number=chunk_number): + if self._plugins_are_cached((target,)): context_hash = self._context_hash() if context_hash in self._fixed_plugin_cache: plugins = self._fixed_plugin_cache[self._context_hash()] @@ -2088,12 +2105,18 @@ def key_for(self, run_id, target, chunk_number=None, combining=False): self.log.warning( f"Context hash changed to {context_hash} for {self._plugin_class_registry}?" ) - plugins = self._get_plugins((target,), run_id, chunk_number=chunk_number) + plugins = self._get_plugins((target,), run_id) + else: + plugins = self._get_plugins((target,), run_id) + + # Prevent modifying the cached plugin + if chunk_number is not None: + plugin = plugins[target].__copy__(True) + self.__assign_chunk_number_to_plugin(plugin, chunk_number=chunk_number) else: - plugins = self._get_plugins((target,), run_id, chunk_number=chunk_number) + plugin = plugins[target].__copy__(False) - lineage = plugins[target].lineage - return self.get_data_key(run_id, target, lineage, combining=combining) + return self.get_data_key(run_id, target, plugin.lineage, combining=combining) def get_metadata(self, run_id, target, chunk_number=None, combining=False) -> dict: """Return metadata for target for run_id, or raise DataNotAvailable if data is not yet diff --git a/strax/plugins/plugin.py b/strax/plugins/plugin.py index f5cf372d3..3a725637b 100644 --- a/strax/plugins/plugin.py +++ b/strax/plugins/plugin.py @@ -196,9 +196,6 @@ def __copy__(self, _deep_copy=False): plugin_copy.__setattr__(attribute, copy(source_value)) return plugin_copy - def __deepcopy__(self): - return self.__copy__(_deep_copy=True) - def __getattr__(self, name): """Allow access to config parameters as attributes this allows backwards compatibility in cases where a descriptor style config depends on a non descriptor style config.""" diff --git a/tests/test_core.py b/tests/test_core.py index a13b87d67..6e1c8e19d 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -514,7 +514,7 @@ def test_per_chunk_storage(): # Per-chunk storage not allowed for some plugins p = type("whatever", (strax.OverlapWindowPlugin,), dict(depends_on="records")) st.register(p) - with pytest.raises(ValueError): + with pytest.raises(NotImplementedError): st.make(run_id, "whatever", chunk_number={"records": [0]})