diff --git a/.github/container/nsys_jax/nsys_jax/analysis.py b/.github/container/nsys_jax/nsys_jax/analysis.py old mode 100755 new mode 100644 index c877863cb..88bf5efcd --- a/.github/container/nsys_jax/nsys_jax/analysis.py +++ b/.github/container/nsys_jax/nsys_jax/analysis.py @@ -28,9 +28,9 @@ def align_profiler_data_timestamps( # Error if the communication frame doesn't exist at all, but not if it is empty. # Calling this on a profile that does not contain any communication should # gracefully yield empty results. - assert frames.communication is not None, ( - "align_profiler_data_timestamps requires a communication frame" - ) + assert ( + frames.communication is not None + ), "align_profiler_data_timestamps requires a communication frame" if not len(frames.communication): # Nothing to be done, return an empty result return frames, {} @@ -43,9 +43,9 @@ def align_profiler_data_timestamps( f"WARNING: cannot align {num_profiled_devices} devices because max collective size is 1" ) return frames, {} - assert num_profiled_devices == max_collective_size, ( - f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented" - ) + assert ( + num_profiled_devices == max_collective_size + ), f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented" # Find the collectives that will be used align_df = comm_df[comm_df["CollectiveSize"] == max_collective_size] # Calculate the collectives' end times @@ -189,18 +189,19 @@ def _get_message_size( ) -> tuple[int, str, int, float, float]: _, inst = module_proto.find_instruction(instruction) comm_inst = inst.communication_proto() - assert comm_inst.opcode in { - "all-gather-start", - "all-reduce-start", - "all-to-all", - "collective-broadcast", - "collective-permute-start", - "dynamic-slice", - "dynamic-update-slice", - "reduce-scatter", - }, ( - f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated" - ) + assert ( + comm_inst.opcode + in { + "all-gather-start", + "all-reduce-start", + "all-to-all", + "collective-broadcast", + "collective-permute-start", + "dynamic-slice", + "dynamic-update-slice", + "reduce-scatter", + } + ), f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated" def _byte_size(inst) -> int: size_bits = math.prod( @@ -254,9 +255,9 @@ def _byte_size(inst) -> int: collective_size = iota_group_list.num_devices_per_group else: collective_sizes = set(len(group.replica_ids) for group in replica_groups) - assert len(collective_sizes) == 1, ( - f"Heterogeneous collective {comm_inst} could not be interpreted" - ) + assert ( + len(collective_sizes) == 1 + ), f"Heterogeneous collective {comm_inst} could not be interpreted" collective_size = next(iter(collective_sizes)) total_msg_size = 0 for operand_id in comm_inst.operand_ids: @@ -391,10 +392,29 @@ def generate_compilation_statistics(compile_df: pd.DataFrame) -> pd.DataFrame: # Assuming there's only one parallel region inside `launcher_row` parallel_start = child_df.loc[~is_main, "StartMs"].min() parallel_end = child_ends[~is_main].max() - # Assert that there are no main-thread tasks during this period - main_before = is_main & (child_ends < parallel_start) - main_after = is_main & (child_df["StartMs"] > parallel_end) + # Check for main-thread tasks that don't overlap with the parallel period + main_before = is_main & (child_ends <= parallel_start) + main_after = is_main & (child_df["StartMs"] >= parallel_end) + # Identify any main-thread tasks that overlap with the parallel period + main_overlap = is_main & ~(main_before | main_after) + if main_overlap.any(): + # Some main-thread tasks overlap with the parallel region. + # This can happen with NCCL operations or other intermediate operations. + # Classify them based on where most of their duration falls. + overlap_tasks = child_df.loc[main_overlap] + for idx in overlap_tasks.index: + task_start = child_df.loc[idx, "StartMs"] + task_end = child_ends.loc[idx] + # Calculate how much time is before, during, and after the parallel region + before_time = max(0, min(task_end, parallel_start) - task_start) + after_time = max(0, task_end - max(task_start, parallel_end)) + # Classify based on which is larger + if before_time > after_time: + main_before.loc[idx] = True + else: + main_after.loc[idx] = True assert ((main_before | main_after) == is_main).all() + # Aggregate statistics for how the worker threads spend their time and use that # distribution to divide up the [parallel_start, parallel_end] range of the overall # compilation time. diff --git a/.github/container/nsys_jax/nsys_jax/data_loaders.py b/.github/container/nsys_jax/nsys_jax/data_loaders.py index 7f7e7cca3..5f66f6c44 100644 --- a/.github/container/nsys_jax/nsys_jax/data_loaders.py +++ b/.github/container/nsys_jax/nsys_jax/data_loaders.py @@ -249,17 +249,50 @@ def _load_nvtx_gpu_proj_trace_single( # get the names of the ranges referred to by ModuleId mod_id_names = df.loc[mod_ids, "Name"] assert mod_ids.shape == mod_id_names.shape + # Get a mask in mod_id_names of entries where ModuleId in the original # Thunk is not referring to a Module yet. Intermediate levels of the # hierarchy can be other thunks (e.g. an individual graph node may - # have a thunk representing the whole graph as a parent). + # have a thunk representing the whole graph as a parent) or other + # range types like NCCL operations. mask = ~mod_id_names.str.startswith(module_prefix) - assert (mask == mod_id_names.str.startswith(thunk_prefix)).all() + + # Identify which non-module entries are thunks vs other range types + is_thunk = mod_id_names.str.startswith(thunk_prefix) + is_module = mod_id_names.str.startswith(module_prefix) + + # Assert that we only have modules, thunks, or known intermediate ranges + # This catches unexpected range types in the hierarchy + is_recognized = ( + is_thunk | is_module | mod_id_names.str.contains("nccl", case=False) + ) + assert is_recognized.all(), f"Found unrecognized range types in hierarchy: {mod_id_names[~is_recognized].unique()}" + + # Assert that mask is consistent with module detection + assert (mask == ~is_module).all(), "Mask inconsistency with module detection" + + # Convert to numpy arrays for cross-indexing operations + # (mod_ids and mask have different pandas indices) + mod_ids_array = mod_ids.values + mask_array = mask.values + is_thunk_array = is_thunk.values + + # Assert that all non-module entries have valid parent IDs to continue navigation + non_module_mod_ids = mod_ids_array[mask_array] + assert ( + df.loc[non_module_mod_ids, "ParentId"].notna().all() + ), "Found non-module entries without valid parent IDs, cannot navigate up hierarchy" + assert mask.shape == mod_ids.shape + # We want to end up without all_thunks containing thunks with child - # thunks, as noted above. - thunk_ids_with_child_thunks = mod_ids.array[mask] - all_thunks[thunk_ids_with_child_thunks] = False + # thunks, as noted above. Only filter out thunks, not other range types. + thunk_ids_with_child_thunks = mod_ids_array[mask_array & is_thunk_array] + # Only update indices that actually exist in all_thunks to prevent reindexing + existing_indices = all_thunks.index.intersection(thunk_ids_with_child_thunks) + if len(existing_indices) > 0: + all_thunks[existing_indices] = False + # Set thunk_ids to be the (shorter) list of indices (in df) of the # Thunks whose ModuleId values need to be updated thunk_ids = thunk_ids[mask] @@ -271,6 +304,10 @@ def _load_nvtx_gpu_proj_trace_single( # Now all the Thunks should have ModuleId pointing to an XlaModule range. mod_ids = sorted(set(df.loc[all_thunks, "ModuleId"].astype(np.int32))) + + # Ensure all_thunks only contains indices that exist in df + all_thunks = all_thunks.reindex(df.index, fill_value=False) + assert df.loc[all_thunks, "Name"].str.startswith(thunk_prefix).all() assert df.loc[mod_ids, "Name"].str.startswith(module_prefix).all() @@ -397,6 +434,10 @@ def clean_data_frame(d): if "thunk" in frames: # At this point there should be no need to look beyond the rows for individual # thunks + the protobuf data, and we can further clean up the data. + + # Ensure all_thunks is aligned with df one final time before using it + all_thunks = all_thunks.reindex(df.index, fill_value=False) + thunk_df = clean_data_frame(df[all_thunks]) thunk_df["Name"] = thunk_df["Name"].str.replace( pat=f"^{tsl_prefix}Thunk:#(?:name=.*?,|)hlo_op=([a-z0-9._-]+)#$",