diff --git a/README.md b/README.md index 1fa96db..422a809 100644 --- a/README.md +++ b/README.md @@ -419,17 +419,21 @@ module MyGlobalRedisInstrumentation MyMonitoringService.instrument("redis.connect") { super } end - def call(command, redis_config) + def call(command, redis_config, context = nil) MyMonitoringService.instrument("redis.query") { super } end - def call_pipelined(commands, redis_config) + def call_pipelined(commands, redis_config, context = nil) MyMonitoringService.instrument("redis.pipeline") { super } end end RedisClient.register(MyGlobalRedisInstrumentation) ``` +Middleware callbacks can optionally accept a third `context` Hash. When present it carries extra information about the current stage of the client. +For instance, during the connection prelude the context includes `stage: :connection_prelude` and the underlying `connection`, which allows a middleware +to temporarily tweak socket timeouts around the initial `AUTH/HELLO` handshake without affecting other commands. + Note that `RedisClient.register` is global and apply to all `RedisClient` instances. To add middlewares to only a single client, you can provide them when creating the config: @@ -447,11 +451,11 @@ module MyGlobalRedisInstrumentation MyMonitoringService.instrument("redis.connect", tags: redis_config.custom[:tags]) { super } end - def call(command, redis_config) + def call(command, redis_config, context = nil) MyMonitoringService.instrument("redis.query", tags: redis_config.custom[:tags]) { super } end - def call_pipelined(commands, redis_config) + def call_pipelined(commands, redis_config, context = nil) MyMonitoringService.instrument("redis.pipeline", tags: redis_config.custom[:tags]) { super } end end @@ -469,7 +473,7 @@ In many cases you may want to ignore retriable errors, or report them differentl ```ruby module MyGlobalRedisInstrumentation - def call(command, redis_config) + def call(command, redis_config, context = nil) super rescue RedisClient::Error => error if error.final? @@ -503,6 +507,16 @@ RedisClient.config( All timeout values are specified in seconds. +You can also configure a specific timeout to apply only to authentication during the connection handshake: + +```ruby +RedisClient.config( + username: "app", + password: "secret", + auth_timeout: 0.2, # applies to AUTH (or HELLO ... AUTH) during connect +).new +``` + ### Reconnection `redis-client` support automatic reconnection after network errors via the `reconnect_attempts:` configuration option. diff --git a/lib/redis_client.rb b/lib/redis_client.rb index d7b813b..bf7238a 100644 --- a/lib/redis_client.rb +++ b/lib/redis_client.rb @@ -326,7 +326,7 @@ def pubsub def measure_round_trip_delay ensure_connected do |connection| - @middlewares.call(["PING"], config) do + @middlewares.call_with_context(["PING"], config) do connection.measure_round_trip_delay end end @@ -335,7 +335,7 @@ def measure_round_trip_delay def call(*command, **kwargs) command = @command_builder.generate(command, kwargs) result = ensure_connected do |connection| - @middlewares.call(command, config) do + @middlewares.call_with_context(command, config) do connection.call(command, nil) end end @@ -350,7 +350,7 @@ def call(*command, **kwargs) def call_v(command) command = @command_builder.generate(command) result = ensure_connected do |connection| - @middlewares.call(command, config) do + @middlewares.call_with_context(command, config) do connection.call(command, nil) end end @@ -365,7 +365,7 @@ def call_v(command) def call_once(*command, **kwargs) command = @command_builder.generate(command, kwargs) result = ensure_connected(retryable: false) do |connection| - @middlewares.call(command, config) do + @middlewares.call_with_context(command, config) do connection.call(command, nil) end end @@ -380,7 +380,7 @@ def call_once(*command, **kwargs) def call_once_v(command) command = @command_builder.generate(command) result = ensure_connected(retryable: false) do |connection| - @middlewares.call(command, config) do + @middlewares.call_with_context(command, config) do connection.call(command, nil) end end @@ -396,7 +396,7 @@ def blocking_call(timeout, *command, **kwargs) command = @command_builder.generate(command, kwargs) error = nil result = ensure_connected do |connection| - @middlewares.call(command, config) do + @middlewares.call_with_context(command, config) do connection.call(command, timeout) end rescue ReadTimeoutError => error @@ -416,7 +416,7 @@ def blocking_call_v(timeout, command) command = @command_builder.generate(command) error = nil result = ensure_connected do |connection| - @middlewares.call(command, config) do + @middlewares.call_with_context(command, config) do connection.call(command, timeout) end rescue ReadTimeoutError => error @@ -490,7 +490,7 @@ def pipelined(exception: true) else results = ensure_connected(retryable: pipeline._retryable?) do |connection| commands = pipeline._commands - @middlewares.call_pipelined(commands, config) do + @middlewares.call_pipelined_with_context(commands, config) do connection.call_pipelined(commands, pipeline._timeouts, exception: exception) end end @@ -510,7 +510,7 @@ def multi(watch: nil, &block) begin if transaction = build_transaction(&block) commands = transaction._commands - results = @middlewares.call_pipelined(commands, config) do + results = @middlewares.call_pipelined_with_context(commands, config) do connection.call_pipelined(commands, nil) end.last else @@ -529,7 +529,7 @@ def multi(watch: nil, &block) else ensure_connected(retryable: transaction._retryable?) do |connection| commands = transaction._commands - @middlewares.call_pipelined(commands, config) do + @middlewares.call_pipelined_with_context(commands, config) do connection.call_pipelined(commands, nil) end.last end @@ -805,13 +805,14 @@ def raw_connection def connect @pid = PIDCache.pid + connect_context = { stage: :connect } if @raw_connection&.revalidate - @middlewares.connect(config) do + @middlewares.connect_with_context(config, connect_context) do @raw_connection.reconnect end else - @raw_connection = @middlewares.connect(config) do + @raw_connection = @middlewares.connect_with_context(config, connect_context) do config.driver.new( config, connect_timeout: connect_timeout, @@ -823,22 +824,27 @@ def connect @raw_connection.retry_attempt = @retry_attempt prelude = config.connection_prelude.dup + timeouts = build_prelude_timeouts(prelude, config.auth_timeout) if id prelude << ["CLIENT", "SETNAME", id] + timeouts << nil if timeouts end - # The connection prelude is deliberately not sent to Middlewares + prelude_context = { stage: :connection_prelude, connection: @raw_connection } + + # The connection prelude goes through middlewares with a dedicated context. if config.sentinel? prelude << ["ROLE"] - role, = @middlewares.call_pipelined(prelude, config) do - @raw_connection.call_pipelined(prelude, nil).last + timeouts << nil if timeouts + role, = @middlewares.call_pipelined_with_context(prelude, config, prelude_context) do + @raw_connection.call_pipelined(prelude, timeouts).last end config.check_role!(role) else unless prelude.empty? - @middlewares.call_pipelined(prelude, config) do - @raw_connection.call_pipelined(prelude, nil) + @middlewares.call_pipelined_with_context(prelude, config, prelude_context) do + @raw_connection.call_pipelined(prelude, timeouts) end end end @@ -857,6 +863,36 @@ def connect raise end end + + # Build the per-command timeouts for the connection prelude. + # Only AUTH-related steps should be bounded by auth_timeout. + # Returns nil if no timeout applies so downstream can skip passing it. + def build_prelude_timeouts(prelude, auth_timeout) + return nil unless auth_timeout + + auth_seen = false + timeouts = prelude.map do |command| + next unless auth_command?(command) + + auth_seen = true + auth_timeout + end + + auth_seen ? timeouts : nil + end + + def auth_command?(command) + return false unless command&.any? + + case command.first + when "AUTH" + true + when "HELLO" + command.size >= 3 && command[2] == "AUTH" + else + false + end + end end require "redis_client/pooled" diff --git a/lib/redis_client/config.rb b/lib/redis_client/config.rb index 94252ab..cf3a04d 100644 --- a/lib/redis_client/config.rb +++ b/lib/redis_client/config.rb @@ -27,6 +27,7 @@ def initialize( read_timeout: timeout, write_timeout: timeout, connect_timeout: timeout, + auth_timeout: nil, ssl: nil, custom: {}, ssl_params: nil, @@ -54,6 +55,8 @@ def initialize( @connect_timeout = connect_timeout @read_timeout = read_timeout @write_timeout = write_timeout + @auth_timeout = auth_timeout + @auth_timeout = nil if @auth_timeout && @auth_timeout <= 0 @driver = driver ? RedisClient.driver(driver) : RedisClient.default_driver @@ -123,6 +126,10 @@ def username @username || DEFAULT_USERNAME end + def auth_timeout + @auth_timeout + end + def resolved? true end diff --git a/lib/redis_client/middlewares.rb b/lib/redis_client/middlewares.rb index f090bd2..618e1cd 100644 --- a/lib/redis_client/middlewares.rb +++ b/lib/redis_client/middlewares.rb @@ -16,6 +16,39 @@ def call(command, _config) yield command end alias_method :call_pipelined, :call + + # These helpers keep backward compatibility with two-argument middlewares + # while allowing newer ones to accept a third `context` parameter. + def connect_with_context(config, context = nil, &block) + invoke_with_optional_context(:connect, [config], context, &block) + end + + def call_with_context(command, config, context = nil, &block) + invoke_with_optional_context(:call, [command, config], context, &block) + end + + def call_pipelined_with_context(commands, config, context = nil, &block) + invoke_with_optional_context(:call_pipelined, [commands, config], context, &block) + end + + private + + def invoke_with_optional_context(method_name, args, context, &block) + method_obj = method(method_name) + if context && accepts_extra_positional_arg?(method_obj, args.length) + method_obj.call(*args, context, &block) + else + method_obj.call(*args, &block) + end + end + + def accepts_extra_positional_arg?(method_obj, required_args) + parameters = method_obj.parameters + return true if parameters.any? { |type, _| type == :rest } + + positional_count = parameters.count { |type, _| type == :req || type == :opt } + positional_count >= (required_args + 1) + end end class Middlewares < BasicMiddleware diff --git a/test/redis_client/middlewares_test.rb b/test/redis_client/middlewares_test.rb index f3f0374..3b0ba14 100644 --- a/test/redis_client/middlewares_test.rb +++ b/test/redis_client/middlewares_test.rb @@ -223,12 +223,42 @@ def call_pipelined(commands, _config, &_) end end + module PreludeContextMiddleware + class << self + attr_accessor :contexts, :client + end + @contexts = [] + + def initialize(client) + super + PreludeContextMiddleware.client = client + end + + def call_pipelined(commands, config, context = nil, &block) + PreludeContextMiddleware.contexts << context if context + super + end + end + def test_instance_middleware second_client = new_client(middlewares: [DummyMiddleware]) assert_equal ["GET", "2"], second_client.call("GET", 2) assert_equal([["GET", "2"]], second_client.pipelined { |p| p.call("GET", 2) }) end + def test_prelude_context_is_exposed + client = new_client(middlewares: [PreludeContextMiddleware]) + client.call("PING") + + context = PreludeContextMiddleware.contexts.find { |ctx| ctx && ctx[:stage] == :connection_prelude } + refute_nil context + assert_equal :connection_prelude, context[:stage] + refute_nil context[:connection] + assert_kind_of RedisClient, PreludeContextMiddleware.client + ensure + PreludeContextMiddleware.contexts.clear + end + private def assert_call(call) diff --git a/test/redis_client_test.rb b/test/redis_client_test.rb index 8b1e361..0bdcf79 100644 --- a/test/redis_client_test.rb +++ b/test/redis_client_test.rb @@ -15,6 +15,70 @@ def test_preselect_database assert_includes client.call("CLIENT", "INFO"), " db=5 " end + def test_auth_timeout_applied_resp3 + capturing_driver = Class.new(RedisClient::RubyConnection) do + class << self + attr_accessor :last_timeouts + end + def call_pipelined(commands, timeouts, exception: true) # rubocop:disable Lint/UnusedMethodArgument + self.class.last_timeouts = timeouts + Array.new(commands.size, "OK") + end + end + client = new_client( + driver: capturing_driver, + username: "user", + password: "pass", + auth_timeout: 0.123, + protocol: 3, + ) + client.call("PING") + assert_equal [0.123], capturing_driver.last_timeouts + end + + def test_auth_timeout_applied_resp2 + capturing_driver = Class.new(RedisClient::RubyConnection) do + class << self + attr_accessor :last_timeouts + end + def call_pipelined(commands, timeouts, exception: true) # rubocop:disable Lint/UnusedMethodArgument + self.class.last_timeouts = timeouts + Array.new(commands.size, "OK") + end + end + client = new_client( + driver: capturing_driver, + username: "user", + password: "pass", + auth_timeout: 0.456, + protocol: 2, + ) + client.call("PING") + assert_equal [0.456], capturing_driver.last_timeouts + end + + def test_auth_timeout_only_applies_to_auth_commands + capturing_driver = Class.new(RedisClient::RubyConnection) do + class << self + attr_accessor :last_timeouts + end + def call_pipelined(commands, timeouts, exception: true) # rubocop:disable Lint/UnusedMethodArgument + self.class.last_timeouts = timeouts + Array.new(commands.size, "OK") + end + end + client = new_client( + driver: capturing_driver, + username: "user", + password: "pass", + db: 5, + auth_timeout: 0.789, + protocol: 2, + ) + client.call("PING") + assert_equal [0.789, nil], capturing_driver.last_timeouts + end + def test_set_client_id client = new_client(id: "peter") assert_includes client.call("CLIENT", "INFO"), " name=peter "