Skip to content

Masking test #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions detokenize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import sys
from typing import List
import numpy as np
import sentencepiece as spm
from megatron.data.indexed_dataset_hacked import MMapIndexedDataset
import argparse
import json

def parse_args():
parser = argparse.ArgumentParser("Detokenizes some amount of tokens from .bin file provided.")

parser.add_argument(
"--input",
type=str,
required=True,
help="Path to bin file with or without extension."
)

parser.add_argument(
"--tokens",
type=int,
required=True,
help="Number of tokens to aim for, for detokenization."
)

parser.add_argument(
"--output",
type=str,
required=True,
help="Output location. Will output a jsonl there."
)

parser.add_argument(
"--tokenizer",
type=str,
required=True,
help="Path to sentencepiece model."
)

args = parser.parse_args()

return args

def load_indexed_dset(path_to_bin: str) -> MMapIndexedDataset:
"""Open <prefix>.idx / <prefix>.bin as a Megatron IndexedDataset."""
if path_to_bin[-4:] == ".bin":
return MMapIndexedDataset(path_to_bin[:-4])
else:
return MMapIndexedDataset(path_to_bin)


def main():
args = parse_args()

model = spm.SentencePieceProcessor(model_file=args.tokenizer)
print(f"Loaded SentencePiece model ({model.get_piece_size()} pieces) from {args.tokenizer!r}")
ds = load_indexed_dset(args.input)

collected = 0
with open(args.output, "w", encoding="utf-8") as f:
chunk = -1
chunk_step = 0.025 * args.tokens
for doc_idx in range(len(ds)):
sample = ds.get(doc_idx)
text = model.decode(sample.tolist())
obj = {"text": text}
print(json.dumps(obj, ensure_ascii=False), file = f)
collected+= len(sample)

new_chunk = collected // chunk_step
if new_chunk != chunk:
chunk = new_chunk
print(collected, "/", args.tokens, "detokenized.")

if collected >= args.tokens:
break

print("Detokenized", collected, "tokens.")


if __name__ == "__main__":
main()
36 changes: 36 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,42 @@ class NeoXArgsOther(NeoXArgsTemplate):
eod_mask_loss: bool = False
"""
Mask loss for the end of document tokens.
If set to true,
every time an EOD token is passed as an input token, the output token's loss will be masked.
"""

loss_mask_pairs: list = None
"""
A list of length 2n of pairs of token indexes.
For each pair, if a sample will contain the two tokens in order, then the model's loss will be masked where
the tokens are being generated, and all tokens between these two tokens.
Mask multiple segments if the pair repeats.
Can produce undesired behaviour if samples are sliced, the same pair is nested, overlapped, improperly opened or
improperly closed.
if both tokens are the same masked-ness simply alternates between the tokens.
Note: This feature mots likely breaks when input tokens and labels are different sequences.
"""

loss_mask_individual: list = None
"""
A list of tokens that will not be trained on.
That is to say, if label contains one of these tokens, that token's loss will be masked.
Note: This feature most likely breaks when input tokens and labels are different sequences.
"""

loss_mask_alternating_individual: list = None
"""
A list of tokens that...
Well honestly it's easier to give an example.
Let's say this is the full sequence - [1, 2, 3, 1, 4, 5, 6, 1, 1, 1, 1].
The output sequence comes out to: [2, 3, 1, 4, 5, 6, 1, 1, 1, 1].
The masked tokens will be: [F, F, F, F, F, F, T, F, T, F].
That is say, for each token, it will be true that every odd occurance (1-based) in the full sample will be
loss masked.
This is done independently for each token in the list.
Note: Will potentially produce undesirable behaviour if your unpacked samples contain odd numbers of one of these
tokens. It expects and even number of tokens.
Note: This feature most likely breaks when input tokens and labels are different sequences.
"""

adlr_autoresume: bool = False
Expand Down
19 changes: 12 additions & 7 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,13 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype, label_mask_zero=False

# Get the masks and position ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
data=tokens,
data=tokens_, # Changed to whole sequence from input sequence
eod_token=neox_args.tokenizer.eod,
eod_mask_loss=neox_args.eod_mask_loss,
sliding_window_width=neox_args.sliding_window_width,
loss_mask_pairs=neox_args.loss_mask_pairs,
loss_mask_individual=neox_args.loss_mask_individual,
loss_mask_alternating_individual=neox_args.loss_mask_alternating_individual
)

# TODO: remove later
Expand Down Expand Up @@ -529,12 +532,14 @@ def get_batch_pipe(data, neox_args, curr_scheduler=None):

