diff --git a/src/pygama/evt/tcm.py b/src/pygama/evt/tcm.py index fa8c537fa..930e04739 100644 --- a/src/pygama/evt/tcm.py +++ b/src/pygama/evt/tcm.py @@ -158,7 +158,13 @@ def _merge_sorted_tcms(a: ak.Array, b: ak.Array, coin_windows_local) -> ak.Array buffer = None if table_keys is None: - table_keys = np.arange(0, len(iterators)) + table_keys = list(np.arange(0, len(iterators))) + + # cache key-mapping helpers once; used to compute last_instance quickly each loop + table_keys_np = np.asarray(table_keys, dtype=np.int64) + _tk_sort = np.argsort(table_keys_np) + _tk_unsort = np.argsort(_tk_sort) + table_keys_sorted = table_keys_np[_tk_sort] while not at_end.all(): curr_mask = ~skip_mask & ~at_end @@ -223,8 +229,31 @@ def _merge_sorted_tcms(a: ak.Array, b: ak.Array, coin_windows_local) -> ak.Array # grab up to evt including last instance of a channel to know that all channels # have been included in previous evts table_key_np = ak.to_numpy(tcm["table_key"]) - last_instance = {arr_id: index for index, arr_id in enumerate(table_key_np)} - log.debug(f"last instance: {last_instance}") + row_in_table_np = ak.to_numpy(tcm["row_in_table"]) # reuse later for output + + # Fast last-occurrence computation (no per-hit Python loop): + # map table_key values -> [0..n_keys) via searchsorted on sorted keys, + # then take max index per key with np.maximum.at + key_pos = np.searchsorted(table_keys_sorted, table_key_np) + last_sorted = np.full(table_keys_sorted.size, -1, dtype=np.int64) + np.maximum.at( + last_sorted, key_pos, np.arange(table_key_np.size, dtype=np.int64) + ) + last_idx = last_sorted[_tk_unsort] + + # build dict only for keys that appear in current tcm buffer (small) + present = last_idx >= 0 + last_instance = { + int(k): int(v) for k, v in zip(table_keys_np[present], last_idx[present]) + } + + log.debug( + "tcm progress: tcm_len=%d, unique_table_keys_in_tcm=%d, at_end=%d/%d", + len(table_key_np), + len(last_instance), + int(at_end.sum()), + len(at_end), + ) for i, entry in enumerate(table_keys): if entry not in last_instance: @@ -232,8 +261,9 @@ def _merge_sorted_tcms(a: ak.Array, b: ak.Array, coin_windows_local) -> ak.Array if at_end[i]: last_instance[entry] = np.inf - if len(np.array(table_keys)[~at_end]) > 1: - comp_chan = np.array(table_keys)[~at_end][0] + active_keys = table_keys_np[~at_end] + if len(active_keys) > 1: + comp_chan = int(active_keys[0]) skip_mask = np.array( [(last_instance[arr] >= last_instance[comp_chan]) for arr in table_keys] ) @@ -244,25 +274,18 @@ def _merge_sorted_tcms(a: ak.Array, b: ak.Array, coin_windows_local) -> ak.Array # want to write entries only up to last entry of a channel to ensure all included in evt if at_end.all(): - log.debug("at end, writing all entries") write_mask = mask last_entry = None else: last_instance_min = int(np.min([last_instance[arr] for arr in table_keys])) last_entry = np.where(mask[:last_instance_min])[0] - log.debug(f"last instance: {last_instance_min}") - log.debug(f"last entry: {last_entry}") if len(last_entry) == 0: - log.debug("last entry 0, going to next iteration") - log.debug(tcm) continue else: last_entry = last_entry[-1] + 1 write_mask = mask[:last_entry] if len(write_mask) == 0: - log.debug("no entries, going to next iteration") - log.debug(tcm) continue # get cumulative_length @@ -275,14 +298,14 @@ def _merge_sorted_tcms(a: ak.Array, b: ak.Array, coin_windows_local) -> ak.Array "table_key", VectorOfVectors( cumulative_length=cumulative_length, - flattened_data=ak.to_numpy(tcm["table_key"])[:last_entry], + flattened_data=table_key_np[:last_entry], ), ) out_tbl.add_field( "row_in_table", VectorOfVectors( cumulative_length=cumulative_length, - flattened_data=ak.to_numpy(tcm["row_in_table"])[:last_entry], + flattened_data=row_in_table_np[:last_entry], ), ) if fields is not None: