diff --git a/typesense/client.go b/typesense/client.go index 33feb5e..1599aad 100644 --- a/typesense/client.go +++ b/typesense/client.go @@ -123,6 +123,7 @@ type ClientConfig struct { CircuitBreakerTimeout time.Duration CircuitBreakerReadyToTrip circuit.GoBreakerReadyToTripFunc CircuitBreakerOnStateChange circuit.GoBreakerOnStateChangeFunc + CustomHTTPClient *http.Client } type ClientOption func(*Client) @@ -272,6 +273,12 @@ func WithClientConfig(config *ClientConfig) ClientOption { } } +func WithCustomHTTPClient(client *http.Client) ClientOption { + return func(c *Client) { + c.apiConfig.CustomHTTPClient = client + } +} + func NewClient(opts ...ClientOption) *Client { c := &Client{apiConfig: &ClientConfig{ RetryInterval: defaultRetryInterval, @@ -296,12 +303,16 @@ func NewClient(opts ...ClientOption) *Client { circuit.WithGoBreakerReadyToTrip(c.apiConfig.CircuitBreakerReadyToTrip), circuit.WithGoBreakerOnStateChange(c.apiConfig.CircuitBreakerOnStateChange), ) + client := c.apiConfig.CustomHTTPClient + if client == nil { + client = &http.Client{ + Timeout: c.apiConfig.ConnectionTimeout, + } + } httpClient := circuit.NewHTTPClient( circuit.WithHTTPRequestDoer( NewAPICall( - &http.Client{ - Timeout: c.apiConfig.ConnectionTimeout, - }, + client, c.apiConfig, )), circuit.WithCircuitBreaker(cb), diff --git a/typesense/client_test.go b/typesense/client_test.go index 09cb820..3bf24e2 100644 --- a/typesense/client_test.go +++ b/typesense/client_test.go @@ -1,6 +1,7 @@ package typesense import ( + "net/http" "reflect" "testing" "time" @@ -238,6 +239,19 @@ func TestClientConfigOptions(t *testing.T) { assert.NotNil(t, client.apiClient) }, }, + { + name: "WithCustomHTTPClient", + options: []ClientOption{ + WithCustomHTTPClient(&http.Client{ + Timeout: 10 * time.Second, + }), + }, + verify: func(t *testing.T, client *Client) { + assert.NotNil(t, client.apiConfig.CustomHTTPClient) + assert.Equal(t, 10*time.Second, client.apiConfig.CustomHTTPClient.Timeout) + assert.NotNil(t, client.apiClient) + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {