diff --git a/examples/megatron_example.py b/examples/megatron_example.py index 3d49895..c953cce 100644 --- a/examples/megatron_example.py +++ b/examples/megatron_example.py @@ -79,16 +79,32 @@ def tokenize_fn(prompts): tensor_parallel_size=2, ) - # Create a hook function - module_name = "layers.28.feed_forward.w3" - hook_function = HookFunction( - module_name=module_name, + # Define some editing function + def do_something(current_module, inputs, save_ctx, modules): + save_ctx.my_tensor = inputs + print(f"Saving tensor: {hash(inputs)}") + return inputs * 20 + + def do_another_thing(current_module, inputs, save_ctx, modules): + past_tensor = save_ctx.my_tensor + print(f"Retrieved tensor: {hash(past_tensor)}") + return inputs + past_tensor + + # Define hook functions + my_hook_function_1 = HookFunction( + module_name="layers.20.feed_forward.w3", expected_shape=(None, None, 13824), - editing_function=None, + editing_function=do_something, + ) + my_hook_function_2 = HookFunction( + module_name="layers.28.feed_forward.w3", + expected_shape=(None, None, 13824), + editing_function=do_another_thing, ) # Register hook function with the model - model.register_forward_hook(hook_function) + model.register_forward_hook(my_hook_function_1) + model.register_forward_hook(my_hook_function_2) # Tokenize a prompt inputs = tokenize_fn(prompts) @@ -99,12 +115,19 @@ def tokenize_fn(prompts): # Activations are only dumped to main process. Activations per-module key # are accumulated in a list. if torch.distributed.get_rank() == 0: - activation = activation_dict[module_name][0] - print(f"Activation shape: {activation.shape}") - print(activation) + activation_1 = activation_dict["layers.20.feed_forward.w3"][0] + print(f"Activation shape: {activation_1.shape}") + print(activation_1) + + assert activation_1.shape[0] == 4 + assert activation_1.shape[-1] == 13824 + + activation_2 = activation_dict["layers.28.feed_forward.w3"][0] + print(f"Activation shape: {activation_2.shape}") + print(activation_2) - assert activation.shape[0] == 4 - assert activation.shape[-1] == 13824 + assert activation_2.shape[0] == 4 + assert activation_2.shape[-1] == 13824 if __name__ == "__main__": diff --git a/profiling/utils.py b/profiling/utils.py index 2b76374..f726856 100644 --- a/profiling/utils.py +++ b/profiling/utils.py @@ -55,9 +55,11 @@ def __init__(self, model_dim, n_layers, is_distributed, config): RowParallelLinear( model_dim, model_dim, - input_is_parallel=True, config=config, init_method=init_method, + bias=False, + input_is_parallel=True, + skip_bias_add=False, ), ) for _ in range(self.n_layers // 2)