Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LocalInferenceImpl update for LS 0.1 #911

Merged
merged 1 commit into from
Feb 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class LocalInference: Inference {

public func chatCompletion(request: Components.Schemas.ChatCompletionRequest) -> AsyncStream<Components.Schemas.ChatCompletionResponseStreamChunk> {
return AsyncStream { continuation in
runnerQueue.async {
let workItem = DispatchWorkItem {
do {
var tokens: [String] = []

Expand Down Expand Up @@ -69,9 +69,10 @@ public class LocalInference: Inference {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .case1(""),
parse_status: Components.Schemas.ToolCallParseStatus.started
delta: .tool_call(Components.Schemas.ToolCallDelta(
parse_status: Components.Schemas.ToolCallParseStatus.started,
tool_call: .case1(""),
_type: Components.Schemas.ToolCallDelta._typePayload.tool_call
)
),
event_type: .progress
Expand All @@ -95,14 +96,18 @@ public class LocalInference: Inference {
text = token
}

var delta: Components.Schemas.ChatCompletionResponseEvent.deltaPayload
var delta: Components.Schemas.ContentDelta
if ipython {
delta = .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .case1(text),
parse_status: .in_progress
delta = .tool_call(Components.Schemas.ToolCallDelta(
parse_status: .in_progress,
tool_call: .case1(text),
_type: .tool_call
))
} else {
delta = .case1(text)
delta = .text(Components.Schemas.TextDelta(
text: text,
_type: Components.Schemas.TextDelta._typePayload.text)
)
}

if stopReason == nil {
Expand All @@ -129,7 +134,12 @@ public class LocalInference: Inference {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(content: .case1(""), parse_status: .failure)),
delta: .tool_call(Components.Schemas.ToolCallDelta(
parse_status: Components.Schemas.ToolCallParseStatus.failed,
tool_call: .case1(""),
_type: Components.Schemas.ToolCallDelta._typePayload.tool_call
)
),
event_type: .progress
)
// TODO: stopReason
Expand All @@ -141,10 +151,12 @@ public class LocalInference: Inference {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .ToolCall(toolCall),
parse_status: .success
)),
delta: .tool_call(Components.Schemas.ToolCallDelta(
parse_status: Components.Schemas.ToolCallParseStatus.succeeded,
tool_call: Components.Schemas.ToolCallDelta.tool_callPayload.ToolCall(toolCall),
_type: Components.Schemas.ToolCallDelta._typePayload.tool_call
)
),
event_type: .progress
)
// TODO: stopReason
Expand All @@ -155,7 +167,10 @@ public class LocalInference: Inference {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .case1(""),
delta: .text(Components.Schemas.TextDelta(
text: "",
_type: Components.Schemas.TextDelta._typePayload.text)
),
event_type: .complete
)
// TODO: stopReason
Expand All @@ -166,6 +181,7 @@ public class LocalInference: Inference {
print("Inference error: " + error.localizedDescription)
}
}
runnerQueue.async(execute: workItem)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ func encodeHeader(role: String) -> String {
return "<|start_header_id|>\(role)<|end_header_id|>\n\n"
}

func encodeDialogPrompt(messages: [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload]) -> String {
func encodeDialogPrompt(messages: [Components.Schemas.Message]) -> String {
var prompt = ""

prompt.append("<|begin_of_text|>")
Expand All @@ -20,24 +20,24 @@ func encodeDialogPrompt(messages: [Components.Schemas.ChatCompletionRequest.mess
return prompt
}

func getRole(message: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload) -> String {
func getRole(message: Components.Schemas.Message) -> String {
switch (message) {
case .UserMessage(let m):
case .user(let m):
return m.role.rawValue
case .SystemMessage(let m):
case .system(let m):
return m.role.rawValue
case .ToolResponseMessage(let m):
case .tool(let m):
return m.role.rawValue
case .CompletionMessage(let m):
case .assistant(let m):
return m.role.rawValue
}
}

func encodeMessage(message: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload) -> String {
func encodeMessage(message: Components.Schemas.Message) -> String {
var prompt = encodeHeader(role: getRole(message: message))

switch (message) {
case .CompletionMessage(let m):
case .assistant(let m):
if (m.tool_calls.count > 0) {
prompt += "<|python_tag|>"
}
Expand All @@ -64,37 +64,37 @@ func encodeMessage(message: Components.Schemas.ChatCompletionRequest.messagesPay
}

switch (message) {
case .UserMessage(let m):
case .user(let m):
prompt += _processContent(m.content)
case .SystemMessage(let m):
case .system(let m):
prompt += _processContent(m.content)
case .ToolResponseMessage(let m):
case .tool(let m):
prompt += _processContent(m.content)
case .CompletionMessage(let m):
case .assistant(let m):
prompt += _processContent(m.content)
}

var eom = false

switch (message) {
case .UserMessage(let m):
case .user(let m):
switch (m.content) {
case .case1(let c):
prompt += _processContent(c)
case .ImageMedia(let c):
case .InterleavedContentItem(let c):
prompt += _processContent(c)
case .case3(let c):
prompt += _processContent(c)
}
case .CompletionMessage(let m):
case .assistant(let m):
// TODO: Support encoding past tool call history
// for t in m.tool_calls {
// _processContent(t.)
//}
eom = m.stop_reason == Components.Schemas.StopReason.end_of_message
case .SystemMessage(_):
case .system(_):
break
case .ToolResponseMessage(_):
case .tool(_):
break
}

Expand All @@ -107,12 +107,12 @@ func encodeMessage(message: Components.Schemas.ChatCompletionRequest.messagesPay
return prompt
}

func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws -> [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload] {
func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws -> [Components.Schemas.Message] {
var existingMessages = request.messages
var existingSystemMessage: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload?
var existingSystemMessage: Components.Schemas.Message?
// TODO: Existing system message

var messages: [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload] = []
var messages: [Components.Schemas.Message] = []

let defaultGen = SystemDefaultGenerator()
let defaultTemplate = defaultGen.gen()
Expand All @@ -123,7 +123,7 @@ func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws -

sysContent += try defaultTemplate.render()

messages.append(.SystemMessage(Components.Schemas.SystemMessage(
messages.append(.system(Components.Schemas.SystemMessage(
content: .case1(sysContent),
role: .system))
)
Expand All @@ -133,7 +133,7 @@ func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws -
let toolGen = FunctionTagCustomToolGenerator()
let toolTemplate = try toolGen.gen(customTools: request.tools!)
let tools = try toolTemplate.render()
messages.append(.UserMessage(Components.Schemas.UserMessage(
messages.append(.user(Components.Schemas.UserMessage(
content: .case1(tools),
role: .user)
))
Expand Down