diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..6c75ad4 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/pipegoose.iml b/.idea/pipegoose.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/pipegoose.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/pipegoose/nn/parallel_mapping.py b/pipegoose/nn/parallel_mapping.py index 108e566..58c47d6 100644 --- a/pipegoose/nn/parallel_mapping.py +++ b/pipegoose/nn/parallel_mapping.py @@ -8,12 +8,51 @@ def __init__(self, module_name: Tuple[str], **kwargs): class ParallelMapping: + + def __init__(self, model): + traced = symbolic_trace(model, input_names=["input_ids", "attention_mask"]) + self.node_list = list(traced.graph.nodes) + self.model = model + + def extract_module_from_node(self, node): + # Split the target string into components + target_path = node.target.split('.') if isinstance(node.target, str) else None + + # Traverse the model hierarchy based on the target path + current_module = self.model + try: + for attr in target_path: + current_module = getattr(current_module, attr) + except AttributeError: + return None + + return current_module + + def extract_node_target_from_module(self, submodule, prefix=''): + for name, module in model.named_children(): + if module is submodule: + return f'{prefix}{name}' if prefix else name + else: + submodule_path = find_submodule_target(module, submodule, prefix=f'{prefix}{name}.') + if submodule_path: + return submodule_path + return None + + @staticmethod + def _extract_module_name(module_name: str) -> str: + if "." in module_name: + # NOTE: transformer.h.0.self_attention.dense -> self_attention.dense + SEPARATOR = "." + sections = module_name.split(SEPARATOR) + return SEPARATOR.join(sections[-2:]) + return module_name + @staticmethod def _search(module_name: str) -> Optional[ParallelInfo]: """ Search for module_name in mappings. """ - module_name = ParallelMapping._extract_module_name(module_name) + module_name = self.module_name for child_class in ParallelMapping.__subclasses__(): if hasattr(child_class, "__MAPPING__"): for items in child_class.__MAPPING__.values(): @@ -25,13 +64,96 @@ def _search(module_name: str) -> Optional[ParallelInfo]: break return None + def is_column_parallel(self, node_target) -> bool: + """Returns True if the module is the first linear layer in an MLP layer, + or if the module is a query, key, value linear, + or a fused qkv linear of an attention layer, or an input embedding""" + """Returns True iff the module is the first linear layer in an MLP layer, + or if the module is a query, key, value linear, + or a fused qkv linear of an attention layer, or an input embedding.""" + + if not isinstance(node_target, str): + return False + + # Check if the node is the first linear layer in an MLP layer + if node_target.endswith('mlp.dense_h_to_4h'): + return True + + # Check if the node is a fused QKV linear layer or the output projection of an attention layer + if 'self_attention.query_key_value' in node_target: + return True + + + # Check if the node is part of the embedding layer + if 'word_embeddings' in node_target: + return True + + return False + def is_row_parallel(self, node_target) -> bool: + """Check if the module is the second linear layer in an MLP layer, + or the output projection of an attention layer.""" + if not isinstance(node_target, str): + return False + + # Check if the node is the second linear layer in an MLP layer + if node_target.endswith('mlp.dense_4h_to_h'): + return True + + # Check if the node is the output projection of an attention layer + if node_target.endswith('self_attention.dense'): + return True + + return False + + def is_lm_head(self, node_target) -> bool: + """Returns True iff the module is language model head.""" + return isinstance(node_target, str) and 'lm_head' in node_target + + def is_text_embedding(self, node_target) -> bool: + """Returns True iff the module is a text embedding module.""" + return isinstance(node_target, str) and 'embeddings' in node_target + + +if __name__ == "__main__": + # test + from transformers import AutoModelForCausalLM, AutoTokenizer + from transformers.utils.fx import symbolic_trace + + tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m") + inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + + pm = ParallelMapping(model) + row_parallels = [] + column_parallels = [] + lm_heads = [] + text_embeddings = [] + for module_name, module in pm.model.named_modules(): + node_target = pm.extract_node_target_from_module(module) + if node_target is None: + continue + if pm.is_row_parallel(node_target): + row_parallels.append(module_name) + if pm.is_column_parallel(node_target): + column_parallels.append(module) + if pm.is_lm_head(node_target): + lm_heads.append(module) + if pm.is_text_embedding(node_target): + text_embeddings.append(module) + + assert len(row_parallels) == 48 + assert len(column_parallels) == 50 + assert len(lm_heads) == 1 + assert len(text_embeddings) == 2 + + + + + + + + + + - @staticmethod - def _extract_module_name(module_name: str) -> str: - if "." in module_name: - # NOTE: transformer.h.0.self_attention.dense -> self_attention.dense - SEPARATOR = "." - sections = module_name.split(SEPARATOR) - return SEPARATOR.join(sections[-2:]) - return module_name