diff --git a/days/w2d3/gpt_sol.py b/days/w2d3/gpt_sol.py index 6d59b4a59..c26ec8015 100644 --- a/days/w2d3/gpt_sol.py +++ b/days/w2d3/gpt_sol.py @@ -150,7 +150,7 @@ def next_token(self, input_ids, temperature, freq_penalty=2.0): return torch.distributions.categorical.Categorical(logits=logits).sample() def generate(self, text, max_length=30, temperature=1.0, freq_penalty=2.0): - self.empty_cache() + self.clear_cache() input_ids = self.tokenizer(text).input_ids generated = [] for i in range(max_length):