diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiClient.java b/core/src/main/java/com/google/adk/sessions/VertexAiClient.java index 2d8a13f2f..10e6524d8 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiClient.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiClient.java @@ -8,6 +8,7 @@ import com.google.auth.oauth2.GoogleCredentials; import com.google.common.base.Splitter; import com.google.common.collect.Iterables; +import com.google.common.net.UrlEscapers; import com.google.genai.types.HttpOptions; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; @@ -111,11 +112,12 @@ Maybe listSessions(String reasoningEngineId, String userId) { .flatMapMaybe(VertexAiClient::getJsonResponse); } - Maybe listEvents(String reasoningEngineId, String sessionId) { - return performApiRequest( - "GET", - "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + "/events", - "") + Maybe listEvents(String reasoningEngineId, String sessionId, @Nullable String filter) { + String path = "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + "/events"; + if (filter != null) { + path += "?filter=" + UrlEscapers.urlFormParameterEscaper().escape(filter); + } + return performApiRequest("GET", path, "") .doOnSuccess(apiResponse -> logger.debug("List events response {}", apiResponse)) .flatMapMaybe(VertexAiClient::getJsonResponse); } diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java index 279ecfa9f..a7bb8400c 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java @@ -164,9 +164,14 @@ private ListSessionsResponse parseListSessionsResponse( @Override public Single listEvents(String appName, String userId, String sessionId) { + return listEventsInternal(appName, sessionId, /* filter= */ null); + } + + private Single listEventsInternal( + String appName, String sessionId, @Nullable String filter) { String reasoningEngineId = parseReasoningEngineId(appName); return client - .listEvents(reasoningEngineId, sessionId) + .listEvents(reasoningEngineId, sessionId, filter) .map(this::parseListEventsResponse) .defaultIfEmpty(ListEventsResponse.builder().build()); } @@ -212,7 +217,7 @@ public Maybe getSession( new TypeReference>() {})); } - return listEvents(appName, userId, sessionId) + return listEventsInternal(appName, sessionId, afterTimestampFilter(config)) .map( response -> { Session.Builder sessionBuilder = @@ -232,48 +237,41 @@ public Maybe getSession( }); } + /** + * Builds the server-side events filter for {@code afterTimestamp}, mirroring the Python and Go + * implementations (inclusive {@code timestamp>=}). The filter is only applied when {@code + * numRecentEvents} is not set, matching the precedence in {@link #filterEvents}. + */ + private static @Nullable String afterTimestampFilter(Optional config) { + if (config.isPresent() + && config.get().numRecentEvents().isEmpty() + && config.get().afterTimestamp().isPresent()) { + return "timestamp>=\"" + config.get().afterTimestamp().get() + "\""; + } + return null; + } + private static List filterEvents( List originalEvents, Optional config) { // Preserve the full event stream that Vertex AI returns. Event timestamps are // assigned client-side while updateTime is assigned server-side, so filtering // on updateTime could silently drop the most recently appended event(s). + // afterTimestamp is filtered server-side (see afterTimestampFilter), so only + // numRecentEvents is applied here. List events = originalEvents.stream() .sorted(Comparator.comparingLong(Event::timestamp)) .collect(toCollection(ArrayList::new)); - if (config.isPresent()) { - if (config.get().numRecentEvents().isPresent()) { - int numRecentEvents = config.get().numRecentEvents().get(); - if (events.size() > numRecentEvents) { - events = events.subList(events.size() - numRecentEvents, events.size()); - } - } else if (config.get().afterTimestamp().isPresent()) { - long afterTimestampMillis = config.get().afterTimestamp().get().toEpochMilli(); - events = events.subList(firstIndexAtOrAfter(events, afterTimestampMillis), events.size()); + if (config.isPresent() && config.get().numRecentEvents().isPresent()) { + int numRecentEvents = config.get().numRecentEvents().get(); + if (events.size() > numRecentEvents) { + events = events.subList(events.size() - numRecentEvents, events.size()); } } return events; } - /** - * Returns the index of the first event whose timestamp is at or after {@code timestampMillis}, or - * the list size if there is none. {@code sortedEvents} must be sorted ascending by timestamp. - */ - private static int firstIndexAtOrAfter(List sortedEvents, long timestampMillis) { - int low = 0; - int high = sortedEvents.size(); - while (low < high) { - int mid = (low + high) >>> 1; - if (sortedEvents.get(mid).timestamp() < timestampMillis) { - low = mid + 1; - } else { - high = mid; - } - } - return low; - } - @Override public Completable deleteSession(String appName, String userId, String sessionId) { String reasoningEngineId = parseReasoningEngineId(appName); diff --git a/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java index 743bcee8d..c6d43c4f5 100644 --- a/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java +++ b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java @@ -4,6 +4,9 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.JsonBaseModel; import com.google.adk.events.Event; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; +import java.time.Instant; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -30,7 +33,8 @@ class MockApiAnswer implements Answer { private static final Pattern APPEND_EVENT_REGEX = Pattern.compile("^reasoningEngines/([^/]+)/sessions/([^/]+):appendEvent$"); private static final Pattern EVENTS_REGEX = - Pattern.compile("^reasoningEngines/([^/]+)/sessions/([^/]+)/events$"); + Pattern.compile("^reasoningEngines/([^/]+)/sessions/([^/]+)/events(?:\\?filter=(.*))?$"); + private static final Pattern TIMESTAMP_FILTER_REGEX = Pattern.compile("timestamp>=\"(.*)\""); private static final MediaType JSON_MEDIA_TYPE = MediaType.parse("application/json; charset=utf-8"); @@ -200,8 +204,16 @@ private ApiResponse handleGetEvents(String path) throws Exception { return null; } String sessionId = matcher.group(2); + // The client URL-escapes the filter value; decode it as the real server would. + String filter = + matcher.group(3) == null + ? null + : URLDecoder.decode(matcher.group(3), StandardCharsets.UTF_8); String eventData = eventMap.get(sessionId); if (eventData != null) { + if (filter != null) { + eventData = applyTimestampFilter(eventData, filter); + } return responseWithBody( String.format( """ @@ -216,6 +228,25 @@ private ApiResponse handleGetEvents(String path) throws Exception { } } + /** Emulates the server-side inclusive {@code timestamp>=} filter on the events list. */ + private static String applyTimestampFilter(String eventData, String filter) throws Exception { + Matcher filterMatcher = TIMESTAMP_FILTER_REGEX.matcher(filter); + if (!filterMatcher.matches()) { + return eventData; + } + Instant threshold = Instant.parse(filterMatcher.group(1)); + List> events = + mapper.readValue(eventData, new TypeReference>>() {}); + List> kept = new ArrayList<>(); + for (Map event : events) { + Instant timestamp = Instant.parse((String) event.get("timestamp")); + if (!timestamp.isBefore(threshold)) { + kept.add(event); + } + } + return mapper.writeValueAsString(kept); + } + private ApiResponse handleGetLro(String path) { return responseWithBody( String.format( diff --git a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java index a03529834..f386a4521 100644 --- a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java @@ -4,6 +4,9 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.fasterxml.jackson.core.type.TypeReference; @@ -29,6 +32,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -444,6 +448,31 @@ public void getSession_afterTimestampBetweenEvents_dropsEventsBeforeThreshold() assertThat(session.events().stream().map(Event::id)).containsExactly("e3"); } + @Test + public void getSession_afterTimestampConfig_urlEscapesFilterInRequest() { + sessionMap.put("9", mockSessionJson("9", "2024-12-12T12:00:30.000000Z")); + eventMap.put("9", mockEventsJson(mockEventJson("e1", "2024-12-12T12:00:15.000000Z"))); + GetSessionConfig config = + GetSessionConfig.builder() + .afterTimestamp(Instant.parse("2024-12-12T12:00:10.000000Z")) + .build(); + + Object unused = + vertexAiSessionService.getSession("123", "user", "9", Optional.of(config)).blockingGet(); + + ArgumentCaptor pathCaptor = ArgumentCaptor.forClass(String.class); + verify(mockApiClient, atLeastOnce()).request(eq("GET"), pathCaptor.capture(), eq("")); + String eventsPath = + pathCaptor.getAllValues().stream() + .filter(path -> path.contains("/events")) + .findFirst() + .orElseThrow(() -> new AssertionError("No list-events request was made")); + // The filter operator and quotes are URL-escaped (>= -> %3E%3D, " -> %22), + // not sent raw. + assertThat(eventsPath).contains("filter=timestamp%3E%3D%22"); + assertThat(eventsPath).doesNotContain("timestamp>="); + } + @Test public void getSession_numRecentEventsConfig_returnsMostRecentEvents() { sessionMap.put("7", mockSessionJson("7", "2024-12-12T12:00:30.000000Z"));