Skip to content

[feat] Hybrid Mamba model with Mamba and discrete Mamba 2 layers #194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 44 commits into
base: main
Choose a base branch
from

Conversation

oleksost
Copy link
Contributor

@oleksost oleksost commented Mar 20, 2025

✨ Description

This PR integrates Mamba1 and discrete Mamba2 blocks into fast-llm training pypeline, this is the initial step to address #68 .
It introduces a basic hybrid architecture that can interleave transformer and mamba-1 blocks.

Next steps:

The training with a simple hybrid model can be tested:

  1. Install mamba_ssm and 'causal-conv1d' dependency, pip install mamba_ssm[causal-conv1d]==2.2.4
    1. launch training by passing
        "args": [
                    "train",
                    "hybrid_ssm",
                    "--config",
                    "path/to/hybrid_config.yaml"
                ],
    

and the following simple config to build a hybrid model:

model:
  base_model:
    transformer:
      num_layers: 6
      use_flash_attention: no  
    ssm:
      dt_rank: auto
      state_size: 16
      expansion_factor: 2
      debug_ssm: false
    block_pattern: ["m", "t", "m", "m2", "m", "m"] # mixing transformer, mamba 1 and descrete mamba layers
  
  distributed:
    training_dtype: bf16
    tensor_parallel: 1 
    pipeline_parallel: 1
    world_size: 1 

training:
  train_iters: 1000  
  logs:
    interval: 10
  validation:
    iterations: 25
    interval: 1000
  wandb:  
    project_name: fast-llm-ssm-test
    group_name: ssm
    entity_name: null

data:
  datasets:
    Training:
      type: memmap
      path: /home/toolkit/dev/fast-llm-tutorial/dataset/shard_0_0
    Validation:
      type: memmap
      path: /home/toolkit/dev/fast-llm-tutorial/dataset/shard_0_0

To load Llamba1B model, add the following to the config:

pretrained:
  format: llamba
  path: /mnt/checkpoints/pretrained_models/Llamba-1B

🔍 Type of change

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

📝 Changes

  • added minimal mamba-1 layer and block
  • added hybrid model and corresponding configs
  • the implementation follows the one from https://github.com/Zyphra/Zamba2 and https://github.com/state-spaces/mamba

✅ Checklist

Make sure the following tasks are completed before submitting the PR:

General

  • 📜 I have read and followed the contributing guidelines.
  • 🏷️ I am using a clear and descriptive PR title that summarizes the key change or feature introduced.
  • 🎉 The functionality is complete, and I have tested the changes.
  • 📝 I have updated the documentation if needed.
  • ⚠️ The change does not introduce any new issues (e.g., runtime warnings, type checker errors, linting problems, unhandled edge cases).
  • 🧩 I have commented my code, especially in hard-to-understand areas.

Dependencies and Configuration

  • 🐋 I have updated the Docker configuration or dependencies, if applicable.
  • 🔄 I have ensured compatibility with the existing setup after dependency changes.

Testing

  • 🧪 I have added or updated tests to cover my changes.
  • ✔️ New and existing tests pass locally with my changes.
  • 🚦 I have tested these changes on GPUs and verified training stability.
  • 🏋️ I have tested the changes on realistic training workloads, if applicable.

Performance Impact

  • 📊 I have run benchmarks where applicable to evaluate the performance impact.
  • ✅ The benchmarks show no performance regression.
  • 🚀 The benchmarks indicate a potential performance improvement.
  • ⚠️ The benchmarks indicate a potential performance degradation.
  • 📈 I have provided benchmark results and detailed any performance impact below, if applicable.

🗒️ Additional Notes

  • currently some parameters that are used for defining hybrid model's architecture are in the transformer config, e.g. num_layers, but they should probably moved to higher level configs at some point

@oleksost oleksost marked this pull request as draft March 20, 2025 02:47
@oleksost oleksost changed the title Mamba 1 blocks [feat] Mamba 1 blocks Mar 20, 2025
@tscholak tscholak added the enhancement New feature or request label Mar 23, 2025
@oleksost oleksost changed the title [feat] Mamba 1 blocks [feat] Hybrid Mamba-1 model Mar 31, 2025
@oleksost oleksost requested a review from tscholak March 31, 2025 13:06
@oleksost oleksost requested a review from jlamypoirier March 31, 2025 17:22
@tscholak tscholak marked this pull request as ready for review March 31, 2025 20:01
@oleksost oleksost changed the title [feat] Hybrid Mamba-1 model [feat] Hybrid Mamba model with mamba and discrete mamba 2 layers Mar 31, 2025
This was linked to issues Apr 7, 2025
hint=FieldHint.core,
)

default_block: str = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant with block_pattern.

Copy link
Contributor Author

@oleksost oleksost Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, this is needed to load Llamba model: we set default_block to m2 in _create_config_converters of LLambaHuggingfaceCheckpointHandler and the block_pattern is then created in the __post_init__ of HybridSSMBaseModelConfig. This is a bit cumbersome indeed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't it set block_pattern? Also this would need to go in _validate, not __post_init__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaiu _create_config_converters does not know about the num_layers in the loaded config, so we cannot set the block pattern as it depends on the number of layers.

Moved the post_init logic for block config to validate

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh so block_pattern is not specifying a repeated pattern, but the entire list? Why not just repeating the list up to num_layers instead, as the name suggests?

Copy link
Contributor Author

@oleksost oleksost Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed block_pattern into hybrid_block_layout and use default_block repeated 'num_layers' times in case 'hybrid_block_layout' is not specified.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, but why not a single variable with automated repetition?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oleksost, try merging these but use this default factory:

    hybrid_block_layout: list[str] = Field(
        default_factory=lambda: ['m'],
        desc="Pattern of blocks to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for Descrete Mamba2.",
        hint=FieldHint.core,
    )

that avoids the mutable default trap.

@oleksost oleksost requested a review from jlamypoirier April 8, 2025 15:42
hint=FieldHint.core,
)

default_block: str = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh so block_pattern is not specifying a repeated pattern, but the entire list? Why not just repeating the list up to num_layers instead, as the name suggests?

@oleksost oleksost requested a review from jlamypoirier April 11, 2025 15:00


@config_class()
class SSMArchitectureConfig(BaseModelArchitectureConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please adjust field names for our naming conventions.

hint=FieldHint.core,
)

dt_rank: str | int = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use None for derived defaults. dt_rank: int = Field(default=None, ...

strict: bool = True,
flat: bool = False,
) -> typing.Self:
if "hybrid_block_layout" in default and isinstance(default["hybrid_block_layout"], dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would it be a dict? There must be another problem elsewhere.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have tests that check if serializing/deserializing lists with strings works correctly? the special property of strings is that they are also lists, kind-of. there could be issues because of that, and they may be not captured by tests if we only test integers of ints, say.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oleksost I see a patch to set_nested_dict_value below. My hunch is that this isn't working properly

)
elif block_type == "m2":
# Create Mamba2 descrete block
mixer_cls = partial(DiscreteMamba2, layer_idx=i)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed, you're already passing layer_idx to the mixer_cls() call

return init_


