diff --git a/.gitignore b/.gitignore index f0a15d2..b8fa4b7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Compiled python modules. *.pyc +venv # Byte-compiled _pycache__/ diff --git a/gemma/gm/text/_chat_sampler.py b/gemma/gm/text/_chat_sampler.py index bed39b2..4a42db9 100644 --- a/gemma/gm/text/_chat_sampler.py +++ b/gemma/gm/text/_chat_sampler.py @@ -28,6 +28,27 @@ from gemma.gm.text import _tokenizer # from gemma.gm.vision import _token_utils from kauldron.typing import PRNGKeyLike, UInt8 # pylint: disable=g-multiple-import,g-importing-member +import jax.numpy as jnp + +def resize_tensor(cache_tensor, new_length): + """ + Resize a cache tensor from shape (1, old_length, 1, 256) to (1, new_length, 1, 256). + If new_length > old_length, pad with zeros. + If new_length < old_length, truncate. + """ + cache_tensor = jnp.asarray(cache_tensor) # Ensure it's a JAX array + old_length = cache_tensor.shape[1] + + if new_length > old_length: + # Padding needed (pad along the second dimension) + pad_width = [(0, 0), (0, new_length - old_length), (0, 0), (0, 0)] + return jnp.pad(cache_tensor, pad_width, mode='constant', constant_values=0) + + elif new_length < old_length: + # Truncate (keep the first `new_length` values) + return cache_tensor[:, :new_length, :, :] + + return cache_tensor # If sizes are the same, return as is @dataclasses.dataclass(frozen=True, kw_only=True, eq=False) @@ -119,6 +140,37 @@ def sampler(self) -> _sampler.Sampler: forbidden_tokens=self.forbidden_tokens, cache_length=self.cache_length, ) + def resize_cache(self, new_length: int): + """Update the cache length dynamically""" + object.__setattr__(self, 'cache_length', new_length) + + # Update the last_state with the new cache size if last_state is not None + if self.last_state is not None: + # Initialize a new cache with the updated size + updated_cache = self.model.init_cache( + batch_size=len(self.last_state.cache), + dtype=self._dtype, + cache_length=new_length, + ) + + # Resize the existing cache data and copy it into the new cache + for layer in self.last_state.cache.keys(): + updated_cache[layer]["v"] = resize_tensor(self.last_state.cache[layer]["v"], new_length) + updated_cache[layer]["k"] = resize_tensor(self.last_state.cache[layer]["k"], new_length) + + # Create a new SamplingState with the updated cache + updated_last_state = _sampler_call.SamplingState( + cache=updated_cache, + predicted_tokens=self.last_state.predicted_tokens, + used_cache_length=min(self.last_state.used_cache_length, new_length), + # Add any other fields that are part of SamplingState + ) + object.__setattr__(self, 'last_state', updated_last_state) + + # Invalidate the cached sampler so it gets recomputed + if 'sampler' in self.__dict__: + del self.__dict__['sampler'] + def chat( self,