|
27 | 27 | from . import max_logging |
28 | 28 | from . import max_utils |
29 | 29 | from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH |
30 | | -from maxdiffusion.common_types import LENGTH, KV_LENGTH |
| 30 | +from maxdiffusion.common_types import LENGTH, KV_LENGTH, RING_ATTENTION_AXIS_RULES, SELF_ATTN_HEAD, SELF_ATTN_KV_LENGTH, SELF_ATTN_Q_LENGTH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, CROSS_ATTN_Q_LENGTH |
31 | 31 |
|
32 | 32 |
|
33 | 33 | def string_to_bool(s: str) -> bool: |
@@ -180,14 +180,22 @@ def user_init(raw_keys): |
180 | 180 | raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"]) |
181 | 181 | # Verify qkv is sharded across sequence. |
182 | 182 | if raw_keys["attention"] == "ring": |
| 183 | + max_logging.log("Using ring attention, adding sequence sharding to q and kv if not already present.") |
183 | 184 | logical_axis_rules = list(raw_keys["logical_axis_rules"]) |
| 185 | + max_logging.log(f"Initial logical axis rules: {logical_axis_rules}") |
| 186 | + new_rules = [] |
184 | 187 | q_seq_sharding = (LENGTH, "fsdp") |
185 | 188 | kv_seq_sharding = (KV_LENGTH, "fsdp") |
186 | 189 | if q_seq_sharding not in logical_axis_rules: |
187 | 190 | logical_axis_rules.append(q_seq_sharding) |
188 | 191 | if kv_seq_sharding not in logical_axis_rules: |
189 | 192 | logical_axis_rules.append(kv_seq_sharding) |
190 | | - raw_keys["logical_axis_rules"] = tuple(logical_axis_rules) |
| 193 | + for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES: |
| 194 | + if ring_attention_axis_rule not in logical_axis_rules: |
| 195 | + max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}") |
| 196 | + new_rules.append(ring_attention_axis_rule) |
| 197 | + raw_keys["logical_axis_rules"] = tuple(new_rules) + tuple(logical_axis_rules) |
| 198 | + max_logging.log(f"Final logical axis rules: {raw_keys['logical_axis_rules']}") |
191 | 199 |
|
192 | 200 | raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"]) |
193 | 201 |
|
|
0 commit comments