-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathparallelization.py
116 lines (103 loc) · 2.16 KB
/
parallelization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import asyncio
import os
from coagent.agents import (
Aggregator,
AggregationResult,
ChatAgent,
ChatMessage,
ModelClient,
Parallel,
)
from coagent.core import AgentSpec, new, set_stderr_logger
from coagent.runtimes import LocalRuntime
client = ModelClient(
model=os.getenv("MODEL_NAME"),
api_base=os.getenv("MODEL_API_BASE"),
api_version=os.getenv("MODEL_API_VERSION"),
api_key=os.getenv("MODEL_API_KEY"),
)
customer = AgentSpec(
"customer",
new(
ChatAgent,
system="""\
Customers:
- Price sensitive
- Want better tech
- Environmental concerns\
""",
client=client,
),
)
employee = AgentSpec(
"employee",
new(
ChatAgent,
system="""\
Employees:
- Job security worries
- Need new skills
- Want clear direction\
""",
client=client,
),
)
investor = AgentSpec(
"investor",
new(
ChatAgent,
system="""\
Investors:
- Expect growth
- Want cost control
- Risk concerns\
""",
client=client,
),
)
supplier = AgentSpec(
"supplier",
new(
ChatAgent,
system="""\
Suppliers:
- Capacity constraints
- Price pressures
- Tech transitions\
""",
client=client,
),
)
aggregator = AgentSpec("aggregator", new(Aggregator))
parallel = AgentSpec(
"parallel",
new(
Parallel,
"customer",
"employee",
"investor",
"supplier",
aggregator="aggregator",
),
)
async def main():
async with LocalRuntime() as runtime:
for spec in [customer, employee, investor, supplier, aggregator, parallel]:
await runtime.register(spec)
result = await parallel.run(
ChatMessage(
role="user",
content="""\
Analyze how market changes will impact this stakeholder group.
Provide specific impacts and recommended actions.
Format with clear sections and priorities.\
""",
).encode()
)
msg = AggregationResult.decode(result)
for result in msg.results:
x = ChatMessage.decode(result)
print(x.content)
if __name__ == "__main__":
set_stderr_logger()
asyncio.run(main())