diff --git a/openai-java-client-okhttp/src/main/kotlin/com/openai/client/okhttp/OkHttpClient.kt b/openai-java-client-okhttp/src/main/kotlin/com/openai/client/okhttp/OkHttpClient.kt index 1efd82e5..6cbaeacc 100644 --- a/openai-java-client-okhttp/src/main/kotlin/com/openai/client/okhttp/OkHttpClient.kt +++ b/openai-java-client-okhttp/src/main/kotlin/com/openai/client/okhttp/OkHttpClient.kt @@ -31,38 +31,11 @@ class OkHttpClient private constructor(private val okHttpClient: okhttp3.OkHttpClient, private val baseUrl: HttpUrl) : HttpClient { - private fun getClient(requestOptions: RequestOptions): okhttp3.OkHttpClient { - val clientBuilder = okHttpClient.newBuilder() - - val logLevel = - when (System.getenv("OPENAI_LOG")?.lowercase()) { - "info" -> HttpLoggingInterceptor.Level.BASIC - "debug" -> HttpLoggingInterceptor.Level.BODY - else -> null - } - if (logLevel != null) { - clientBuilder.addNetworkInterceptor( - HttpLoggingInterceptor().setLevel(logLevel).apply { redactHeader("Authorization") } - ) - } - - val timeout = requestOptions.timeout - if (timeout != null) { - clientBuilder - .connectTimeout(timeout) - .readTimeout(timeout) - .writeTimeout(timeout) - .callTimeout(if (timeout.seconds == 0L) timeout else timeout.plusSeconds(30)) - } - - return clientBuilder.build() - } - override fun execute( request: HttpRequest, requestOptions: RequestOptions, ): HttpResponse { - val call = getClient(requestOptions).newCall(request.toRequest()) + val call = newCall(request, requestOptions) return try { call.execute().toResponse() @@ -81,18 +54,18 @@ private constructor(private val okHttpClient: okhttp3.OkHttpClient, private val request.body?.run { future.whenComplete { _, _ -> close() } } - val call = getClient(requestOptions).newCall(request.toRequest()) - call.enqueue( - object : Callback { - override fun onResponse(call: Call, response: Response) { - future.complete(response.toResponse()) - } + newCall(request, requestOptions) + .enqueue( + object : Callback { + override fun onResponse(call: Call, response: Response) { + future.complete(response.toResponse()) + } - override fun onFailure(call: Call, e: IOException) { - future.completeExceptionally(OpenAIIoException("Request failed", e)) + override fun onFailure(call: Call, e: IOException) { + future.completeExceptionally(OpenAIIoException("Request failed", e)) + } } - } - ) + ) return future } @@ -103,7 +76,35 @@ private constructor(private val okHttpClient: okhttp3.OkHttpClient, private val okHttpClient.cache?.close() } - private fun HttpRequest.toRequest(): Request { + private fun newCall(request: HttpRequest, requestOptions: RequestOptions): Call { + val clientBuilder = okHttpClient.newBuilder() + + val logLevel = + when (System.getenv("OPENAI_LOG")?.lowercase()) { + "info" -> HttpLoggingInterceptor.Level.BASIC + "debug" -> HttpLoggingInterceptor.Level.BODY + else -> null + } + if (logLevel != null) { + clientBuilder.addNetworkInterceptor( + HttpLoggingInterceptor().setLevel(logLevel).apply { redactHeader("Authorization") } + ) + } + + val timeout = requestOptions.timeout + if (timeout != null) { + clientBuilder + .connectTimeout(timeout) + .readTimeout(timeout) + .writeTimeout(timeout) + .callTimeout(if (timeout.seconds == 0L) timeout else timeout.plusSeconds(30)) + } + + val client = clientBuilder.build() + return client.newCall(request.toRequest(client)) + } + + private fun HttpRequest.toRequest(client: okhttp3.OkHttpClient): Request { var body: RequestBody? = body?.toRequestBody() // OkHttpClient always requires a request body for PUT and POST methods. if (body == null && (method == HttpMethod.PUT || method == HttpMethod.POST)) { @@ -115,6 +116,21 @@ private constructor(private val okHttpClient: okhttp3.OkHttpClient, private val headers.values(name).forEach { builder.header(name, it) } } + if ( + !headers.names().contains("X-Stainless-Read-Timeout") && client.readTimeoutMillis != 0 + ) { + builder.header( + "X-Stainless-Read-Timeout", + Duration.ofMillis(client.readTimeoutMillis.toLong()).seconds.toString() + ) + } + if (!headers.names().contains("X-Stainless-Timeout") && client.callTimeoutMillis != 0) { + builder.header( + "X-Stainless-Timeout", + Duration.ofMillis(client.callTimeoutMillis.toLong()).seconds.toString() + ) + } + return builder.build() }