-
Notifications
You must be signed in to change notification settings - Fork 28
[feat] Hybrid Mamba model with Mamba and discrete Mamba 2 layers #194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
hint=FieldHint.core, | ||
) | ||
|
||
default_block: str = Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant with block_pattern
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, this is needed to load Llamba model: we set default_block
to m2
in _create_config_converters
of LLambaHuggingfaceCheckpointHandler
and the block_pattern
is then created in the __post_init__
of HybridSSMBaseModelConfig
. This is a bit cumbersome indeed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't it set block_pattern
? Also this would need to go in _validate
, not __post_init__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
afaiu _create_config_converters
does not know about the num_layers
in the loaded config, so we cannot set the block pattern as it depends on the number of layers.
Moved the post_init logic for block config to validate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh so block_pattern
is not specifying a repeated pattern, but the entire list? Why not just repeating the list up to num_layers
instead, as the name suggests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed block_pattern
into hybrid_block_layout
and use default_block
repeated 'num_layers' times in case 'hybrid_block_layout' is not specified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, but why not a single variable with automated repetition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@oleksost, try merging these but use this default factory:
hybrid_block_layout: list[str] = Field(
default_factory=lambda: ['m'],
desc="Pattern of blocks to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for Descrete Mamba2.",
hint=FieldHint.core,
)
that avoids the mutable default trap.
hint=FieldHint.core, | ||
) | ||
|
||
default_block: str = Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh so block_pattern
is not specifying a repeated pattern, but the entire list? Why not just repeating the list up to num_layers
instead, as the name suggests?
|
||
|
||
@config_class() | ||
class SSMArchitectureConfig(BaseModelArchitectureConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please adjust field names for our naming conventions.
hint=FieldHint.core, | ||
) | ||
|
||
dt_rank: str | int = Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use None
for derived defaults. dt_rank: int = Field(default=None, ...
strict: bool = True, | ||
flat: bool = False, | ||
) -> typing.Self: | ||
if "hybrid_block_layout" in default and isinstance(default["hybrid_block_layout"], dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would it be a dict? There must be another problem elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have tests that check if serializing/deserializing lists with strings works correctly? the special property of strings is that they are also lists, kind-of. there could be issues because of that, and they may be not captured by tests if we only test integers of ints, say.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@oleksost I see a patch to set_nested_dict_value
below. My hunch is that this isn't working properly
) | ||
elif block_type == "m2": | ||
# Create Mamba2 descrete block | ||
mixer_cls = partial(DiscreteMamba2, layer_idx=i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed, you're already passing layer_idx
to the mixer_cls()
call
return init_ | ||
|
||
|
||
class MambaLayer(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't work with TP, need to explicitly prevent. (Not sure about PP).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we're gonna have TP eventually, not in scope for this PR though
@jlamypoirier, naming-wise, "layer" implies a single functional unit, like attention or an MLP. But here, the repeated unit is a composite: It includes a mixer (attention, SSM, etc.), normalization, residuals, and post-mixer processing (MLPs, MoEs, etc.). This structure isn't atomic. It's multiple layers stitched into a reusable computation unit, which is typically referred to as a "block" in other model families (e.g., ResNet, Swin, and even “Transformer blocks” in many papers and blogs).
Keeping the name Also, and I find this the most important argument: internally and casually, we already call them blocks. Making the code match mental models reduces friction. |
@tscholak Sounds reasonable, problem s these things are called "layers" everywhere else in Fast-LLM. Should we think about renaming these too? |
@@ -32,7 +32,6 @@ def fast_llm(args=None): | |||
sys.exit(1) | |||
except Exception: # noqa | |||
logger.critical(traceback.format_exc()) | |||
sys.exit(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know we need this line to be removed for the debugger to work. You could do this instead:
except Exception: # noqa
if sys.gettrace():
raise
logger.critical(traceback.format_exc())
sys.exit(1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would work, we do need that line outside a debugger. Same thing needed for ValidationError
above.
else: | ||
d[int(keys[-1])] = value | ||
else: | ||
d[keys[-1]] = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whole function is mysterious. It blindly creates empty dictionaries (setdefault(key, {})
), even when we might actually want lists.
On top of this, do we need this patch? What does this do? The special case for lists is confusing. Do we have tests for the added behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method can't really support list, let's not try to make up a new handling rule for them.
I don't really see what would be the use for it here anyway, since it's easy to just pass a list of strings.
Also set_nested_dict_value
has been rewritten in main...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks already very good to me, thanks @oleksost
This should go in asap!
I had a few comments and suggestions. Mostly, I think we don't want to be too picky with this PR at this point because we're going to be actively working on improving many parts of this anyway in the next weeks.
@@ -91,7 +94,7 @@ def _debug_log(self, tensor: torch.Tensor | None, name: str, kwargs: dict[str, t | |||
distributed=self._tensor_space.distributed, | |||
) | |||
|
|||
def forward( | |||
def _forward_impl( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the renaming?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we can ignore some issues with the current SSM implementation, if we handle it properly. This means clear warnings that the model is experimental and may break at any point, ex. at model runtime and/or in file headers. We still need to fix code changes outside the model though (transformers.py
and set_nested_dict_value
)
Also keep in mind that future modifications may break experiment configs, pretrained models and checkpoints, hence the importance of getting good config and parameter structures as soon as possible.
@@ -32,7 +32,6 @@ def fast_llm(args=None): | |||
sys.exit(1) | |||
except Exception: # noqa | |||
logger.critical(traceback.format_exc()) | |||
sys.exit(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would work, we do need that line outside a debugger. Same thing needed for ValidationError
above.
@@ -1,5 +1,6 @@ | |||
import logging | |||
import typing | |||
from abc import ABC, abstractmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good moment to reflect on the pattern.
Right now, Fast-LLM expects contributors to internalize a set of nuanced import rules (see https://servicenow.github.io/Fast-LLM/contributing/style-guide/#imports) that go beyond what most Python projects require. That may have worked at some point, but it doesn't scale. New contributors can't memorize this, and even returning ones keep tripping over it.
If this style is important, it needs to be enforced and automatically fixable through linting or a pre-commit hook. If it can't be, we should let go of it. Patterns that can't be learned quickly or applied automatically create friction and slow down the team.
@jlamypoirier, could you file a ticket outlining what it would take to automate this rule? That's the only sustainable way forward.
@oleksost, can you emit a warning in the logger when someone tries to instantiate the config for this model class? that should be enough for now. |
✨ Description
This PR integrates Mamba1 and discrete Mamba2 blocks into fast-llm training pypeline, this is the initial step to address #68 .
It introduces a basic hybrid architecture that can interleave transformer and mamba-1 blocks.
Next steps:
The training with a simple hybrid model can be tested:
mamba_ssm
and 'causal-conv1d' dependency,pip install mamba_ssm[causal-conv1d]==2.2.4
and the following simple config to build a hybrid model:
To load Llamba1B model, add the following to the config:
🔍 Type of change
Select all that apply:
📝 Changes
https://github.com/Zyphra/Zamba2
andhttps://github.com/state-spaces/mamba
✅ Checklist
Make sure the following tasks are completed before submitting the PR:
General
Dependencies and Configuration
Testing
Performance Impact
🗒️ Additional Notes
num_layers
, but they should probably moved to higher level configs at some point