-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
131 lines (109 loc) · 4.54 KB
/
Copy patheval.py
File metadata and controls
131 lines (109 loc) · 4.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import argparse
import json
import math
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
import torch
from nano_transformer import Tokenizer, TransformerLM
from nano_transformer.training import Trainer, TrainerConfig, load_checkpoint, prepare_pretraining_data
def default_device() -> str:
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Evaluate a trained Nano-Transformer checkpoint.")
parser.add_argument("--run-dir", type=Path, required=True)
parser.add_argument("--checkpoint", type=Path)
parser.add_argument("--device", type=str, default=default_device())
parser.add_argument("--eval-iters", type=int)
return parser.parse_args()
def load_run_config(run_dir: Path) -> dict:
with (run_dir / "run_config.json").open("r", encoding="utf-8") as f:
return json.load(f)
def bits_per_byte(loss_nats: float, num_tokens: int, num_bytes: int) -> float:
if num_tokens <= 0 or num_bytes <= 0:
return float("nan")
bits_per_token = loss_nats / math.log(2)
return bits_per_token * (num_tokens / num_bytes)
def count_evaluated_text_units(
text_path: str | Path,
tokenizer: Tokenizer,
*,
max_document_length: int | None,
) -> tuple[int, int]:
text_path = Path(text_path)
document_eos_token_id = None
if tokenizer.special_tokens:
document_eos_token_id = tokenizer.bytes_to_token_id.get(tokenizer.special_tokens[0].encode("utf-8"))
raw_token_budget = max_document_length
if raw_token_budget is not None and document_eos_token_id is not None:
raw_token_budget = max(raw_token_budget - 1, 0)
num_tokens = 0
num_bytes = 0
with text_path.open("r", encoding="utf-8") as f:
for line in f:
token_ids = tokenizer.encode(line)
if raw_token_budget is not None:
token_ids = token_ids[:raw_token_budget]
num_tokens += len(token_ids)
if token_ids:
num_bytes += len(tokenizer.decode(token_ids).encode("utf-8"))
return num_tokens, num_bytes
def main() -> None:
args = parse_args()
run_config = load_run_config(args.run_dir)
checkpoint = args.checkpoint or args.run_dir / "latest.pt"
use_document_ids = run_config.get("train_document_ids_path") is not None
prepared = prepare_pretraining_data(
train_text_path=run_config["train_text_path"],
valid_text_path=run_config["valid_text_path"],
vocab_path=run_config["vocab_path"],
merges_path=run_config["merges_path"],
train_cache_path=run_config["train_cache_path"],
valid_cache_path=run_config["valid_cache_path"],
train_document_ids_path=run_config.get("train_document_ids_path"),
valid_document_ids_path=run_config.get("valid_document_ids_path"),
special_tokens=run_config["special_tokens"],
force_rebuild=False,
return_document_ids=use_document_ids,
max_document_length=run_config.get("max_document_length"),
)
train_document_ids = None
valid_document_ids = None
if use_document_ids:
train_dataset, valid_dataset, tokenizer, train_document_ids, valid_document_ids = prepared
else:
train_dataset, valid_dataset, tokenizer = prepared
model = TransformerLM.from_config(run_config["model_config"])
trainer_config_dict = dict(run_config["trainer_config"])
trainer_config_dict["device"] = args.device
trainer_config_dict["compile_model"] = False
trainer_config_dict["write_config"] = False
if args.eval_iters is not None:
trainer_config_dict["eval_iters"] = args.eval_iters
trainer = Trainer(
model,
train_dataset,
valid_dataset,
TrainerConfig(**trainer_config_dict),
train_document_ids=train_document_ids,
valid_document_ids=valid_document_ids,
)
trainer.iteration = load_checkpoint(checkpoint, trainer.raw_model, optimizer=None)
metrics = trainer.evaluate()
valid_tokens, valid_bytes = count_evaluated_text_units(
run_config["valid_text_path"],
tokenizer,
max_document_length=run_config.get("max_document_length"),
)
valid_bpb = bits_per_byte(metrics["valid_ce_loss"], valid_tokens, valid_bytes)
print(json.dumps(metrics, indent=2))
print(f"valid_bits_per_byte={valid_bpb:.4f}")
if __name__ == "__main__":
main()