Skip to content

Commit

Permalink
Improve generation with some useful features
Browse files Browse the repository at this point in the history
  • Loading branch information
LuisLechugaRuiz committed Dec 11, 2023
1 parent 2ee7fd2 commit 553f882
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 14 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*__pycache__/
*.egg-info/
build/
**.so
20 changes: 20 additions & 0 deletions mamba_ssm/models/config_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
class MambaConfig:
def __init__(
self,
d_model: 2560,
n_layer: 64,
vocab_size: 50277,
ssm_cfg: {},
rms_norm: True,
residual_in_fp32: True,
fused_add_norm: True,
pad_vocab_size_multiple: 8
):
self.d_model = d_model
self.n_layer = n_layer
self.vocab_size = vocab_size
self.ssm_cfg = ssm_cfg
self.rms_norm = rms_norm
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.pad_vocab_size_multiple = pad_vocab_size_multiple
47 changes: 39 additions & 8 deletions mamba_ssm/models/mixer_seq_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

import math
from functools import partial
import json
import os

from collections import namedtuple

import torch
import torch.nn as nn

from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.modules.mamba_simple import Mamba, Block
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
Expand Down Expand Up @@ -174,25 +177,34 @@ class MambaLMHeadModel(nn.Module, GenerationMixin):

def __init__(
self,
d_model: int,
n_layer: int,
vocab_size: int,
config: MambaConfig,
initializer_cfg=None,
pad_vocab_size_multiple: int = 1,
device=None,
dtype=None,
**backbone_kwargs,
) -> None:
self.config = config
d_model = config.d_model
n_layer = config.n_layer
vocab_size = config.vocab_size
ssm_cfg = config.ssm_cfg
rms_norm = config.rms_norm
residual_in_fp32 = config.residual_in_fp32
fused_add_norm = config.fused_add_norm
pad_vocab_size_multiple = config.pad_vocab_size_multiple
factory_kwargs = {"device": device, "dtype": dtype}

super().__init__()
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
self.backbone = MixerModel(
d_model=d_model,
n_layer=n_layer,
vocab_size=vocab_size,
ssm_cfg=ssm_cfg,
rms_norm=rms_norm,
initializer_cfg=initializer_cfg,
**backbone_kwargs,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
**factory_kwargs,
)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
Expand Down Expand Up @@ -227,7 +239,26 @@ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_

@classmethod
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
config = load_config_hf(pretrained_model_name)
model = cls(**config, device=device, dtype=dtype, **kwargs)
config_data = load_config_hf(pretrained_model_name)
config = MambaConfig(**config_data)
model = cls(config, device=device, dtype=dtype, **kwargs)
model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
return model

def save_pretrained(self, save_directory):
"""
Minimal implementation of save_pretrained for MambaLMHeadModel.
Save the model and its configuration file to a directory.
"""
# Ensure save_directory exists
if not os.path.exists(save_directory):
os.makedirs(save_directory)

# Save the model's state_dict
model_path = os.path.join(save_directory, 'pytorch_model.bin')
torch.save(self.state_dict(), model_path)

# Save the configuration of the model
config_path = os.path.join(save_directory, 'config.json')
with open(config_path, 'w') as f:
json.dump(self.config.__dict__, f)
8 changes: 4 additions & 4 deletions mamba_ssm/modules/mamba_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ def forward(self, hidden_states, inference_params=None):
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn(
x,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
x=x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
)

# We're careful here about the layout, to avoid extra transposes.
Expand Down
13 changes: 11 additions & 2 deletions mamba_ssm/utils/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from einops import rearrange, repeat
from torch import Tensor
from torch.profiler import ProfilerActivity, profile, record_function
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer


@dataclass
Expand Down Expand Up @@ -103,6 +103,7 @@ def decode(
tensor_parallel=1,
cg=False,
enable_timing=False,
streamer: Optional[TextStreamer] = None
):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Expand All @@ -119,6 +120,9 @@ def decode(
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
if streamer is not None:
streamer.put(input_ids.cpu())

batch_size, seqlen_og = input_ids.shape
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
if cg:
Expand Down Expand Up @@ -189,7 +193,12 @@ def should_stop(current_token, inference_params):
while not should_stop(sequences[-1], inference_params):
scores.append(get_logits(sequences[-1], inference_params))
inference_params.seqlen_offset += sequences[-1].shape[1]
sequences.append(sample_tokens(scores[-1], inference_params))
sampled_tokens = sample_tokens(scores[-1], inference_params)
sequences.append(sampled_tokens)
if streamer is not None:
streamer.put(sampled_tokens.cpu())
if streamer is not None:
streamer.end()
if enable_timing:
end.record()
if tensor_parallel > 1:
Expand Down

0 comments on commit 553f882

Please sign in to comment.