Skip to content

Commit

Permalink
subclass from transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
markrogersjr committed Mar 7, 2024
1 parent 34076d6 commit 240a343
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
4 changes: 3 additions & 1 deletion mamba_ssm/models/config_mamba.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from dataclasses import dataclass, field

from transformers import PretrainedConfig


@dataclass
class MambaConfig:
class MambaConfig(PretrainedConfig):

d_model: int = 2560
n_layer: int = 64
Expand Down
8 changes: 5 additions & 3 deletions mamba_ssm/models/mixer_seq_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput

from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.modules.mamba_simple import Mamba, Block
Expand Down Expand Up @@ -173,7 +175,7 @@ def forward(self, input_ids, inference_params=None):
return hidden_states


class MambaLMHeadModel(nn.Module, GenerationMixin):
class MambaLMHeadModel(PreTrainedModel, GenerationMixin):

def __init__(
self,
Expand All @@ -193,7 +195,8 @@ def __init__(
pad_vocab_size_multiple = config.pad_vocab_size_multiple
factory_kwargs = {"device": device, "dtype": dtype}

super().__init__()
PreTrainedModel.__init__(self, config)
GenerationMixin.__init__(self)
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
self.backbone = MixerModel(
Expand Down Expand Up @@ -235,7 +238,6 @@ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)

@classmethod
Expand Down

0 comments on commit 240a343

Please sign in to comment.