diff --git a/rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts b/rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts index 6237c8a7fe..ff3454b7ed 100644 --- a/rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts +++ b/rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts @@ -56,7 +56,6 @@ import { type QueueSendResult, type QueueSendWaitOptions, } from "./queue"; -import { resolveGatewayTarget } from "./resolve-gateway-target"; import { type WebSocketMessage as ConnMessage, messageLength, @@ -578,9 +577,7 @@ export class ActorConnRaw { async #connectWebSocket() { const params = await this.#resolveConnectionParams(); - const target = this.#gatewayOptions.skipReadyWait - ? await this.#resolveGatewayTargetForSkipReadyWait() - : getGatewayTarget(this.#actorResolutionState); + const target = getGatewayTarget(this.#actorResolutionState); const ws = await this.#driver.openWebSocket( PATH_CONNECT, target, @@ -634,25 +631,6 @@ export class ActorConnRaw { }); } - async #resolveGatewayTargetForSkipReadyWait() { - if ("getForId" in this.#actorResolutionState) { - return { - directId: this.#actorResolutionState.getForId.actorId, - } as const; - } - - if (this.#actorId) { - return { directId: this.#actorId } as const; - } - - return { - directId: await resolveGatewayTarget( - this.#driver, - this.#actorResolutionState, - ), - } as const; - } - /** Called by the onopen event from drivers. */ #handleOnOpen() { // Connection was disposed before Init message arrived - close the websocket to avoid leak diff --git a/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts b/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts index c50027d715..e6c17a60c5 100644 --- a/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts +++ b/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts @@ -140,7 +140,13 @@ export class ActorHandleRaw { for (let attempt = 0; attempt < maxAttempts; attempt++) { let actorId: string | undefined; try { - const target = await this.#resolveActionTarget(useQueryTarget); + const gatewayOptions = resolveActorGatewayOptions( + this.#gatewayOptions, + ); + const target = await this.#resolveGatewayRequestTarget( + useQueryTarget, + gatewayOptions, + ); actorId = "directId" in target ? target.directId : undefined; return await createQueueSender({ @@ -150,9 +156,7 @@ export class ActorHandleRaw { return await this.#driver.sendRequest( target, request, - resolveActorGatewayOptions( - this.#gatewayOptions, - ), + gatewayOptions, ); }, }).send(name, body, options as any); @@ -269,7 +273,10 @@ export class ActorHandleRaw { for (let attempt = 0; attempt < maxAttempts; attempt++) { let actorId: string | undefined; try { - const target = await this.#resolveActionTarget(useQueryTarget); + const target = await this.#resolveGatewayRequestTarget( + useQueryTarget, + gatewayOptions, + ); actorId = "directId" in target ? target.directId : undefined; logger().debug( @@ -558,6 +565,17 @@ export class ActorHandleRaw { } } + async #resolveGatewayRequestTarget( + useQueryTarget: boolean, + gatewayOptions: ActorGatewayOptions, + ) { + if (gatewayOptions.skipReadyWait) { + return getGatewayTarget(this.#actorResolutionState); + } + + return await this.#resolveActionTarget(useQueryTarget); + } + /** * Establishes a persistent connection to the actor. * @@ -616,7 +634,10 @@ export class ActorHandleRaw { for (let attempt = 0; attempt < maxAttempts; attempt++) { let actorId: string | undefined; try { - const target = await this.#resolveActionTarget(useQueryTarget); + const target = await this.#resolveGatewayRequestTarget( + useQueryTarget, + gatewayOptions, + ); actorId = "directId" in target ? target.directId : undefined; const response = await rawHttpFetch( this.#driver, @@ -836,9 +857,10 @@ export class ActorHandleRaw { this.#gatewayOptions, options, ); - const target = gatewayOptions.skipReadyWait - ? await this.#resolveActionTarget(false) - : getGatewayTarget(this.#actorResolutionState); + const target = await this.#resolveGatewayRequestTarget( + false, + gatewayOptions, + ); return await rawWebSocket( this.#driver, target, diff --git a/rivetkit-typescript/packages/rivetkit/tests/remote-engine-client-public-token.test.ts b/rivetkit-typescript/packages/rivetkit/tests/remote-engine-client-public-token.test.ts index 1f20d69e5a..deee0d059c 100644 --- a/rivetkit-typescript/packages/rivetkit/tests/remote-engine-client-public-token.test.ts +++ b/rivetkit-typescript/packages/rivetkit/tests/remote-engine-client-public-token.test.ts @@ -1,5 +1,6 @@ import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; import { ClientConfigSchema } from "@/client/config"; +import { createClient } from "@/client/mod"; import { HEADER_RIVET_ACTOR, HEADER_RIVET_SKIP_READY_WAIT, @@ -10,7 +11,6 @@ import { WS_PROTOCOL_TARGET, WS_PROTOCOL_TOKEN, } from "@/common/actor-router-consts"; -import { createClient } from "@/client/mod"; import { RemoteEngineControlClient } from "@/engine-client/mod"; describe.sequential("RemoteEngineControlClient public token usage", () => { @@ -162,6 +162,48 @@ describe.sequential("RemoteEngineControlClient public token usage", () => { ); }); + test("query handle fetch keeps skip ready wait on gateway URL", async () => { + const fetchCalls: Request[] = []; + const fetchMock = vi.fn(async (input: Request | URL | string) => { + const request = normalizeRequest(input); + fetchCalls.push(request); + return new Response("ok"); + }); + vi.stubGlobal("fetch", fetchMock); + + const client = createClient({ + endpoint: "https://api.rivet.dev", + disableMetadataLookup: true, + gateway: { skipReadyWait: true }, + }); + const handle = client.getOrCreate("mockAgenticLoop", [ + "query-http-skip-ready-wait", + ]); + + const response = await handle.fetch("/skip-ready-wait"); + + expect(response.status).toBe(200); + expect(fetchCalls).toHaveLength(1); + + const actorRequest = fetchCalls[0]; + expect(actorRequest).toBeDefined(); + if (!actorRequest) throw new Error("missing actor request"); + const url = new URL(actorRequest.url); + expect(url.pathname).toBe( + "/gateway/mockAgenticLoop/request/skip-ready-wait", + ); + expect(url.searchParams.get("rvt-method")).toBe("getOrCreate"); + expect(url.searchParams.get("rvt-key")).toBe( + "query-http-skip-ready-wait", + ); + expect(url.searchParams.get("rvt-skip-ready-wait")).toBe("true"); + expect(actorRequest?.headers.get(HEADER_RIVET_TARGET)).toBeNull(); + expect(actorRequest?.headers.get(HEADER_RIVET_ACTOR)).toBeNull(); + expect(actorRequest?.headers.get(HEADER_RIVET_SKIP_READY_WAIT)).toBe( + "1", + ); + }); + test("uses metadata clientToken for actor websocket gateway requests", async () => { const fetchMock = vi.fn(async (input: Request | URL | string) => { const request = normalizeRequest(input); @@ -258,6 +300,36 @@ describe.sequential("RemoteEngineControlClient public token usage", () => { WS_PROTOCOL_SKIP_READY_WAIT, ]), ); + + const client = createClient({ + endpoint: "https://api.rivet.dev", + disableMetadataLookup: true, + gateway: { skipReadyWait: true }, + }); + const handle = client.getOrCreate("mockAgenticLoop", [ + "query-ws-skip-ready-wait", + ]); + + await handle.webSocket("/skip-ready-wait"); + + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(sockets).toHaveLength(4); + const querySocket = sockets[3]; + expect(querySocket).toBeDefined(); + if (!querySocket) throw new Error("missing query websocket"); + const url = new URL(querySocket.url); + expect(url.pathname).toBe( + "/gateway/mockAgenticLoop/websocket/skip-ready-wait", + ); + expect(url.searchParams.get("rvt-method")).toBe("getOrCreate"); + expect(url.searchParams.get("rvt-key")).toBe( + "query-ws-skip-ready-wait", + ); + expect(url.searchParams.get("rvt-skip-ready-wait")).toBe("true"); + expect(querySocket.protocols).toContain(WS_PROTOCOL_SKIP_READY_WAIT); + expect(querySocket.protocols).not.toContain( + `${WS_PROTOCOL_TARGET}actor`, + ); }); });