diff --git a/core/src/main/java/com/google/adk/models/ApigeeLlm.java b/core/src/main/java/com/google/adk/models/ApigeeLlm.java index 088e0af76..c2ad27b01 100644 --- a/core/src/main/java/com/google/adk/models/ApigeeLlm.java +++ b/core/src/main/java/com/google/adk/models/ApigeeLlm.java @@ -19,7 +19,9 @@ import static com.google.common.base.Strings.isNullOrEmpty; import com.google.adk.Version; +import com.google.adk.models.chat.ChatCompletionsHttpClient; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.Client; @@ -28,6 +30,8 @@ import java.util.HashMap; import java.util.Map; import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A {@link BaseLlm} implementation for calling an Apigee proxy. @@ -36,6 +40,7 @@ * allows for specifying the provider (Gemini or Vertex AI), API version, and model ID. */ public class ApigeeLlm extends BaseLlm { + private static final Logger logger = LoggerFactory.getLogger(ApigeeLlm.class); private static final String GOOGLE_GENAI_USE_VERTEXAI_ENV_VARIABLE_NAME = "GOOGLE_GENAI_USE_VERTEXAI"; private static final String APIGEE_PROXY_URL_ENV_VARIABLE_NAME = "APIGEE_PROXY_URL"; @@ -51,9 +56,18 @@ public class ApigeeLlm extends BaseLlm { "user-agent", versionHeaderValue); } + /** Defines the type of API to be used by the Apigee proxy. */ + public enum ApiType { + UNKNOWN, + CHAT_COMPLETIONS, + GENAI + } + private final Gemini geminiDelegate; + private final ChatCompletionsHttpClient chatCompletionsHttpClient; private final Client apiClient; private final HttpOptions httpOptions; + private final ApiType apiType; /** * Constructs a new ApigeeLlm instance. @@ -62,7 +76,8 @@ public class ApigeeLlm extends BaseLlm { * @param proxyUrl The URL of the Apigee proxy. * @param customHeaders A map of custom headers to be sent with the request. */ - private ApigeeLlm(String modelName, String proxyUrl, Map customHeaders) { + private ApigeeLlm( + String modelName, String proxyUrl, Map customHeaders, ApiType apiType) { super(modelName); if (!validateModelString(modelName)) { @@ -71,6 +86,16 @@ private ApigeeLlm(String modelName, String proxyUrl, Map customH + modelName); } + if (apiType == ApiType.UNKNOWN) { + if (modelName.startsWith("apigee/openai/")) { + this.apiType = ApiType.CHAT_COMPLETIONS; + } else { + this.apiType = ApiType.GENAI; + } + } else { + this.apiType = apiType; + } + String effectiveProxyUrl = proxyUrl; if (isNullOrEmpty(effectiveProxyUrl)) { effectiveProxyUrl = System.getenv(APIGEE_PROXY_URL_ENV_VARIABLE_NAME); @@ -96,13 +121,26 @@ private ApigeeLlm(String modelName, String proxyUrl, Map customH .buildOrThrow()); } this.httpOptions = httpOptionsBuilder.build(); - Client.Builder apiClientBuilder = Client.builder().httpOptions(this.httpOptions); - if (isVertexAiModel(modelName)) { - apiClientBuilder.vertexAI(true); + + if (this.apiType == ApiType.CHAT_COMPLETIONS) { + this.apiClient = null; + this.geminiDelegate = null; + this.chatCompletionsHttpClient = new ChatCompletionsHttpClient(this.httpOptions); + } else { + Client.Builder apiClientBuilder = Client.builder().httpOptions(this.httpOptions); + if (isVertexAiModel(modelName)) { + apiClientBuilder.vertexAI(true); + } + this.apiClient = apiClientBuilder.build(); + this.geminiDelegate = new Gemini(modelName, apiClient); + this.chatCompletionsHttpClient = null; } - this.apiClient = apiClientBuilder.build(); - this.geminiDelegate = new Gemini(modelName, apiClient); + logger.trace( + "ApigeeLlm constructed: modelName={} apiType={} effectiveProxyUrl={}", + modelName, + this.apiType, + effectiveProxyUrl); } /** @@ -113,10 +151,31 @@ private ApigeeLlm(String modelName, String proxyUrl, Map customH */ @VisibleForTesting ApigeeLlm(String modelName, Gemini geminiDelegate) { + this(modelName, geminiDelegate, null); + } + + /** + * Constructs a new ApigeeLlm instance for testing purposes. + * + * @param modelName The name of the Apigee model to use. + * @param geminiDelegate The Gemini delegate to use for making API calls. + * @param chatCompletionsHttpClient The ChatCompletionsHttpClient to use for making API calls. + */ + @VisibleForTesting + ApigeeLlm( + String modelName, + Gemini geminiDelegate, + ChatCompletionsHttpClient chatCompletionsHttpClient) { super(modelName); this.apiClient = null; this.httpOptions = null; this.geminiDelegate = geminiDelegate; + this.chatCompletionsHttpClient = chatCompletionsHttpClient; + if (chatCompletionsHttpClient != null) { + this.apiType = ApiType.CHAT_COMPLETIONS; + } else { + this.apiType = ApiType.GENAI; + } } /** @@ -178,6 +237,7 @@ public static class Builder { private String modelName; private String proxyUrl; private Map customHeaders = new HashMap<>(); + private ApiType apiType = ApiType.UNKNOWN; protected Builder() {} @@ -243,6 +303,19 @@ public Builder customHeaders(Map customHeaders) { return this; } + /** + * Sets the explicit {@link ApiType} to use (e.g., CHAT_COMPLETIONS or GENAI). + * + * @param apiType the type of API. + * @return this builder. + * @throws NullPointerException if {@code apiType} is null. + */ + @CanIgnoreReturnValue + public Builder apiType(ApiType apiType) { + this.apiType = Preconditions.checkNotNull(apiType); + return this; + } + /** * Builds the {@link ApigeeLlm} instance. * @@ -255,7 +328,7 @@ public ApigeeLlm build() { throw new IllegalArgumentException("Invalid model string: " + modelName); } - return new ApigeeLlm(modelName, proxyUrl, customHeaders); + return new ApigeeLlm(modelName, proxyUrl, customHeaders, apiType); } } @@ -264,11 +337,23 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre String modelToUse = llmRequest.model().orElse(model()); String modelId = getModelId(modelToUse); LlmRequest newLlmRequest = llmRequest.toBuilder().model(modelId).build(); + + logger.debug("ApigeeLlm.generateContent routing through {} for model {}", apiType, modelId); + + if (apiType == ApiType.CHAT_COMPLETIONS) { + return chatCompletionsHttpClient.complete(newLlmRequest, stream); + } + return geminiDelegate.generateContent(newLlmRequest, stream); } @Override public BaseLlmConnection connect(LlmRequest llmRequest) { + if (apiType == ApiType.CHAT_COMPLETIONS) { + throw new UnsupportedOperationException( + "Streaming connections are not supported for chat completions."); + } + String modelToUse = llmRequest.model().orElse(model()); String modelId = getModelId(modelToUse); LlmRequest newLlmRequest = llmRequest.toBuilder().model(modelId).build(); @@ -297,7 +382,9 @@ private static boolean validateModelString(String model) { return components[1].startsWith("v"); } if (components.length == 2) { - if (components[0].equals("vertex_ai") || components[0].equals("gemini")) { + if (components[0].equals("vertex_ai") + || components[0].equals("gemini") + || components[0].equals("openai")) { return true; } return components[0].startsWith("v"); diff --git a/core/src/test/java/com/google/adk/models/ApigeeLlmTest.java b/core/src/test/java/com/google/adk/models/ApigeeLlmTest.java index 6ba2832c0..fce26e812 100644 --- a/core/src/test/java/com/google/adk/models/ApigeeLlmTest.java +++ b/core/src/test/java/com/google/adk/models/ApigeeLlmTest.java @@ -18,6 +18,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; import static org.junit.Assume.assumeNotNull; +import static org.junit.Assume.assumeTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.eq; @@ -25,6 +26,8 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.adk.models.ApigeeLlm.ApiType; +import com.google.adk.models.chat.ChatCompletionsHttpClient; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; @@ -47,6 +50,7 @@ public class ApigeeLlmTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @Mock private Gemini mockGeminiDelegate; + @Mock private ChatCompletionsHttpClient mockCcClient; private static final String PROXY_URL = "https://test.apigee.net"; @@ -62,7 +66,8 @@ public void build_withValidModelStrings_succeeds() { "apigee/v1/whatever-model", "apigee/vertex_ai/whatever-model", "apigee/gemini/v1/whatever-model", - "apigee/vertex_ai/v1beta/whatever-model" + "apigee/vertex_ai/v1beta/whatever-model", + "apigee/openai/gpt-4" }; for (String modelName : validModelStrings) { @@ -84,9 +89,10 @@ public void build_withInvalidModelStrings_throwsException() { }; for (String modelName : invalidModelStrings) { - assertThrows( - IllegalArgumentException.class, - () -> ApigeeLlm.builder().modelName(modelName).proxyUrl(PROXY_URL).build()); + ApigeeLlm.Builder builder = ApigeeLlm.builder().modelName(modelName).proxyUrl(PROXY_URL); + IllegalArgumentException e = + assertThrows(IllegalArgumentException.class, () -> builder.build()); + assertThat(e).hasMessageThat().contains("Invalid model string: " + modelName); } } @@ -108,6 +114,47 @@ public void generateContent_stripsApigeePrefixAndSendsToDelegate() { assertThat(requestCaptor.getValue().model()).hasValue("whatever-model"); } + @Test + public void generateContent_withChatCompletionsApiType_sendsToCcClient() { + when(mockCcClient.complete(any(), anyBoolean())).thenReturn(Flowable.empty()); + + ApigeeLlm llm = new ApigeeLlm("apigee/openai/gpt-4", mockGeminiDelegate, mockCcClient); + + LlmRequest request = + LlmRequest.builder() + .model("apigee/openai/gpt-4") + .contents(ImmutableList.of(Content.builder().parts(Part.fromText("hello")).build())) + .build(); + llm.generateContent(request, false).test().assertNoErrors(); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(LlmRequest.class); + verify(mockCcClient).complete(requestCaptor.capture(), eq(false)); + verify(mockGeminiDelegate, never()).generateContent(any(), anyBoolean()); + assertThat(requestCaptor.getValue().model()).hasValue("gpt-4"); + } + + @Test + public void connect_withChatCompletionsApiType_throwsUnsupportedOperationException() { + ApigeeLlm llm = new ApigeeLlm("apigee/openai/gpt-4", mockGeminiDelegate, mockCcClient); + LlmRequest request = LlmRequest.builder().model("apigee/openai/gpt-4").build(); + UnsupportedOperationException e = + assertThrows(UnsupportedOperationException.class, () -> llm.connect(request)); + assertThat(e) + .hasMessageThat() + .contains("Streaming connections are not supported for chat completions."); + } + + @Test + public void build_withExplicitChatCompletionsApiType_success() { + ApigeeLlm llm = + ApigeeLlm.builder() + .modelName("apigee/whatever-model") + .proxyUrl(PROXY_URL) + .apiType(ApiType.CHAT_COMPLETIONS) + .build(); + assertThat(llm).isNotNull(); + } + // Add a test to verify the vertexAI flag is set correctly. @Test public void generateContent_setsVertexAiFlagCorrectly_withVertexAi() { @@ -187,23 +234,163 @@ public void build_withTrailingSlashInModel_parsesVersionAndModelId() { LlmRequest.builder() .contents(ImmutableList.of(Content.builder().parts(Part.fromText("hi")).build())) .build(); - assertThrows(IllegalArgumentException.class, () -> llm.generateContent(request, false)); + IllegalArgumentException e = + assertThrows(IllegalArgumentException.class, () -> llm.generateContent(request, false)); + assertThat(e) + .hasMessageThat() + .contains( + "Invalid model string, expected apigee/[/][/]: " + + "apigee/gemini/v1/"); verify(mockGeminiDelegate, never()).generateContent(any(), anyBoolean()); } @Test - public void build_withoutProxyUrl_readsFromEnvironment() { + public void build_withoutProxyUrlAndEnvVarSet_readsFromEnvironment() { + assumeNotNull(System.getenv("APIGEE_PROXY_URL")); String envProxyUrl = System.getenv("APIGEE_PROXY_URL"); - if (envProxyUrl != null) { - ApigeeLlm llm = ApigeeLlm.builder().modelName("apigee/whatever-model").build(); - assertThat(llm.getHttpOptions().baseUrl()).hasValue(envProxyUrl); - } else { - assertThrows( - IllegalArgumentException.class, - () -> ApigeeLlm.builder().modelName("apigee/whatever-model").build()); - ApigeeLlm llm = - ApigeeLlm.builder().proxyUrl(PROXY_URL).modelName("apigee/whatever-model").build(); - assertThat(llm.getHttpOptions().baseUrl()).hasValue(PROXY_URL); + ApigeeLlm llm = ApigeeLlm.builder().modelName("apigee/whatever-model").build(); + assertThat(llm.getHttpOptions().baseUrl()).hasValue(envProxyUrl); + } + + @Test + public void build_withoutProxyUrlAndEnvVarNotSet_throwsException() { + assumeTrue(System.getenv("APIGEE_PROXY_URL") == null); + ApigeeLlm.Builder builder = ApigeeLlm.builder().modelName("apigee/whatever-model"); + IllegalArgumentException e = + assertThrows(IllegalArgumentException.class, () -> builder.build()); + assertThat(e) + .hasMessageThat() + .contains( + "Apigee proxy URL is not set and not found in the environment variable" + + " APIGEE_PROXY_URL."); + } + + @Test + public void build_withProxyUrl_usesProvidedUrl() { + ApigeeLlm llm = + ApigeeLlm.builder().proxyUrl(PROXY_URL).modelName("apigee/whatever-model").build(); + assertThat(llm.getHttpOptions().baseUrl()).hasValue(PROXY_URL); + } + + @Test + public void generateContent_withChatCompletionsApiType_sendsToCcClient_streaming() { + when(mockCcClient.complete(any(), anyBoolean())).thenReturn(Flowable.empty()); + + ApigeeLlm llm = new ApigeeLlm("apigee/openai/gpt-4o", mockGeminiDelegate, mockCcClient); + LlmRequest request = + LlmRequest.builder() + .model("apigee/openai/gpt-4o") + .contents(ImmutableList.of(Content.builder().parts(Part.fromText("hello")).build())) + .build(); + llm.generateContent(request, true).test().assertNoErrors(); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(LlmRequest.class); + verify(mockCcClient).complete(requestCaptor.capture(), eq(true)); + verify(mockGeminiDelegate, never()).generateContent(any(), anyBoolean()); + assertThat(requestCaptor.getValue().model()).hasValue("gpt-4o"); + } + + @Test + public void generateContent_requestLevelModelOverride_extractedCorrectly() { + when(mockCcClient.complete(any(), anyBoolean())).thenReturn(Flowable.empty()); + + ApigeeLlm llm = new ApigeeLlm("apigee/openai/gpt-4o", mockGeminiDelegate, mockCcClient); + LlmRequest request = + LlmRequest.builder() + .model("apigee/openai/gpt-3.5-turbo") + .contents(ImmutableList.of(Content.builder().parts(Part.fromText("hello")).build())) + .build(); + + llm.generateContent(request, false).test().assertNoErrors(); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(LlmRequest.class); + verify(mockCcClient).complete(requestCaptor.capture(), eq(false)); + assertThat(requestCaptor.getValue().model()).hasValue("gpt-3.5-turbo"); + } + + @Test + public void validateModelString_rejectsOpenAiWithVersion() { + // 3-component model string (e.g. apigee/openai/v1/gpt-4o) fails because "openai" != "vertex_ai" + // and != "gemini" + ApigeeLlm.Builder builder = + ApigeeLlm.builder().modelName("apigee/openai/v1/gpt-4o").proxyUrl(PROXY_URL); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, builder::build); + assertThat(e).hasMessageThat().contains("Invalid model string: apigee/openai/v1/gpt-4o"); + } + + @Test + public void build_withCustomHeadersOverlappingTrackingHeaders_throwsException() { + ImmutableMap overlappingHeaders = ImmutableMap.of("user-agent", "custom-agent"); + ApigeeLlm.Builder builder = + ApigeeLlm.builder() + .modelName("apigee/openai/gpt-4o") + .proxyUrl(PROXY_URL) + .customHeaders(overlappingHeaders); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, builder::build); + assertThat(e).hasMessageThat().contains("Multiple entries with same key: user-agent="); + } + + @Test + public void build_withNullApiType_throwsNullPointerException() { + ApigeeLlm.Builder builder = + ApigeeLlm.builder().modelName("apigee/whatever-model").proxyUrl(PROXY_URL); + assertThrows(NullPointerException.class, () -> builder.apiType(null)); + } + + @Test + public void generateContent_crossApiTypeRequestOverride_routesBasedOnOriginalApiType() { + when(mockGeminiDelegate.generateContent(any(), anyBoolean())).thenReturn(Flowable.empty()); + + // Original ApiType is GENAI implicitly + ApigeeLlm llm = new ApigeeLlm("apigee/gemini/gemini-pro", mockGeminiDelegate); + + // Override specifies openai models + LlmRequest request = + LlmRequest.builder() + .model("apigee/openai/gpt-4o") + .contents(ImmutableList.of(Content.builder().parts(Part.fromText("hello")).build())) + .build(); + + llm.generateContent(request, false).test().assertNoErrors(); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(LlmRequest.class); + verify(mockGeminiDelegate).generateContent(requestCaptor.capture(), eq(false)); + // It still goes to gemini delegate, but model is stripped + assertThat(requestCaptor.getValue().model()).hasValue("gpt-4o"); + } + + @Test + public void generateContent_invalidRequestLevelOverride_throwsException() { + ApigeeLlm llm = new ApigeeLlm("apigee/openai/gpt-4o", mockGeminiDelegate, mockCcClient); + + LlmRequest request = + LlmRequest.builder() + .model("invalid-no-apigee-prefix") + .contents(ImmutableList.of(Content.builder().parts(Part.fromText("hello")).build())) + .build(); + + IllegalArgumentException e = + assertThrows(IllegalArgumentException.class, () -> llm.generateContent(request, false)); + assertThat(e) + .hasMessageThat() + .contains( + "Invalid model string, expected apigee/[/][/]: " + + "invalid-no-apigee-prefix"); + } + + @Test + public void build_nullModelName_throwsNullPointerException() { + ApigeeLlm.Builder builder = ApigeeLlm.builder().modelName(null).proxyUrl(PROXY_URL); + assertThrows(NullPointerException.class, builder::build); + } + + @Test + public void build_malformedModelsWithTrailingSlashes_throwsException() { + String[] malformedModels = {"apigee/openai/", "apigee/openai/gpt-4o/"}; + for (String modelName : malformedModels) { + ApigeeLlm.Builder builder = ApigeeLlm.builder().modelName(modelName).proxyUrl(PROXY_URL); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, builder::build); + assertThat(e).hasMessageThat().contains("Invalid model string: " + modelName); } } } diff --git a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java index eb00b3770..ec4246f90 100644 --- a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java @@ -710,10 +710,8 @@ private static void assertThoughtSignatureExtraContent( assertThat(extraContent).containsKey("google"); @SuppressWarnings("unchecked") // This code won't run in production and it is a JSON object. Map google = (Map) extraContent.get("google"); - assertThat(google).containsKey("thought_signature"); - Object sigObj = google.get("thought_signature"); - assertThat(sigObj).isInstanceOf(String.class); - assertThat(Base64.getDecoder().decode((String) sigObj)).isEqualTo(expected); + String expectedB64 = Base64.getEncoder().encodeToString(expected); + assertThat(google).containsEntry("thought_signature", expectedB64); } @Test