class MambaLayer(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work with TP, need to explicitly prevent. (Not sure about PP).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're gonna have TP eventually, not in scope for this PR though

@tscholak
Copy link
Collaborator

@jlamypoirier, naming-wise, "layer" implies a single functional unit, like attention or an MLP. But here, the repeated unit is a composite: It includes a mixer (attention, SSM, etc.), normalization, residuals, and post-mixer processing (MLPs, MoEs, etc.). This structure isn't atomic. It's multiple layers stitched into a reusable computation unit, which is typically referred to as a "block" in other model families (e.g., ResNet, Swin, and even “Transformer blocks” in many papers and blogs).
Now that we're supporting architectures beyond transformers (Mamba, etc.), the term “block” avoids misleading assumptions:

  1. "Layer" strongly implies a fixed layout rooted in the transformer design.
  2. "Block" reflects the actual structure: a swappable, composite computation unit.

Keeping the name BaseBlock lets us consistently subclass it for Mamba, attention, etc., without implying that all of them are "layers" in the same architectural tradition. In this sense, "TransformerLayer" can be a Block instance, but not every Block is a TransformerLayer.

Also, and I find this the most important argument: internally and casually, we already call them blocks. Making the code match mental models reduces friction.

@jlamypoirier
Copy link
Collaborator

@tscholak Sounds reasonable, problem s these things are called "layers" everywhere else in Fast-LLM. Should we think about renaming these too?

@@ -32,7 +32,6 @@ def fast_llm(args=None):
sys.exit(1)
except Exception: # noqa
logger.critical(traceback.format_exc())
sys.exit(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we need this line to be removed for the debugger to work. You could do this instead:

    except Exception:  # noqa
        if sys.gettrace():
            raise
        logger.critical(traceback.format_exc())
        sys.exit(1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would work, we do need that line outside a debugger. Same thing needed for ValidationError above.

else:
d[int(keys[-1])] = value
else:
d[keys[-1]] = value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole function is mysterious. It blindly creates empty dictionaries (setdefault(key, {})), even when we might actually want lists.
On top of this, do we need this patch? What does this do? The special case for lists is confusing. Do we have tests for the added behavior?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method can't really support list, let's not try to make up a new handling rule for them.
I don't really see what would be the use for it here anyway, since it's easy to just pass a list of strings.
Also set_nested_dict_value has been rewritten in main...

Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks already very good to me, thanks @oleksost
This should go in asap!
I had a few comments and suggestions. Mostly, I think we don't want to be too picky with this PR at this point because we're going to be actively working on improving many parts of this anyway in the next weeks.

@@ -91,7 +94,7 @@ def _debug_log(self, tensor: torch.Tensor | None, name: str, kwargs: dict[str, t
distributed=self._tensor_space.distributed,
)

def forward(
def _forward_impl(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the renaming?

Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can ignore some issues with the current SSM implementation, if we handle it properly. This means clear warnings that the model is experimental and may break at any point, ex. at model runtime and/or in file headers. We still need to fix code changes outside the model though (transformers.py and set_nested_dict_value)

Also keep in mind that future modifications may break experiment configs, pretrained models and checkpoints, hence the importance of getting good config and parameter structures as soon as possible.

@@ -32,7 +32,6 @@ def fast_llm(args=None):
sys.exit(1)
except Exception: # noqa
logger.critical(traceback.format_exc())
sys.exit(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would work, we do need that line outside a debugger. Same thing needed for ValidationError above.

@@ -1,5 +1,6 @@
import logging
import typing
from abc import ABC, abstractmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good moment to reflect on the pattern.

Right now, Fast-LLM expects contributors to internalize a set of nuanced import rules (see https://servicenow.github.io/Fast-LLM/contributing/style-guide/#imports) that go beyond what most Python projects require. That may have worked at some point, but it doesn't scale. New contributors can't memorize this, and even returning ones keep tripping over it.

If this style is important, it needs to be enforced and automatically fixable through linting or a pre-commit hook. If it can't be, we should let go of it. Patterns that can't be learned quickly or applied automatically create friction and slow down the team.

@jlamypoirier, could you file a ticket outlining what it would take to automate this rule? That's the only sustainable way forward.

@tscholak
Copy link
Collaborator

This means clear warnings that the model is experimental and may break at any point, ex. at model runtime and/or in file headers.

@oleksost, can you emit a warning in the logger when someone tries to instantiate the config for this model class? that should be enough for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Llamba support [feat] Support Mamba 2 blocks
3 participants