Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -419,17 +419,21 @@ module MyGlobalRedisInstrumentation
MyMonitoringService.instrument("redis.connect") { super }
end

def call(command, redis_config)
def call(command, redis_config, context = nil)
MyMonitoringService.instrument("redis.query") { super }
end

def call_pipelined(commands, redis_config)
def call_pipelined(commands, redis_config, context = nil)
MyMonitoringService.instrument("redis.pipeline") { super }
end
end
RedisClient.register(MyGlobalRedisInstrumentation)
```

Middleware callbacks can optionally accept a third `context` Hash. When present it carries extra information about the current stage of the client.
For instance, during the connection prelude the context includes `stage: :connection_prelude` and the underlying `connection`, which allows a middleware
to temporarily tweak socket timeouts around the initial `AUTH/HELLO` handshake without affecting other commands.

Note that `RedisClient.register` is global and apply to all `RedisClient` instances.

To add middlewares to only a single client, you can provide them when creating the config:
Expand All @@ -447,11 +451,11 @@ module MyGlobalRedisInstrumentation
MyMonitoringService.instrument("redis.connect", tags: redis_config.custom[:tags]) { super }
end

def call(command, redis_config)
def call(command, redis_config, context = nil)
MyMonitoringService.instrument("redis.query", tags: redis_config.custom[:tags]) { super }
end

def call_pipelined(commands, redis_config)
def call_pipelined(commands, redis_config, context = nil)
MyMonitoringService.instrument("redis.pipeline", tags: redis_config.custom[:tags]) { super }
end
end
Expand All @@ -469,7 +473,7 @@ In many cases you may want to ignore retriable errors, or report them differentl

```ruby
module MyGlobalRedisInstrumentation
def call(command, redis_config)
def call(command, redis_config, context = nil)
super
rescue RedisClient::Error => error
if error.final?
Expand Down Expand Up @@ -503,6 +507,16 @@ RedisClient.config(

All timeout values are specified in seconds.

You can also configure a specific timeout to apply only to authentication during the connection handshake:

```ruby
RedisClient.config(
username: "app",
password: "secret",
auth_timeout: 0.2, # applies to AUTH (or HELLO ... AUTH) during connect
).new
```

### Reconnection

`redis-client` support automatic reconnection after network errors via the `reconnect_attempts:` configuration option.
Expand Down
70 changes: 53 additions & 17 deletions lib/redis_client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def pubsub

def measure_round_trip_delay
ensure_connected do |connection|
@middlewares.call(["PING"], config) do
@middlewares.call_with_context(["PING"], config) do
connection.measure_round_trip_delay
end
end
Expand All @@ -335,7 +335,7 @@ def measure_round_trip_delay
def call(*command, **kwargs)
command = @command_builder.generate(command, kwargs)
result = ensure_connected do |connection|
@middlewares.call(command, config) do
@middlewares.call_with_context(command, config) do
connection.call(command, nil)
end
end
Expand All @@ -350,7 +350,7 @@ def call(*command, **kwargs)
def call_v(command)
command = @command_builder.generate(command)
result = ensure_connected do |connection|
@middlewares.call(command, config) do
@middlewares.call_with_context(command, config) do
connection.call(command, nil)
end
end
Expand All @@ -365,7 +365,7 @@ def call_v(command)
def call_once(*command, **kwargs)
command = @command_builder.generate(command, kwargs)
result = ensure_connected(retryable: false) do |connection|
@middlewares.call(command, config) do
@middlewares.call_with_context(command, config) do
connection.call(command, nil)
end
end
Expand All @@ -380,7 +380,7 @@ def call_once(*command, **kwargs)
def call_once_v(command)
command = @command_builder.generate(command)
result = ensure_connected(retryable: false) do |connection|
@middlewares.call(command, config) do
@middlewares.call_with_context(command, config) do
connection.call(command, nil)
end
end
Expand All @@ -396,7 +396,7 @@ def blocking_call(timeout, *command, **kwargs)
command = @command_builder.generate(command, kwargs)
error = nil
result = ensure_connected do |connection|
@middlewares.call(command, config) do
@middlewares.call_with_context(command, config) do
connection.call(command, timeout)
end
rescue ReadTimeoutError => error
Expand All @@ -416,7 +416,7 @@ def blocking_call_v(timeout, command)
command = @command_builder.generate(command)
error = nil
result = ensure_connected do |connection|
@middlewares.call(command, config) do
@middlewares.call_with_context(command, config) do
connection.call(command, timeout)
end
rescue ReadTimeoutError => error
Expand Down Expand Up @@ -490,7 +490,7 @@ def pipelined(exception: true)
else
results = ensure_connected(retryable: pipeline._retryable?) do |connection|
commands = pipeline._commands
@middlewares.call_pipelined(commands, config) do
@middlewares.call_pipelined_with_context(commands, config) do
connection.call_pipelined(commands, pipeline._timeouts, exception: exception)
end
end
Expand All @@ -510,7 +510,7 @@ def multi(watch: nil, &block)
begin
if transaction = build_transaction(&block)
commands = transaction._commands
results = @middlewares.call_pipelined(commands, config) do
results = @middlewares.call_pipelined_with_context(commands, config) do
connection.call_pipelined(commands, nil)
end.last
else
Expand All @@ -529,7 +529,7 @@ def multi(watch: nil, &block)
else
ensure_connected(retryable: transaction._retryable?) do |connection|
commands = transaction._commands
@middlewares.call_pipelined(commands, config) do
@middlewares.call_pipelined_with_context(commands, config) do
connection.call_pipelined(commands, nil)
end.last
end
Expand Down Expand Up @@ -805,13 +805,14 @@ def raw_connection

def connect
@pid = PIDCache.pid
connect_context = { stage: :connect }

if @raw_connection&.revalidate
@middlewares.connect(config) do
@middlewares.connect_with_context(config, connect_context) do
@raw_connection.reconnect
end
else
@raw_connection = @middlewares.connect(config) do
@raw_connection = @middlewares.connect_with_context(config, connect_context) do
config.driver.new(
config,
connect_timeout: connect_timeout,
Expand All @@ -823,22 +824,27 @@ def connect
@raw_connection.retry_attempt = @retry_attempt

prelude = config.connection_prelude.dup
timeouts = build_prelude_timeouts(prelude, config.auth_timeout)

if id
prelude << ["CLIENT", "SETNAME", id]
timeouts << nil if timeouts
end

# The connection prelude is deliberately not sent to Middlewares
prelude_context = { stage: :connection_prelude, connection: @raw_connection }

# The connection prelude goes through middlewares with a dedicated context.
if config.sentinel?
prelude << ["ROLE"]
role, = @middlewares.call_pipelined(prelude, config) do
@raw_connection.call_pipelined(prelude, nil).last
timeouts << nil if timeouts
role, = @middlewares.call_pipelined_with_context(prelude, config, prelude_context) do
@raw_connection.call_pipelined(prelude, timeouts).last
end
config.check_role!(role)
else
unless prelude.empty?
@middlewares.call_pipelined(prelude, config) do
@raw_connection.call_pipelined(prelude, nil)
@middlewares.call_pipelined_with_context(prelude, config, prelude_context) do
@raw_connection.call_pipelined(prelude, timeouts)
end
end
end
Expand All @@ -857,6 +863,36 @@ def connect
raise
end
end

# Build the per-command timeouts for the connection prelude.
# Only AUTH-related steps should be bounded by auth_timeout.
# Returns nil if no timeout applies so downstream can skip passing it.
def build_prelude_timeouts(prelude, auth_timeout)
return nil unless auth_timeout

auth_seen = false
timeouts = prelude.map do |command|
next unless auth_command?(command)

auth_seen = true
auth_timeout
end

auth_seen ? timeouts : nil
end

def auth_command?(command)
return false unless command&.any?

case command.first
when "AUTH"
true
when "HELLO"
command.size >= 3 && command[2] == "AUTH"
else
false
end
end
end

require "redis_client/pooled"
Expand Down
7 changes: 7 additions & 0 deletions lib/redis_client/config.rb
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def initialize(
read_timeout: timeout,
write_timeout: timeout,
connect_timeout: timeout,
auth_timeout: nil,
ssl: nil,
custom: {},
ssl_params: nil,
Expand Down Expand Up @@ -54,6 +55,8 @@ def initialize(
@connect_timeout = connect_timeout
@read_timeout = read_timeout
@write_timeout = write_timeout
@auth_timeout = auth_timeout
@auth_timeout = nil if @auth_timeout && @auth_timeout <= 0

@driver = driver ? RedisClient.driver(driver) : RedisClient.default_driver

Expand Down Expand Up @@ -123,6 +126,10 @@ def username
@username || DEFAULT_USERNAME
end

def auth_timeout
@auth_timeout
end

def resolved?
true
end
Expand Down
33 changes: 33 additions & 0 deletions lib/redis_client/middlewares.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,39 @@ def call(command, _config)
yield command
end
alias_method :call_pipelined, :call

# These helpers keep backward compatibility with two-argument middlewares
# while allowing newer ones to accept a third `context` parameter.
def connect_with_context(config, context = nil, &block)
invoke_with_optional_context(:connect, [config], context, &block)
end

def call_with_context(command, config, context = nil, &block)
invoke_with_optional_context(:call, [command, config], context, &block)
end

def call_pipelined_with_context(commands, config, context = nil, &block)
invoke_with_optional_context(:call_pipelined, [commands, config], context, &block)
end

private

def invoke_with_optional_context(method_name, args, context, &block)
method_obj = method(method_name)
if context && accepts_extra_positional_arg?(method_obj, args.length)
method_obj.call(*args, context, &block)
else
method_obj.call(*args, &block)
end
end

def accepts_extra_positional_arg?(method_obj, required_args)
parameters = method_obj.parameters
return true if parameters.any? { |type, _| type == :rest }

positional_count = parameters.count { |type, _| type == :req || type == :opt }
positional_count >= (required_args + 1)
end
end

class Middlewares < BasicMiddleware
Expand Down
30 changes: 30 additions & 0 deletions test/redis_client/middlewares_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,42 @@ def call_pipelined(commands, _config, &_)
end
end

module PreludeContextMiddleware
class << self
attr_accessor :contexts, :client
end
@contexts = []

def initialize(client)
super
PreludeContextMiddleware.client = client
end

def call_pipelined(commands, config, context = nil, &block)
PreludeContextMiddleware.contexts << context if context
super
end
end

def test_instance_middleware
second_client = new_client(middlewares: [DummyMiddleware])
assert_equal ["GET", "2"], second_client.call("GET", 2)
assert_equal([["GET", "2"]], second_client.pipelined { |p| p.call("GET", 2) })
end

def test_prelude_context_is_exposed
client = new_client(middlewares: [PreludeContextMiddleware])
client.call("PING")

context = PreludeContextMiddleware.contexts.find { |ctx| ctx && ctx[:stage] == :connection_prelude }
refute_nil context
assert_equal :connection_prelude, context[:stage]
refute_nil context[:connection]
assert_kind_of RedisClient, PreludeContextMiddleware.client
ensure
PreludeContextMiddleware.contexts.clear
end

private

def assert_call(call)
Expand Down
Loading