Skip to content

Commit 7c7ebb0

Browse files
authored
feat: retry parts (#3369)
1 parent 9def7cf commit 7c7ebb0

File tree

5 files changed

+485
-212
lines changed

5 files changed

+485
-212
lines changed

packages/opencode/src/session/compaction.ts

Lines changed: 159 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { streamText, type ModelMessage, LoadAPIKeyError } from "ai"
1+
import { streamText, type ModelMessage, LoadAPIKeyError, type StreamTextResult, type Tool as AITool } from "ai"
22
import { Session } from "."
33
import { Identifier } from "../id/id"
44
import { Instance } from "../project/instance"
@@ -14,8 +14,8 @@ import { Flag } from "../flag/flag"
1414
import { Token } from "../util/token"
1515
import { Log } from "../util/log"
1616
import { SessionLock } from "./lock"
17-
import { NamedError } from "../util/error"
1817
import { ProviderTransform } from "@/provider/transform"
18+
import { SessionRetry } from "./retry"
1919

2020
export namespace SessionCompaction {
2121
const log = Log.create({ service: "session.compaction" })
@@ -41,6 +41,7 @@ export namespace SessionCompaction {
4141

4242
export const PRUNE_MINIMUM = 20_000
4343
export const PRUNE_PROTECT = 40_000
44+
const MAX_RETRIES = 10
4445

4546
// goes backwards through parts until there are 40_000 tokens worth of tool
4647
// calls. then erases output of previous tool calls. idea is to throw away old
@@ -142,112 +143,173 @@ export namespace SessionCompaction {
142143
},
143144
})) as MessageV2.TextPart
144145

