Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update for local inference demo for LS 0.1 #163

Merged
merged 11 commits into from
Feb 6, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct ContentView: View {
public init () {
self.inference = LocalInference(queue: runnerQueue)
self.localAgents = LocalAgents(inference: self.inference)
self.remoteAgents = RemoteAgents(url: URL(string: "http://localhost:5000")!)
self.remoteAgents = RemoteAgents(url: URL(string: "https://llama-stack.together.ai")!)
}

var agents: Agents {
Expand Down Expand Up @@ -130,39 +130,39 @@ struct ContentView: View {
func summarizeConversation(prompt: String) async {
do {
let request = Components.Schemas.CreateAgentTurnRequest(
agent_id: self.agentId,
messages: [
.UserMessage(Components.Schemas.UserMessage(
content: .case1("Summarize the following conversation in 1-2 sentences:\n\n \(prompt)"),
role: .user
))
],
session_id: self.agenticSystemSessionId,
stream: true
)

for try await chunk in try await self.agents.createTurn(request: request) {
for try await chunk in try await self.agents.createTurn(agent_id: self.agentId, session_id: self.agenticSystemSessionId, request: request) {
let payload = chunk.event.payload
switch (payload) {
case .AgentTurnResponseStepStartPayload(_):
case .step_start(_):
break
case .AgentTurnResponseStepProgressPayload(let step):
if (step.model_response_text_delta != nil) {
case .step_progress(let step):
if (step.delta != nil) {
DispatchQueue.main.async {
withAnimation {
var message = messages.removeLast()
message.text += step.model_response_text_delta!
if case .text(let delta) = step.delta {
message.text += "\(delta.text)"
}
message.tokenCount += 2
message.dateUpdated = Date()
messages.append(message)
}
}
}
case .AgentTurnResponseStepCompletePayload(_):
case .step_complete(_):
break
case .AgentTurnResponseTurnStartPayload(_):
case .turn_start(_):
break
case .AgentTurnResponseTurnCompletePayload(_):
case .turn_complete(_):
break

}
Expand All @@ -175,103 +175,100 @@ struct ContentView: View {

func actionItems(prompt: String) async throws {
let request = Components.Schemas.CreateAgentTurnRequest(
agent_id: self.agentId,
messages: [
.UserMessage(Components.Schemas.UserMessage(
content: .case1("List out any action items based on this text:\n\n \(prompt)"),
role: .user
))
],
session_id: self.agenticSystemSessionId,
stream: true
)

for try await chunk in try await self.agents.createTurn(request: request) {
for try await chunk in try await self.agents.createTurn(agent_id: self.agentId, session_id: self.agenticSystemSessionId, request: request) {
let payload = chunk.event.payload
switch (payload) {
case .AgentTurnResponseStepStartPayload(_):
case .step_start(_):
break
case .AgentTurnResponseStepProgressPayload(let step):
if (step.model_response_text_delta != nil) {
DispatchQueue.main.async {
withAnimation {
var message = messages.removeLast()
message.text += step.model_response_text_delta!
message.tokenCount += 2
message.dateUpdated = Date()
messages.append(message)

self.actionItems += step.model_response_text_delta!
case .step_progress(let step):
DispatchQueue.main.async(execute: DispatchWorkItem {
withAnimation {
var message = messages.removeLast()

if case .text(let delta) = step.delta {
message.text += "\(delta.text)"
self.actionItems += "\(delta.text)"
}
message.tokenCount += 2
message.dateUpdated = Date()
messages.append(message)
}
}
case .AgentTurnResponseStepCompletePayload(_):
})
case .step_complete(_):
break
case .AgentTurnResponseTurnStartPayload(_):
case .turn_start(_):
break
case .AgentTurnResponseTurnCompletePayload(_):
case .turn_complete(_):
break
}
}
}

func callTools(prompt: String) async throws {
let request = Components.Schemas.CreateAgentTurnRequest(
agent_id: self.agentId,
messages: [
.UserMessage(Components.Schemas.UserMessage(
content: .case1("Call functions as needed to handle any actions in the following text:\n\n" + prompt),
role: .user
))
],
session_id: self.agenticSystemSessionId,
stream: true
)

for try await chunk in try await self.agents.createTurn(request: request) {
for try await chunk in try await self.agents.createTurn(agent_id: self.agentId, session_id: self.agenticSystemSessionId, request: request) {
let payload = chunk.event.payload
switch (payload) {
case .AgentTurnResponseStepStartPayload(_):
case .step_start(_):
break
case .AgentTurnResponseStepProgressPayload(let step):
if (step.tool_call_delta != nil) {
switch (step.tool_call_delta!.content) {
case .case1(_):
break
case .ToolCall(let call):
switch (call.tool_name) {
case .BuiltinTool(_):
break
case .case2(let toolName):
if (toolName == "create_event") {
var args: [String : String] = [:]
for (arg_name, arg) in call.arguments.additionalProperties {
switch (arg) {
case .case1(let s): // type string
args[arg_name] = s
case .case2(_), .case3(_), .case4(_), .case5(_), .case6(_):
break
case .step_progress(let step):
switch (step.delta) {
case .tool_call(let call):
if call.parse_status == .succeeded {
switch (call.tool_call) {
case .ToolCall(let toolCall):
var args: [String : String] = [:]
for (arg_name, arg) in toolCall.arguments.additionalProperties {
switch (arg) {
case .case1(let s):
args[arg_name] = s
case .case2(_), .case3(_), .case4(_), .case5(_), .case6(_):
break
}
}
}

let formatter = DateFormatter()
formatter.dateFormat = "yyyy-MM-dd HH:mm"
formatter.timeZone = TimeZone.current
formatter.locale = Locale.current
self.triggerAddEventToCalendar(
title: args["event_name"]!,
startDate: formatter.date(from: args["start"]!) ?? Date(),
endDate: formatter.date(from: args["end"]!) ?? Date()
)
let formatter = DateFormatter()
formatter.dateFormat = "yyyy-MM-dd HH:mm"
formatter.timeZone = TimeZone.current
formatter.locale = Locale.current
self.triggerAddEventToCalendar(
title: args["event_name"]!,
startDate: formatter.date(from: args["start"]!) ?? Date(),
endDate: formatter.date(from: args["end"]!) ?? Date()
)
case .case1(_):
break
}
}
case .text(let text):
break
case .image(_):
break
}
}
case .AgentTurnResponseStepCompletePayload(_):
break
case .AgentTurnResponseTurnStartPayload(_):
case .step_complete(_):
break
case .turn_start(_):
break
case .AgentTurnResponseTurnCompletePayload(_):
case .turn_complete(_):
break
}
}
Expand Down Expand Up @@ -308,22 +305,17 @@ struct ContentView: View {
let createSystemResponse = try await self.agents.create(
request: Components.Schemas.CreateAgentRequest(
agent_config: Components.Schemas.AgentConfig(
client_tools: [ CustomTools.getCreateEventToolForAgent() ],
enable_session_persistence: false,
instructions: "You are a helpful assistant",
max_infer_iters: 1,
model: "Llama3.1-8B-Instruct",
tools: [
Components.Schemas.AgentConfig.toolsPayloadPayload.FunctionCallToolDefinition(
CustomTools.getCreateEventTool()
)
]
model: "meta-llama/Llama-3.1-8B-Instruct"
)
)
)
self.agentId = createSystemResponse.agent_id

let createSessionResponse = try await self.agents.createSession(
request: Components.Schemas.CreateAgentSessionRequest(agent_id: self.agentId, session_name: "llama-assistant")
let createSessionResponse = try await self.agents.createSession(agent_id: self.agentId, request: Components.Schemas.CreateAgentSessionRequest(session_name: "llama-assistant")
)
self.agenticSystemSessionId = createSessionResponse.session_id

Expand Down