-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace pygraphviz with neo4j-viz for graph visualization #306
base: main
Are you sure you want to change the base?
Changes from all commits
b2971a9
9a92182
08248bf
7cb41d3
a07c206
7b8ef44
8aa4684
ba2a78c
7db25a8
1cb21f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -23,9 +23,16 @@ | |||||
from neo4j_graphrag.utils.logging import prettify | ||||||
|
||||||
try: | ||||||
import pygraphviz as pgv | ||||||
from neo4j_viz import ( | ||||||
Node, | ||||||
Relationship, | ||||||
VisualizationGraph as NeoVizGraph, | ||||||
CaptionAlignment, | ||||||
) | ||||||
|
||||||
HAS_NEO4J_VIZ = True | ||||||
except ImportError: | ||||||
pgv = None | ||||||
HAS_NEO4J_VIZ = False | ||||||
|
||||||
from pydantic import BaseModel | ||||||
|
||||||
|
@@ -182,40 +189,100 @@ def show_as_dict(self) -> dict[str, Any]: | |||||
return pipeline_config.model_dump() | ||||||
|
||||||
def draw( | ||||||
self, path: str, layout: str = "dot", hide_unused_outputs: bool = True | ||||||
self, path: str, layout: str = "force", hide_unused_outputs: bool = True | ||||||
) -> Any: | ||||||
G = self.get_pygraphviz_graph(hide_unused_outputs) | ||||||
G.layout(layout) | ||||||
G.draw(path) | ||||||
"""Draw the pipeline graph using neo4j-viz. | ||||||
|
||||||
def get_pygraphviz_graph(self, hide_unused_outputs: bool = True) -> pgv.AGraph: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't remember why we haven't made this method private, but since we are removing it we should mention it in the changelog. |
||||||
if pgv is None: | ||||||
Args: | ||||||
path (str): Path to save the visualization. If the path ends with .html, it will save an HTML file. | ||||||
Otherwise, it will save a PNG image. | ||||||
layout (str): Layout algorithm to use. Default is "force". | ||||||
hide_unused_outputs (bool): Whether to hide unused outputs. Default is True. | ||||||
|
||||||
Returns: | ||||||
Any: The visualization object. | ||||||
""" | ||||||
G = self.get_neo4j_viz_graph(hide_unused_outputs) | ||||||
if path.endswith(".html"): | ||||||
# Save as HTML file | ||||||
with open(path, "w") as f: | ||||||
f.write(G.render()._repr_html_()) | ||||||
else: | ||||||
# For other formats, we'll use the render method and save the image | ||||||
G.render() | ||||||
# Note: neo4j-viz doesn't support direct saving to image formats | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment saying that it can't save to PNG? Because code in the 'else' looks very similar to the one in the former 'if' block to me. |
||||||
# If image format is needed, consider using a screenshot or other methods | ||||||
with open(path, "w") as f: | ||||||
f.write(G.render()._repr_html_()) | ||||||
|
||||||
def get_neo4j_viz_graph(self, hide_unused_outputs: bool = True) -> NeoVizGraph: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And maybe make this method private? |
||||||
"""Create a neo4j-viz visualization graph from the pipeline. | ||||||
|
||||||
Args: | ||||||
hide_unused_outputs (bool): Whether to hide unused outputs. Default is True. | ||||||
|
||||||
Returns: | ||||||
NeoVizGraph: The neo4j-viz visualization graph. | ||||||
""" | ||||||
if not HAS_NEO4J_VIZ: | ||||||
raise ImportError( | ||||||
"Could not import pygraphviz. " | ||||||
"Follow installation instruction in pygraphviz documentation " | ||||||
"to get it up and running on your system." | ||||||
"Could not import neo4j-viz. " | ||||||
"Install it with 'pip install neo4j-viz'." | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think using extra groups instead reduces a bit dependency issues:
Suggested change
|
||||||
) | ||||||
self.validate_parameter_mapping() | ||||||
G = pgv.AGraph(strict=False, directed=True) | ||||||
# create a node for each component | ||||||
|
||||||
nodes = [] | ||||||
relationships = [] | ||||||
node_ids = {} | ||||||
node_counter = 0 | ||||||
|
||||||
# Create nodes for each component | ||||||
for n, node in self._nodes.items(): | ||||||
comp_inputs = ",".join( | ||||||
f"{i}: {d['annotation']}" | ||||||
for i, d in node.component.component_inputs.items() | ||||||
) | ||||||
G.add_node( | ||||||
n, | ||||||
node_type="component", | ||||||
shape="rectangle", | ||||||
label=f"{node.component.__class__.__name__}: {n}({comp_inputs})", | ||||||
node_ids[n] = node_counter | ||||||
nodes.append( | ||||||
Node( | ||||||
id=node_counter, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to the doc, it looks like |
||||||
caption=f"{node.component.__class__.__name__}: {n}({comp_inputs})", | ||||||
size=20, # Component nodes are larger | ||||||
color="#4C8BF5", # Blue for component nodes | ||||||
caption_alignment=CaptionAlignment.CENTER, | ||||||
caption_size=3, | ||||||
) | ||||||
) | ||||||
# create a node for each output field and connect them it to its component | ||||||
node_counter += 1 | ||||||
|
||||||
# Create nodes for each output field | ||||||
for o in node.component.component_outputs: | ||||||
param_node_name = f"{n}.{o}" | ||||||
G.add_node(param_node_name, label=o, node_type="output") | ||||||
G.add_edge(n, param_node_name) | ||||||
# then we create the edges between a component output | ||||||
# and the component it gets added to | ||||||
node_ids[param_node_name] = node_counter | ||||||
nodes.append( | ||||||
Node( | ||||||
id=node_counter, | ||||||
caption=o, | ||||||
size=10, # Output nodes are smaller | ||||||
color="#34A853", # Green for output nodes | ||||||
caption_alignment=CaptionAlignment.CENTER, | ||||||
caption_size=3, | ||||||
) | ||||||
) | ||||||
# Connect component to its output | ||||||
relationships.append( | ||||||
Relationship( | ||||||
source=node_ids[n], | ||||||
target=node_ids[param_node_name], | ||||||
caption="", | ||||||
caption_align=CaptionAlignment.CENTER, | ||||||
caption_size=10, | ||||||
color="#000000", | ||||||
) | ||||||
) | ||||||
node_counter += 1 | ||||||
|
||||||
# Create edges between components and their inputs | ||||||
for component_name, params in self.param_mapping.items(): | ||||||
for param, mapping in params.items(): | ||||||
source_component = mapping["component"] | ||||||
|
@@ -224,13 +291,30 @@ def get_pygraphviz_graph(self, hide_unused_outputs: bool = True) -> pgv.AGraph: | |||||
source_output_node = f"{source_component}.{source_param_name}" | ||||||
else: | ||||||
source_output_node = source_component | ||||||
G.add_edge(source_output_node, component_name, label=param) | ||||||
# remove outputs that are not mapped | ||||||
|
||||||
if source_output_node in node_ids and component_name in node_ids: | ||||||
relationships.append( | ||||||
Relationship( | ||||||
source=node_ids[source_output_node], | ||||||
target=node_ids[component_name], | ||||||
caption=param, | ||||||
color="#EA4335", # Red for parameter connections | ||||||
caption_align=CaptionAlignment.CENTER, | ||||||
caption_size=10, | ||||||
) | ||||||
) | ||||||
|
||||||
# Filter unused outputs if requested | ||||||
if hide_unused_outputs: | ||||||
for n in G.nodes(): | ||||||
if n.attr["node_type"] == "output" and G.out_degree(n) == 0: # type: ignore | ||||||
G.remove_node(n) | ||||||
return G | ||||||
used_nodes = set() | ||||||
for rel in relationships: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there is a logic issue here: we do not want to filter out isolated nodes (this should not happen I think?), but node of type "output" with no outgoing relationship. |
||||||
used_nodes.add(rel.source) | ||||||
used_nodes.add(rel.target) | ||||||
|
||||||
filtered_nodes = [node for node in nodes if node.id in used_nodes] | ||||||
return NeoVizGraph(nodes=filtered_nodes, relationships=relationships) | ||||||
|
||||||
return NeoVizGraph(nodes=nodes, relationships=relationships) | ||||||
|
||||||
def add_component(self, component: Component, name: str) -> None: | ||||||
"""Add a new component. Components are uniquely identified | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -379,39 +379,48 @@ async def test_pipeline_async() -> None: | |
assert pipeline_result[1].result == {"add": {"result": 12}} | ||
|
||
|
||
def test_pipeline_to_pgv() -> None: | ||
def test_pipeline_to_neo4j_viz() -> None: | ||
pipe = Pipeline() | ||
component_a = ComponentAdd() | ||
component_b = ComponentMultiply() | ||
pipe.add_component(component_a, "a") | ||
pipe.add_component(component_b, "b") | ||
pipe.connect("a", "b", {"number1": "a.result"}) | ||
g = pipe.get_pygraphviz_graph() | ||
# 3 nodes: | ||
g = pipe.get_neo4j_viz_graph() | ||
# 4 nodes: | ||
# - 2 components 'a' and 'b' | ||
# - 1 output 'a.result' | ||
assert len(g.nodes()) == 3 | ||
g = pipe.get_pygraphviz_graph(hide_unused_outputs=False) | ||
# - 2 outputs 'a.result' and 'b.result' (neo4j-viz implementation includes both) | ||
assert len(g.nodes) == 4 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Related to my comment above about logic, the test should not need to be updated. |
||
|
||
# Count component nodes | ||
component_nodes = [node for node in g.nodes if node.size == 20] | ||
assert len(component_nodes) == 2 | ||
|
||
# Count output nodes | ||
output_nodes = [node for node in g.nodes if node.size == 10] | ||
assert len(output_nodes) == 2 | ||
|
||
g = pipe.get_neo4j_viz_graph(hide_unused_outputs=False) | ||
# 4 nodes: | ||
# - 2 components 'a' and 'b' | ||
# - 2 output 'a.result' and 'b.result' | ||
assert len(g.nodes()) == 4 | ||
# - 2 outputs 'a.result' and 'b.result' | ||
assert len(g.nodes) == 4 | ||
|
||
|
||
def test_pipeline_draw() -> None: | ||
pipe = Pipeline() | ||
pipe.add_component(ComponentAdd(), "add") | ||
t = tempfile.NamedTemporaryFile() | ||
t = tempfile.NamedTemporaryFile(suffix=".html") | ||
pipe.draw(t.name) | ||
content = t.file.read() | ||
assert len(content) > 0 | ||
|
||
|
||
@patch("neo4j_graphrag.experimental.pipeline.pipeline.pgv", None) | ||
def test_pipeline_draw_missing_pygraphviz_dep() -> None: | ||
@patch("neo4j_graphrag.experimental.pipeline.pipeline.HAS_NEO4J_VIZ", False) | ||
def test_pipeline_draw_missing_neo4j_viz_dep() -> None: | ||
pipe = Pipeline() | ||
pipe.add_component(ComponentAdd(), "add") | ||
t = tempfile.NamedTemporaryFile() | ||
t = tempfile.NamedTemporaryFile(suffix=".html") | ||
with pytest.raises(ImportError): | ||
pipe.draw(t.name) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this better compared to the previous approach? (out of curiosity)