145-
const stream = streamText({
146-
maxRetries: 10,
147-
model: model.language,
148-
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
149-
abortSignal: signal,
150-
onError(error) {
151-
log.error("stream error", {
152-
error,
153-
})
154-
},
155-
messages: [
156-
...system.map(
157-
(x): ModelMessage => ({
158-
role: "system",
159-
content: x,
160-
}),
161-
),
162-
...MessageV2.toModelMessage(toSummarize),
163-
{
164-
role: "user",
165-
content: [
166-
{
167-
type: "text",
168-
text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
169-
},
170-
],
146+
const doStream = () =>
147+
streamText({
148+
// set to 0, we handle loop
149+
maxRetries: 0,
150+
model: model.language,
151+
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
152+
abortSignal: signal,
153+
onError(error) {
154+
log.error("stream error", {
155+
error,
156+
})
171157
},
172-
],
173-
})
158+
messages: [
159+
...system.map(
160+
(x): ModelMessage => ({
161+
role: "system",
162+
content: x,
163+
}),
164+
),
165+
...MessageV2.toModelMessage(toSummarize),
166+
{
167+
role: "user",
168+
content: [
169+
{
170+
type: "text",
171+
text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
172+
},
173+
],
174+
},
175+
],
176+
})
174177

175-
try {
176-
for await (const value of stream.fullStream) {
177-
signal.throwIfAborted()
178-
switch (value.type) {
179-
case "text-delta":
180-
part.text += value.text
181-
if (value.providerMetadata) part.metadata = value.providerMetadata
182-
if (part.text) await Session.updatePart(part)
183-
continue
184-
case "text-end": {
185-
part.text = part.text.trimEnd()
186-
part.time = {
187-
start: Date.now(),
188-
end: Date.now(),
178+
// TODO: reduce duplication between compaction.ts & prompt.ts
179+
const process = async (
180+
stream: StreamTextResult<Record<string, AITool>, never>,
181+
retries: { count: number; max: number },
182+
) => {
183+
let shouldRetry = false
184+
try {
185+
for await (const value of stream.fullStream) {
186+
signal.throwIfAborted()
187+
switch (value.type) {
188+
case "text-delta":
189+
part.text += value.text
190+
if (value.providerMetadata) part.metadata = value.providerMetadata
191+
if (part.text) await Session.updatePart(part)
192+
continue
193+
case "text-end": {
194+
part.text = part.text.trimEnd()
195+
part.time = {
196+
start: Date.now(),
197+
end: Date.now(),
198+
}
199+
if (value.providerMetadata) part.metadata = value.providerMetadata
200+
await Session.updatePart(part)
201+
continue
189202
}
190-
if (value.providerMetadata) part.metadata = value.providerMetadata
191-
await Session.updatePart(part)
192-
continue
193-
}
194-
case "finish-step": {
195-
const usage = Session.getUsage({
196-
model: model.info,
197-
usage: value.usage,
198-
metadata: value.providerMetadata,
199-
})
200-
msg.cost += usage.cost
201-
msg.tokens = usage.tokens
202-
await Session.updateMessage(msg)
203-
continue
203+
case "finish-step": {
204+
const usage = Session.getUsage({
205+
model: model.info,
206+
usage: value.usage,
207+
metadata: value.providerMetadata,
208+
})
209+
msg.cost += usage.cost
210+
msg.tokens = usage.tokens
211+
await Session.updateMessage(msg)
212+
continue
213+
}
214+
case "error":
215+
throw value.error
216+
default:
217+
continue
204218
}
205-
case "error":
206-
throw value.error
207-
default:
208-
continue
209219
}
210-
}
211-
} catch (e) {
212-
log.error("compaction error", {
213-
error: e,
214-
})
215-
switch (true) {
216-
case e instanceof DOMException && e.name === "AbortError":
217-
msg.error = new MessageV2.AbortedError(
218-
{ message: e.message },
219-
{
220-
cause: e,
221-
},
222-
).toObject()
223-
break
224-
case MessageV2.OutputLengthError.isInstance(e):
225-
msg.error = e
226-
break
227-
case LoadAPIKeyError.isInstance(e):
228-
msg.error = new MessageV2.AuthError(
229-
{
230-
providerID: model.providerID,
231-
message: e.message,
220+
} catch (e) {
221+
log.error("compaction error", {
222+
error: e,
223+
})
224+
const error = MessageV2.fromError(e, { providerID: input.providerID })
225+
if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
226+
shouldRetry = true
227+
await Session.updatePart({
228+
id: Identifier.ascending("part"),
229+
messageID: msg.id,
230+
sessionID: msg.sessionID,
231+
type: "retry",
232+
attempt: retries.count + 1,
233+
time: {
234+
created: Date.now(),
232235
},
233-
{ cause: e },
234-
).toObject()
235-
break
236-
case e instanceof Error:
237-
msg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
236+
error,
237+
})
238+
} else {
239+
msg.error = error
240+
Bus.publish(Session.Event.Error, {
241+
sessionID: msg.sessionID,
242+
error: msg.error,
243+
})
244+
}
245+
}
246+
247+
const parts = await Session.getParts(msg.id)
248+
return {
249+
info: msg,
250+
parts,
251+
shouldRetry,
252+
}
253+
}
254+
255+
let stream = doStream()
256+
let result = await process(stream, {
257+
count: 0,
258+
max: MAX_RETRIES,
259+
})
260+
if (result.shouldRetry) {
261+
for (let retry = 1; retry < MAX_RETRIES; retry++) {
262+
const lastRetryPart = result.parts.findLast((p) => p.type === "retry")
263+
264+
if (lastRetryPart) {
265+
const delayMs = SessionRetry.getRetryDelayInMs(lastRetryPart.error, retry)
266+
267+
log.info("retrying with backoff", {
268+
attempt: retry,
269+
delayMs,
270+
})
271+
272+
const stop = await SessionRetry.sleep(delayMs, signal)
273+
.then(() => false)
274+
.catch((error) => {
275+
if (error instanceof DOMException && error.name === "AbortError") {
276+
const err = new MessageV2.AbortedError(
277+
{ message: error.message },
278+
{
279+
cause: error,
280+
},
281+
).toObject()
282+
result.info.error = err
283+
Bus.publish(Session.Event.Error, {
284+
sessionID: result.info.sessionID,
285+
error: result.info.error,
286+
})
287+
return true
288+
}
289+
throw error
290+
})
291+
292+
if (stop) break
293+
}
294+
295+
stream = doStream()
296+
result = await process(stream, {
297+
count: retry,
298+
max: MAX_RETRIES,
299+
})
300+
if (!result.shouldRetry) {
238301
break
239-
default:
240-
msg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
302+
}
241303
}
242-
Bus.publish(Session.Event.Error, {
243-
sessionID: input.sessionID,
244-
error: msg.error,
245-
})
246304
}
247305

248306
msg.time.completed = Date.now()
249307

250-
if (!msg.error || MessageV2.AbortedError.isInstance(msg.error)) {
308+
if (
309+
!msg.error ||
310+
(MessageV2.AbortedError.isInstance(msg.error) &&
311+
result.parts.some((part) => part.type === "text" && part.text.length > 0))
312+
) {
251313
msg.summary = true
252314
Bus.publish(Event.Compacted, {
253315
sessionID: input.sessionID,
@@ -257,7 +319,7 @@ export namespace SessionCompaction {
257319

258320
return {
259321
info: msg,
260-
parts: [part],
322+
parts: result.parts,
261323
}
262324
}
263325
}

0 commit comments

Comments
 (0)