diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index 7164991f3..1e6267dde 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -36,13 +36,15 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; -import java.util.HashMap; +import java.util.ArrayList; +import java.util.ConcurrentModificationException; +import java.util.Iterator; import java.util.List; +import java.util.ListIterator; import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -787,15 +789,49 @@ public void processRequest_notEmptyContent() { public void processRequest_concurrentReadAndWrite_noException() throws Exception { LlmAgent agent = LlmAgent.builder().name(AGENT).includeContents(LlmAgent.IncludeContents.DEFAULT).build(); + List customEvents = + new ArrayList() { + private void checkLock() { + if (!Thread.holdsLock(this)) { + throw new ConcurrentModificationException("Unsynchronized iteration detected!"); + } + } + + @Override + public Iterator iterator() { + checkLock(); + return super.iterator(); + } + + @Override + public ListIterator listIterator() { + checkLock(); + return super.listIterator(); + } + + @Override + public ListIterator listIterator(int index) { + checkLock(); + return super.listIterator(index); + } + + @Override + public Stream stream() { + checkLock(); + return super.stream(); + } + }; + Session session = - sessionService - .createSession("test-app", "test-user", new HashMap<>(), "test-session") - .blockingGet(); + Session.builder("test-session") + .appName("test-app") + .userId("test-user") + .events(customEvents) + .build(); - // Seed with dummy events to widen the race capability - for (int i = 0; i < 5000; i++) { - session.events().add(createUserEvent("dummy" + i, "dummy")); - } + // The list must have at least one element so that operations interacting with events trigger + // iteration. + customEvents.add(createUserEvent("dummy", "dummy")); InvocationContext context = InvocationContext.builder() @@ -807,37 +843,8 @@ public void processRequest_concurrentReadAndWrite_noException() throws Exception LlmRequest initialRequest = LlmRequest.builder().build(); - AtomicReference writerError = new AtomicReference<>(); - CountDownLatch startLatch = new CountDownLatch(1); - - Thread writerThread = - new Thread( - () -> { - startLatch.countDown(); - try { - for (int i = 0; i < 2000; i++) { - session.events().add(createUserEvent("writer" + i, "new data")); - } - } catch (Throwable t) { - writerError.set(t); - } - }); - - writerThread.start(); - startLatch.await(); // wait for writer to be ready - - // Process (read) requests concurrently to trigger race conditions - for (int i = 0; i < 200; i++) { - var unused = contentsProcessor.processRequest(context, initialRequest).blockingGet(); - if (writerError.get() != null) { - throw new RuntimeException("Writer failed", writerError.get()); - } - } - - writerThread.join(); - if (writerError.get() != null) { - throw new RuntimeException("Writer failed", writerError.get()); - } + // This single call will throw the exception if the list is accessed insecurely. + var unused = contentsProcessor.processRequest(context, initialRequest).blockingGet(); } private static Event createUserEvent(String id, String text) {