diff --git a/apps/code/src/main/services/agent/service.ts b/apps/code/src/main/services/agent/service.ts index f6348191ad..3a54f52785 100644 --- a/apps/code/src/main/services/agent/service.ts +++ b/apps/code/src/main/services/agent/service.ts @@ -1,4 +1,4 @@ -import fs, { mkdirSync, symlinkSync } from "node:fs"; +import fs, { promises as fsPromises, mkdirSync, symlinkSync } from "node:fs"; import { homedir, tmpdir } from "node:os"; import { isAbsolute, join, relative, resolve, sep } from "node:path"; import { @@ -18,7 +18,10 @@ import { POSTHOG_NOTIFICATIONS, } from "@posthog/agent"; import type { McpToolApprovals } from "@posthog/agent/adapters/claude/mcp/tool-metadata"; -import { hydrateSessionJsonl } from "@posthog/agent/adapters/claude/session/jsonl-hydration"; +import { + getSessionJsonlPath, + hydrateSessionJsonl, +} from "@posthog/agent/adapters/claude/session/jsonl-hydration"; import { getReasoningEffortOptions } from "@posthog/agent/adapters/reasoning-effort"; import { Agent } from "@posthog/agent/agent"; import { @@ -38,10 +41,12 @@ import { getLlmGatewayUrl } from "@posthog/agent/posthog-api"; import { extractCreatedPrUrl } from "@posthog/agent/pr-url-detector"; import type * as AgentTypes from "@posthog/agent/types"; import { getCurrentBranch } from "@posthog/git/queries"; +import { CaptureCheckpointSaga } from "@posthog/git/sagas/checkpoint"; import type { IAppMeta } from "@posthog/platform/app-meta"; import type { IBundledResources } from "@posthog/platform/bundled-resources"; import type { IPowerManager } from "@posthog/platform/power-manager"; import type { IStoragePaths } from "@posthog/platform/storage-paths"; +import { DATA_DIR } from "@shared/constants"; import { isAuthError } from "@shared/errors"; import type { AcpMessage } from "@shared/types/session-events"; import { inject, injectable, preDestroy } from "inversify"; @@ -286,6 +291,13 @@ export class AgentService extends TypedEventEmitter { private sessions = new Map(); private pendingPermissions = new Map(); + /** taskRunIds that need force-refetch of JSONL on next reconnect (checkpoint restore). */ + private checkpointRestoreTaskRunIds = new Set(); + /** Checkpoint notifications captured per taskRunId, in capture order. Survives session reconnect. */ + private sessionCheckpoints = new Map< + string, + Array<{ checkpointId: string; ts: number; promptId: number | undefined }> + >(); private mockNodeReady = false; private idleTimeouts = new Map< string, @@ -758,6 +770,11 @@ When creating pull requests, add the following footer at the end of the PR descr if (adapter !== "codex") { const posthogAPI = agent.getPosthogAPI(); if (posthogAPI) { + const forceRefetch = + this.checkpointRestoreTaskRunIds.has(taskRunId); + if (forceRefetch) { + this.checkpointRestoreTaskRunIds.delete(taskRunId); + } const hasSession = await hydrateSessionJsonl({ sessionId: existingSessionId, cwd: repoPath, @@ -766,6 +783,7 @@ When creating pull requests, add the following footer at the end of the PR descr permissionMode: config.permissionMode, posthogAPI, log, + forceRefetch, }); if (!hasSession) { log.info( @@ -780,6 +798,11 @@ When creating pull requests, add the following footer at the end of the PR descr if (isReconnect && config.sessionId) { const existingSessionId = config.sessionId; + log.info("Reconnecting with existing sessionId", { + taskId, + taskRunId, + sessionId: existingSessionId, + }); // Both adapters implement resumeSession: // - Claude: delegates to SDK's resumeSession with JSONL hydration @@ -993,6 +1016,106 @@ When creating pull requests, add the following footer at the end of the PR descr return this.sessions.get(taskRunId); } + getSessionInfo(taskRunId: string): + | { + sessionId: string; + repoPath: string; + taskId: string; + apiHost: string; + projectId: number; + adapter: "claude" | "codex" | undefined; + } + | undefined { + const session = this.sessions.get(taskRunId); + if (!session?.config.sessionId) return undefined; + return { + sessionId: session.config.sessionId, + repoPath: session.repoPath, + taskId: session.config.taskId, + apiHost: session.config.credentials.apiHost, + projectId: session.config.credentials.projectId, + adapter: session.config.adapter, + }; + } + + /** + * Re-emit stored checkpoint notifications through the SessionEvent channel. + * Called by the renderer after its subscription is set up so no events are lost. + */ + replayCheckpoints(taskRunId: string): number { + const checkpoints = this.sessionCheckpoints.get(taskRunId) ?? []; + if (checkpoints.length === 0) return 0; + + log.info("Replaying stored checkpoints via SessionEvent", { + taskRunId, + count: checkpoints.length, + }); + + for (const { checkpointId, ts, promptId } of checkpoints) { + this.emit(AgentServiceEvent.SessionEvent, { + taskRunId, + payload: { + type: "acp_message", + ts, + message: { + jsonrpc: "2.0" as const, + method: POSTHOG_NOTIFICATIONS.GIT_CHECKPOINT, + // Mark as replay so renderer doesn't re-sync to S3 + params: { checkpointId, promptId, replay: true }, + }, + }, + }); + } + + return checkpoints.length; + } + + /** + * Get the promptId for a checkpoint. Used when truncating S3 log so the + * backend can find the correct turn boundary by promptId. + */ + getCheckpointPromptId( + taskRunId: string, + checkpointId: string, + ): number | undefined { + const checkpoints = this.sessionCheckpoints.get(taskRunId) ?? []; + const cp = checkpoints.find((c) => c.checkpointId === checkpointId); + return cp?.promptId; + } + + /** + * Remove stored checkpoints that come AFTER the given checkpointId (inclusive + * of the target). Called after a restore so replayCheckpoints only re-emits + * the surviving checkpoints and not orphaned ones whose git refs were deleted. + */ + truncateCheckpoints(taskRunId: string, keepUpToCheckpointId: string): number { + const checkpoints = this.sessionCheckpoints.get(taskRunId) ?? []; + const idx = checkpoints.findIndex( + (cp) => cp.checkpointId === keepUpToCheckpointId, + ); + if (idx === -1) return 0; + this.sessionCheckpoints.set(taskRunId, checkpoints.slice(0, idx + 1)); + log.info("Truncated stored checkpoints after restore", { + taskRunId, + keepUpTo: keepUpToCheckpointId, + kept: idx + 1, + removed: checkpoints.length - idx - 1, + }); + return idx + 1; + } + + /** + * Mark a taskRunId so the next reconnect forces JSONL re-hydration from S3 + * instead of reusing an existing (stale) JSONL file. Called before cancelSession + * during checkpoint restore so hydrateSessionJsonl skips the "already exists" guard. + */ + markCheckpointRestore(taskRunId: string): void { + this.checkpointRestoreTaskRunIds.add(taskRunId); + log.info("Marked taskRunId for force JSONL refetch on reconnect", { + taskRunId, + }); + } + async setSessionConfigOption( sessionId: string, configId: string, @@ -1263,6 +1386,10 @@ For git operations while detached: }); }; + // Track the most recent session/prompt request ID so the checkpoint + // notification can be tagged with the turn it belongs to. + let latestPromptId: number | undefined; + const onAcpMessage = (message: unknown) => { const acpMessage: AcpMessage = { type: "acp_message", @@ -1271,8 +1398,25 @@ For git operations while detached: }; emitToRenderer(acpMessage); + // Track session/prompt request IDs for turn-tagging + const raw = message as { method?: string; id?: number }; + if (raw.method === "session/prompt" && raw.id !== undefined) { + latestPromptId = raw.id; + log.debug("Tracked session/prompt id", { taskRunId, promptId: raw.id }); + } + // Inspect tool call updates for PR URLs and file activity this.handleToolCallUpdate(taskRunId, message as AcpMessage["message"]); + + // Capture a local git checkpoint when a turn completes. + // Intercepted here (raw stream tap) rather than extNotification because + // the ACP SDK does not reliably route _posthog/ notifications to that callback. + this.handleTurnCompleteForCheckpoint( + taskRunId, + message, + latestPromptId, + emitToRenderer, + ); }; const tappedReadable = createTappedReadableStream( @@ -1729,6 +1873,166 @@ For git operations while detached: }); } + private handleTurnCompleteForCheckpoint( + taskRunId: string, + message: unknown, + promptId: number | undefined, + emitToRenderer: (payload: unknown) => void, + ): void { + const msg = message as { method?: string }; + if (!isNotification(msg.method, POSTHOG_NOTIFICATIONS.TURN_COMPLETE)) + return; + + const session = this.sessions.get(taskRunId); + if (!session?.config.repoPath) { + log.debug("TURN_COMPLETE in stream — no repoPath, skipping checkpoint", { + taskRunId, + }); + return; + } + + log.info("TURN_COMPLETE in stream — capturing local checkpoint", { + taskRunId, + repoPath: session.config.repoPath, + promptId, + }); + + this.captureLocalCheckpoint( + taskRunId, + session.config.repoPath, + session.config.sessionId, + promptId, + emitToRenderer, + ).catch((err) => { + log.warn("Local checkpoint capture failed", { + taskRunId, + error: err instanceof Error ? err.message : String(err), + }); + }); + } + + /** + * Capture a local git checkpoint after a turn completes, emit the + * `_posthog/git_checkpoint` notification to the renderer, and append it to + * the session JSONL so it survives page reload. + */ + private async captureLocalCheckpoint( + taskRunId: string, + repoPath: string, + sessionId: string | undefined, + promptId: number | undefined, + emitToRenderer: (payload: unknown) => void, + ): Promise { + log.info("Capturing local checkpoint after turn", { taskRunId, repoPath }); + + const saga = new CaptureCheckpointSaga(); + const sagaResult = await saga.run({ baseDir: repoPath }); + if (!sagaResult.success) { + log.warn("CaptureCheckpointSaga failed — no checkpoint for this turn", { + taskRunId, + error: sagaResult.error, + }); + return; + } + + const result = sagaResult.data; + log.info("Local checkpoint captured", { + taskRunId, + checkpointId: result.checkpointId, + commit: result.commit, + branch: result.branch, + }); + + // Persist mapping so we can re-inject on reconnect, with promptId for + // correct turn association regardless of when the notification arrives. + const ts = Date.now(); + const existing = this.sessionCheckpoints.get(taskRunId) ?? []; + existing.push({ checkpointId: result.checkpointId, ts, promptId }); + this.sessionCheckpoints.set(taskRunId, existing); + log.info("Stored checkpoint for reconnect replay", { + taskRunId, + checkpointId: result.checkpointId, + promptId, + totalStored: existing.length, + }); + + const notification = { + jsonrpc: "2.0" as const, + method: POSTHOG_NOTIFICATIONS.GIT_CHECKPOINT, + params: { checkpointId: result.checkpointId, promptId }, + }; + + // Emit to renderer so the restore button activates on the completed turn + const acpMessage: AcpMessage = { + type: "acp_message", + ts: Date.now(), + message: notification as AcpMessage["message"], + }; + emitToRenderer(acpMessage); + + log.info("Emitted GIT_CHECKPOINT notification to renderer", { + taskRunId, + checkpointId: result.checkpointId, + }); + + // Append to the session JSONL so restore can find the checkpoint on reload + if (sessionId) { + try { + const jsonlPath = getSessionJsonlPath(sessionId, repoPath); + const line = `${JSON.stringify({ notification })}\n`; + await fsPromises.appendFile(jsonlPath, line, "utf-8"); + log.info("Checkpoint appended to JSONL", { + taskRunId, + checkpointId: result.checkpointId, + jsonlPath, + }); + } catch (err) { + log.warn( + "Failed to append checkpoint to JSONL (restore may not survive reload)", + { + taskRunId, + error: err instanceof Error ? err.message : String(err), + }, + ); + } + } else { + log.warn("No sessionId yet — checkpoint not written to JSONL", { + taskRunId, + }); + } + + // Also append to the local logs.ndjson cache. The renderer's fetchSessionLogs + // reads logs.ndjson first (before S3), and the in-memory sessionCheckpoints + // map is lost when the main process restarts. Without this, the checkpoint + // notification lives only in S3 + the in-memory map, so after the app is + // reopened the cold load reads a checkpoint-less local cache and every + // restore icon goes disabled. This append (matching the SessionLogWriter + // tap's line-by-line model) keeps checkpoints visible across restarts. + try { + const sessionDir = join(homedir(), DATA_DIR, "sessions", taskRunId); + await fsPromises.mkdir(sessionDir, { recursive: true }); + const entry = { + type: "notification" as const, + timestamp: new Date().toISOString(), + notification, + }; + await fsPromises.appendFile( + join(sessionDir, "logs.ndjson"), + `${JSON.stringify(entry)}\n`, + "utf-8", + ); + log.info("Checkpoint appended to local logs.ndjson", { + taskRunId, + checkpointId: result.checkpointId, + }); + } catch (err) { + log.warn("Failed to append checkpoint to local logs.ndjson", { + taskRunId, + error: err instanceof Error ? err.message : String(err), + }); + } + } + async getGatewayModels(apiHost: string) { const gatewayUrl = getLlmGatewayUrl(apiHost); const models = await fetchGatewayModels({ gatewayUrl }); diff --git a/apps/code/src/main/services/local-logs/service.test.ts b/apps/code/src/main/services/local-logs/service.test.ts index 80b735e739..24195d3215 100644 --- a/apps/code/src/main/services/local-logs/service.test.ts +++ b/apps/code/src/main/services/local-logs/service.test.ts @@ -2,11 +2,14 @@ import os from "node:os"; import path from "node:path"; import { beforeEach, describe, expect, it, vi } from "vitest"; -const { mockMkdir, mockWriteFile, mockReadFile } = vi.hoisted(() => ({ - mockMkdir: vi.fn(), - mockWriteFile: vi.fn(), - mockReadFile: vi.fn(), -})); +const { mockMkdir, mockWriteFile, mockReadFile, mockRename, mockUnlink } = + vi.hoisted(() => ({ + mockMkdir: vi.fn(), + mockWriteFile: vi.fn(), + mockReadFile: vi.fn(), + mockRename: vi.fn(), + mockUnlink: vi.fn(), + })); vi.mock("node:fs", () => ({ default: { @@ -14,6 +17,8 @@ vi.mock("node:fs", () => ({ mkdir: mockMkdir, writeFile: mockWriteFile, readFile: mockReadFile, + rename: mockRename, + unlink: mockUnlink, }, }, })); @@ -63,6 +68,8 @@ describe("LocalLogsService", () => { mockMkdir.mockReset().mockResolvedValue(undefined); mockWriteFile.mockReset().mockResolvedValue(undefined); mockReadFile.mockReset(); + mockRename.mockReset().mockResolvedValue(undefined); + mockUnlink.mockReset().mockResolvedValue(undefined); }); describe("readLocalLogs", () => { @@ -231,4 +238,140 @@ describe("LocalLogsService", () => { expect(mockMkdir).toHaveBeenCalledTimes(1); }); }); + + describe("truncateLocalLogsAtPromptBoundary", () => { + // A boundary line is the JSON-RPC response to a session/prompt request: + // type "notification" with a matching notification.id and a result field. + const responseLine = (id: number) => + JSON.stringify({ + type: "notification", + notification: { id, result: {} }, + }); + + it("truncates to the prompt boundary and returns true", async () => { + const kept = `${[responseLine(2), responseLine(3)].join("\n")}\n`; + mockReadFile.mockResolvedValue( + `${[responseLine(2), responseLine(3), responseLine(4)].join("\n")}\n`, + ); + + const service = new LocalLogsService(); + const result = await service.truncateLocalLogsAtPromptBoundary(RUN_ID, 3); + + expect(result).toBe(true); + // Atomic write: tmp file written then renamed over the real path. + expect(mockRename).toHaveBeenCalledTimes(1); + const wroteTruncated = mockWriteFile.mock.calls.some( + (call) => call[1] === kept, + ); + expect(wroteTruncated).toBe(true); + await flushMicrotasks(); + }); + + it("returns false and does not rewrite when the boundary is not found", async () => { + mockReadFile.mockResolvedValue( + `${[responseLine(2), responseLine(3)].join("\n")}\n`, + ); + + const service = new LocalLogsService(); + const result = await service.truncateLocalLogsAtPromptBoundary( + RUN_ID, + 99, + ); + + expect(result).toBe(false); + expect(mockRename).not.toHaveBeenCalled(); + }); + + it("returns true without rewriting when the boundary is already last", async () => { + mockReadFile.mockResolvedValue( + `${[responseLine(2), responseLine(3)].join("\n")}\n`, + ); + + const service = new LocalLogsService(); + const result = await service.truncateLocalLogsAtPromptBoundary(RUN_ID, 3); + + expect(result).toBe(true); + expect(mockRename).not.toHaveBeenCalled(); + }); + + it("returns true when the log file does not exist", async () => { + mockReadFile.mockRejectedValue( + Object.assign(new Error("nope"), { code: "ENOENT" }), + ); + + const service = new LocalLogsService(); + const result = await service.truncateLocalLogsAtPromptBoundary(RUN_ID, 3); + + expect(result).toBe(true); + expect(mockRename).not.toHaveBeenCalled(); + }); + + it("returns false on a non-ENOENT read error", async () => { + mockReadFile.mockRejectedValue(new Error("boom")); + + const service = new LocalLogsService(); + const result = await service.truncateLocalLogsAtPromptBoundary(RUN_ID, 3); + + expect(result).toBe(false); + }); + + // The restored turn's git_checkpoint notification sits after its prompt + // response, so the trim drops it; the restore flow re-adds it so the + // restored turn stays restorable after an app restart. + it("re-appends preserved trailing entries after the boundary", async () => { + const cpEntry = JSON.stringify({ + type: "notification", + timestamp: "t", + notification: { + jsonrpc: "2.0", + method: "_posthog/git_checkpoint", + params: { checkpointId: "cp1", promptId: 3 }, + }, + }); + mockReadFile.mockResolvedValue( + `${[responseLine(2), responseLine(3), responseLine(4)].join("\n")}\n`, + ); + + const service = new LocalLogsService(); + const result = await service.truncateLocalLogsAtPromptBoundary( + RUN_ID, + 3, + [cpEntry], + ); + + expect(result).toBe(true); + const expected = `${[responseLine(2), responseLine(3), cpEntry].join("\n")}\n`; + expect( + mockWriteFile.mock.calls.some((call) => call[1] === expected), + ).toBe(true); + await flushMicrotasks(); + }); + + it("re-appends preserved entries even when the boundary is the last line", async () => { + const cpEntry = JSON.stringify({ + type: "notification", + timestamp: "t", + notification: { + jsonrpc: "2.0", + method: "_posthog/git_checkpoint", + params: { checkpointId: "cp1", promptId: 3 }, + }, + }); + mockReadFile.mockResolvedValue( + `${[responseLine(2), responseLine(3)].join("\n")}\n`, + ); + + const service = new LocalLogsService(); + const result = await service.truncateLocalLogsAtPromptBoundary( + RUN_ID, + 3, + [cpEntry], + ); + + expect(result).toBe(true); + // Boundary was the last line, but the preserved entry forces a rewrite. + expect(mockRename).toHaveBeenCalledTimes(1); + await flushMicrotasks(); + }); + }); }); diff --git a/apps/code/src/main/services/local-logs/service.ts b/apps/code/src/main/services/local-logs/service.ts index 4c4281bf2f..f3e64719d2 100644 --- a/apps/code/src/main/services/local-logs/service.ts +++ b/apps/code/src/main/services/local-logs/service.ts @@ -25,6 +25,147 @@ export class LocalLogsService { { state: WriteState; inFlight: Promise } >(); + async truncateLocalLogs(taskRunId: string, lineCount: number): Promise { + const logPath = this.getLocalLogPath(taskRunId); + let content: string; + try { + content = await fs.promises.readFile(logPath, "utf-8"); + } catch (error) { + if ((error as NodeJS.ErrnoException).code !== "ENOENT") { + log.warn("Failed to read local logs for truncation:", error); + } + return; + } + const lines = content.split("\n").filter((l) => l.trim()); + if (lines.length <= lineCount) return; + const truncated = `${lines.slice(0, lineCount).join("\n")}\n`; + const tmpPath = `${logPath}.tmp.${Date.now()}`; + try { + await fs.promises.writeFile(tmpPath, truncated, "utf-8"); + await fs.promises.rename(tmpPath, logPath); + } catch (error) { + log.warn("Failed to write truncated local logs:", error); + await fs.promises.unlink(tmpPath).catch(() => {}); + } + } + + /** + * After truncateLocalLogs, any in-flight drain may overwrite the truncated + * file with stale accumulated content. Clearing pending prevents an extra + * write after the current doWrite completes. The renderer's first writeLocalLogs + * after reconnect will restore the correct truncated content anyway. + */ + cancelPendingWrite(taskRunId: string): void { + const entry = this.writes.get(taskRunId); + if (entry) { + entry.state.pending = undefined; + } + } + + /** + * Trims local logs.ndjson to the JSON-RPC response for the given prompt id + * (the turn-completion boundary). Returns false when the trim could not be + * guaranteed — a real read/write error, or the boundary line was not found + * while there was content to search — so the caller can warn the user that + * the restored view may still contain post-checkpoint turns. A missing log + * file (nothing to trim) returns true. + */ + async truncateLocalLogsAtPromptBoundary( + taskRunId: string, + promptId: number, + preserveTrailingEntries: string[] = [], + ): Promise { + const logPath = this.getLocalLogPath(taskRunId); + let content: string; + try { + content = await fs.promises.readFile(logPath, "utf-8"); + } catch (error) { + if ((error as NodeJS.ErrnoException).code !== "ENOENT") { + log.warn( + "Failed to read local logs for prompt-boundary truncation:", + error, + ); + return false; + } + // No log file yet — nothing to truncate. + return true; + } + + const lines = content.split("\n").filter((l) => l.trim()); + let boundaryIdx = -1; + + for (let i = 0; i < lines.length; i++) { + try { + const parsed = JSON.parse(lines[i]) as { + type?: string; + notification?: { id?: number; result?: unknown }; + }; + const notif = parsed.notification; + if ( + parsed.type === "notification" && + notif != null && + typeof notif.id === "number" && + notif.id === promptId && + "result" in notif + ) { + boundaryIdx = i; + } + } catch { + // skip unparseable lines + } + } + + if (boundaryIdx === -1) { + log.warn("truncateLocalLogsAtPromptBoundary: prompt boundary not found", { + taskRunId, + promptId, + }); + return false; + } + + const keptLines = lines.slice(0, boundaryIdx + 1); + // Re-add entries that belong to the kept turns but were appended after the + // boundary — e.g. the restored turn's git_checkpoint notification, captured + // on TURN_COMPLETE after the prompt response, which the trim would otherwise + // drop (leaving the restored turn with a disabled restore icon). Skip dups. + const preserved = preserveTrailingEntries + .map((e) => e.trim()) + .filter((e) => e.length > 0 && !keptLines.includes(e)); + + // Nothing after the boundary and nothing to re-add → already trimmed. + if (boundaryIdx + 1 >= lines.length && preserved.length === 0) { + return true; + } + + const truncated = `${[...keptLines, ...preserved].join("\n")}\n`; + const tmpPath = `${logPath}.tmp.${Date.now()}`; + try { + await fs.promises.writeFile(tmpPath, truncated, "utf-8"); + await fs.promises.rename(tmpPath, logPath); + } catch (error) { + log.warn("Failed to write prompt-boundary truncated local logs:", error); + await fs.promises.unlink(tmpPath).catch(() => {}); + return false; + } + + log.info( + "truncateLocalLogsAtPromptBoundary: truncated to prompt boundary", + { + taskRunId, + promptId, + keptLines: keptLines.length, + preserved: preserved.length, + originalLines: lines.length, + }, + ); + + // Queue the truncated content as the next write so any in-flight drain + // that overwrites the file with pre-truncate content is followed by a + // corrective write of the properly-trimmed content. + this.writeLocalLogs(taskRunId, truncated); + return true; + } + async readLocalLogs(taskRunId: string): Promise { const logPath = this.getLocalLogPath(taskRunId); try { diff --git a/apps/code/src/main/trpc/router.ts b/apps/code/src/main/trpc/router.ts index f0f8dd9eb5..08cf091103 100644 --- a/apps/code/src/main/trpc/router.ts +++ b/apps/code/src/main/trpc/router.ts @@ -1,6 +1,7 @@ import { additionalDirectoriesRouter } from "./routers/additional-directories"; import { agentRouter } from "./routers/agent"; import { analyticsRouter } from "./routers/analytics"; +import { checkpointRouter } from "./routers/checkpoint"; import { archiveRouter } from "./routers/archive"; import { authRouter } from "./routers/auth"; import { cloudTaskRouter } from "./routers/cloud-task"; @@ -46,6 +47,7 @@ export const trpcRouter = router({ analytics: analyticsRouter, archive: archiveRouter, auth: authRouter, + checkpoint: checkpointRouter, cloudTask: cloudTaskRouter, connectivity: connectivityRouter, contextMenu: contextMenuRouter, diff --git a/apps/code/src/main/trpc/routers/checkpoint.test.ts b/apps/code/src/main/trpc/routers/checkpoint.test.ts new file mode 100644 index 0000000000..f504ee1d8a --- /dev/null +++ b/apps/code/src/main/trpc/routers/checkpoint.test.ts @@ -0,0 +1,121 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +// The restore mutation guards against overlapping restores for the same +// session with an in-flight lock. A restore with no taskRunId skips the whole +// truncation block, so these tests only need to control the revert saga to +// exercise the lock's acquire/release behaviour. +const sagaRunMock = vi.hoisted(() => vi.fn()); + +vi.mock("@posthog/git/sagas/checkpoint", () => ({ + RevertCheckpointSaga: class { + run = sagaRunMock; + }, + deleteCheckpoint: vi.fn(), +})); + +vi.mock("../../di/container", () => ({ + container: { get: vi.fn() }, +})); + +vi.mock("../../utils/logger", () => ({ + logger: { + scope: () => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }), + }, +})); + +import { checkpointRouter } from "./checkpoint"; + +function deferred(): { + promise: Promise; + resolve: (v: T) => void; +} { + let resolve!: (v: T) => void; + const promise = new Promise((res) => { + resolve = res; + }); + return { promise, resolve }; +} + +async function flushMicrotasks(): Promise { + for (let i = 0; i < 10; i++) await Promise.resolve(); +} + +describe("checkpointRouter.restore concurrency lock", () => { + const caller = checkpointRouter.createCaller({}); + + beforeEach(() => { + sagaRunMock.mockReset(); + }); + + it("rejects a second restore for the same repo while one is in flight", async () => { + // First revert hangs so the first restore stays in flight and holds the lock. + const gate = deferred<{ success: boolean }>(); + sagaRunMock.mockReturnValueOnce(gate.promise); + + const first = caller.restore({ checkpointId: "cp1", repoPath: "/repo" }); + await flushMicrotasks(); + + await expect( + caller.restore({ checkpointId: "cp2", repoPath: "/repo" }), + ).rejects.toThrow(/already in progress/i); + + gate.resolve({ success: true }); + await expect(first).resolves.toEqual({ + restoredSessionId: undefined, + truncationFailed: false, + }); + }); + + it("allows a different repo to restore concurrently", async () => { + const gate = deferred<{ success: boolean }>(); + sagaRunMock.mockReturnValueOnce(gate.promise); + sagaRunMock.mockResolvedValueOnce({ success: true }); + + const first = caller.restore({ checkpointId: "cp1", repoPath: "/repo-a" }); + await flushMicrotasks(); + + // Different repoPath → different lock key → not blocked. + await expect( + caller.restore({ checkpointId: "cp2", repoPath: "/repo-b" }), + ).resolves.toEqual({ + restoredSessionId: undefined, + truncationFailed: false, + }); + + gate.resolve({ success: true }); + await first; + }); + + it("releases the lock after a successful restore", async () => { + sagaRunMock.mockResolvedValue({ success: true }); + + await caller.restore({ checkpointId: "cp1", repoPath: "/repo" }); + await expect( + caller.restore({ checkpointId: "cp2", repoPath: "/repo" }), + ).resolves.toEqual({ + restoredSessionId: undefined, + truncationFailed: false, + }); + }); + + it("releases the lock even when the revert fails", async () => { + sagaRunMock.mockResolvedValueOnce({ success: false, error: "revert boom" }); + await expect( + caller.restore({ checkpointId: "cp1", repoPath: "/repo" }), + ).rejects.toThrow(/revert boom/); + + // Lock released by the finally block → a later restore proceeds. + sagaRunMock.mockResolvedValueOnce({ success: true }); + await expect( + caller.restore({ checkpointId: "cp2", repoPath: "/repo" }), + ).resolves.toEqual({ + restoredSessionId: undefined, + truncationFailed: false, + }); + }); +}); diff --git a/apps/code/src/main/trpc/routers/checkpoint.ts b/apps/code/src/main/trpc/routers/checkpoint.ts new file mode 100644 index 0000000000..70ad252218 --- /dev/null +++ b/apps/code/src/main/trpc/routers/checkpoint.ts @@ -0,0 +1,341 @@ +import { POSTHOG_NOTIFICATIONS } from "@posthog/agent"; +import { truncateCodexRollout } from "@posthog/agent/adapters/codex/rollout"; +import { createGitClient } from "@posthog/git/client"; +import { + deleteCheckpoint, + RevertCheckpointSaga, +} from "@posthog/git/sagas/checkpoint"; +import { z } from "zod"; +import { container } from "../../di/container"; +import { MAIN_TOKENS } from "../../di/tokens"; +import type { AgentService } from "../../services/agent/service"; +import type { AuthService } from "../../services/auth/service"; +import type { LocalLogsService } from "../../services/local-logs/service"; +import { logger } from "../../utils/logger"; +import { publicProcedure, router } from "../trpc"; + +const log = logger.scope("checkpoint-router"); + +// Guards against concurrent restores for the same session. Two restores racing +// would truncate logs.ndjson + the rollout at different offsets and corrupt +// both. Keyed by taskRunId (falling back to repoPath when there is no run). +const restoreInFlight = new Set(); + +const getAgentService = () => + container.get(MAIN_TOKENS.AgentService); + +const getAuthService = () => + container.get(MAIN_TOKENS.AuthService); + +const getLocalLogsService = () => + container.get(MAIN_TOKENS.LocalLogsService); + +const restoreInput = z.object({ + checkpointId: z.string(), + repoPath: z.string(), + taskRunId: z.string().optional(), +}); + +export const checkpointRouter = router({ + /** + * Re-emit stored checkpoint notifications for a session through the existing + * SessionEvent channel so the renderer receives them after a reconnect. + * Called after subscribeToChannel so events are not lost. + */ + replayCheckpoints: publicProcedure + .input(z.object({ taskRunId: z.string() })) + .mutation(({ input }) => { + const agentService = getAgentService(); + const count = agentService.replayCheckpoints(input.taskRunId); + log.info("replayCheckpoints mutation", { + taskRunId: input.taskRunId, + count, + }); + return { count }; + }), + + restore: publicProcedure.input(restoreInput).mutation(async ({ input }) => { + // Reject overlapping restores for the same session (see restoreInFlight). + const lockKey = input.taskRunId ?? input.repoPath; + if (restoreInFlight.has(lockKey)) { + throw new Error( + "A checkpoint restore is already in progress for this session. Please wait for it to finish.", + ); + } + restoreInFlight.add(lockKey); + try { + return await runRestore(input); + } finally { + restoreInFlight.delete(lockKey); + } + }), +}); + +/** + * Performs the actual checkpoint restore. Extracted so the mutation can wrap it + * in the restoreInFlight lock with a try/finally. + * + * Returns `truncationFailed: true` when any log/rollout truncation step errored. + * The git revert still succeeded in that case, but the agent may keep memory + * past the checkpoint, so the renderer surfaces a warning to the user. + */ +async function runRestore(input: { + checkpointId: string; + repoPath: string; + taskRunId?: string; +}): Promise<{ restoredSessionId?: string; truncationFailed: boolean }> { + // 1. Revert git files to checkpoint state + const saga = new RevertCheckpointSaga(); + const result = await saga.run({ + baseDir: input.repoPath, + checkpointId: input.checkpointId, + }); + if (!result.success) { + throw new Error(result.error ?? "Failed to revert checkpoint"); + } + + // 2. Truncate logs, clean up orphaned refs, and restart the agent. + // Everything here is non-fatal: git files were already reverted. + let restoredSessionId: string | undefined; + // Tracks whether any truncation step failed, so the renderer can warn that + // the restore was partial (agent memory may extend past the checkpoint). + let truncationFailed = false; + if (input.taskRunId) { + try { + const agentService = getAgentService(); + const info = agentService.getSessionInfo(input.taskRunId); + if (info) { + restoredSessionId = info.sessionId; + let orphanedCheckpointIds: string[] = []; + + // Truncate S3 + local cache BEFORE cancelling the session. + // cancelSession triggers reconnect; if reconnect reads local cache + // before truncation, the stale full history would be loaded. + try { + const authService = getAuthService(); + const url = `${info.apiHost}/api/projects/${info.projectId}/tasks/${info.taskId}/runs/${input.taskRunId}/truncate_log/`; + const promptId = agentService.getCheckpointPromptId( + input.taskRunId, + input.checkpointId, + ); + const response = await authService.authenticatedFetch(fetch, url, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + checkpoint_id: input.checkpointId, + prompt_id: promptId, + }), + }); + if (response.ok) { + const s3Result = (await response.json()) as { + truncated: boolean; + original_line_count: number; + truncated_line_count: number; + orphaned_checkpoint_ids: string[]; + }; + log.info("S3 log truncated after checkpoint restore", { + taskRunId: input.taskRunId, + checkpointId: input.checkpointId, + promptId, + truncated: s3Result.truncated, + originalLines: s3Result.original_line_count, + truncatedLines: s3Result.truncated_line_count, + orphanedCheckpoints: s3Result.orphaned_checkpoint_ids, + }); + const localLogsSvc = getLocalLogsService(); + if (s3Result.truncated) { + orphanedCheckpointIds = s3Result.orphaned_checkpoint_ids ?? []; + // Coarse trim: align local cache with the S3 line count. + await localLogsSvc + .truncateLocalLogs( + input.taskRunId, + s3Result.truncated_line_count, + ) + .catch((err: unknown) => { + truncationFailed = true; + log.warn("Failed to truncate local log cache (non-fatal)", { + taskRunId: input.taskRunId, + error: err instanceof Error ? err.message : String(err), + }); + }); + } else { + log.warn( + "S3 log was not truncated — will rely on client-side prompt-boundary trim", + { + taskRunId: input.taskRunId, + checkpointId: input.checkpointId, + promptId, + }, + ); + } + // The restored turn's own git_checkpoint notification sits after its + // prompt response, so BOTH the S3 truncate and the local prompt- + // boundary trim drop it — leaving the restored turn with a disabled + // restore icon after a restart. Re-add it to both stores so the + // restored turn stays restorable. + const restoredCheckpointEntry = + promptId != null + ? { + type: "notification" as const, + timestamp: new Date().toISOString(), + notification: { + jsonrpc: "2.0" as const, + method: POSTHOG_NOTIFICATIONS.GIT_CHECKPOINT, + params: { + checkpointId: input.checkpointId, + promptId, + }, + }, + } + : undefined; + + // Re-append to S3 only when the truncate actually removed it (else + // the checkpoint is still there and re-appending would duplicate). + if (s3Result.truncated && restoredCheckpointEntry) { + const appendUrl = `${info.apiHost}/api/projects/${info.projectId}/tasks/${info.taskId}/runs/${input.taskRunId}/append_log/`; + await authService + .authenticatedFetch(fetch, appendUrl, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ entries: [restoredCheckpointEntry] }), + }) + .then((res) => { + if (!res.ok) { + log.warn("Failed to re-append restored checkpoint to S3", { + taskRunId: input.taskRunId, + status: res.status, + }); + } + }) + .catch((err: unknown) => { + log.warn("Failed to re-append restored checkpoint to S3", { + taskRunId: input.taskRunId, + error: err instanceof Error ? err.message : String(err), + }); + }); + } + + // Fine-trim the local cache to the exact prompt boundary regardless + // of the S3 result (the primary defense for Codex sessions, whose + // events are all type:"notification"). Cancel any in-flight drain + // first so our write wins, and re-add the restored checkpoint so the + // restored turn keeps its icon after a restart. + localLogsSvc.cancelPendingWrite(input.taskRunId); + const boundaryTrimmed = await localLogsSvc + .truncateLocalLogsAtPromptBoundary( + input.taskRunId, + promptId ?? -1, + restoredCheckpointEntry + ? [JSON.stringify(restoredCheckpointEntry)] + : [], + ) + .catch((err: unknown) => { + log.warn( + "Failed to truncate local logs at prompt boundary (non-fatal)", + { + taskRunId: input.taskRunId, + error: err instanceof Error ? err.message : String(err), + }, + ); + return false; + }); + // Flag a missing boundary only when we had a real prompt id to + // anchor to — a checkpoint without one (-1) has nothing to find. + if (!boundaryTrimmed && promptId != null) { + truncationFailed = true; + log.warn( + "Prompt-boundary trim did not complete; restored view may keep post-checkpoint turns", + { taskRunId: input.taskRunId, promptId }, + ); + } + } else { + log.warn("S3 log truncation returned non-ok status", { + taskRunId: input.taskRunId, + status: response.status, + }); + } + } catch (err) { + truncationFailed = true; + log.warn("Failed to truncate S3 log (non-fatal)", { + taskRunId: input.taskRunId, + error: err instanceof Error ? err.message : String(err), + }); + } + + // Clean up git refs for orphaned checkpoints + if (orphanedCheckpointIds.length > 0) { + const git = createGitClient(input.repoPath); + await Promise.all( + orphanedCheckpointIds.map((id) => + deleteCheckpoint(git, id).catch(() => {}), + ), + ); + log.info("Deleted orphaned checkpoint refs", { + orphanedCheckpointIds, + }); + } + + // Trim in-memory checkpoints so replayCheckpoints only re-emits + // survivors — must happen before cancelSession triggers reconnect. + // Returns the number of surviving turns (used as keepTurns for Codex). + const survivingTurns = agentService.truncateCheckpoints( + input.taskRunId, + input.checkpointId, + ); + + // Mark this taskRunId so hydrateSessionJsonl force-refetches from + // the truncated S3 on reconnect, bypassing any stale existing JSONL. + // Must be set before cancelSession triggers the reconnect. + agentService.markCheckpointRestore(input.taskRunId); + + // Cancel the session — renderer reconnects, hydrateSessionJsonl + // fetches the truncated S3 log and overwrites the stale JSONL. + await agentService.cancelSession(input.taskRunId); + log.info("Agent session cancelled for checkpoint restore", { + taskRunId: input.taskRunId, + checkpointId: input.checkpointId, + }); + + // For Codex: truncate the on-disk rollout AFTER the subprocess is + // killed (file is now closed) so the resumed session has memory only + // up to the checkpoint. Must happen after cancelSession. + // survivingTurns=0 means the checkpoint wasn't found — skip truncation. + log.info("Checkpoint restore rollout decision", { + taskRunId: input.taskRunId, + adapter: info.adapter, + sessionId: info.sessionId, + survivingTurns, + }); + if (info.adapter === "codex" && survivingTurns > 0) { + await truncateCodexRollout(info.sessionId, survivingTurns, log).catch( + (err: unknown) => { + truncationFailed = true; + log.warn("Failed to truncate codex rollout (non-fatal)", { + taskRunId: input.taskRunId, + sessionId: info.sessionId, + error: err instanceof Error ? err.message : String(err), + }); + }, + ); + } + } else { + log.warn("No active session found for checkpoint restore", { + taskRunId: input.taskRunId, + }); + } + } catch (err) { + truncationFailed = true; + log.warn("Failed to truncate agent session", { + taskRunId: input.taskRunId, + error: err instanceof Error ? err.message : String(err), + }); + } + } + + log.info("Checkpoint restore complete", { + taskRunId: input.taskRunId, + restoredSessionId, + truncationFailed, + }); + return { restoredSessionId, truncationFailed }; +} diff --git a/apps/code/src/renderer/api/posthogClient.ts b/apps/code/src/renderer/api/posthogClient.ts index e47c4e63e2..a27faf8775 100644 --- a/apps/code/src/renderer/api/posthogClient.ts +++ b/apps/code/src/renderer/api/posthogClient.ts @@ -1410,6 +1410,34 @@ export class PostHogAPIClient { } } + async truncateTaskRunLog( + taskId: string, + runId: string, + checkpointId: string, + ): Promise<{ + truncated: boolean; + original_line_count: number; + truncated_line_count: number; + orphaned_checkpoint_ids: string[]; + }> { + const teamId = await this.getTeamId(); + const url = `${this.api.baseUrl}/api/projects/${teamId}/tasks/${taskId}/runs/${runId}/truncate_log/`; + const response = await this.api.fetcher.fetch({ + method: "post", + url: new URL(url), + path: url, + overrides: { + body: JSON.stringify({ checkpoint_id: checkpointId }), + }, + }); + if (!response.ok) { + throw new Error( + `Failed to truncate task run log: ${response.statusText}`, + ); + } + return response.json(); + } + async getTaskRunSessionLogs( taskId: string, runId: string, diff --git a/apps/code/src/renderer/constants/keyboard-shortcuts.ts b/apps/code/src/renderer/constants/keyboard-shortcuts.ts index b162013bbc..8e2b9c650d 100644 --- a/apps/code/src/renderer/constants/keyboard-shortcuts.ts +++ b/apps/code/src/renderer/constants/keyboard-shortcuts.ts @@ -22,6 +22,7 @@ export const SHORTCUTS = { SPACE_UP: "mod+up", SPACE_DOWN: "mod+down", FIND_IN_CONVERSATION: "mod+f", + FILE_PICKER: "mod+p", BLUR: "escape", SUBMIT_BLUR: "mod+enter", } as const; @@ -35,6 +36,8 @@ export interface KeyboardShortcut { category: ShortcutCategory; context?: string; alternateKeys?: string; + /** Whether this shortcut's keybinding can be customized by the user. */ + configurable?: boolean; } export const KEYBOARD_SHORTCUTS: KeyboardShortcut[] = [ @@ -160,6 +163,14 @@ export const KEYBOARD_SHORTCUTS: KeyboardShortcut[] = [ category: "panels", context: "Task detail", }, + { + id: "file-picker", + keys: SHORTCUTS.FILE_PICKER, + description: "Open file picker", + category: "panels", + context: "Task detail", + configurable: true, + }, { id: "paste-as-file", keys: SHORTCUTS.PASTE_AS_FILE, @@ -218,6 +229,56 @@ export const CATEGORY_LABELS: Record = { editor: "Editor", }; +export const CONFIGURABLE_SHORTCUT_IDS = [ + "command-menu", + "new-task", + "settings", + "shortcuts", + "inbox", + "prev-task", + "next-task", + "space-up", + "space-down", + "go-back", + "go-forward", + "toggle-left-sidebar", + "toggle-review-panel", + "close-tab", + "open-in-editor", + "copy-path", + "toggle-focus", + "file-picker", + "paste-as-file", + "prompt-history-prev", + "prompt-history-next", +] as const; + +export type ConfigurableShortcutId = (typeof CONFIGURABLE_SHORTCUT_IDS)[number]; + +export const DEFAULT_KEYBINDINGS: Record = { + "command-menu": SHORTCUTS.COMMAND_MENU, + "new-task": SHORTCUTS.NEW_TASK, + settings: SHORTCUTS.SETTINGS, + shortcuts: SHORTCUTS.SHORTCUTS_SHEET, + inbox: SHORTCUTS.INBOX, + "prev-task": SHORTCUTS.PREV_TASK, + "next-task": SHORTCUTS.NEXT_TASK, + "space-up": SHORTCUTS.SPACE_UP, + "space-down": SHORTCUTS.SPACE_DOWN, + "go-back": SHORTCUTS.GO_BACK, + "go-forward": SHORTCUTS.GO_FORWARD, + "toggle-left-sidebar": SHORTCUTS.TOGGLE_LEFT_SIDEBAR, + "toggle-review-panel": SHORTCUTS.TOGGLE_REVIEW_PANEL, + "close-tab": SHORTCUTS.CLOSE_TAB, + "open-in-editor": SHORTCUTS.OPEN_IN_EDITOR, + "copy-path": SHORTCUTS.COPY_PATH, + "toggle-focus": SHORTCUTS.TOGGLE_FOCUS, + "file-picker": SHORTCUTS.FILE_PICKER, + "paste-as-file": SHORTCUTS.PASTE_AS_FILE, + "prompt-history-prev": "shift+up", + "prompt-history-next": "shift+down", +}; + export function getShortcutsByCategory(): Record< ShortcutCategory, KeyboardShortcut[] diff --git a/apps/code/src/renderer/features/sessions/components/ConversationView.tsx b/apps/code/src/renderer/features/sessions/components/ConversationView.tsx index 4afb50fd67..9a9d950ee9 100644 --- a/apps/code/src/renderer/features/sessions/components/ConversationView.tsx +++ b/apps/code/src/renderer/features/sessions/components/ConversationView.tsx @@ -1,6 +1,7 @@ import { CHAT_CONTENT_MAX_WIDTH } from "@features/sessions/constants"; import { useContextUsage } from "@features/sessions/hooks/useContextUsage"; import { useConversationSearch } from "@features/sessions/hooks/useConversationSearch"; +import { useRestoreCheckpoint } from "@features/sessions/hooks/useRestoreCheckpoint"; import { SessionTaskIdProvider } from "@features/sessions/hooks/useSessionTaskId"; import { sessionStoreSetters, @@ -27,6 +28,7 @@ import { ConversationSearchBar } from "./ConversationSearchBar"; import { GitActionMessage } from "./GitActionMessage"; import { GitActionResult } from "./GitActionResult"; import { mergeConversationItems } from "./mergeConversationItems"; +import { RestoreCheckpointDialog } from "./RestoreCheckpointDialog"; import { SessionFooter } from "./SessionFooter"; import { QueuedMessageView } from "./session-update/QueuedMessageView"; import { @@ -126,6 +128,12 @@ export function ConversationView({ const isCloud = session?.isCloud ?? false; + const restore = useRestoreCheckpoint({ + repoPath: repoPath ?? undefined, + taskId, + taskRunId: session?.taskRunId, + }); + const items = useMemo( () => mergeConversationItems({ @@ -197,6 +205,19 @@ export function ConversationView({ ? slackThreadUrl : undefined } + onRestoreCheckpoint={ + item.turnContext?.lastCheckpointId && !isCloud + ? () => + restore.requestRestore( + item.turnContext?.lastCheckpointId as string, + ) + : undefined + } + restoreDisabledReason={ + isCloud + ? "Checkpoint restore isn't available for cloud tasks" + : "No checkpoint was captured for this turn" + } /> ); case "git_action": @@ -240,7 +261,15 @@ export function ConversationView({ ); } }, - [repoPath, taskId, slackThreadUrl, firstUserMessageId, initialItemIds], + [ + repoPath, + taskId, + slackThreadUrl, + firstUserMessageId, + initialItemIds, + restore.requestRestore, + isCloud, + ], ); const getItemKey = useCallback((item: ConversationItem) => item.id, []); @@ -310,6 +339,13 @@ export function ConversationView({ )} + ); } diff --git a/apps/code/src/renderer/features/sessions/components/RestoreCheckpointDialog.tsx b/apps/code/src/renderer/features/sessions/components/RestoreCheckpointDialog.tsx new file mode 100644 index 0000000000..eb301c202d --- /dev/null +++ b/apps/code/src/renderer/features/sessions/components/RestoreCheckpointDialog.tsx @@ -0,0 +1,52 @@ +import { Button, Dialog, Flex, Text } from "@radix-ui/themes"; + +interface RestoreCheckpointDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + onConfirm: () => void; + isLoading: boolean; + /** True when the agent is mid-response; restoring will stop it. */ + isTurnInProgress?: boolean; +} + +export function RestoreCheckpointDialog({ + open, + onOpenChange, + onConfirm, + isLoading, + isTurnInProgress = false, +}: RestoreCheckpointDialogProps) { + return ( + + + + + Restore checkpoint + + + This will revert all file changes made after this point. This action + cannot be undone. + {isTurnInProgress + ? " The agent is still responding — restoring will stop the current response." + : ""} + + + + + + + + + + + ); +} diff --git a/apps/code/src/renderer/features/sessions/components/buildConversationItems.test.ts b/apps/code/src/renderer/features/sessions/components/buildConversationItems.test.ts index 0bdc3d3d88..451bc30a81 100644 --- a/apps/code/src/renderer/features/sessions/components/buildConversationItems.test.ts +++ b/apps/code/src/renderer/features/sessions/components/buildConversationItems.test.ts @@ -1,3 +1,4 @@ +import { POSTHOG_NOTIFICATIONS } from "@posthog/agent"; import type { AcpMessage } from "@shared/types/session-events"; import { makeAttachmentUri } from "@utils/promptContent"; import { describe, expect, it } from "vitest"; @@ -74,6 +75,22 @@ function turnCompleteMsg(ts: number, stopReason = "end_turn"): AcpMessage { }; } +function gitCheckpointMsg( + ts: number, + checkpointId: string, + promptId: number, +): AcpMessage { + return { + type: "acp_message", + ts, + message: { + jsonrpc: "2.0", + method: POSTHOG_NOTIFICATIONS.GIT_CHECKPOINT, + params: { checkpointId, promptId }, + }, + }; +} + describe("buildConversationItems", () => { it("extracts cloud prompt attachments into user messages", () => { const uri = makeAttachmentUri("/tmp/hello world.txt"); @@ -117,10 +134,53 @@ describe("buildConversationItems", () => { label: "hello world.txt", }, ], + turnContext: expect.anything(), }, ]); }); + // Restore-icon gating: a turn is restorable iff its user_message carries a + // lastCheckpointId, set from the GIT_CHECKPOINT notification. This must work + // when the checkpoint is loaded from logs.ndjson on a cold start (the only + // place it lives after the in-memory map is gone on app restart). + it("associates a git_checkpoint notification with its turn via promptId", () => { + const result = buildConversationItems( + [ + userPromptMsg(1, 1, "first"), + promptResponseMsg(2, 1), + turnCompleteMsg(3), + gitCheckpointMsg(4, "cp-abc", 1), + ], + null, + ); + + const userMsg = result.items.find((i) => i.type === "user_message"); + expect(userMsg?.type).toBe("user_message"); + expect( + userMsg?.type === "user_message" + ? userMsg.turnContext?.lastCheckpointId + : undefined, + ).toBe("cp-abc"); + }); + + it("leaves lastCheckpointId unset for a turn with no checkpoint", () => { + const result = buildConversationItems( + [ + userPromptMsg(1, 1, "first"), + promptResponseMsg(2, 1), + turnCompleteMsg(3), + ], + null, + ); + + const userMsg = result.items.find((i) => i.type === "user_message"); + expect( + userMsg?.type === "user_message" + ? (userMsg.turnContext?.lastCheckpointId ?? null) + : "missing", + ).toBeNull(); + }); + it("marks cloud turns complete from structured turn completion notifications", () => { const result = buildConversationItems( [userPromptMsg(10, 42, "hello"), turnCompleteMsg(25)], @@ -175,6 +235,7 @@ describe("buildConversationItems", () => { label: "test.txt", }, ], + turnContext: expect.anything(), }, ]); }); @@ -218,6 +279,7 @@ describe("buildConversationItems", () => { label: "Receipt-2264-0277.pdf", }, ], + turnContext: expect.anything(), }, ]); }); diff --git a/apps/code/src/renderer/features/sessions/components/buildConversationItems.ts b/apps/code/src/renderer/features/sessions/components/buildConversationItems.ts index fbd0d1ee4b..c75a862a81 100644 --- a/apps/code/src/renderer/features/sessions/components/buildConversationItems.ts +++ b/apps/code/src/renderer/features/sessions/components/buildConversationItems.ts @@ -28,6 +28,7 @@ export interface TurnContext { childItems: Map; turnCancelled: boolean; turnComplete: boolean; + lastCheckpointId: string | null; } export type ConversationItem = @@ -38,6 +39,7 @@ export type ConversationItem = timestamp: number; attachments?: UserMessageAttachment[]; pinToTop?: boolean; + turnContext?: TurnContext; } | { type: "git_action"; id: string; actionType: GitActionType } | { type: "skill_button_action"; id: string; buttonId: SkillButtonId } @@ -92,12 +94,22 @@ interface TurnState { context: TurnContext; gitAction: ReturnType; itemCount: number; + lastCheckpointId: string | null; } interface ItemBuilder { items: ConversationItem[]; currentTurn: TurnState | null; pendingPrompts: Map; + /** All turns ever created, keyed by promptId. Used to associate GIT_CHECKPOINT + * notifications with the correct turn regardless of stream order. */ + allTurns: Map; + /** GIT_CHECKPOINT notifications deferred until after all turns are built, + * so allTurns.get(promptId) is guaranteed to find the right entry. */ + pendingCheckpoints: Array<{ + checkpointId: string; + promptId: number | undefined; + }>; shellExecutes: Map; isCompacting: boolean; nextId: () => number; @@ -114,6 +126,8 @@ function createItemBuilder(): ItemBuilder { items: [], currentTurn: null, pendingPrompts: new Map(), + allTurns: new Map(), + pendingCheckpoints: [], shellExecutes: new Map(), isCompacting: false, nextId: () => idCounter++, @@ -206,6 +220,18 @@ export function buildConversationItems( b.currentTurn.context.turnComplete = true; } + // Post-pass: assign deferred checkpoints now that allTurns is fully populated. + for (const { checkpointId, promptId } of b.pendingCheckpoints) { + const targetTurn = + promptId !== undefined + ? (b.allTurns.get(promptId) ?? b.currentTurn) + : b.currentTurn; + if (targetTurn) { + targetTurn.lastCheckpointId = checkpointId; + targetTurn.context.lastCheckpointId = checkpointId; + } + } + markThoughtCompletion(b.items); const lastTurnInfo: LastTurnInfo | null = b.currentTurn @@ -248,6 +274,7 @@ function handlePromptRequest( childItems, turnCancelled: false, turnComplete: false, + lastCheckpointId: null, }; b.currentTurn = { @@ -259,9 +286,11 @@ function handlePromptRequest( context, gitAction, itemCount: 0, + lastCheckpointId: null, }; b.pendingPrompts.set(msg.id, b.currentTurn); + b.allTurns.set(msg.id, b.currentTurn); if (gitAction.isGitAction && gitAction.actionType) { b.items.push({ @@ -282,6 +311,7 @@ function handlePromptRequest( content: userContent, timestamp: ts, attachments: userPrompt.attachments, + turnContext: context, }); } } @@ -439,6 +469,19 @@ function handleNotification( }); return; } + + if (isNotification(msg.method, POSTHOG_NOTIFICATIONS.GIT_CHECKPOINT)) { + const params = msg.params as { checkpointId?: string; promptId?: number }; + if (!params?.checkpointId) return; + // Defer until after all turns are built so allTurns.get(promptId) always + // finds the right entry, even if this notification arrives before the + // session/prompt events (e.g. after a reconnect replay). + b.pendingCheckpoints.push({ + checkpointId: params.checkpointId, + promptId: params.promptId, + }); + return; + } } function ensureProgressCardForGroup( @@ -533,6 +576,7 @@ function ensureImplicitTurn(b: ItemBuilder, ts: number) { childItems, turnCancelled: false, turnComplete: false, + lastCheckpointId: null, }; b.currentTurn = { @@ -544,6 +588,7 @@ function ensureImplicitTurn(b: ItemBuilder, ts: number) { context, gitAction: { isGitAction: false, actionType: null, prompt: "" }, itemCount: 0, + lastCheckpointId: null, }; } diff --git a/apps/code/src/renderer/features/sessions/components/session-update/AgentMessage.tsx b/apps/code/src/renderer/features/sessions/components/session-update/AgentMessage.tsx index 0cf7d00083..33e8fb81b4 100644 --- a/apps/code/src/renderer/features/sessions/components/session-update/AgentMessage.tsx +++ b/apps/code/src/renderer/features/sessions/components/session-update/AgentMessage.tsx @@ -8,7 +8,7 @@ import { useCwd } from "@features/sidebar/hooks/useCwd"; import type { FileItem } from "@hooks/useRepoFiles"; import { useRepoFiles } from "@hooks/useRepoFiles"; import { Check, Copy } from "@phosphor-icons/react"; -import { Box, Code, IconButton } from "@radix-ui/themes"; +import { Box, Code, Flex, IconButton } from "@radix-ui/themes"; import { memo, useCallback, useMemo, useState } from "react"; import type { Components } from "react-markdown"; @@ -154,7 +154,10 @@ export const AgentMessage = memo(function AgentMessage({ content={content} componentsOverride={agentComponents} /> - + : } - + ); }); diff --git a/apps/code/src/renderer/features/sessions/components/session-update/ToolCallBlock.tsx b/apps/code/src/renderer/features/sessions/components/session-update/ToolCallBlock.tsx index 5ebe91129c..a5bf4a6766 100644 --- a/apps/code/src/renderer/features/sessions/components/session-update/ToolCallBlock.tsx +++ b/apps/code/src/renderer/features/sessions/components/session-update/ToolCallBlock.tsx @@ -52,6 +52,7 @@ export function ToolCallBlock({ childItems: childItemsMap ?? new Map(), turnCancelled: turnCancelled ?? false, turnComplete: turnComplete ?? false, + lastCheckpointId: null, }; return ( diff --git a/apps/code/src/renderer/features/sessions/components/session-update/UserMessage.tsx b/apps/code/src/renderer/features/sessions/components/session-update/UserMessage.tsx index aeb82a09b1..635f640d6a 100644 --- a/apps/code/src/renderer/features/sessions/components/session-update/UserMessage.tsx +++ b/apps/code/src/renderer/features/sessions/components/session-update/UserMessage.tsx @@ -1,6 +1,7 @@ import { Tooltip } from "@components/ui/Tooltip"; import { MarkdownRenderer } from "@features/editor/components/MarkdownRenderer"; import { + ArrowCounterClockwise, CaretDown, CaretUp, Check, @@ -30,6 +31,13 @@ interface UserMessageProps { sourceUrl?: string; attachments?: UserMessageAttachment[]; animate?: boolean; + onRestoreCheckpoint?: () => void; + /** + * Tooltip shown when restore is unavailable for this turn. When + * onRestoreCheckpoint is absent the button still renders but is disabled, + * with this text explaining why (no checkpoint, or cloud task). + */ + restoreDisabledReason?: string; } function formatTimestamp(ts: number): string { @@ -49,6 +57,8 @@ export function UserMessage({ sourceUrl, attachments = [], animate = true, + onRestoreCheckpoint, + restoreDisabledReason, }: UserMessageProps) { const containsFileMentions = hasFileMentions(content); const showAttachmentChips = attachments.length > 0 && !containsFileMentions; @@ -161,6 +171,28 @@ export function UserMessage({ {formatTimestamp(timestamp)} )} + + {/* span wrapper keeps the tooltip hoverable when the button is + disabled (Radix disables pointer events on disabled buttons) */} + + + + + + ( + null, + ); + const [isRestoring, setIsRestoring] = useState(false); + + const requestRestore = useCallback((checkpointId: string) => { + setPendingCheckpointId(checkpointId); + setDialogOpen(true); + }, []); + + const confirmRestore = useCallback(async () => { + if (!pendingCheckpointId || !repoPath) return; + + // Checkpoint restore only works for local sessions: it reverts local git + // state and truncates the on-disk rollout/logs. Cloud sessions have no + // local agent to reconnect, so block here (before any destructive revert) + // and tell the user why instead of failing silently. + const session = taskId + ? sessionStoreSetters.getSessionByTaskId(taskId) + : undefined; + if (session?.isCloud) { + toast.error("Checkpoint restore isn't available for cloud sessions"); + setDialogOpen(false); + setPendingCheckpointId(null); + return; + } + + setIsRestoring(true); + try { + const restoreResult = await trpcClient.checkpoint.restore.mutate({ + checkpointId: pendingCheckpointId, + repoPath, + taskRunId, + }); + if (taskId) { + sessionStoreSetters.truncateEventsToCheckpoint( + taskId, + pendingCheckpointId, + ); + // Reconnect the agent, resuming the same Codex/Claude session so the + // agent has memory only up to the restored checkpoint. + getSessionService() + .restoreCheckpointReconnect( + taskId, + repoPath, + restoreResult?.restoredSessionId, + ) + .catch(() => {}); + } + if (restoreResult?.truncationFailed) { + toast.warning( + "Checkpoint restored, but trimming the agent's history failed — it may still remember messages after this point.", + ); + } else { + toast.success("Checkpoint restored successfully"); + } + setDialogOpen(false); + setPendingCheckpointId(null); + } catch (error) { + const message = + error instanceof Error ? error.message : "Failed to restore checkpoint"; + toast.error(message); + } finally { + setIsRestoring(false); + } + }, [pendingCheckpointId, repoPath, taskId, taskRunId]); + + const cancelRestore = useCallback(() => { + setDialogOpen(false); + setPendingCheckpointId(null); + }, []); + + return { + dialogOpen, + setDialogOpen, + isRestoring, + requestRestore, + confirmRestore, + cancelRestore, + }; +} diff --git a/apps/code/src/renderer/features/sessions/service/service.ts b/apps/code/src/renderer/features/sessions/service/service.ts index c0903429bd..a12a2f1ee9 100644 --- a/apps/code/src/renderer/features/sessions/service/service.ts +++ b/apps/code/src/renderer/features/sessions/service/service.ts @@ -619,8 +619,29 @@ export class SessionService { } sessionStoreSetters.setSession(session); + this.subscribeToChannel(taskRunId); + // Re-emit stored checkpoint notifications through the SessionEvent channel + // so restore buttons survive reconnects. Must be called AFTER subscribeToChannel + // so the subscription is active and the emitted events are not lost. + trpcClient.checkpoint.replayCheckpoints + .mutate({ taskRunId }) + .then(({ count }) => { + if (count > 0) { + log.info("Checkpoint replay triggered after reconnect", { + taskRunId, + count, + }); + } + }) + .catch((err) => { + log.warn("Failed to replay checkpoints after reconnect", { + taskRunId, + error: err, + }); + }); + try { const modeOpt = getConfigOptionByCategory(persistedConfigOptions, "mode"); const persistedMode = @@ -647,6 +668,7 @@ export class SessionService { }); const { customInstructions } = useSettingsStore.getState(); + const result = await trpcClient.agent.reconnect.mutate({ taskId, taskRunId, @@ -1402,6 +1424,56 @@ export class SessionService { this.drainQueuedMessages(taskRunId, session); } + + // Sync GIT_CHECKPOINT notifications to S3 so truncate_log can find them + if ( + "method" in msg && + isNotification(msg.method, POSTHOG_NOTIFICATIONS.GIT_CHECKPOINT) + ) { + const params = msg.params as { + checkpointId?: string; + promptId?: number; + replay?: boolean; + }; + // Skip replayed checkpoints - they're already persisted and syncing them + // would undo S3 log truncation after a restore + if (params?.replay) { + log.debug("Skipping checkpoint sync (replayed)", { + taskRunId, + checkpointId: params.checkpointId, + }); + return; + } + if (params?.checkpointId) { + const storedEntry: StoredLogEntry = { + type: "notification", + timestamp: new Date().toISOString(), + notification: msg as StoredLogEntry["notification"], + }; + getAuthenticatedClient() + .then((client) => { + if (client) { + return client.appendTaskRunLog(session.taskId, taskRunId, [ + storedEntry, + ]); + } + return undefined; + }) + .then(() => { + log.info("Checkpoint synced to S3", { + taskRunId, + checkpointId: params.checkpointId, + }); + }) + .catch((error) => { + log.warn("Failed to sync checkpoint to S3", { + taskRunId, + checkpointId: params.checkpointId, + error, + }); + }); + } + } } private drainQueuedMessages( @@ -2618,6 +2690,113 @@ export class SessionService { await this.reconnectInPlace(taskId, repoPath, null); } + /** + * Reconnect after a checkpoint restore without overwriting the renderer's + * already-truncated events or replaying S3 history. + * + * The normal reconnect path (reconnectToLocalSession) always overwrites + * session.events from the JSONL and then hydrates the agent from S3, which + * causes both the truncated conversation to disappear and all post-checkpoint + * turns to reappear. This method: + * 1. Keeps the existing session.events untouched (truncated by the caller). + * 2. Cancels the old agent and re-subscribes. + * 3. Starts a FRESH agent session (sessionId=undefined) so codex-acp / claude + * has no memory of turns after the restored checkpoint. + * 4. Disables the prompt input while reconnecting rather than showing a + * full "Connecting to agent…" overlay. + */ + async restoreCheckpointReconnect( + taskId: string, + repoPath: string, + restoredSessionId?: string | null, + ): Promise { + this.localRepoPaths.set(taskId, repoPath); + const session = sessionStoreSetters.getSessionByTaskId(taskId); + if (!session || session.isCloud) return; + + const { taskRunId, logUrl } = session; + + // Block prompt input while reconnecting, but keep status=connected so + // the UI doesn't show a "Connecting to agent…" overlay. + sessionStoreSetters.updateSession(taskRunId, { + isPromptPending: true, + promptStartedAt: null, + }); + + try { + await trpcClient.agent.cancel.mutate({ sessionId: taskRunId }); + } catch { + // expected if agent already exited + } + this.unsubscribeFromChannel(taskRunId); + this.subscribeToChannel(taskRunId); + + // Do NOT call replayCheckpoints here — truncateEventsToCheckpoint already + // preserved the surviving checkpoints in session.events. Replaying would + // duplicate them. + + const auth = await this.getAuthCredentials(); + if (!auth) { + sessionStoreSetters.updateSession(taskRunId, { + status: "error", + isPromptPending: false, + errorMessage: "Authentication required. Please sign in.", + }); + return; + } + + const persistedConfigOptions = getPersistedConfigOptions(taskRunId); + const modeOpt = getConfigOptionByCategory(persistedConfigOptions, "mode"); + const persistedMode = + modeOpt?.type === "select" ? String(modeOpt.currentValue) : undefined; + const { customInstructions } = useSettingsStore.getState(); + + log.info( + "restoreCheckpointReconnect: resuming with session from restore mutation", + { + taskRunId, + restoredSessionId, + willResumeSession: !!restoredSessionId, + }, + ); + + try { + const result = await trpcClient.agent.reconnect.mutate({ + taskId, + taskRunId, + repoPath, + apiHost: auth.apiHost, + projectId: auth.projectId, + logUrl, + sessionId: restoredSessionId ?? undefined, + adapter: session.adapter, + permissionMode: persistedMode, + customInstructions: customInstructions || undefined, + }); + + if (result) { + sessionStoreSetters.updateSession(taskRunId, { + status: "connected", + isPromptPending: false, + }); + } else { + sessionStoreSetters.updateSession(taskRunId, { + status: "error", + isPromptPending: false, + errorMessage: + "Session could not resume after restore. Please retry or start a new session.", + }); + } + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + sessionStoreSetters.updateSession(taskRunId, { + status: "error", + isPromptPending: false, + errorMessage: msg || "Failed to reconnect after restore.", + }); + } + } + /** * Cancel the current backend agent and reconnect under the same taskRunId. * Does NOT remove the session from the store (avoids connect effect loop). diff --git a/apps/code/src/renderer/features/sessions/stores/sessionStore.ts b/apps/code/src/renderer/features/sessions/stores/sessionStore.ts index 718206228b..da27e152dc 100644 --- a/apps/code/src/renderer/features/sessions/stores/sessionStore.ts +++ b/apps/code/src/renderer/features/sessions/stores/sessionStore.ts @@ -5,9 +5,14 @@ import type { SessionConfigSelectOption, SessionConfigSelectOptions, } from "@agentclientprotocol/sdk"; +import { isNotification, POSTHOG_NOTIFICATIONS } from "@posthog/agent"; import type { ExecutionMode, TaskRunStatus } from "@shared/types"; import type { SkillButtonId } from "@shared/types/analytics"; -import type { AcpMessage } from "@shared/types/session-events"; +import { + type AcpMessage, + isJsonRpcNotification, + isJsonRpcRequest, +} from "@shared/types/session-events"; import { create } from "zustand"; import { immer } from "zustand/middleware/immer"; import type { PermissionRequest } from "../utils/parseSessionLogs"; @@ -498,6 +503,51 @@ export const sessionStoreSetters = { return useSessionStore.getState().sessions; }, + truncateEventsToCheckpoint: ( + taskId: string, + checkpointId: string, + ): boolean => { + const state = useSessionStore.getState(); + const taskRunId = state.taskIdIndex[taskId]; + if (!taskRunId) return false; + const session = state.sessions[taskRunId]; + if (!session) return false; + + const events = session.events; + let checkpointEventIdx = -1; + for (let i = 0; i < events.length; i++) { + const msg = events[i].message; + if (!isJsonRpcNotification(msg)) continue; + if (!isNotification(msg.method, POSTHOG_NOTIFICATIONS.GIT_CHECKPOINT)) + continue; + const params = msg.params as { checkpointId?: string } | undefined; + if (params?.checkpointId === checkpointId) { + checkpointEventIdx = i; + break; + } + } + if (checkpointEventIdx === -1) return false; + + let cutoff = events.length; + for (let i = checkpointEventIdx + 1; i < events.length; i++) { + const msg = events[i].message; + if (isJsonRpcRequest(msg) && msg.method === "session/prompt") { + cutoff = i; + break; + } + } + + useSessionStore.setState((draft) => { + const trid = draft.taskIdIndex[taskId]; + if (!trid) return; + const s = draft.sessions[trid]; + if (s) { + s.events = s.events.slice(0, cutoff); + } + }); + return true; + }, + clearAll: () => { useSessionStore.setState((state) => { state.sessions = {}; diff --git a/apps/code/src/renderer/features/sessions/utils/extractSearchableText.test.ts b/apps/code/src/renderer/features/sessions/utils/extractSearchableText.test.ts index 01f217b0e5..efa2e99ff8 100644 --- a/apps/code/src/renderer/features/sessions/utils/extractSearchableText.test.ts +++ b/apps/code/src/renderer/features/sessions/utils/extractSearchableText.test.ts @@ -26,6 +26,7 @@ describe("extractSearchableText", () => { childItems: new Map(), turnCancelled: false, turnComplete: true, + lastCheckpointId: null, }, }; expect(extractSearchableText(item)).toBe("agent reply"); @@ -44,6 +45,7 @@ describe("extractSearchableText", () => { childItems: new Map(), turnCancelled: false, turnComplete: true, + lastCheckpointId: null, }, }; expect(extractSearchableText(item)).toBe("thinking..."); @@ -66,6 +68,7 @@ describe("extractSearchableText", () => { childItems: new Map(), turnCancelled: false, turnComplete: true, + lastCheckpointId: null, }, }; expect(extractSearchableText(item)).toBe(""); @@ -85,6 +88,7 @@ describe("extractSearchableText", () => { childItems: new Map(), turnCancelled: false, turnComplete: true, + lastCheckpointId: null, }, }; expect(extractSearchableText(item)).toBe(""); @@ -104,6 +108,7 @@ describe("extractSearchableText", () => { childItems: new Map(), turnCancelled: false, turnComplete: true, + lastCheckpointId: null, }, }; expect(extractSearchableText(item)).toBe("console output"); @@ -123,6 +128,7 @@ describe("extractSearchableText", () => { childItems: new Map(), turnCancelled: false, turnComplete: true, + lastCheckpointId: null, }, }; expect(extractSearchableText(item)).toBe("something broke"); @@ -141,6 +147,7 @@ describe("extractSearchableText", () => { childItems: new Map(), turnCancelled: false, turnComplete: true, + lastCheckpointId: null, }, }; expect(extractSearchableText(item)).toBe("running"); @@ -162,6 +169,7 @@ describe("extractSearchableText", () => { childItems: new Map(), turnCancelled: false, turnComplete: true, + lastCheckpointId: null, }, }; expect(extractSearchableText(item)).toBe("task done"); diff --git a/apps/code/src/renderer/hooks/useShortcut.ts b/apps/code/src/renderer/hooks/useShortcut.ts new file mode 100644 index 0000000000..36a4521cf4 --- /dev/null +++ b/apps/code/src/renderer/hooks/useShortcut.ts @@ -0,0 +1,6 @@ +import type { ConfigurableShortcutId } from "@renderer/constants/keyboard-shortcuts"; +import { resolveKey, useKeybindingsStore } from "@stores/keybindingsStore"; + +export function useShortcut(id: ConfigurableShortcutId): string { + return useKeybindingsStore((s) => resolveKey(s.customKeybindings, id)); +} diff --git a/apps/code/src/renderer/stores/keybindingsStore.ts b/apps/code/src/renderer/stores/keybindingsStore.ts new file mode 100644 index 0000000000..cbb20ad207 --- /dev/null +++ b/apps/code/src/renderer/stores/keybindingsStore.ts @@ -0,0 +1,152 @@ +import { + CONFIGURABLE_SHORTCUT_IDS, + type ConfigurableShortcutId, + DEFAULT_KEYBINDINGS, + KEYBOARD_SHORTCUTS, +} from "@renderer/constants/keyboard-shortcuts"; +import { electronStorage } from "@utils/electronStorage"; +import { create } from "zustand"; +import { persist } from "zustand/middleware"; + +export const MAX_CUSTOM_BINDINGS = 2; + +interface KeybindingsState { + customKeybindings: Partial>; + getKey: (id: ConfigurableShortcutId) => string; + addKeybinding: (id: ConfigurableShortcutId, key: string) => void; + updateKeybinding: ( + id: ConfigurableShortcutId, + oldKey: string, + newKey: string, + ) => void; + removeKeybinding: (id: ConfigurableShortcutId, key: string) => void; + resetShortcut: (id: ConfigurableShortcutId) => void; + resetAll: () => void; +} + +export function resolveKey( + customKeybindings: Partial>, + id: ConfigurableShortcutId, +): string { + const customs = customKeybindings[id]; + if (customs && customs.length > 0) return customs.join(","); + return DEFAULT_KEYBINDINGS[id]; +} + +/** + * Split a keybinding string by comma, but preserve commas that are part of a + * key combo (e.g. "mod+," must not be split at the trailing comma). + * A valid separator comma is one NOT immediately preceded by "+". + */ +export function splitBindings(keyStr: string): string[] { + return keyStr + .split(/(? k.trim()) + .filter(Boolean); +} + +export interface ConflictResult { + id: ConfigurableShortcutId | null; + description: string | null; + /** true when the conflicting shortcut is not user-configurable */ + isFixed: boolean; +} + +export function findConflict( + newKey: string, + excludeId: ConfigurableShortcutId, +): ConflictResult { + const state = useKeybindingsStore.getState(); + + for (const id of CONFIGURABLE_SHORTCUT_IDS) { + if (id === excludeId) continue; + const keyStr = state.getKey(id); + const parts = splitBindings(keyStr); + if (parts.includes(newKey)) { + const entry = KEYBOARD_SHORTCUTS.find((s) => s.id === id); + return { id, description: entry?.description ?? id, isFixed: false }; + } + } + + for (const shortcut of KEYBOARD_SHORTCUTS) { + if ( + CONFIGURABLE_SHORTCUT_IDS.includes(shortcut.id as ConfigurableShortcutId) + ) + continue; + const parts = splitBindings(shortcut.keys); + if (shortcut.alternateKeys) { + parts.push(...splitBindings(shortcut.alternateKeys)); + } + if (parts.includes(newKey)) { + return { id: null, description: shortcut.description, isFixed: true }; + } + } + + return { id: null, description: null, isFixed: false }; +} + +export const useKeybindingsStore = create()( + persist( + (set, get) => ({ + customKeybindings: {}, + + getKey: (id) => resolveKey(get().customKeybindings, id), + + addKeybinding: (id, key) => { + const existing = get().customKeybindings[id] ?? []; + if (existing.includes(key)) return; + if (existing.length >= MAX_CUSTOM_BINDINGS) return; + set({ + customKeybindings: { + ...get().customKeybindings, + [id]: [...existing, key], + }, + }); + }, + + updateKeybinding: (id, oldKey, newKey) => { + const existing = get().customKeybindings[id] ?? []; + // When editing a default binding, copy all defaults first so the other + // defaults are preserved — only the edited key gets replaced. + const base = + existing.length > 0 + ? existing + : splitBindings(DEFAULT_KEYBINDINGS[id]); + const updated = base.map((k) => (k === oldKey ? newKey : k)); + // Deduplicate — conflict detection excludes the edited shortcut's own bindings, + // so editing one binding to match another on the same shortcut can slip through. + const deduped = [...new Set(updated)]; + set({ + customKeybindings: { ...get().customKeybindings, [id]: deduped }, + }); + }, + + removeKeybinding: (id, key) => { + const existing = get().customKeybindings[id] ?? []; + const updated = existing.filter((k) => k !== key); + set({ + customKeybindings: { + ...get().customKeybindings, + [id]: updated, + }, + }); + }, + + resetShortcut: (id) => { + const { [id]: _removed, ...rest } = get().customKeybindings; + set({ + customKeybindings: rest as Partial< + Record + >, + }); + }, + + resetAll: () => set({ customKeybindings: {} }), + }), + { + name: "keybindings-storage", + storage: electronStorage, + partialize: (state) => ({ customKeybindings: state.customKeybindings }), + }, + ), +); diff --git a/packages/agent/package.json b/packages/agent/package.json index 858be1e41a..d8347821a8 100644 --- a/packages/agent/package.json +++ b/packages/agent/package.json @@ -44,6 +44,10 @@ "types": "./dist/adapters/claude/conversion/tool-use-to-acp.d.ts", "import": "./dist/adapters/claude/conversion/tool-use-to-acp.js" }, + "./adapters/codex/rollout": { + "types": "./dist/adapters/codex/rollout.d.ts", + "import": "./dist/adapters/codex/rollout.js" + }, "./adapters/claude/session/jsonl-hydration": { "types": "./dist/adapters/claude/session/jsonl-hydration.d.ts", "import": "./dist/adapters/claude/session/jsonl-hydration.js" diff --git a/packages/agent/src/adapters/claude/session/jsonl-hydration.hydrate.test.ts b/packages/agent/src/adapters/claude/session/jsonl-hydration.hydrate.test.ts new file mode 100644 index 0000000000..6aef2db413 --- /dev/null +++ b/packages/agent/src/adapters/claude/session/jsonl-hydration.hydrate.test.ts @@ -0,0 +1,137 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import type { PostHogAPIClient } from "../../../posthog-api"; +import type { StoredEntry } from "../../../types"; + +// Mock fs so hydrateSessionJsonl's access/mkdir/writeFile/rename are observable +// without touching disk. The Claude checkpoint-restore memory truncation works +// by force-refetching the (already backend-truncated) S3 log and rewriting the +// JSONL — there is no session/update replay, so there is no duplication risk; +// these tests lock in that contract. +const fsMock = vi.hoisted(() => ({ + access: vi.fn(), + mkdir: vi.fn(), + writeFile: vi.fn(), + rename: vi.fn(), +})); +vi.mock("node:fs/promises", () => fsMock); + +import { hydrateSessionJsonl } from "./jsonl-hydration"; + +function s3Entry( + sessionUpdate: string, + extra: Record = {}, +): StoredEntry { + return { + type: "notification", + timestamp: "2026-03-03T12:00:00.000Z", + notification: { + jsonrpc: "2.0", + method: "session/update", + params: { update: { sessionUpdate, ...extra } }, + }, + }; +} + +function makeApi(overrides: { + logUrl?: string | null; + entries?: StoredEntry[]; +}): PostHogAPIClient & { + getTaskRun: ReturnType; + fetchTaskRunLogs: ReturnType; +} { + const api = { + getTaskRun: vi.fn(async () => ({ log_url: overrides.logUrl ?? "s3://x" })), + fetchTaskRunLogs: vi.fn(async () => overrides.entries ?? []), + }; + return api as unknown as PostHogAPIClient & { + getTaskRun: ReturnType; + fetchTaskRunLogs: ReturnType; + }; +} + +const log = { info: vi.fn(), warn: vi.fn() }; + +const baseParams = { + sessionId: "sess-1", + cwd: "/repo", + taskId: "task-1", + runId: "run-1", + log, +}; + +describe("hydrateSessionJsonl checkpoint restore", () => { + let prevConfigDir: string | undefined; + + beforeEach(() => { + vi.clearAllMocks(); + fsMock.mkdir.mockResolvedValue(undefined); + fsMock.writeFile.mockResolvedValue(undefined); + fsMock.rename.mockResolvedValue(undefined); + prevConfigDir = process.env.CLAUDE_CONFIG_DIR; + process.env.CLAUDE_CONFIG_DIR = "/tmp/claude-hydrate-test"; + }); + + afterEach(() => { + if (prevConfigDir === undefined) delete process.env.CLAUDE_CONFIG_DIR; + else process.env.CLAUDE_CONFIG_DIR = prevConfigDir; + }); + + it("reuses the existing JSONL without refetching when forceRefetch is false", async () => { + fsMock.access.mockResolvedValue(undefined); // file exists + const api = makeApi({ entries: [] }); + + const ok = await hydrateSessionJsonl({ ...baseParams, posthogAPI: api }); + + expect(ok).toBe(true); + expect(api.getTaskRun).not.toHaveBeenCalled(); + expect(fsMock.writeFile).not.toHaveBeenCalled(); + }); + + it("force-refetches from truncated S3 even when a stale JSONL exists", async () => { + fsMock.access.mockResolvedValue(undefined); // stale file exists + // Truncated S3: only the first turn survives the checkpoint restore. + const api = makeApi({ + entries: [ + s3Entry("user_message", { + content: { type: "text", text: "greeting" }, + }), + s3Entry("agent_message", { content: { type: "text", text: "done" } }), + ], + }); + + const ok = await hydrateSessionJsonl({ + ...baseParams, + posthogAPI: api, + forceRefetch: true, + }); + + expect(ok).toBe(true); + // The existing-file shortcut is bypassed; S3 is refetched and rewritten. + expect(api.getTaskRun).toHaveBeenCalledTimes(1); + expect(api.fetchTaskRunLogs).toHaveBeenCalledTimes(1); + expect(fsMock.rename).toHaveBeenCalledTimes(1); + + // The rewritten JSONL contains only the surviving turn — memory parity: + // truncated S3 in, truncated memory out. + const written = fsMock.writeFile.mock.calls[0][1] as string; + expect(written).toContain("greeting"); + expect(written).toContain("done"); + expect(written).not.toContain("post-checkpoint"); + }); + + it("returns false when the task run has no log URL", async () => { + fsMock.access.mockRejectedValue( + Object.assign(new Error("nope"), { code: "ENOENT" }), + ); + const api = makeApi({ logUrl: null }); + + const ok = await hydrateSessionJsonl({ + ...baseParams, + posthogAPI: api, + forceRefetch: true, + }); + + expect(ok).toBe(false); + expect(fsMock.writeFile).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/agent/src/adapters/claude/session/jsonl-hydration.ts b/packages/agent/src/adapters/claude/session/jsonl-hydration.ts index 14b39a71b6..33203abd7c 100644 --- a/packages/agent/src/adapters/claude/session/jsonl-hydration.ts +++ b/packages/agent/src/adapters/claude/session/jsonl-hydration.ts @@ -497,16 +497,28 @@ export async function hydrateSessionJsonl(params: { permissionMode?: string; posthogAPI: PostHogAPIClient; log: HydrationLog; + forceRefetch?: boolean; }): Promise { const { posthogAPI, log } = params; try { const jsonlPath = getSessionJsonlPath(params.sessionId, params.cwd); - try { - await fs.access(jsonlPath); - return true; - } catch { - // File doesn't exist, proceed with hydration + if (!params.forceRefetch) { + try { + await fs.access(jsonlPath); + log.info("Session JSONL already exists, reusing without S3 fetch", { + sessionId: params.sessionId, + jsonlPath, + }); + return true; + } catch { + // File doesn't exist, proceed with hydration + } + } else { + log.info( + "Force-refetching JSONL from S3 (checkpoint restore), skipping existing file check", + { sessionId: params.sessionId, jsonlPath }, + ); } const taskRun = await posthogAPI.getTaskRun(params.taskId, params.runId); diff --git a/packages/agent/src/adapters/codex/codex-agent.test.ts b/packages/agent/src/adapters/codex/codex-agent.test.ts index 382fe61319..cf84150d1c 100644 --- a/packages/agent/src/adapters/codex/codex-agent.test.ts +++ b/packages/agent/src/adapters/codex/codex-agent.test.ts @@ -145,6 +145,50 @@ describe("CodexAcpAgent", () => { ); }); + it("suppresses rollout replay during loadSession and clears it after", async () => { + const { agent } = createAgent(); + const state = ( + agent as unknown as { sessionState: { suppressReplay?: boolean } } + ).sessionState; + + // Capture the flag's value at the moment codex-acp would re-stream the + // rollout (i.e. while loadSession is in flight). + let duringLoad: boolean | undefined; + mockCodexConnection.loadSession.mockImplementation(async () => { + duringLoad = state.suppressReplay; + return { + modes: { currentModeId: "auto", availableModes: [] }, + configOptions: [], + } satisfies Partial; + }); + + expect(state.suppressReplay).toBe(false); + await agent.loadSession({ + sessionId: "session-1", + cwd: process.cwd(), + } as never); + + expect(duringLoad).toBe(true); + expect(state.suppressReplay).toBe(false); + }); + + it("clears suppressReplay even when loadSession throws", async () => { + const { agent } = createAgent(); + const state = ( + agent as unknown as { sessionState: { suppressReplay?: boolean } } + ).sessionState; + + mockCodexConnection.loadSession.mockRejectedValue(new Error("load failed")); + + await expect( + agent.loadSession({ + sessionId: "session-1", + cwd: process.cwd(), + } as never), + ).rejects.toThrow("load failed"); + expect(state.suppressReplay).toBe(false); + }); + it("does not emit SDK_SESSION on loadSession when taskRunId is absent", async () => { const { agent, client } = createAgent(); mockCodexConnection.loadSession.mockResolvedValue({ diff --git a/packages/agent/src/adapters/codex/codex-agent.ts b/packages/agent/src/adapters/codex/codex-agent.ts index f064778ec5..72e7fa7133 100644 --- a/packages/agent/src/adapters/codex/codex-agent.ts +++ b/packages/agent/src/adapters/codex/codex-agent.ts @@ -455,7 +455,15 @@ export class CodexAcpAgent extends BaseAcpAgent { this.applyStructuredOutput(params, meta), meta, ); - const response = await this.codexConnection.loadSession(injectedParams); + // Suppress replayed session/update events during load so codex-acp's + // rollout re-stream isn't re-persisted (logs.ndjson + S3) or re-displayed. + this.sessionState.suppressReplay = true; + let response: LoadSessionResponse; + try { + response = await this.codexConnection.loadSession(injectedParams); + } finally { + this.sessionState.suppressReplay = false; + } response.configOptions = normalizeCodexConfigOptions( response.configOptions, ); @@ -506,8 +514,15 @@ export class CodexAcpAgent extends BaseAcpAgent { meta, ); - // codex-acp doesn't support resume natively, use loadSession instead - const loadResponse = await this.codexConnection.loadSession(injectedParams); + // codex-acp doesn't support resume natively, use loadSession instead. + // Suppress the rollout replay so it isn't re-persisted or re-displayed. + this.sessionState.suppressReplay = true; + let loadResponse: LoadSessionResponse; + try { + loadResponse = await this.codexConnection.loadSession(injectedParams); + } finally { + this.sessionState.suppressReplay = false; + } loadResponse.configOptions = normalizeCodexConfigOptions( loadResponse.configOptions, ); @@ -899,11 +914,18 @@ export class CodexAcpAgent extends BaseAcpAgent { protocolVersion: 1, }; await newConnection.initialize(initRequest); - await newConnection.loadSession({ - sessionId: this.sessionId, - cwd: this.sessionState.cwd, - mcpServers, - }); + // Suppress the rollout replay during rehydration so it isn't re-persisted + // or re-displayed. newConnection shares this.sessionState via its client. + this.sessionState.suppressReplay = true; + try { + await newConnection.loadSession({ + sessionId: this.sessionId, + cwd: this.sessionState.cwd, + mcpServers, + }); + } finally { + this.sessionState.suppressReplay = false; + } // Swap everything at once so closeSession/prompt/cancel target the new // subprocess going forward. Preserve sessionState (accumulatedUsage, diff --git a/packages/agent/src/adapters/codex/codex-client.test.ts b/packages/agent/src/adapters/codex/codex-client.test.ts index 56a8d8b1d0..fde6113e4c 100644 --- a/packages/agent/src/adapters/codex/codex-client.test.ts +++ b/packages/agent/src/adapters/codex/codex-client.test.ts @@ -289,6 +289,96 @@ describe("createCodexClient onStructuredOutput", () => { }); }); +describe("createCodexClient suppressReplay", () => { + const logger = new Logger({ debug: false, prefix: "[test]" }); + + function makeUpstream(): AgentSideConnection { + return { + sessionUpdate: vi.fn(async () => {}), + requestPermission: vi.fn(), + readTextFile: vi.fn(), + writeTextFile: vi.fn(), + createTerminal: vi.fn(), + terminalOutput: vi.fn(), + releaseTerminal: vi.fn(), + waitForTerminalExit: vi.fn(), + killTerminal: vi.fn(), + extMethod: vi.fn(), + extNotification: vi.fn(), + } as unknown as AgentSideConnection; + } + + function notification(update: Record): SessionNotification { + return { sessionId: "sess", update } as unknown as SessionNotification; + } + + // Core of the checkpoint-restore duplication fix: while a loadSession replay + // is in flight, codex-acp re-streams the whole rollout as session/update + // events. Forwarding them re-persists history (logs.ndjson + S3) and re-fires + // callbacks. The client must drop them at the top of sessionUpdate. + test("drops session updates while suppressReplay is set (no upstream forward)", async () => { + const sessionState = createSessionState("sess", "/tmp"); + sessionState.suppressReplay = true; + const upstream = makeUpstream(); + const client = createCodexClient(upstream, logger, sessionState); + + await client.sessionUpdate?.( + notification({ + sessionUpdate: "agent_message_chunk", + content: { type: "text", text: "replayed history" }, + }), + ); + + expect(upstream.sessionUpdate).not.toHaveBeenCalled(); + }); + + test("forwards session updates once suppressReplay is cleared", async () => { + const sessionState = createSessionState("sess", "/tmp"); + sessionState.suppressReplay = true; + const upstream = makeUpstream(); + const client = createCodexClient(upstream, logger, sessionState); + + const note = notification({ + sessionUpdate: "agent_message_chunk", + content: { type: "text", text: "live turn" }, + }); + + await client.sessionUpdate?.(note); + expect(upstream.sessionUpdate).not.toHaveBeenCalled(); + + // Replay finished — subsequent live events flow through normally. + sessionState.suppressReplay = false; + await client.sessionUpdate?.(note); + expect(upstream.sessionUpdate).toHaveBeenCalledTimes(1); + expect(upstream.sessionUpdate).toHaveBeenCalledWith(note); + }); + + test("does not re-fire onStructuredOutput for a replayed create_output", async () => { + const sessionState = createSessionState("sess", "/tmp"); + sessionState.suppressReplay = true; + const onStructuredOutput = vi.fn(async () => {}); + const upstream = makeUpstream(); + const client = createCodexClient(upstream, logger, sessionState, { + onStructuredOutput, + }); + + // A historical create_output completion arriving during replay must not + // re-trigger the structured-output callback. + await client.sessionUpdate?.( + notification({ + sessionUpdate: "tool_call", + toolCallId: "tc-replay", + title: "create_output", + status: "completed", + rawInput: { result: "stale" }, + }), + ); + + expect(onStructuredOutput).not.toHaveBeenCalled(); + expect(upstream.sessionUpdate).not.toHaveBeenCalled(); + }); +}); + describe("createCodexClient usage_update propagation", () => { const logger = new Logger({ debug: false, prefix: "[test]" }); diff --git a/packages/agent/src/adapters/codex/codex-client.ts b/packages/agent/src/adapters/codex/codex-client.ts index ebe05a4fe2..62d8154b86 100644 --- a/packages/agent/src/adapters/codex/codex-client.ts +++ b/packages/agent/src/adapters/codex/codex-client.ts @@ -165,6 +165,12 @@ export function createCodexClient( }, async sessionUpdate(params: SessionNotification): Promise { + // During loadSession replay, codex-acp re-streams the entire rollout as + // session/update notifications. Forwarding them would re-persist history + // (logs.ndjson + S3) and re-fire structured-output callbacks. Drop them; + // post-load state is re-established by resetSessionState. + if (sessionState.suppressReplay) return; + const update = params.update as Record | undefined; if ( diff --git a/packages/agent/src/adapters/codex/rollout.retry.test.ts b/packages/agent/src/adapters/codex/rollout.retry.test.ts new file mode 100644 index 0000000000..69c4300571 --- /dev/null +++ b/packages/agent/src/adapters/codex/rollout.retry.test.ts @@ -0,0 +1,90 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; + +// Mocks node:fs/promises so we can simulate the Windows file-lock race where +// the just-killed codex-acp process still holds the rollout handle and rename +// fails with EPERM/EACCES/EBUSY for a short window. writeFileWithRetry is +// internal, so we drive it through truncateCodexRollout. +const fsMock = vi.hoisted(() => ({ + readdir: vi.fn(), + readFile: vi.fn(), + writeFile: vi.fn(), + rename: vi.fn(), + unlink: vi.fn(), +})); +vi.mock("node:fs/promises", () => fsMock); + +import { truncateCodexRollout } from "./rollout"; + +const sessionId = "retry-session"; +const ROLLOUT = `${[ + JSON.stringify({ type: "session_meta", payload: {} }), + JSON.stringify({ type: "event_msg", payload: { type: "task_started" } }), + JSON.stringify({ type: "response_item", payload: { text: "t1" } }), + JSON.stringify({ type: "event_msg", payload: { type: "task_complete" } }), +].join("\n")}\n`; + +function errWithCode(code: string): NodeJS.ErrnoException { + const e = new Error(code) as NodeJS.ErrnoException; + e.code = code; + return e; +} + +describe("truncateCodexRollout file-lock retry", () => { + beforeEach(() => { + vi.clearAllMocks(); + fsMock.readdir.mockResolvedValue([ + { + isFile: () => true, + name: `rollout-x-${sessionId}.jsonl`, + parentPath: "/fake/sessions/2026/06/08", + }, + ]); + fsMock.readFile.mockResolvedValue(ROLLOUT); + fsMock.writeFile.mockResolvedValue(undefined); + fsMock.unlink.mockResolvedValue(undefined); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + test("retries on EPERM and succeeds once the lock releases", async () => { + vi.useFakeTimers(); + fsMock.rename + .mockRejectedValueOnce(errWithCode("EPERM")) + .mockRejectedValueOnce(errWithCode("EBUSY")) + .mockResolvedValueOnce(undefined); + + const p = truncateCodexRollout(sessionId, 1); + await vi.runAllTimersAsync(); + const ok = await p; + + expect(ok).toBe(true); + expect(fsMock.rename).toHaveBeenCalledTimes(3); + }); + + test("returns false after exhausting retries on a persistent lock", async () => { + vi.useFakeTimers(); + fsMock.rename.mockRejectedValue(errWithCode("EACCES")); + + const p = truncateCodexRollout(sessionId, 1); + await vi.runAllTimersAsync(); + const ok = await p; + + expect(ok).toBe(false); + // delaysMs = [0, 50, 100, 200, 400, 800] → 6 attempts. + expect(fsMock.rename).toHaveBeenCalledTimes(6); + }); + + test("fails fast on a non-lock error without retrying", async () => { + vi.useFakeTimers(); + fsMock.rename.mockRejectedValue(errWithCode("ENOSPC")); + + const p = truncateCodexRollout(sessionId, 1); + await vi.runAllTimersAsync(); + const ok = await p; + + expect(ok).toBe(false); + expect(fsMock.rename).toHaveBeenCalledTimes(1); + }); +}); diff --git a/packages/agent/src/adapters/codex/rollout.test.ts b/packages/agent/src/adapters/codex/rollout.test.ts new file mode 100644 index 0000000000..6a11521a7a --- /dev/null +++ b/packages/agent/src/adapters/codex/rollout.test.ts @@ -0,0 +1,115 @@ +import * as fs from "node:fs/promises"; +import * as os from "node:os"; +import * as path from "node:path"; +import { afterEach, beforeEach, describe, expect, test } from "vitest"; +import { findCodexRollout, truncateCodexRollout } from "./rollout"; + +// truncateCodexRollout is the memory-truncation half of checkpoint restore: +// after a restore it trims the on-disk codex-acp rollout to the first N turns +// so the resumed session only remembers up to the checkpoint. These tests use +// a real temp CODEX_HOME so the file-walk + atomic write are exercised end to +// end (the writeFileWithRetry success path included). + +const META = JSON.stringify({ type: "session_meta", payload: { id: "s" } }); + +function turn(n: number): string[] { + return [ + JSON.stringify({ type: "event_msg", payload: { type: "task_started" } }), + JSON.stringify({ type: "response_item", payload: { text: `turn${n}` } }), + JSON.stringify({ type: "event_msg", payload: { type: "task_complete" } }), + ]; +} + +function rollout(turns: number): string { + const lines = [META]; + for (let i = 1; i <= turns; i++) lines.push(...turn(i)); + return `${lines.join("\n")}\n`; +} + +describe("truncateCodexRollout", () => { + let codexHome: string; + let prevCodexHome: string | undefined; + const sessionId = "test-session-abc"; + + async function writeRollout(content: string): Promise { + const dir = path.join(codexHome, "sessions", "2026", "06", "08"); + await fs.mkdir(dir, { recursive: true }); + const file = path.join( + dir, + `rollout-2026-06-08T00-00-00-${sessionId}.jsonl`, + ); + await fs.writeFile(file, content, "utf-8"); + return file; + } + + async function readLines(file: string): Promise { + const content = await fs.readFile(file, "utf-8"); + return content.split("\n").filter((l) => l.trim()); + } + + beforeEach(async () => { + prevCodexHome = process.env.CODEX_HOME; + codexHome = await fs.mkdtemp(path.join(os.tmpdir(), "codex-rollout-")); + process.env.CODEX_HOME = codexHome; + }); + + afterEach(async () => { + if (prevCodexHome === undefined) delete process.env.CODEX_HOME; + else process.env.CODEX_HOME = prevCodexHome; + await fs.rm(codexHome, { recursive: true, force: true }); + }); + + test("findCodexRollout locates the rollout file for a session", async () => { + const file = await writeRollout(rollout(2)); + const found = await findCodexRollout(sessionId); + expect(found).toBe(file); + }); + + test("keeps only the first turn when keepTurns=1 (restore to first turn)", async () => { + const file = await writeRollout(rollout(3)); + const ok = await truncateCodexRollout(sessionId, 1); + expect(ok).toBe(true); + + const lines = await readLines(file); + // session_meta + 3 lines of turn 1 + expect(lines).toHaveLength(4); + expect(lines[0]).toBe(META); + expect(lines.some((l) => l.includes("turn1"))).toBe(true); + expect(lines.some((l) => l.includes("turn2"))).toBe(false); + expect(lines.some((l) => l.includes("turn3"))).toBe(false); + }); + + test("keeps the first two turns when keepTurns=2 (restore to a middle turn)", async () => { + const file = await writeRollout(rollout(3)); + const ok = await truncateCodexRollout(sessionId, 2); + expect(ok).toBe(true); + + const lines = await readLines(file); + // session_meta + 2 turns * 3 lines + expect(lines).toHaveLength(7); + expect(lines.some((l) => l.includes("turn2"))).toBe(true); + expect(lines.some((l) => l.includes("turn3"))).toBe(false); + }); + + test("returns false when the session rollout file is not found", async () => { + const ok = await truncateCodexRollout("nonexistent-session", 1); + expect(ok).toBe(false); + }); + + test("returns false and leaves the file intact when no complete turns exist", async () => { + // Only a started turn with no task_complete. + const content = `${[ + META, + JSON.stringify({ type: "event_msg", payload: { type: "task_started" } }), + JSON.stringify({ type: "response_item", payload: { text: "partial" } }), + ].join("\n")}\n`; + const file = await writeRollout(content); + + const ok = await truncateCodexRollout(sessionId, 1); + expect(ok).toBe(false); + + // File untouched. + const lines = await readLines(file); + expect(lines).toHaveLength(3); + }); +}); diff --git a/packages/agent/src/adapters/codex/rollout.ts b/packages/agent/src/adapters/codex/rollout.ts new file mode 100644 index 0000000000..4cf6fd5aba --- /dev/null +++ b/packages/agent/src/adapters/codex/rollout.ts @@ -0,0 +1,201 @@ +import * as fs from "node:fs/promises"; +import * as os from "node:os"; +import * as path from "node:path"; + +interface RolloutLog { + info: (msg: string, data?: unknown) => void; + warn: (msg: string, data?: unknown) => void; +} + +function getCodexHome(): string { + return process.env.CODEX_HOME ?? path.join(os.homedir(), ".codex"); +} + +const sleep = (ms: number): Promise => + new Promise((resolve) => setTimeout(resolve, ms)); + +/** + * On Windows the just-killed codex-acp subprocess can keep the rollout file + * handle open for a short window after cancelSession, so an immediate + * write/rename fails with EPERM/EACCES/EBUSY. Retry across a ~1.5s window so + * the handle has time to release. Non-lock errors fail fast. + */ +async function writeFileWithRetry( + targetPath: string, + content: string, + log?: RolloutLog, +): Promise { + const delaysMs = [0, 50, 100, 200, 400, 800]; + let lastError: unknown; + for (let attempt = 0; attempt < delaysMs.length; attempt++) { + if (delaysMs[attempt] > 0) await sleep(delaysMs[attempt]); + const tmpPath = `${targetPath}.tmp.${Date.now()}.${attempt}`; + try { + await fs.writeFile(tmpPath, content, "utf-8"); + await fs.rename(tmpPath, targetPath); + return true; + } catch (err) { + lastError = err; + await fs.unlink(tmpPath).catch(() => {}); + const code = (err as NodeJS.ErrnoException).code; + if (code !== "EPERM" && code !== "EACCES" && code !== "EBUSY") break; + log?.info("Rollout write blocked by file lock, retrying", { + code, + attempt, + nextDelayMs: delaysMs[attempt + 1], + }); + } + } + log?.warn("Failed to write truncated codex rollout", { + error: lastError instanceof Error ? lastError.message : String(lastError), + }); + return false; +} + +/** + * Find the codex-acp rollout file for a given session ID. + * Rollout files live at: /sessions////rollout-*-.jsonl + */ +export async function findCodexRollout( + sessionId: string, +): Promise { + const sessionsDir = path.join(getCodexHome(), "sessions"); + try { + const entries = await fs.readdir(sessionsDir, { + recursive: true, + withFileTypes: true, + }); + for (const entry of entries) { + if ( + entry.isFile() && + entry.name.startsWith("rollout-") && + entry.name.endsWith(`-${sessionId}.jsonl`) + ) { + return path.join(entry.parentPath, entry.name); + } + } + } catch { + // Sessions dir may not exist yet + } + return undefined; +} + +/** + * Truncate the codex-acp rollout file for a session to the first `keepTurns` + * complete turns. A turn is bounded by event_msg:task_started → event_msg:task_complete. + * The session_meta line (line 1) is always preserved. + * + * Must be called AFTER the codex-acp subprocess is killed so the file is not + * being written to concurrently. + * + * Returns true if the file was successfully truncated, false otherwise (including + * when the file is not found — fail-safe so git restore still works for display). + */ +export async function truncateCodexRollout( + sessionId: string, + keepTurns: number, + log?: RolloutLog, +): Promise { + log?.info("Truncating codex rollout for checkpoint restore", { + sessionId, + keepTurns, + }); + + const rolloutPath = await findCodexRollout(sessionId); + if (!rolloutPath) { + log?.warn("Could not find codex rollout file for truncation", { + sessionId, + }); + return false; + } + + log?.info("Found codex rollout file", { sessionId, rolloutPath }); + + let content: string; + try { + content = await fs.readFile(rolloutPath, "utf-8"); + } catch (err) { + log?.warn("Failed to read codex rollout file", { + sessionId, + error: err instanceof Error ? err.message : String(err), + }); + return false; + } + + const lines = content.split("\n").filter((l) => l.trim()); + log?.info("Codex rollout file read", { + sessionId, + totalLines: lines.length, + keepTurns, + }); + if (lines.length === 0) return false; + + const keepLines: string[] = []; + + // Always keep session_meta (first line) + keepLines.push(lines[0]); + + // Walk remaining lines grouping into turn blocks bounded by + // event_msg:task_started → event_msg:task_complete. + // Keep the first keepTurns complete blocks; drop everything after. + let turnsKept = 0; + let inTurn = false; + let currentTurnLines: string[] = []; + + for (let i = 1; i < lines.length; i++) { + const line = lines[i]; + let parsed: { type?: string; payload?: { type?: string } } | undefined; + try { + parsed = JSON.parse(line) as { + type?: string; + payload?: { type?: string }; + }; + } catch { + if (inTurn) currentTurnLines.push(line); + continue; + } + + if ( + parsed.type === "event_msg" && + parsed.payload?.type === "task_started" + ) { + if (turnsKept >= keepTurns) break; // Discard this and all following turns + inTurn = true; + currentTurnLines = [line]; + } else if ( + parsed.type === "event_msg" && + parsed.payload?.type === "task_complete" + ) { + if (inTurn) { + currentTurnLines.push(line); + keepLines.push(...currentTurnLines); + turnsKept++; + inTurn = false; + currentTurnLines = []; + } + } else if (inTurn) { + currentTurnLines.push(line); + } + } + + if (turnsKept === 0 && keepTurns > 0) { + log?.warn( + "No complete turns found in codex rollout — leaving file intact", + { sessionId, keepTurns, totalLines: lines.length }, + ); + return false; + } + + const truncated = `${keepLines.join("\n")}\n`; + const wrote = await writeFileWithRetry(rolloutPath, truncated, log); + if (wrote) { + log?.info("Truncated codex rollout to checkpoint boundary", { + sessionId, + keepTurns, + turnsKept, + keptLines: keepLines.length, + originalLines: lines.length, + }); + } + return wrote; +} diff --git a/packages/agent/src/adapters/codex/session-state.ts b/packages/agent/src/adapters/codex/session-state.ts index 9aa8694a85..b72f0afdd3 100644 --- a/packages/agent/src/adapters/codex/session-state.ts +++ b/packages/agent/src/adapters/codex/session-state.ts @@ -22,6 +22,13 @@ export interface CodexSessionState { permissionMode: PermissionMode; taskRunId?: string; taskId?: string; + /** + * True while a loadSession replay is in flight. codex-acp re-streams the + * entire rollout as session/update notifications on load; forwarding them + * would re-persist history (logs.ndjson + S3) and re-fire callbacks. The + * codex-client drops session updates while this is set. + */ + suppressReplay?: boolean; } export function createSessionState( @@ -50,6 +57,7 @@ export function createSessionState( permissionMode: opts?.permissionMode ?? "auto", taskRunId: opts?.taskRunId, taskId: opts?.taskId, + suppressReplay: false, }; } @@ -85,6 +93,7 @@ export function resetSessionState( state.permissionMode = opts?.permissionMode ?? "auto"; state.taskRunId = opts?.taskRunId; state.taskId = opts?.taskId; + state.suppressReplay = false; } export function resetUsage(state: CodexSessionState): void { diff --git a/packages/agent/tsup.config.ts b/packages/agent/tsup.config.ts index e704a62e9b..ca575744f1 100644 --- a/packages/agent/tsup.config.ts +++ b/packages/agent/tsup.config.ts @@ -113,6 +113,7 @@ export default defineConfig([ "src/adapters/claude/conversion/tool-use-to-acp.ts", "src/adapters/claude/session/jsonl-hydration.ts", "src/adapters/claude/session/models.ts", + "src/adapters/codex/rollout.ts", "src/adapters/codex/models.ts", "src/adapters/claude/mcp/tool-metadata.ts", "src/adapters/codex/structured-output-mcp-server.ts",