-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy path__init__.py
More file actions
129 lines (110 loc) · 3.44 KB
/
__init__.py
File metadata and controls
129 lines (110 loc) · 3.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
"""
Render Module
This module provides rendering capabilities for GNN specifications to various
target languages and simulation environments.
"""
# Phase 6: render submodules are in-tree; unconditional imports.
from .activeinference_jl import render_gnn_to_activeinference_jl
from .discopy import render_gnn_to_discopy
from .generators import (
generate_activeinference_jl_code,
generate_discopy_code,
generate_pymdp_code,
generate_rxinfer_code,
)
from .numpyro import render_gnn_to_numpyro
from .pomdp_processor import POMDPRenderProcessor, process_pomdp_for_frameworks
from .processor import (
get_available_renderers,
get_module_info,
process_render,
render_gnn_spec,
)
from .pymdp import render_gnn_to_pymdp
from .pymdp.pymdp_renderer import PyMDPRenderer
from .pytorch import render_gnn_to_pytorch
from .rxinfer import render_gnn_to_rxinfer, render_gnn_to_rxinfer_toml
class JAXRenderer:
"""Facade over ``render_gnn_to_jax`` exposed as a class for callers that
want polymorphic dispatch. The real rendering work is in
``render/jax/jax_renderer.py`` — this class forwards ``render`` to it."""
def render(self, spec) -> str:
from .jax.jax_renderer import render_gnn_to_jax
result = render_gnn_to_jax(spec)
return result if isinstance(result, str) else str(result)
def get_supported_frameworks():
"""Return list of supported rendering frameworks.
Returns:
List of framework names that can be used for rendering.
"""
return [
"pymdp",
"rxinfer",
"activeinference_jl",
"jax",
"discopy",
"pytorch",
"numpyro",
]
def validate_render(result, framework=None):
"""Validate render output.
Args:
result: The render result to validate.
framework: Optional framework name for framework-specific validation.
Returns:
True if validation passes.
Raises:
ValueError: If validation fails.
"""
if result is None:
raise ValueError("Render result is None")
if isinstance(result, str) and len(result) == 0:
raise ValueError("Render result is empty string")
return True
__all__ = [
# Core functions
"process_render",
"render_gnn_spec",
"get_module_info",
"get_available_renderers",
# Generator functions
"generate_pymdp_code",
"generate_rxinfer_code",
"generate_activeinference_jl_code",
"generate_discopy_code",
# Specific renderer functions
"render_gnn_to_pymdp",
"render_gnn_to_rxinfer",
"render_gnn_to_rxinfer_toml",
"render_gnn_to_discopy",
"render_gnn_to_activeinference_jl",
"render_gnn_to_pytorch",
"render_gnn_to_numpyro",
# Renderer classes
"PyMDPRenderer",
"JAXRenderer",
# POMDP processing
"POMDPRenderProcessor",
"process_pomdp_for_frameworks",
# Utility functions
"get_supported_frameworks",
"validate_render",
]
__version__ = "1.6.0"
FEATURES = {
"pymdp_rendering": True,
"rxinfer_rendering": True,
"activeinference_jl_rendering": True,
"discopy_rendering": True,
"jax_rendering": True,
"pytorch_rendering": True,
"numpyro_rendering": True,
"mcp_integration": True,
"pomdp_processing": True,
"state_space_extraction": True,
"modular_injection": True,
"framework_specific_outputs": True,
"structured_documentation": True,
}
from .render import main # expose CLI entry as attribute for tests