Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions src/pygama/evt/tcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,13 @@ def _merge_sorted_tcms(a: ak.Array, b: ak.Array, coin_windows_local) -> ak.Array
buffer = None

if table_keys is None:
table_keys = np.arange(0, len(iterators))
table_keys = list(np.arange(0, len(iterators)))

# cache key-mapping helpers once; used to compute last_instance quickly each loop
table_keys_np = np.asarray(table_keys, dtype=np.int64)
_tk_sort = np.argsort(table_keys_np)
_tk_unsort = np.argsort(_tk_sort)
table_keys_sorted = table_keys_np[_tk_sort]

while not at_end.all():
curr_mask = ~skip_mask & ~at_end
Expand Down Expand Up @@ -223,17 +229,41 @@ def _merge_sorted_tcms(a: ak.Array, b: ak.Array, coin_windows_local) -> ak.Array
# grab up to evt including last instance of a channel to know that all channels
# have been included in previous evts
table_key_np = ak.to_numpy(tcm["table_key"])
last_instance = {arr_id: index for index, arr_id in enumerate(table_key_np)}
log.debug(f"last instance: {last_instance}")
row_in_table_np = ak.to_numpy(tcm["row_in_table"]) # reuse later for output

# Fast last-occurrence computation (no per-hit Python loop):
# map table_key values -> [0..n_keys) via searchsorted on sorted keys,
# then take max index per key with np.maximum.at
key_pos = np.searchsorted(table_keys_sorted, table_key_np)
last_sorted = np.full(table_keys_sorted.size, -1, dtype=np.int64)
np.maximum.at(
last_sorted, key_pos, np.arange(table_key_np.size, dtype=np.int64)
)
last_idx = last_sorted[_tk_unsort]

# build dict only for keys that appear in current tcm buffer (small)
present = last_idx >= 0
last_instance = {
int(k): int(v) for k, v in zip(table_keys_np[present], last_idx[present])
}

log.debug(
"tcm progress: tcm_len=%d, unique_table_keys_in_tcm=%d, at_end=%d/%d",
len(table_key_np),
len(last_instance),
int(at_end.sum()),
len(at_end),
)

for i, entry in enumerate(table_keys):
if entry not in last_instance:
last_instance[entry] = np.inf
if at_end[i]:
last_instance[entry] = np.inf

if len(np.array(table_keys)[~at_end]) > 1:
comp_chan = np.array(table_keys)[~at_end][0]
active_keys = table_keys_np[~at_end]
if len(active_keys) > 1:
comp_chan = int(active_keys[0])
skip_mask = np.array(
[(last_instance[arr] >= last_instance[comp_chan]) for arr in table_keys]
)
Expand All @@ -244,25 +274,18 @@ def _merge_sorted_tcms(a: ak.Array, b: ak.Array, coin_windows_local) -> ak.Array

# want to write entries only up to last entry of a channel to ensure all included in evt
if at_end.all():
log.debug("at end, writing all entries")
write_mask = mask
last_entry = None
else:
last_instance_min = int(np.min([last_instance[arr] for arr in table_keys]))
last_entry = np.where(mask[:last_instance_min])[0]
log.debug(f"last instance: {last_instance_min}")
log.debug(f"last entry: {last_entry}")

if len(last_entry) == 0:
log.debug("last entry 0, going to next iteration")
log.debug(tcm)
continue
else:
last_entry = last_entry[-1] + 1
write_mask = mask[:last_entry]
if len(write_mask) == 0:
log.debug("no entries, going to next iteration")
log.debug(tcm)
continue

# get cumulative_length
Expand All @@ -275,14 +298,14 @@ def _merge_sorted_tcms(a: ak.Array, b: ak.Array, coin_windows_local) -> ak.Array
"table_key",
VectorOfVectors(
cumulative_length=cumulative_length,
flattened_data=ak.to_numpy(tcm["table_key"])[:last_entry],
flattened_data=table_key_np[:last_entry],
),
)
out_tbl.add_field(
"row_in_table",
VectorOfVectors(
cumulative_length=cumulative_length,
flattened_data=ak.to_numpy(tcm["row_in_table"])[:last_entry],
flattened_data=row_in_table_np[:last_entry],
),
)
if fields is not None:
Expand Down