Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions src/main/core/conversations/impl/local-conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const MAX_RESPAWNS = 2;

export class LocalConversationProvider implements ConversationProvider {
private sessions = new Map<string, Pty>();
private knownSessionIds = new Set<string>();
private respawnCounts = new Map<string, number>();
private readonly projectId: string;
private readonly taskPath: string;
Expand Down Expand Up @@ -73,6 +74,7 @@ export class LocalConversationProvider implements ConversationProvider {
conversation.taskId,
conversation.id
);
this.knownSessionIds.add(sessionId);
if (this.sessions.has(sessionId)) return;

await claudeTrustService.maybeAutoTrustLocal({
Expand Down Expand Up @@ -188,28 +190,31 @@ export class LocalConversationProvider implements ConversationProvider {

async stopSession(conversationId: string): Promise<void> {
const sessionId = makePtySessionId(this.projectId, this.taskId, conversationId);
this.knownSessionIds.delete(sessionId);
const pty = this.sessions.get(sessionId);
if (!pty) return;
try {
pty.kill();
} catch (e) {
log.warn('LocalAgentProvider: error killing PTY', { sessionId, error: String(e) });
if (pty) {
try {
pty.kill();
} catch (e) {
log.warn('LocalAgentProvider: error killing PTY', { sessionId, error: String(e) });
}
this.sessions.delete(sessionId);
ptySessionRegistry.unregister(sessionId);
}
this.sessions.delete(sessionId);
ptySessionRegistry.unregister(sessionId);
if (this.tmux) {
await killTmuxSession(this.exec, makeTmuxSessionName(sessionId));
}
}

async destroyAll(): Promise<void> {
const sessionIds = Array.from(this.sessions.keys());
const sessionIds = Array.from(this.knownSessionIds);
await this.detachAll();
if (this.tmux) {
await Promise.all(
sessionIds.map((id) => killTmuxSession(this.exec, makeTmuxSessionName(id)))
);
}
this.knownSessionIds.clear();
}

async detachAll(): Promise<void> {
Expand Down
21 changes: 13 additions & 8 deletions src/main/core/conversations/impl/ssh-conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const MAX_RESPAWNS = 2;

export class SshConversationProvider implements ConversationProvider {
private sessions = new Map<string, Pty>();
private knownSessionIds = new Set<string>();
private respawnCounts = new Map<string, number>();
private readonly projectId: string;
private readonly taskPath: string;
Expand Down Expand Up @@ -74,6 +75,7 @@ export class SshConversationProvider implements ConversationProvider {
conversation.taskId,
conversation.id
);
this.knownSessionIds.add(sessionId);

if (this.sessions.has(sessionId)) return;

Expand Down Expand Up @@ -185,28 +187,31 @@ export class SshConversationProvider implements ConversationProvider {

async stopSession(conversationId: string): Promise<void> {
const sessionId = makePtySessionId(this.projectId, this.taskId, conversationId);
this.knownSessionIds.delete(sessionId);
const pty = this.sessions.get(sessionId);
if (!pty) return;
try {
pty.kill();
} catch (e) {
log.warn('SshAgentProvider: error killing PTY', { sessionId, error: String(e) });
if (pty) {
try {
pty.kill();
} catch (e) {
log.warn('SshAgentProvider: error killing PTY', { sessionId, error: String(e) });
}
this.sessions.delete(sessionId);
ptySessionRegistry.unregister(sessionId);
}
this.sessions.delete(sessionId);
ptySessionRegistry.unregister(sessionId);
if (this.tmux) {
await killTmuxSession(this.exec, makeTmuxSessionName(sessionId));
}
}

async destroyAll(): Promise<void> {
const sessionIds = Array.from(this.sessions.keys());
const sessionIds = Array.from(this.knownSessionIds);
await this.detachAll();
if (this.tmux) {
await Promise.all(
sessionIds.map((id) => killTmuxSession(this.exec, makeTmuxSessionName(id)))
);
}
this.knownSessionIds.clear();
}

async detachAll(): Promise<void> {
Expand Down
53 changes: 46 additions & 7 deletions src/main/core/projects/impl/local-project-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import fs from 'node:fs';
import path from 'node:path';
import { Conversation } from '@shared/conversations';
import { LocalProject } from '@shared/projects';
import { makePtySessionId } from '@shared/ptySessionId';
import { err, ok, type Result } from '@shared/result';
import { getTaskEnvVars } from '@shared/task/envVars';
import { Task, type TaskBootstrapStatus } from '@shared/tasks';
Expand All @@ -15,9 +16,11 @@ import { GitService } from '@main/core/git/impl/git-service';
import { bareRefName } from '@main/core/git/impl/git-utils';
import type { GitProvider } from '@main/core/git/types';
import { githubAuthService } from '@main/core/github/services/github-auth-service';
import { killTmuxSession, makeTmuxSessionName } from '@main/core/pty/tmux-session-name';
import { appSettingsService } from '@main/core/settings/settings-service';
import { getTaskSessionLeafIds } from '@main/core/tasks/session-targets';
import { LocalTerminalProvider } from '@main/core/terminals/impl/local-terminal-provider';
import { getGitLocalExec } from '@main/core/utils/exec';
import { getGitLocalExec, getLocalExec } from '@main/core/utils/exec';
import type { Workspace } from '@main/core/workspaces/workspace';
import { WorkspaceLifecycleService } from '@main/core/workspaces/workspace-lifecycle-service';
import { WorkspaceRegistry } from '@main/core/workspaces/workspace-registry';
Expand All @@ -36,6 +39,7 @@ import { TimeoutSignal, withTimeout } from '../utils';
import { WorktreeService } from '../worktrees/worktree-service';

const TASK_TIMEOUT_MS = 60_000;
const TEARDOWN_SCRIPT_WAIT_MS = 10_000;

function toProvisionError(e: unknown): ProvisionTaskError {
if (e instanceof TimeoutSignal) return { type: 'timeout', message: e.message, timeout: e.ms };
Expand Down Expand Up @@ -72,6 +76,7 @@ export class LocalProjectProvider implements ProjectProvider {
private bootstrapErrors = new Map<string, ProvisionTaskError>();
private worktreeService: WorktreeService;
private workspaceRegistry = new WorkspaceRegistry();
private readonly localExec = getLocalExec();

constructor(
private readonly project: LocalProject,
Expand Down Expand Up @@ -307,15 +312,24 @@ export class LocalProjectProvider implements ProjectProvider {
async teardownTask(taskId: string): Promise<Result<void, TeardownTaskError>> {
if (this.tearingDownTasks.has(taskId)) return this.tearingDownTasks.get(taskId)!;
const task = this.tasks.get(taskId);
if (!task) return ok();
if (!task) {
await this.cleanupDetachedTmuxSessions(taskId);
return ok();
}

const promise = withTimeout(this.doTeardownTask(task), TASK_TIMEOUT_MS)
.then(() => ok<void>())
.catch((e) => {
.catch(async (e) => {
log.error('LocalProjectProvider: failed to teardown task', {
taskId,
error: String(e),
});
await this.cleanupDetachedTmuxSessions(taskId).catch((cleanupError) => {
log.warn('LocalProjectProvider: fallback tmux cleanup failed', {
taskId,
error: String(cleanupError),
});
});
return err<TeardownTaskError>(toTeardownError(e));
})
.finally(() => {
Expand Down Expand Up @@ -345,10 +359,25 @@ export class LocalProjectProvider implements ProjectProvider {
const scripts = settings.scripts;

if (scripts?.teardown && this.workspaceRegistry.refCount(wsId) === 1) {
await workspace.lifecycleService.runLifecycleScript(
{ type: 'teardown', script: scripts.teardown },
{ waitForExit: true, exit: true }
);
try {
const runTeardown = workspace.lifecycleService.runLifecycleScript(
{ type: 'teardown', script: scripts.teardown },
{ waitForExit: true, exit: true }
);
await withTimeout(runTeardown, TEARDOWN_SCRIPT_WAIT_MS);
} catch (error) {
if (error instanceof TimeoutSignal) {
log.debug('LocalProjectProvider: teardown script wait timed out', {
taskId: task.taskId,
timeoutMs: TEARDOWN_SCRIPT_WAIT_MS,
});
} else {
log.warn('LocalProjectProvider: teardown script failed (continuing cleanup)', {
taskId: task.taskId,
error: String(error),
});
}
}
}
}

Expand All @@ -357,6 +386,16 @@ export class LocalProjectProvider implements ProjectProvider {
await this.workspaceRegistry.release(wsId);
}

private async cleanupDetachedTmuxSessions(taskId: string): Promise<void> {
const { conversationIds, terminalIds } = await getTaskSessionLeafIds(this.project.id, taskId);
const sessionIds = [...conversationIds, ...terminalIds].map((leafId) =>
makePtySessionId(this.project.id, taskId, leafId)
);
await Promise.all(
sessionIds.map((sessionId) => killTmuxSession(this.localExec, makeTmuxSessionName(sessionId)))
);
}

async removeTaskWorktree(taskBranch: string): Promise<void> {
const worktreePath = await this.worktreeService.getWorktree(taskBranch);
if (worktreePath) {
Expand Down
51 changes: 45 additions & 6 deletions src/main/core/projects/impl/ssh-project-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import path from 'node:path';
import type { SFTPWrapper } from 'ssh2';
import { Conversation } from '@shared/conversations';
import type { SshProject } from '@shared/projects';
import { makePtySessionId } from '@shared/ptySessionId';
import { err, ok, type Result } from '@shared/result';
import { getTaskEnvVars } from '@shared/task/envVars';
import { Task, type TaskBootstrapStatus } from '@shared/tasks';
Expand All @@ -15,8 +16,10 @@ import { GitService } from '@main/core/git/impl/git-service';
import { bareRefName } from '@main/core/git/impl/git-utils';
import type { GitProvider } from '@main/core/git/types';
import { githubAuthService } from '@main/core/github/services/github-auth-service';
import { killTmuxSession, makeTmuxSessionName } from '@main/core/pty/tmux-session-name';
import { SshClientProxy } from '@main/core/ssh/ssh-client-proxy';
import { SshConnectionEvent, sshConnectionManager } from '@main/core/ssh/ssh-connection-manager';
import { getTaskSessionLeafIds } from '@main/core/tasks/session-targets';
import { SshTerminalProvider } from '@main/core/terminals/impl/ssh-terminal-provider';
import { getGitSshExec, getSshExec } from '@main/core/utils/exec';
import type { Workspace } from '@main/core/workspaces/workspace';
Expand All @@ -37,6 +40,7 @@ import { TimeoutSignal, withTimeout } from '../utils';
import { WorktreeService } from '../worktrees/worktree-service';

const TASK_TIMEOUT_MS = 60_000;
const TEARDOWN_SCRIPT_WAIT_MS = 10_000;

function toProvisionError(e: unknown): ProvisionTaskError {
if (e instanceof TimeoutSignal) return { type: 'timeout', message: e.message, timeout: e.ms };
Expand Down Expand Up @@ -347,15 +351,24 @@ export class SshProjectProvider implements ProjectProvider {
async teardownTask(taskId: string): Promise<Result<void, TeardownTaskError>> {
if (this.tearingDownTasks.has(taskId)) return this.tearingDownTasks.get(taskId)!;
const task = this.tasks.get(taskId);
if (!task) return ok();
if (!task) {
await this.cleanupDetachedTmuxSessions(taskId);
return ok();
}

const promise = withTimeout(this.doTeardownTask(task), TASK_TIMEOUT_MS)
.then(() => ok<void>())
.catch((e) => {
.catch(async (e) => {
log.error('SshProjectProvider: failed to teardown task', {
taskId,
error: String(e),
});
await this.cleanupDetachedTmuxSessions(taskId).catch((cleanupError) => {
log.warn('SshProjectProvider: fallback tmux cleanup failed', {
taskId,
error: String(cleanupError),
});
});
return err<TeardownTaskError>(toTeardownError(e));
})
.finally(() => {
Expand Down Expand Up @@ -387,10 +400,25 @@ export class SshProjectProvider implements ProjectProvider {
const scripts = settings.scripts;

if (scripts?.teardown && this.workspaceRegistry.refCount(wsId) === 1) {
await workspace.lifecycleService.runLifecycleScript(
{ type: 'teardown', script: scripts.teardown },
{ waitForExit: true, exit: true }
);
try {
const runTeardown = workspace.lifecycleService.runLifecycleScript(
{ type: 'teardown', script: scripts.teardown },
{ waitForExit: true, exit: true }
);
await withTimeout(runTeardown, TEARDOWN_SCRIPT_WAIT_MS);
} catch (error) {
if (error instanceof TimeoutSignal) {
log.debug('SshProjectProvider: teardown script wait timed out', {
taskId: task.taskId,
timeoutMs: TEARDOWN_SCRIPT_WAIT_MS,
});
} else {
log.warn('SshProjectProvider: teardown script failed (continuing cleanup)', {
taskId: task.taskId,
error: String(error),
});
}
}
}
}

Expand All @@ -399,6 +427,17 @@ export class SshProjectProvider implements ProjectProvider {
await this.workspaceRegistry.release(wsId);
}

private async cleanupDetachedTmuxSessions(taskId: string): Promise<void> {
const { conversationIds, terminalIds } = await getTaskSessionLeafIds(this.project.id, taskId);
const sessionIds = [...conversationIds, ...terminalIds].map((leafId) =>
makePtySessionId(this.project.id, taskId, leafId)
);
const exec = getSshExec(this.proxy);
await Promise.all(
sessionIds.map((sessionId) => killTmuxSession(exec, makeTmuxSessionName(sessionId)))
);
}

async removeTaskWorktree(taskBranch: string): Promise<void> {
const worktreePath = await this.worktreeService.getWorktree(taskBranch);
if (worktreePath) {
Expand Down
22 changes: 11 additions & 11 deletions src/main/core/tasks/deleteTask.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@ export async function deleteTask(projectId: string, taskId: string): Promise<voi

const project = projectManager.getProject(projectId);

if (project) {
const teardownResult = await project.teardownTask(taskId).catch((e) => {
log.warn('deleteTask: teardown failed', { taskId, error: String(e) });
return null;
});

if (teardownResult && !teardownResult.success) {
log.warn('deleteTask: teardown failed', { taskId, error: teardownResult.error.message });
}
}

await db.delete(tasks).where(eq(tasks.id, taskId));
void viewStateService.del(`task:${taskId}`);
capture('task_deleted');

if (project) {
void project
.teardownTask(taskId)
.then((teardownResult) => {
if (!teardownResult.success) {
log.warn('deleteTask: teardown failed', { taskId, error: teardownResult.error.message });
}
})
.catch((e) => {
log.warn('deleteTask: teardown failed', { taskId, error: String(e) });
});

if (task.taskBranch) {
const siblings = await db
.select({ id: tasks.id })
Expand Down
29 changes: 29 additions & 0 deletions src/main/core/tasks/session-targets.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import { and, eq } from 'drizzle-orm';
import { db } from '@main/db/client';
import { conversations, terminals } from '@main/db/schema';

export type TaskSessionLeafIds = {
conversationIds: string[];
terminalIds: string[];
};

export async function getTaskSessionLeafIds(
projectId: string,
taskId: string
): Promise<TaskSessionLeafIds> {
const [conversationRows, terminalRows] = await Promise.all([
db
.select({ id: conversations.id })
.from(conversations)
.where(and(eq(conversations.projectId, projectId), eq(conversations.taskId, taskId))),
db
.select({ id: terminals.id })
.from(terminals)
.where(and(eq(terminals.projectId, projectId), eq(terminals.taskId, taskId))),
]);

return {
conversationIds: conversationRows.map((row) => row.id),
terminalIds: terminalRows.map((row) => row.id),
};
}
Loading