Skip to content

Commit

Permalink
subclass generation mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
markrogersjr committed Mar 7, 2024
1 parent 240a343 commit 49e6513
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion benchmarks/benchmark_generation_mamba_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

if is_mamba:
fn = lambda: model.generate(
input_ids=input_ids,
inputs=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/models/mixer_seq_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.modules.mamba_simple import Mamba, Block
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.generation import MambaGenerationMixin as GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf

try:
Expand Down
8 changes: 4 additions & 4 deletions mamba_ssm/utils/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from einops import rearrange, repeat
from torch import Tensor
from torch.profiler import ProfilerActivity, profile, record_function
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer, GenerationMixin


@dataclass
Expand Down Expand Up @@ -241,13 +241,13 @@ def should_stop(current_token, inference_params):
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))


class GenerationMixin:
class MambaGenerationMixin(GenerationMixin):
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
raise NotImplementedError

def generate(
self,
input_ids,
inputs,
max_length,
top_k=1,
top_p=0.0,
Expand All @@ -258,7 +258,7 @@ def generate(
**kwargs,
):
output = decode(
input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
inputs, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
)
if not output_scores:
output.scores = None
Expand Down

0 comments on commit 49e6513

Please sign in to comment.