diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 51e1b8f25..1f7d924ab 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -313,6 +313,7 @@ private Single appendNewMessageToSession( throw new IllegalArgumentException("No parts in the new_message."); } + Completable saveArtifactsFlow = Completable.complete(); if (this.artifactService != null && saveInputBlobsAsArtifacts) { // The runner directly saves the artifacts (if applicable) in the user message and replaces // the artifact data with a file name placeholder. @@ -322,9 +323,11 @@ private Single appendNewMessageToSession( continue; } String fileName = "artifact_" + invocationContext.invocationId() + "_" + i; - var unused = - this.artifactService.saveArtifact( - this.appName, session.userId(), session.id(), fileName, part); + saveArtifactsFlow = + saveArtifactsFlow.andThen( + this.artifactService + .saveArtifact(this.appName, session.userId(), session.id(), fileName, part) + .ignoreElement()); newMessage .parts() @@ -349,7 +352,8 @@ private Single appendNewMessageToSession( EventActions.builder().stateDelta(new ConcurrentHashMap<>(stateDelta)).build()); } - return this.sessionService.appendEvent(session, eventBuilder.build()); + return saveArtifactsFlow.andThen( + this.sessionService.appendEvent(session, eventBuilder.build())); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 2eb515fa2..a3e21cb73 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -24,6 +24,8 @@ import static com.google.adk.testing.TestUtils.createTextLlmResponse; import static com.google.adk.testing.TestUtils.simplifyEvents; import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Arrays.stream; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.CALLS_REAL_METHODS; import static org.mockito.Mockito.mock; @@ -36,6 +38,7 @@ import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.apps.App; +import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.Functions; import com.google.adk.models.LlmResponse; @@ -65,12 +68,14 @@ import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.subscribers.TestSubscriber; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -78,6 +83,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; @RunWith(JUnit4.class) public final class RunnerTest { @@ -849,6 +855,19 @@ private Content createContent(String text) { return Content.builder().parts(Part.builder().text(text).build()).build(); } + private static Content createInlineDataContent(byte[]... data) { + return Content.builder() + .parts( + stream(data) + .map(dataBytes -> Part.fromBytes(dataBytes, "example/octet-stream")) + .toArray(Part[]::new)) + .build(); + } + + private static Content createInlineDataContent(String... data) { + return createInlineDataContent(stream(data).map(d -> d.getBytes(UTF_8)).toArray(byte[][]::new)); + } + @Test public void runAsync_createsInvocationSpan() { var unused = @@ -1331,4 +1350,40 @@ public static ImmutableMap echoTool(String message) { return ImmutableMap.of("message", message); } } + + @Test + public void runner_executesSaveArtifactFlow() { + // arrange + final AtomicInteger artifactsSavedCounter = new AtomicInteger(); + BaseArtifactService mockArtifactService = Mockito.mock(BaseArtifactService.class); + when(mockArtifactService.saveArtifact(any(), any(), any(), any(), any())) + .thenReturn( + Single.defer( + () -> { + // we want to assert not only that the saveArtifact method was + // called, but also that the flow that it returned was run, so + // we need to record the call in a counter + artifactsSavedCounter.incrementAndGet(); + return Single.just(42); + })); + Runner runner = + Runner.builder() + .app(App.builder().name("test").rootAgent(agent).build()) + .artifactService(mockArtifactService) + .build(); + session = runner.sessionService().createSession("test", "user").blockingGet(); + // each inline data will be saved using our mock artifact service + Content content = createInlineDataContent("test data", "test data 2"); + RunConfig runConfig = RunConfig.builder().setSaveInputBlobsAsArtifacts(true).build(); + + // act + var events = runner.runAsync("user", session.id(), content, runConfig).test(); + + // assert + events.assertComplete(); + // artifacts were saved + assertThat(artifactsSavedCounter.get()).isEqualTo(2); + // agent was run + assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm"); + } }