diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index 93be370..26331af 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -7,8 +7,6 @@ from enum import Enum import torch -from pydantic import BaseModel -from torch.nn import functional as F from bytelatent.distributed import get_local_rank from bytelatent.entropy_model import load_entropy_model @@ -16,6 +14,8 @@ # from src.slurm import get_local_rank from bytelatent.tokenizers.blt_tokenizer import BPE_ID, OFFSET from bytelatent.tokenizers.constants import BPE_ID, OFFSET +from pydantic import BaseModel +from torch.nn import functional as F class PatchingModeEnum(str, Enum): @@ -77,34 +77,40 @@ def calculate_entropies( grad_context = nullcontext() if enable_grad else torch.no_grad() + patching_batch_size = min(patching_batch_size, tokens.shape[0]) + with grad_context: entropies = [] preds = [] max_length = getattr(entropy_model, "max_length", 8192) - batch_numel = max_length * patching_batch_size - splits = torch.split(tokens.flatten(), batch_numel) - for split in splits: - pad_size = (max_length - (split.numel() % max_length)) % max_length + + batch_numel = min(max_length, tokens.shape[0]) * patching_batch_size + for i in range(math.ceil(tokens.shape[0] / patching_batch_size)): + start = i * patching_batch_size + end = (i + 1) * patching_batch_size + tokens_batch = tokens[start:end] + + pad_size = (max_length - (tokens_batch.shape[-1] % max_length)) % max_length pad = torch.zeros( - pad_size, dtype=split.dtype, device=split.device, requires_grad=False + (patching_batch_size, pad_size), + dtype=tokens.dtype, + device=tokens.device, + requires_grad=False, ) - split = torch.cat((split, pad), dim=0) - split = split.reshape(-1, max_length) - if device is not None: - split = split.to(device) - # assert torch.all(split >= 0) and torch.all(split < 260) - pred = entropy_model(split) - pred = pred.reshape(-1, pred.shape[-1])[ - : split.numel() - pad_size, : - ] # [batch_size * seq_len, vocab] + + tokens_batch = torch.cat((tokens_batch, pad), dim=1) + assert tokens_batch.shape[-1] == max_length + + pred = entropy_model(tokens_batch) + pred = pred[ + :, : tokens_batch.shape[-1] - pad_size, : + ] # [batch_size, seq_len, vocab] preds.append(pred) pred_entropies = entropy(pred) entropies.append(pred_entropies) concat_entropies = torch.cat(entropies, dim=0) - concat_entropies = concat_entropies.reshape(tokens.shape) concat_preds = torch.cat(preds, dim=0) - concat_preds = concat_preds.reshape(tokens.shape[0], -1) return concat_entropies, concat_preds @@ -568,6 +574,7 @@ def patch( self.patching_batch_size, self.device, ) + if self.log_time: self.log["calculate_entropies"] += time.time() - s s = time.time()