diff --git a/test/provider/openai/oauth_test.exs b/test/provider/openai/oauth_test.exs new file mode 100644 index 00000000..48e6f963 --- /dev/null +++ b/test/provider/openai/oauth_test.exs @@ -0,0 +1,145 @@ +defmodule Provider.OpenAI.OAuthTest do + use ExUnit.Case, async: true + + alias ReqLLM.Providers.OpenAI.OAuth + + describe "refresh/2" do + test "returns refreshed credentials and derives account id from the access token" do + access_token = jwt_with_account_id("acct_123") + + assert {:ok, + %{ + "type" => "oauth", + "access" => ^access_token, + "refresh" => "fresh-refresh-token", + "expires" => expires, + "accountId" => "acct_123" + }} = + OAuth.refresh(%{"refresh" => "refresh-token-123"}, + oauth_http_options: [ + adapter: + response_adapter(200, %{ + "access_token" => access_token, + "refresh_token" => "fresh-refresh-token", + "expires_in" => 3600 + }) + ] + ) + + assert is_integer(expires) + assert expires > System.system_time(:millisecond) + end + + test "returns an error when the refresh response is missing access_token" do + assert {:error, "OpenAI OAuth refresh response did not include access_token"} = + OAuth.refresh(%{"refresh" => "refresh-token-123"}, + oauth_http_options: [adapter: response_adapter(200, "not-json")] + ) + end + + test "returns an error when the refresh response is missing refresh_token" do + assert {:error, "OpenAI OAuth refresh response did not include refresh_token"} = + OAuth.refresh(%{"refresh" => "refresh-token-123"}, + oauth_http_options: [ + adapter: + response_adapter(200, %{ + "access_token" => "fresh-access-token", + "expires_in" => 3600 + }) + ] + ) + end + + test "returns an error when expires_in is invalid" do + assert {:error, "OpenAI OAuth refresh response did not include expires_in"} = + OAuth.refresh(%{"refresh" => "refresh-token-123"}, + oauth_http_options: [ + adapter: + response_adapter(200, %{ + "access_token" => 123, + "refresh_token" => "fresh-refresh-token", + "expires_in" => "soon" + }) + ] + ) + end + + test "formats nested OAuth error messages from failed refresh responses" do + assert {:error, "OpenAI OAuth refresh failed with status 401: refresh denied"} = + OAuth.refresh(%{"refresh" => "refresh-token-123"}, + oauth_http_options: [ + adapter: + response_adapter(401, %{ + "error" => %{"message" => "refresh denied"} + }) + ] + ) + end + + test "formats string OAuth error messages from failed refresh responses" do + assert {:error, "OpenAI OAuth refresh failed with status 401: refresh denied"} = + OAuth.refresh(%{"refresh" => "refresh-token-123"}, + oauth_http_options: [ + adapter: response_adapter(401, %{"error" => "refresh denied"}) + ] + ) + end + + test "formats top-level OAuth error messages from failed refresh responses" do + assert {:error, "OpenAI OAuth refresh failed with status 400: refresh denied"} = + OAuth.refresh(%{"refresh" => "refresh-token-123"}, + oauth_http_options: [ + adapter: response_adapter(400, %{"message" => "refresh denied"}) + ] + ) + end + + test "falls back to a status-only error for unstructured failed refresh responses" do + assert {:error, "OpenAI OAuth refresh failed with status 500"} = + OAuth.refresh(%{"refresh" => "refresh-token-123"}, + oauth_http_options: [adapter: response_adapter(500, [:bad_response])] + ) + end + + test "returns adapter exceptions as OAuth refresh errors" do + assert {:error, "OpenAI OAuth refresh failed: boom"} = + OAuth.refresh(%{"refresh" => "refresh-token-123"}, + oauth_http_options: [adapter: error_adapter("boom")] + ) + end + end + + describe "account_id_from_token/1" do + test "returns nil for malformed or non-binary tokens" do + assert OAuth.account_id_from_token("not-a-jwt") == nil + assert OAuth.account_id_from_token("a.invalid-payload.sig") == nil + assert OAuth.account_id_from_token(123) == nil + end + end + + defp response_adapter(status, body) do + fn request -> + {request, %Req.Response{status: status, body: body}} + end + end + + defp error_adapter(message) do + fn request -> + {request, RuntimeError.exception(message)} + end + end + + defp jwt_with_account_id(account_id) do + header = + %{"alg" => "none", "typ" => "JWT"} |> Jason.encode!() |> Base.url_encode64(padding: false) + + payload = + %{ + "https://api.openai.com/auth" => %{"chatgpt_account_id" => account_id} + } + |> Jason.encode!() + |> Base.url_encode64(padding: false) + + "#{header}.#{payload}.sig" + end +end diff --git a/test/provider/openai/param_profiles_test.exs b/test/provider/openai/param_profiles_test.exs new file mode 100644 index 00000000..d82f6e88 --- /dev/null +++ b/test/provider/openai/param_profiles_test.exs @@ -0,0 +1,97 @@ +defmodule Provider.OpenAI.ParamProfilesTest do + use ExUnit.Case, async: true + + alias ReqLLM.Providers.OpenAI.ParamProfiles + + describe "steps_for/2" do + test "always includes canonical reasoning transforms" do + steps = ParamProfiles.steps_for(:chat, model(id: "gpt-4.1")) + translate_reasoning_effort = translation_fun(steps) + + assert Enum.any?(steps, &match?({:drop, :reasoning_token_budget, nil}, &1)) + assert translate_reasoning_effort.(:none) == "none" + assert translate_reasoning_effort.(:minimal) == "minimal" + assert translate_reasoning_effort.(:low) == "low" + assert translate_reasoning_effort.(:medium) == "medium" + assert translate_reasoning_effort.(:high) == "high" + assert translate_reasoning_effort.(:xhigh) == "xhigh" + assert translate_reasoning_effort.(:default) == nil + assert translate_reasoning_effort.("custom") == "custom" + end + + test "adds reasoning and no sampling profiles for o-series models" do + steps = ParamProfiles.steps_for(:chat, model(id: "o3-mini")) + + assert Enum.any?(steps, &match?({:rename, :max_tokens, :max_completion_tokens, _}, &1)) + assert Enum.any?(steps, &match?({:drop, :temperature, _}, &1)) + assert Enum.any?(steps, &match?({:drop, :top_p, _}, &1)) + assert Enum.any?(steps, &match?({:drop, :top_k, _}, &1)) + end + + test "adds reasoning profile from atom-key capabilities" do + steps = + ParamProfiles.steps_for( + :chat, + raw_model(id: "custom-openai", capabilities: %{reasoning: true}) + ) + + assert Enum.any?(steps, &match?({:rename, :max_tokens, :max_completion_tokens, _}, &1)) + refute Enum.any?(steps, &match?({:drop, :top_p, _}, &1)) + end + + test "adds reasoning profile from string-key capabilities map" do + steps = + ParamProfiles.steps_for( + :chat, + raw_model(id: "custom-openai", capabilities: %{"reasoning" => %{"enabled" => true}}) + ) + + assert Enum.any?(steps, &match?({:rename, :max_tokens, :max_completion_tokens, _}, &1)) + end + + test "does not apply chat-only profiles to other operations" do + steps = ParamProfiles.steps_for(:embedding, model(id: "o3-mini")) + + refute Enum.any?(steps, &match?({:rename, :max_tokens, :max_completion_tokens, _}, &1)) + refute Enum.any?(steps, &match?({:drop, :temperature, _}, &1)) + refute Enum.any?(steps, &match?({:drop, :top_p, _}, &1)) + refute Enum.any?(steps, &match?({:drop, :top_k, _}, &1)) + end + end + + defp model(attrs) do + attrs = Enum.into(attrs, %{}) + + LLMDB.Model.new!( + Map.merge( + %{ + provider: :openai, + id: "gpt-4.1" + }, + attrs + ) + ) + end + + defp raw_model(attrs) do + attrs = Enum.into(attrs, %{}) + + struct( + LLMDB.Model, + Map.merge( + %{ + provider: :openai, + id: "gpt-4.1" + }, + attrs + ) + ) + end + + defp translation_fun(steps) do + Enum.find_value(steps, fn + {:transform, :reasoning_effort, fun, nil} -> fun + _ -> nil + end) + end +end diff --git a/test/provider/openai/responses_api_unit_test.exs b/test/provider/openai/responses_api_unit_test.exs index 72c0a0d9..b47b7705 100644 --- a/test/provider/openai/responses_api_unit_test.exs +++ b/test/provider/openai/responses_api_unit_test.exs @@ -1274,6 +1274,65 @@ defmodule Provider.OpenAI.ResponsesAPIUnitTest do describe "ResponseBuilder - streaming reasoning_details extraction" do alias ReqLLM.Providers.OpenAI.ResponsesAPI.ResponseBuilder + test "upgrades stop finish reason to tool_calls when tool chunks are present" do + {:ok, model} = ReqLLM.model("openai:gpt-4o") + context = %ReqLLM.Context{messages: []} + + chunks = [ + ReqLLM.StreamChunk.tool_call("get_weather", %{"city" => "SF"}), + ReqLLM.StreamChunk.text("Calling a tool") + ] + + {:ok, response} = + ResponseBuilder.build_response( + chunks, + %{finish_reason: :stop}, + context: context, + model: model + ) + + assert response.finish_reason == :tool_calls + assert [%ReqLLM.ToolCall{function: %{name: "get_weather"}}] = response.message.tool_calls + end + + test "upgrades string stop finish reason to tool_calls when tool chunks are present" do + {:ok, model} = ReqLLM.model("openai:gpt-4o") + context = %ReqLLM.Context{messages: []} + + chunks = [ + ReqLLM.StreamChunk.tool_call("search", %{"query" => "docs"}) + ] + + {:ok, response} = + ResponseBuilder.build_response( + chunks, + %{finish_reason: "stop"}, + context: context, + model: model + ) + + assert response.finish_reason == :tool_calls + end + + test "preserves non-stop finish reason when tool chunks are present" do + {:ok, model} = ReqLLM.model("openai:gpt-4o") + context = %ReqLLM.Context{messages: []} + + chunks = [ + ReqLLM.StreamChunk.tool_call("search", %{"query" => "docs"}) + ] + + {:ok, response} = + ResponseBuilder.build_response( + chunks, + %{finish_reason: :length}, + context: context, + model: model + ) + + assert response.finish_reason == :length + end + test "extracts reasoning_details from thinking chunks" do {:ok, model} = ReqLLM.model("openai:gpt-4o") context = %ReqLLM.Context{messages: []} @@ -1363,5 +1422,24 @@ defmodule Provider.OpenAI.ResponsesAPIUnitTest do assert length(context_msg.reasoning_details) == 1 assert hd(context_msg.reasoning_details).text == "Deep thought" end + + test "leaves message metadata unchanged when response_id is absent" do + {:ok, model} = ReqLLM.model("openai:gpt-4o") + context = %ReqLLM.Context{messages: []} + + chunks = [ + ReqLLM.StreamChunk.text("No response id") + ] + + {:ok, response} = + ResponseBuilder.build_response( + chunks, + %{finish_reason: :stop}, + context: context, + model: model + ) + + assert response.message.metadata == %{} + end end end diff --git a/test/provider/openai/web_socket_test.exs b/test/provider/openai/web_socket_test.exs new file mode 100644 index 00000000..dee87b44 --- /dev/null +++ b/test/provider/openai/web_socket_test.exs @@ -0,0 +1,101 @@ +defmodule Provider.OpenAI.WebSocketTest do + use ExUnit.Case, async: true + + alias ReqLLM.Providers.OpenAI.WebSocket + + describe "headers/2" do + test "includes auth and custom headers" do + headers = + WebSocket.headers( + model(), + api_key: "socket-test-key", + req_http_options: [headers: [{"X-Test", "1"}]] + ) + + assert headers == [ + {"Authorization", "Bearer socket-test-key"}, + {"X-Test", "1"} + ] + end + end + + describe "responses_url/2" do + test "prefers model base_url over request options" do + url = + WebSocket.responses_url( + model(base_url: "http://localhost:4010/custom/"), + base_url: "https://ignored.example.com/v1" + ) + + assert url == "ws://localhost:4010/custom/responses" + end + end + + describe "realtime_url/2" do + test "uses provider_model_id and preserves existing query params" do + url = + WebSocket.realtime_url( + model( + base_url: "https://api.example.com/v1?api-version=2025-01-01", + provider_model_id: "gpt-5-deploy" + ), + [] + ) + + uri = URI.parse(url) + + assert uri.scheme == "wss" + assert uri.path == "/v1/realtime" + + assert URI.decode_query(uri.query) == %{ + "api-version" => "2025-01-01", + "model" => "gpt-5-deploy" + } + end + end + + describe "websocket_url/3" do + test "normalizes common schemes and joins paths" do + assert WebSocket.websocket_url("http://example.com/base/", "/responses") == + "ws://example.com/base/responses" + + assert WebSocket.websocket_url("https://example.com/v1", "/responses") == + "wss://example.com/v1/responses" + + assert WebSocket.websocket_url("ws://example.com/v1", "/responses") == + "ws://example.com/v1/responses" + + assert WebSocket.websocket_url("wss://example.com/v1", "/responses") == + "wss://example.com/v1/responses" + end + + test "returns a root path when the suffix is empty" do + assert WebSocket.websocket_url("https://example.com", "") == "wss://example.com/" + end + + test "preserves passthrough schemes and merges query values" do + url = + WebSocket.websocket_url("custom://example.com/v1?bad=%ZZ", "/responses", model: "gpt-5") + + uri = URI.parse(url) + + assert uri.scheme == "custom" + assert uri.path == "/v1/responses" + assert URI.decode_query(uri.query) == %{"bad" => "%ZZ", "model" => "gpt-5"} + end + end + + defp model(attrs \\ []) do + attrs = Enum.into(attrs, %{}) + + LLMDB.Model.new!( + Map.merge( + %{ + provider: :openai, + id: "gpt-5" + }, + attrs + ) + ) + end +end