def get_batch_sequential(forward_input, neox_args):
"""A modification of get_batch() to work with the latest batch instead of an iterator."""
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
data=forward_input[0],
eod_token=neox_args.tokenizer.eod,
eod_mask_loss=neox_args.eod_mask_loss,
)
return (forward_input[0], forward_input[1], attention_mask)
#attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
# data=forward_input[0],
# eod_token=neox_args.tokenizer.eod,
# eod_mask_loss=neox_args.eod_mask_loss,
#)
# Like I think it's fine to just reforward the input.
# Like for normal training and DPO I think it literally just recomputes attention the same exact way as previously, for no reason.
return (forward_input[0], forward_input[1], forward_input[2])


def average_losses_across_data_parallel_group(losses):
Expand Down
154 changes: 97 additions & 57 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,49 @@ def get_attn_mask(seq_length, device, sliding_window_width, batch_size=1):
return mask < 0.5


# noinspection PyUnreachableCode
def get_ltor_masks_and_position_ids(
data,
eod_token,
eod_mask_loss=False,
sliding_window_width=None,
loss_mask_pairs=None,
loss_mask_individual=None,
loss_mask_alternating_individual=None,
):
"""Build masks and position id for left to right model."""
"""Build masks and position id for left to right model.

data - tensor of shape [batch, seq_len + 1] (I think). Tensor of full sequences.
loss_mask_pairs: list[tuple[int, int]] - Loss will be masked where the pairs are visible in the output tokens,
as well as the pair internal tokens will also be masked. Example:
loss_mask_pairs = [[2, 3]]; data[:, 1:] = [[5,2,4,1,3,1,2,3]] -> loss_mask = [[0,1,1,1,1,0,1,1]]
That is to say if you have something like "<user> How is babby born? </user> According to science babies are..."
where you wouldn't want the model to learn to generate user questions, then you'd add the token ids of
(<user>, </user>) as a pair in this list.
Note: If you do weird nesting or overlapping, consider that undefined behaviour.

loss_mask_individual: list[int] - Loss will be masked where these tokens appear as output tokens.
That is to say - tokens that you don't wish the model to learn to generate.
loss_mask_individual=[2, 3, 5]; data[:, 1:] = [[5,2,4,1,3,1,2,3]] -> loss_mask = [[1,1,0,0,1,0,1,1]]

eod_mask_loss: Bool - Whether to mask loss on output token right after eod. (Added cause Martin wanted this)
Allegedly done just so the reported loss is not skewed by the high perplexity of the first token.

Note: During loss mask creation, assumes input sequences to transformer are data[:,:-1] and
output sequences are data[:, 1:].
That is to say, you don't have weird shenaniganry of differing input and labels.
"""

if loss_mask_pairs is None:
loss_mask_pairs = []
if loss_mask_individual is None:
loss_mask_individual = []
if loss_mask_alternating_individual is None:
loss_mask_alternating_individual = []

# Extract batch size and sequence length.
batch_size, seq_length = data.size()
seq_length = seq_length - 1

