Skip to content

Commit 43ce844

Browse files
committed
feat: graph visualization
In order to ease flow_graph inspection, we added a visualization tool. It accepts a job config as an input and it generates a visual graph. In order to save it as PNG, user must specify folder path and the file name will be {job_id}.png
1 parent e6d28c8 commit 43ce844

File tree

2 files changed

+199
-0
lines changed

2 files changed

+199
-0
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ dependencies = [
1616
"datasets >= 2.14.6",
1717
"fastapi >= 0.109.0",
1818
"grpcio >= 1.60.0",
19+
"matplotlib>=3.7",
1920
"multiworld >= 0.2.2",
21+
"networkx>=3.2",
2022
"protobuf >= 4.25.1",
2123
"psutil >= 5.9.5",
2224
"pydantic >= 2.4.2",
@@ -42,6 +44,9 @@ dev = [
4244
"python-lsp-server",
4345
]
4446

47+
[tool.setuptools.packages.find]
48+
exclude = ["tests*", "examples*", "integration-tests*", "tools*"]
49+
4550
# [tool.setuptools]
4651
# packages = ["infscale"]
4752

tools/visualise_flow_graph.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import argparse
2+
import os
3+
4+
import matplotlib.pyplot as plt
5+
import networkx as nx
6+
import yaml
7+
8+
from infscale.configs.job import JobConfig
9+
10+
11+
def load_job_config(path: str) -> JobConfig:
12+
with open(path) as f:
13+
data = yaml.safe_load(f)
14+
15+
return JobConfig(**data)
16+
17+
18+
def build_graph(job: JobConfig) -> tuple[nx.DiGraph, dict[str, int]]:
19+
"""Build graph using job configuration."""
20+
graph = nx.DiGraph()
21+
22+
# Add all worker nodes
23+
worker_stage = {w.id: w.stage["start"] for w in job.workers}
24+
25+
for wid in job.flow_graph:
26+
graph.add_node(wid, stage=worker_stage.get(wid, 0))
27+
28+
# Add edges (peer -> wid) with labels from worlds
29+
for wid, worlds in job.flow_graph.items():
30+
for world in worlds:
31+
for peer in world.peers:
32+
label = f"{world.name}"
33+
if world.addr:
34+
label += f"\n{world.addr}"
35+
if world.backend:
36+
label += f"\n{world.backend}"
37+
graph.add_edge(peer, wid, label=label)
38+
39+
return graph, worker_stage
40+
41+
42+
def draw_graph(
43+
graph: nx.DiGraph,
44+
worker_stage: dict[str, int],
45+
job_id: str,
46+
output_path: str = "",
47+
) -> None:
48+
"""Draw graph where worker_stage maps node -> stage (start)."""
49+
# build positions (horizontal by stage, vertical stacked)
50+
stage_to_nodes = {}
51+
for node, stage in worker_stage.items():
52+
stage_to_nodes.setdefault(stage, []).append(node)
53+
54+
sorted_stages = sorted(stage_to_nodes.keys())
55+
pos = {}
56+
x_spacing = 4
57+
y_spacing = 2
58+
59+
for x, stage in enumerate(sorted_stages):
60+
nodes = sorted(stage_to_nodes[stage]) # sort for stable layout
61+
for y, node in enumerate(nodes):
62+
pos[node] = (x * x_spacing, -y * y_spacing)
63+
64+
plt.figure(figsize=(12, 7))
65+
ax = plt.gca()
66+
67+
nx.draw_networkx_nodes(
68+
graph, pos, node_size=2000, node_color="#5bf4a7", edgecolors="black", ax=ax
69+
)
70+
nx.draw_networkx_labels(graph, pos, font_size=10, font_weight="bold", ax=ax)
71+
72+
# draw edges and capture artists. Use explicit edgelist to ensure order.
73+
edgelist = list(graph.edges())
74+
edge_artists = list(
75+
nx.draw_networkx_edges(
76+
graph,
77+
pos,
78+
edgelist=edgelist,
79+
arrows=True,
80+
arrowstyle="-|>",
81+
arrowsize=20,
82+
width=1.5,
83+
connectionstyle=f"arc3,rad={0.15}",
84+
min_source_margin=25,
85+
min_target_margin=25,
86+
ax=ax,
87+
)
88+
)
89+
90+
# get edge labels from graph attributes (we created them earlier as "label")
91+
edge_labels = nx.get_edge_attributes(graph, "label")
92+
93+
# For each edge artist, get its path and compute the point at t along the curve.
94+
for (src, dst), artist in zip(edgelist, edge_artists):
95+
label = edge_labels.get((src, dst), "")
96+
if not label:
97+
continue
98+
99+
# try to extract the actual path used by the artist.
100+
# for FancyArrowPatch, get_path() returns a Path containing
101+
# MOVETO and CURVE3 (quadratic bezier) segments for arc3.
102+
try:
103+
path = artist.get_path()
104+
verts = path.vertices
105+
except Exception:
106+
verts = None
107+
108+
label_x, label_y = None, None
109+
110+
if verts is not None and len(verts) >= 3:
111+
# Common pattern for quad bezier: [P0, P1, P2] or path with MOVETO + CURVE3 + CURVE3...
112+
# We'll attempt to find a quadratic bezier triplet P0, P1, P2.
113+
# Find first MOVETO index
114+
try:
115+
# Find indices for MOVETO and first CURVE3 sequence
116+
# Fallback: take the first three vertices
117+
P0 = verts[0]
118+
P1 = verts[1]
119+
P2 = verts[2]
120+
# If the path contains more points, using the first three typically represents the curve.
121+
# Evaluate quadratic Bezier at t
122+
t = float(0.2)
123+
one_minus_t = 1 - t
124+
bx = (
125+
(one_minus_t**2) * P0[0]
126+
+ 2 * one_minus_t * t * P1[0]
127+
+ (t**2) * P2[0]
128+
)
129+
by = (
130+
(one_minus_t**2) * P0[1]
131+
+ 2 * one_minus_t * t * P1[1]
132+
+ (t**2) * P2[1]
133+
)
134+
label_x, label_y = bx, by
135+
except Exception:
136+
label_x, label_y = None, None
137+
138+
if label_x is None:
139+
# Fallback: straight-line t% between source and target positions
140+
x1, y1 = pos[src]
141+
x2, y2 = pos[dst]
142+
label_x = x1 + (x2 - x1) * 0.2
143+
label_y = y1 + (y2 - y1) * 0.2
144+
145+
# Draw the label at computed position
146+
ax.text(
147+
label_x,
148+
label_y,
149+
label,
150+
fontsize=8,
151+
ha="center",
152+
va="center",
153+
rotation=0,
154+
bbox=dict(boxstyle="round,pad=0.2", fc="white", alpha=0.9),
155+
zorder=10,
156+
)
157+
158+
ax.set_axis_off()
159+
plt.tight_layout()
160+
if output_path:
161+
os.makedirs(output_path, exist_ok=True)
162+
output_file = os.path.join(output_path, f"{job_id}.png")
163+
plt.savefig(output_file, dpi=300, bbox_inches="tight")
164+
print(f"Graph saved at: {output_file}")
165+
else:
166+
print("Graph opened in a new window.")
167+
plt.show()
168+
169+
170+
def main():
171+
parser = argparse.ArgumentParser(description="Visualize JobConfig flow graph")
172+
parser.add_argument("config_path", help="Path to job YAML config")
173+
parser.add_argument(
174+
"-o", "--output", help="Directory to save output image (optional)", default=None
175+
)
176+
args = parser.parse_args()
177+
178+
try:
179+
config = load_job_config(args.config_path)
180+
except FileNotFoundError as e:
181+
print(f"Error while loading file: {e}")
182+
return
183+
184+
graph, worker_stage = build_graph(config)
185+
try:
186+
draw_graph(graph, worker_stage, config.job_id, args.output)
187+
except nx.exception.NetworkXError as e:
188+
print(f"Error while drawing graph: {e}")
189+
except KeyboardInterrupt:
190+
pass
191+
192+
193+
if __name__ == "__main__":
194+
main()

0 commit comments

Comments
 (0)