diff --git a/.changes/behavior-attention-channel-bottle.json b/.changes/behavior-attention-channel-bottle.json new file mode 100644 index 00000000..30149f64 --- /dev/null +++ b/.changes/behavior-attention-channel-bottle.json @@ -0,0 +1 @@ +{"type":"MINOR","changes":["Add RequestOptions; configuration points for backend implementation details such as api version and timeout."]} diff --git a/.changes/condition-company-cloth-distribution.json b/.changes/condition-company-cloth-distribution.json new file mode 100644 index 00000000..4aeafa0b --- /dev/null +++ b/.changes/condition-company-cloth-distribution.json @@ -0,0 +1 @@ +{"type":"MAJOR","changes":["Support a general model naming schema"]} diff --git a/generativeai/build.gradle.kts b/generativeai/build.gradle.kts index d78d3079..39d5f72c 100644 --- a/generativeai/build.gradle.kts +++ b/generativeai/build.gradle.kts @@ -83,7 +83,7 @@ dependencies { implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1") implementation("androidx.core:core-ktx:1.12.0") - implementation("org.slf4j:slf4j-android:1.7.36") + implementation("org.slf4j:slf4j-nop:2.0.9") implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3") implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.7.3") implementation("org.reactivestreams:reactive-streams:1.0.3") @@ -94,6 +94,7 @@ dependencies { testImplementation("junit:junit:4.13.2") testImplementation("io.kotest:kotest-assertions-core:4.0.7") testImplementation("io.kotest:kotest-assertions-jvm:4.0.7") + testImplementation("io.kotest:kotest-assertions-json:4.0.7") testImplementation("io.ktor:ktor-client-mock:$ktorVersion") androidTestImplementation("androidx.test.ext:junit:1.1.5") androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 15064aa7..94f4f9c5 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -29,6 +29,7 @@ import com.google.ai.client.generativeai.type.GenerateContentResponse import com.google.ai.client.generativeai.type.GenerationConfig import com.google.ai.client.generativeai.type.GoogleGenerativeAIException import com.google.ai.client.generativeai.type.PromptBlockedException +import com.google.ai.client.generativeai.type.RequestOptions import com.google.ai.client.generativeai.type.ResponseStoppedException import com.google.ai.client.generativeai.type.SafetySetting import com.google.ai.client.generativeai.type.SerializationException @@ -45,6 +46,7 @@ import kotlinx.coroutines.flow.map * @property generationConfig configuration parameters to use for content generation * @property safetySettings the safety bounds to use during alongside prompts during content * generation + * @property requestOptions configuration options to utilize during backend communication */ class GenerativeModel internal constructor( @@ -52,6 +54,7 @@ internal constructor( val apiKey: String, val generationConfig: GenerationConfig? = null, val safetySettings: List? = null, + val requestOptions: RequestOptions = RequestOptions(), private val controller: APIController ) { @@ -61,7 +64,15 @@ internal constructor( apiKey: String, generationConfig: GenerationConfig? = null, safetySettings: List? = null, - ) : this(modelName, apiKey, generationConfig, safetySettings, APIController(apiKey, modelName)) + requestOptions: RequestOptions = RequestOptions(), + ) : this( + modelName, + apiKey, + generationConfig, + safetySettings, + requestOptions, + APIController(apiKey, modelName, requestOptions.apiVersion, requestOptions.timeout) + ) /** * Generates a response from the backend with the provided [Content]s. diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt index e46016ca..fd750111 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt @@ -37,14 +37,15 @@ import io.ktor.http.ContentType import io.ktor.http.HttpStatusCode import io.ktor.http.contentType import io.ktor.serialization.kotlinx.json.json +import kotlin.time.Duration import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.channelFlow +import kotlinx.coroutines.flow.timeout import kotlinx.coroutines.launch import kotlinx.serialization.json.Json -// TODO: Should these stay here or be moved elsewhere? -internal const val DOMAIN = "https://generativelanguage.googleapis.com/v1" +internal const val DOMAIN = "https://generativelanguage.googleapis.com" internal val JSON = Json { ignoreUnknownKeys = true @@ -60,42 +61,46 @@ internal val JSON = Json { * Exposed primarily for DI in tests. * @property key The API key used for authentication. * @property model The model to use for generation. + * @property apiVersion the endpoint version to communicate with. + * @property timeout the maximum amount of time for a request to take in the initial exchange. */ internal class APIController( private val key: String, model: String, - httpEngine: HttpClientEngine = OkHttp.create() + private val apiVersion: String, + private val timeout: Duration, + httpEngine: HttpClientEngine = OkHttp.create(), ) { private val model = fullModelName(model) private val client = HttpClient(httpEngine) { install(HttpTimeout) { - requestTimeoutMillis = HttpTimeout.INFINITE_TIMEOUT_MS + requestTimeoutMillis = timeout.inWholeMilliseconds socketTimeoutMillis = 80_000 } install(ContentNegotiation) { json(JSON) } } - suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse { - return client - .post("$DOMAIN/$model:generateContent") { applyCommonConfiguration(request) } + suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse = + client + .post("$DOMAIN/$apiVersion/$model:generateContent") { applyCommonConfiguration(request) } .also { validateResponse(it) } .body() - } fun generateContentStream(request: GenerateContentRequest): Flow { - return client.postStream("$DOMAIN/$model:streamGenerateContent?alt=sse") { + return client.postStream( + "$DOMAIN/$apiVersion/$model:streamGenerateContent?alt=sse" + ) { applyCommonConfiguration(request) } } - suspend fun countTokens(request: CountTokensRequest): CountTokensResponse { - return client - .post("$DOMAIN/$model:countTokens") { applyCommonConfiguration(request) } + suspend fun countTokens(request: CountTokensRequest): CountTokensResponse = + client + .post("$DOMAIN/$apiVersion/$model:countTokens") { applyCommonConfiguration(request) } .also { validateResponse(it) } .body() - } private fun HttpRequestBuilder.applyCommonConfiguration(request: Request) { when (request) { @@ -113,8 +118,7 @@ internal class APIController( * * Models must be prepended with the `models/` prefix when communicating with the backend. */ -private fun fullModelName(name: String): String = - name.takeIf { it.startsWith("models/") } ?: "models/$name" +private fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name" /** * Makes a POST request to the specified [url] and returns a [Flow] of deserialized response objects diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Exceptions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Exceptions.kt index f8fefb68..9171da3c 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Exceptions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Exceptions.kt @@ -18,6 +18,7 @@ package com.google.ai.client.generativeai.type import com.google.ai.client.generativeai.GenerativeModel import io.ktor.serialization.JsonConvertException +import kotlinx.coroutines.TimeoutCancellationException /** Parent class for any errors that occur from [GenerativeModel]. */ sealed class GoogleGenerativeAIException(message: String, cause: Throwable? = null) : @@ -39,6 +40,8 @@ sealed class GoogleGenerativeAIException(message: String, cause: Throwable? = nu "Something went wrong while trying to deserialize a response from the server.", cause ) + is TimeoutCancellationException -> + RequestTimeoutException("The request failed to complete in the allotted time.") else -> UnknownException("Something unexpected happened.", cause) } } @@ -84,6 +87,14 @@ class ResponseStoppedException(val response: GenerateContentResponse, cause: Thr cause ) +/** + * A request took too long to complete. + * + * Usually occurs due to a user specified [timeout][RequestOptions.timeout]. + */ +class RequestTimeoutException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) + /** Catch all case for exceptions not explicitly expected. */ class UnknownException(message: String, cause: Throwable? = null) : GoogleGenerativeAIException(message, cause) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt new file mode 100644 index 00000000..cc9669d9 --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt @@ -0,0 +1,40 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +import io.ktor.client.plugins.HttpTimeout +import kotlin.time.Duration +import kotlin.time.DurationUnit +import kotlin.time.toDuration + +/** + * Configurable options unique to how requests to the backend are performed. + * + * @property timeout the maximum amount of time for a request to take, from the first request to + * first response. + * @property apiVersion the api endpoint to call. + */ +class RequestOptions(val timeout: Duration, val apiVersion: String = "v1") { + @JvmOverloads + constructor( + timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS, + apiVersion: String = "v1" + ) : this( + (timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS), + apiVersion + ) +} diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt index da0e90b6..15145389 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt @@ -16,16 +16,29 @@ package com.google.ai.client.generativeai +import com.google.ai.client.generativeai.type.RequestOptions +import com.google.ai.client.generativeai.type.RequestTimeoutException import com.google.ai.client.generativeai.util.commonTest +import com.google.ai.client.generativeai.util.createGenerativeModel import com.google.ai.client.generativeai.util.createResponses +import com.google.ai.client.generativeai.util.doBlocking import com.google.ai.client.generativeai.util.prepareStreamingResponse +import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.respond +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.headersOf +import io.ktor.utils.io.ByteChannel import io.ktor.utils.io.close import io.ktor.utils.io.writeFully import kotlin.time.Duration.Companion.seconds -import kotlinx.coroutines.flow.collect import kotlinx.coroutines.withTimeout import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized internal class GenerativeModelTests { private val testTimeout = 5.seconds @@ -45,4 +58,49 @@ internal class GenerativeModelTests { } } } + + @Test + fun `(generateContent) respects a custom timeout`() = + commonTest(requestOptions = RequestOptions(2.seconds)) { + shouldThrow { + withTimeout(testTimeout) { model.generateContent("d") } + } + } +} + +@RunWith(Parameterized::class) +internal class ModelNamingTests(private val modelName: String, private val actualName: String) { + + @Test + fun `request should include right model name`() = doBlocking { + val channel = ByteChannel(autoFlush = true) + val mockEngine = MockEngine { + respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) + } + prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } + val model = + createGenerativeModel(modelName, "super_cool_test_key", RequestOptions(), mockEngine) + + withTimeout(5.seconds) { + model.generateContentStream().collect { + it.candidates.isEmpty() shouldBe false + channel.close() + } + } + + mockEngine.requestHistory.first().url.encodedPath shouldContain actualName + } + + companion object { + @JvmStatic + @Parameterized.Parameters + fun data() = + listOf( + arrayOf("gemini-pro", "models/gemini-pro"), + arrayOf("x/gemini-pro", "x/gemini-pro"), + arrayOf("models/gemini-pro", "models/gemini-pro"), + arrayOf("/modelname", "/modelname"), + arrayOf("modifiedNaming/mymodel", "modifiedNaming/mymodel"), + ) + } } diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/util/tests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/util/tests.kt index fe498754..092b141c 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/util/tests.kt +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/util/tests.kt @@ -28,6 +28,7 @@ import com.google.ai.client.generativeai.internal.api.shared.Content import com.google.ai.client.generativeai.internal.api.shared.TextPart import com.google.ai.client.generativeai.internal.util.SSE_SEPARATOR import com.google.ai.client.generativeai.internal.util.send +import com.google.ai.client.generativeai.type.RequestOptions import io.ktor.client.engine.mock.MockEngine import io.ktor.client.engine.mock.respond import io.ktor.http.HttpHeaders @@ -93,19 +94,42 @@ internal typealias CommonTest = suspend CommonTestScope.() -> Unit * ``` * * @param status An optional [HttpStatusCode] to return as a response + * @param requestOptions Optional [RequestOptions] to utilize in the underlying controller * @param block The test contents themselves, with the [CommonTestScope] implicitly provided * @see CommonTestScope */ -internal fun commonTest(status: HttpStatusCode = HttpStatusCode.OK, block: CommonTest) = - doBlocking { - val channel = ByteChannel(autoFlush = true) - val mockEngine = MockEngine { - respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json")) - } - val controller = APIController("super_cool_test_key", "gemini-pro", mockEngine) - val model = GenerativeModel("gemini-pro", "super_cool_test_key", controller = controller) - CommonTestScope(channel, model).block() +internal fun commonTest( + status: HttpStatusCode = HttpStatusCode.OK, + requestOptions: RequestOptions = RequestOptions(), + block: CommonTest +) = doBlocking { + val channel = ByteChannel(autoFlush = true) + val mockEngine = MockEngine { + respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json")) } + val model = createGenerativeModel("gemini-pro", "super_cool_test_key", requestOptions, mockEngine) + CommonTestScope(channel, model).block() +} + +/** Simple wrapper that guarantees the model and APIController are created using the same data */ +internal fun createGenerativeModel( + name: String, + apikey: String, + requestOptions: RequestOptions = RequestOptions(), + engine: MockEngine +) = + GenerativeModel( + name, + apikey, + controller = + APIController( + "super_cool_test_key", + name, + requestOptions.apiVersion, + requestOptions.timeout, + engine + ) + ) /** * A variant of [commonTest] for performing *streaming-based* snapshot tests.