# Attention mask (lower triangular).
attention_mask = get_attn_mask(
Expand All @@ -105,12 +138,12 @@ def get_ltor_masks_and_position_ids(
batch_size=batch_size
)


in_seq = data[:, :-1] # Change full sequence to input sequence like before.
out_seq = data[:, 1:]
if os.environ.get("CURSE_ATTENTION", "0") == "1":

# Step 2: Use EOD token to zero out cross-document attention
for b in range(batch_size):
eod_positions = (data[b] == eod_token).nonzero(as_tuple=False).flatten()
eod_positions = (in_seq[b] == eod_token).nonzero(as_tuple=False).flatten()
prev_idx = 0
for i in eod_positions:
attention_mask[b, 0, (i + 1):, prev_idx:(i + 1)] = 0.0 # Remove cross-document attention
Expand All @@ -119,66 +152,73 @@ def get_ltor_masks_and_position_ids(
# convert to bool and flip
attention_mask = attention_mask < 0.5

# ---------------------------------------------------------------
# Loss-mask: 1 ⇒ keep token in loss, 0 ⇒ ignore
# ---------------------------------------------------------------
if os.environ.get("SFT_ENABLED", "0") == "1":
loss_mask = torch.ones_like(data, dtype=torch.float)
block_loss = torch.zeros_like(out_seq, dtype=torch.bool) # Boolean loss mask.

# mask instruction blocks (add as many pairs as you need)
instr_pairs = [(7, 7)]
unclosed_pair = False

tok = data # (B, L)
instr_mask = torch.zeros_like(tok, dtype=torch.bool)

for b_id, e_id in instr_pairs:
# Perform pair loss masking.
# Guess I'll just reuse and change up Martin's code for this.
if len(loss_mask_pairs) > 0:
if len(loss_mask_pairs) % 2 == 1:
raise ValueError("--loss-mask-pairs should have an even length, not size.")
for pair_start in range(0, len(loss_mask_pairs), 2):
b_id = loss_mask_pairs[pair_start]
e_id = loss_mask_pairs[pair_start + 1]
if b_id == e_id:
hits = (tok == b_id).cumsum(dim=1) # 1,2,3…
inside = (hits & 1).bool() # odd → between pair
unmatched = (hits[:, -1] & 1).bool() # open at EOS?
cs = (data == b_id).cumsum(dim=1)
pair_mask = (cs & 1).bool() # Masked where pair is not closed yet. (excluding closer token)
pair_mask[:, 1:] = pair_mask[:, 1:] | pair_mask[:, :-1] # Mask on pair closing too.
block_loss|= pair_mask[:, 1:]
# Checking whether we have a dangling unclosed pair.
if (cs[:, -1] & 1).bool().any():
print_rank_0(f"WARNING: Found unclosed pair {(b_id, e_id)} in one or more samples")
unclosed_pair = True
else:
opens = (tok == b_id).cumsum(dim=1)
closes = (tok == e_id).cumsum(dim=1)
inside = opens > closes # started but not closed
unmatched = (opens[:, -1] > closes[:, -1])

# union of: inside-region + both delimiters
instr_mask |= inside | (tok == b_id) | (tok == e_id)

# optional debug – costs ~0
if unmatched.any():
print_rank_0(
f"WARNING: {unmatched.sum().item()} sequence(s) end with an "
f"unclosed instruction (begin={b_id}, end={e_id})"
)

# # 3) NEW: mask **only** the opening tokens of ID 8
# open_id = 8
# hits8 = (tok == open_id).cumsum(dim=1) # how many 8s seen so far
# open8 = (hits8 & 1).bool() & (tok == open_id) # odd-count 8’s are “opening”
# instr_mask |= open8


loss_mask[instr_mask] = 0.0 # hide all instruction tokens

# end-of-document masking
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0

# # FIXME: remove?
# loss_mask[data == 8] = 0.0

# padding mask
loss_mask[data == 0] = 0.0
else:
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
# end-of-document masking
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
cs_open = (data == b_id).cumsum(dim=1)
cs_close = (data == e_id).cumsum(dim=1)
pair_mask = (cs_open - cs_close) > 0 # Masked where opens are not closed yet. (excluding closer token)
if pair_mask[:, -1].any():
print_rank_0(f"WARNING: Found unclosed pair {(b_id, e_id)} in one or more samples")
unclosed_pair = True
pair_mask[:, 1:] = pair_mask[:, 1:] | pair_mask[:, :-1] # Mask on pair closing too.
block_loss|= pair_mask[:, 1:]

# Perform individual token masking.
if len(loss_mask_individual) > 0:
for id in loss_mask_individual:
block_loss|= out_seq == id

# Perform post eod masking.
# Explainer - Neox code used to mask the loss on the output tokens predicted from EOD as input tokens.
# And Martin wants this.
if eod_mask_loss:
block_loss|= in_seq == eod_token

# Perform masking of the odd-numbered token recurrances.
# Martin wants this, ask him or read the neox_args docstring for the variable.
if len(loss_mask_alternating_individual) > 0:
for id in loss_mask_alternating_individual:
eq = data == id
cs = eq.cumsum(dim=1)
odd_pos = eq & (cs % 2 == 1) # Odd positions of occurance in sequence.
block_loss|= odd_pos[:, 1:]
if (cs[:, -1] % 2 == 1).any():
print_rank_0(f"WARNING: Found unclosed alternating individual token {id} in one or more samples.")
unclosed_pair = True

loss_mask = (~block_loss).to(torch.float) # Code that uses this expect floats.

# Samples or packing is malformed.
if unclosed_pair:
print_rank_0(
f"WARNING: sequence(s) contained invalid loss_mask_* tokens. " +
"Either your samples are too long, you're using the wrong packing method, or your loss_mask_* arguments " +
"are incorrectly configured."
)

# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
position_ids = position_ids.unsqueeze(0).expand_as(out_seq)

return attention_mask, loss_mask, position_ids

Expand Down
Loading
Loading