Skip to content

Add/156 dynamic cache resize #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Compiled python modules.
*.pyc
venv

# Byte-compiled
_pycache__/
Expand Down
52 changes: 52 additions & 0 deletions gemma/gm/text/_chat_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,27 @@
from gemma.gm.text import _tokenizer
# from gemma.gm.vision import _token_utils
from kauldron.typing import PRNGKeyLike, UInt8 # pylint: disable=g-multiple-import,g-importing-member
import jax.numpy as jnp

def resize_tensor(cache_tensor, new_length):
"""
Resize a cache tensor from shape (1, old_length, 1, 256) to (1, new_length, 1, 256).
If new_length > old_length, pad with zeros.
If new_length < old_length, truncate.
"""
cache_tensor = jnp.asarray(cache_tensor) # Ensure it's a JAX array
old_length = cache_tensor.shape[1]

if new_length > old_length:
# Padding needed (pad along the second dimension)
pad_width = [(0, 0), (0, new_length - old_length), (0, 0), (0, 0)]
return jnp.pad(cache_tensor, pad_width, mode='constant', constant_values=0)

elif new_length < old_length:
# Truncate (keep the first `new_length` values)
return cache_tensor[:, :new_length, :, :]

return cache_tensor # If sizes are the same, return as is


@dataclasses.dataclass(frozen=True, kw_only=True, eq=False)
Expand Down Expand Up @@ -119,6 +140,37 @@ def sampler(self) -> _sampler.Sampler:
forbidden_tokens=self.forbidden_tokens,
cache_length=self.cache_length,
)
def resize_cache(self, new_length: int):
"""Update the cache length dynamically"""
object.__setattr__(self, 'cache_length', new_length)

# Update the last_state with the new cache size if last_state is not None
if self.last_state is not None:
# Initialize a new cache with the updated size
updated_cache = self.model.init_cache(
batch_size=len(self.last_state.cache),
dtype=self._dtype,
cache_length=new_length,
)

# Resize the existing cache data and copy it into the new cache
for layer in self.last_state.cache.keys():
updated_cache[layer]["v"] = resize_tensor(self.last_state.cache[layer]["v"], new_length)
updated_cache[layer]["k"] = resize_tensor(self.last_state.cache[layer]["k"], new_length)

# Create a new SamplingState with the updated cache
updated_last_state = _sampler_call.SamplingState(
cache=updated_cache,
predicted_tokens=self.last_state.predicted_tokens,
used_cache_length=min(self.last_state.used_cache_length, new_length),
# Add any other fields that are part of SamplingState
)
object.__setattr__(self, 'last_state', updated_last_state)

# Invalidate the cached sampler so it gets recomputed
if 'sampler' in self.__dict__:
del self.__dict__['sampler']


def chat(
self,
Expand Down