-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathbenchmark.py
More file actions
128 lines (101 loc) · 3.75 KB
/
Copy pathbenchmark.py
File metadata and controls
128 lines (101 loc) · 3.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# MODEL_URI="Overworld/Waypoint-1.5-1B" uv run --dev pytest examples/benchmark.py
import os
import pytest
import torch
import random
from world_engine import WorldEngine, CtrlInput
MODEL_URI = os.environ.get("MODEL_URI", "Overworld/Waypoint-1-Small")
def version_with_commit(pkg):
import json
from importlib.metadata import distribution
dist = distribution(pkg.__name__.split('.')[0])
version = dist.version
try:
data = dist.read_text("direct_url.json")
commit = (data and json.loads(data).get("vcs_info", {}).get("commit_id"))
except (FileNotFoundError, json.JSONDecodeError, TypeError):
commit = None
return f"{version} @ {commit[:7]}" if commit else version
@pytest.fixture(scope="session", autouse=True)
def print_env_info():
import platform
import world_engine as world_engine_pkg
print(
"\n=== Environment ===\n"
f"torch: {torch.__version__}\n"
f"torch.cuda: {torch.version.cuda}\n"
f"world_engine: {version_with_commit(world_engine_pkg)}\n\n"
"=== Hardware ===\n"
f"OS: {platform.system()} {platform.release()} ({platform.machine()})\n"
f"CPU: {platform.processor() or 'unknown'}"
)
if torch.cuda.is_available():
idx = torch.cuda.current_device()
props = torch.cuda.get_device_properties(idx)
print(
f"GPU: {props.name}\n"
f" capability {props.major}.{props.minor}\n"
f" total memory: {props.total_memory / 1e9:.1f} GB"
)
else:
print("GPU: none (CUDA not available)")
def get_warm_engine(model_uri, model_overrides=None):
model_config_overrides = {}
model_config_overrides.update(model_overrides or {})
engine = WorldEngine(
model_uri,
model_config_overrides=model_config_overrides,
quant=None,
device="cuda",
load_weights=False
)
# global warmup
for _ in range(3):
engine.gen_frame()
return engine
@pytest.fixture(scope="session")
def engine():
return get_warm_engine(MODEL_URI)
@pytest.fixture(scope="session")
def last_latent(engine):
return engine.gen_frame(return_img=False).detach()
def test_img_decoder_only(benchmark, engine, last_latent):
def run():
with torch.amp.autocast("cuda", torch.bfloat16):
engine.vae.decode(last_latent)
torch.cuda.synchronize()
benchmark(run)
MODEL_OVERRIDES = [None]
@pytest.mark.parametrize("blocking", [False])
@pytest.mark.parametrize("dit_only", [True])
@pytest.mark.parametrize("n_frames", [256])
@pytest.mark.parametrize(
"model_overrides", MODEL_OVERRIDES,
ids=lambda d: (",".join(f"{k}={v}" for k, v in d.items()) or "") if d else ""
)
def test_ar_rollout(benchmark, dit_only, n_frames, model_overrides, blocking):
engine = get_warm_engine(MODEL_URI, model_overrides=model_overrides)
try:
total_params = sum(p.numel() for p in engine.model.parameters())
active_params = int(engine.model.get_active_parameters())
benchmark.name = f"{benchmark.name} | params={total_params:,} | active={active_params:,}"
except Exception:
pass
def setup():
engine.reset()
engine.gen_frame(return_img=not dit_only)
torch.cuda.synchronize()
def target():
ctrls = [
CtrlInput(
button=set(random.sample(range(1, 65), random.randint(0, 10))),
mouse=(random.random(), random.random()),
scroll_wheel=random.choice((-1, 0, 1))
)
for _ in range(n_frames)
]
for ctrl in ctrls:
engine.gen_frame(return_img=not dit_only)
if blocking:
torch.cuda.synchronize()
benchmark.pedantic(target, setup=setup, rounds=20)