Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions bytelatent/data/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from enum import Enum

import torch
from pydantic import BaseModel
from torch.nn import functional as F

from bytelatent.distributed import get_local_rank
from bytelatent.entropy_model import load_entropy_model

# from src.slurm import get_local_rank
from bytelatent.tokenizers.blt_tokenizer import BPE_ID, OFFSET
from bytelatent.tokenizers.constants import BPE_ID, OFFSET
from pydantic import BaseModel
from torch.nn import functional as F


class PatchingModeEnum(str, Enum):
Expand Down Expand Up @@ -77,34 +77,40 @@ def calculate_entropies(

grad_context = nullcontext() if enable_grad else torch.no_grad()

patching_batch_size = min(patching_batch_size, tokens.shape[0])

with grad_context:
entropies = []
preds = []
max_length = getattr(entropy_model, "max_length", 8192)
batch_numel = max_length * patching_batch_size
splits = torch.split(tokens.flatten(), batch_numel)
for split in splits:
pad_size = (max_length - (split.numel() % max_length)) % max_length

batch_numel = min(max_length, tokens.shape[0]) * patching_batch_size
for i in range(math.ceil(tokens.shape[0] / patching_batch_size)):
start = i * patching_batch_size
end = (i + 1) * patching_batch_size
tokens_batch = tokens[start:end]

pad_size = (max_length - (tokens_batch.shape[-1] % max_length)) % max_length
pad = torch.zeros(
pad_size, dtype=split.dtype, device=split.device, requires_grad=False
(patching_batch_size, pad_size),
dtype=tokens.dtype,
device=tokens.device,
requires_grad=False,
)
split = torch.cat((split, pad), dim=0)
split = split.reshape(-1, max_length)
if device is not None:
split = split.to(device)
# assert torch.all(split >= 0) and torch.all(split < 260)
pred = entropy_model(split)
pred = pred.reshape(-1, pred.shape[-1])[
: split.numel() - pad_size, :
] # [batch_size * seq_len, vocab]

tokens_batch = torch.cat((tokens_batch, pad), dim=1)
assert tokens_batch.shape[-1] == max_length

pred = entropy_model(tokens_batch)
pred = pred[
:, : tokens_batch.shape[-1] - pad_size, :
] # [batch_size, seq_len, vocab]
preds.append(pred)
pred_entropies = entropy(pred)
entropies.append(pred_entropies)

concat_entropies = torch.cat(entropies, dim=0)
concat_entropies = concat_entropies.reshape(tokens.shape)
concat_preds = torch.cat(preds, dim=0)
concat_preds = concat_preds.reshape(tokens.shape[0], -1)
return concat_entropies, concat_preds


Expand Down Expand Up @@ -568,6 +574,7 @@ def patch(
self.patching_batch_size,
self.device,
)

if self.log_time:
self.log["calculate_entropies"] += time.time() - s
s = time.time()
Expand Down