Skip to content

Commit 377056a

Browse files
committed
feat: graph visualization
In order to ease flow_graph inspection, we added a visualization tool. It accepts a job config as an input and it generates a visual graph. In order to save it as PNG, user must specify folder path and the file name will be {job_id}.png
1 parent e6d28c8 commit 377056a

File tree

2 files changed

+192
-0
lines changed

2 files changed

+192
-0
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ dependencies = [
1616
"datasets >= 2.14.6",
1717
"fastapi >= 0.109.0",
1818
"grpcio >= 1.60.0",
19+
"matplotlib>=3.7",
1920
"multiworld >= 0.2.2",
21+
"networkx>=3.2",
2022
"protobuf >= 4.25.1",
2123
"psutil >= 5.9.5",
2224
"pydantic >= 2.4.2",
@@ -42,6 +44,9 @@ dev = [
4244
"python-lsp-server",
4345
]
4446

47+
[tool.setuptools.packages.find]
48+
exclude = ["tests*", "examples*", "integration-tests*", "tools*"]
49+
4550
# [tool.setuptools]
4651
# packages = ["infscale"]
4752

tools/visualise_flow_graph.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import argparse
2+
import os
3+
4+
import matplotlib.pyplot as plt
5+
import networkx as nx
6+
import yaml
7+
8+
from infscale.configs.job import JobConfig
9+
10+
11+
def load_job_config(path: str) -> JobConfig:
12+
with open(path) as f:
13+
data = yaml.safe_load(f)
14+
15+
return JobConfig(**data)
16+
17+
18+
def build_graph(job: JobConfig) -> tuple[nx.DiGraph, dict[str, int]]:
19+
"""Build graph using job configuration."""
20+
graph = nx.DiGraph()
21+
22+
# Add all worker nodes
23+
worker_stage = {w.id: w.stage["start"] for w in job.workers}
24+
25+
for wid in job.flow_graph:
26+
graph.add_node(wid, stage=worker_stage.get(wid, 0))
27+
28+
# Add edges (peer -> wid) with labels from worlds
29+
for wid, worlds in job.flow_graph.items():
30+
for world in worlds:
31+
for peer in world.peers:
32+
label = f"{world.name}"
33+
if world.addr:
34+
label += f"\n{world.addr}"
35+
if world.backend:
36+
label += f"\n{world.backend}"
37+
graph.add_edge(peer, wid, label=label)
38+
39+
return graph, worker_stage
40+
41+
42+
def draw_graph(
43+
graph: nx.DiGraph,
44+
worker_stage: dict[str, int],
45+
job_id: str,
46+
output_path: str = "",
47+
) -> None:
48+
"""Draw graph where worker_stage maps node -> stage (start)."""
49+
# build positions (horizontal by stage, vertical stacked)
50+
stage_to_nodes = {}
51+
for node, stage in worker_stage.items():
52+
stage_to_nodes.setdefault(stage, []).append(node)
53+
54+
sorted_stages = sorted(stage_to_nodes.keys())
55+
pos = {}
56+
x_spacing = 4
57+
y_spacing = 2
58+
59+
for x, stage in enumerate(sorted_stages):
60+
nodes = sorted(stage_to_nodes[stage]) # sort for stable layout
61+
for y, node in enumerate(nodes):
62+
pos[node] = (x * x_spacing, -y * y_spacing)
63+
64+
plt.figure(figsize=(12, 7))
65+
ax = plt.gca()
66+
67+
nx.draw_networkx_nodes(
68+
graph, pos, node_size=2000, node_color="#5bf4a7", edgecolors="black", ax=ax
69+
)
70+
nx.draw_networkx_labels(graph, pos, font_size=10, font_weight="bold", ax=ax)
71+
72+
# get edge labels from graph attributes (we created them earlier as "label")
73+
edge_labels = nx.get_edge_attributes(graph, "label")
74+
75+
# For each edge artist, get its path and compute the point at t along the curve.
76+
edge_artists = []
77+
for src, dst in graph.edges():
78+
# Determine arc direction and size
79+
src_stage = worker_stage.get(src, 0)
80+
dst_stage = worker_stage.get(dst, 0)
81+
_, y = pos[src]
82+
83+
if worker_stage[dst] == -1 and y < 0:
84+
# edge connection to server on rows below the first
85+
rad = -abs(0.58)
86+
elif worker_stage[dst] == -1 and y == 0:
87+
# edge connection to server on first row
88+
rad = abs(0.2)
89+
else:
90+
# edge connection between sibling nodes
91+
rad = 0.05 if dst_stage >= src_stage else -0.05
92+
93+
# draw edge
94+
artist = nx.draw_networkx_edges(
95+
graph,
96+
pos,
97+
edgelist=[(src, dst)],
98+
arrows=True,
99+
arrowstyle="-|>",
100+
arrowsize=20,
101+
width=1.5,
102+
connectionstyle=f"arc3,rad={rad}",
103+
min_source_margin=25,
104+
min_target_margin=25,
105+
ax=ax,
106+
)
107+
edge_artists.extend(artist)
108+
109+
# label placement (same as before)
110+
label = edge_labels.get((src, dst))
111+
if not label:
112+
continue
113+
114+
try:
115+
path = artist[0].get_path()
116+
verts = path.vertices
117+
if len(verts) >= 3:
118+
P0, P1, P2 = verts[0], verts[1], verts[2]
119+
t = 0.3
120+
one_minus_t = 1 - t
121+
x = (
122+
(one_minus_t**2) * P0[0]
123+
+ 2 * one_minus_t * t * P1[0]
124+
+ (t**2) * P2[0]
125+
)
126+
y = (
127+
(one_minus_t**2) * P0[1]
128+
+ 2 * one_minus_t * t * P1[1]
129+
+ (t**2) * P2[1]
130+
)
131+
else:
132+
raise ValueError
133+
except Exception:
134+
x1, y1 = pos[src]
135+
x2, y2 = pos[dst]
136+
t = 0.3
137+
x, y = x1 + (x2 - x1) * t, y1 + (y2 - y1) * t
138+
139+
ax.text(
140+
x,
141+
y,
142+
label,
143+
fontsize=8,
144+
ha="center",
145+
va="center",
146+
rotation=0,
147+
bbox=dict(boxstyle="round,pad=0.2", fc="white", alpha=0.9),
148+
zorder=10,
149+
)
150+
151+
ax.set_axis_off()
152+
plt.tight_layout()
153+
if output_path:
154+
os.makedirs(output_path, exist_ok=True)
155+
output_file = os.path.join(output_path, f"{job_id}.png")
156+
plt.savefig(output_file, dpi=300, bbox_inches="tight")
157+
print(f"Graph saved at: {output_file}")
158+
else:
159+
print("Graph opened in a new window.")
160+
plt.show()
161+
162+
163+
def main():
164+
parser = argparse.ArgumentParser(description="Visualize JobConfig flow graph")
165+
parser.add_argument("config_path", help="Path to job YAML config")
166+
parser.add_argument(
167+
"-o", "--output", help="Directory to save output image (optional)", default=None
168+
)
169+
args = parser.parse_args()
170+
171+
try:
172+
config = load_job_config(args.config_path)
173+
except FileNotFoundError as e:
174+
print(f"Error while loading file: {e}")
175+
return
176+
177+
graph, worker_stage = build_graph(config)
178+
try:
179+
draw_graph(graph, worker_stage, config.job_id, args.output)
180+
except nx.exception.NetworkXError as e:
181+
print(f"Error while drawing graph: {e}")
182+
except KeyboardInterrupt:
183+
pass
184+
185+
186+
if __name__ == "__main__":
187+
main()

0 commit comments

Comments
 (0)