Skip to content

Commit 40e3af0

Browse files
(Redis Cluster) - Fixes for using redis cluster + pipeline (BerriAI#8442)
* update RedisCluster creation * update RedisClusterCache * add redis ClusterCache * update async_set_cache_pipeline * cleanup redis cluster usage * fix redis pipeline * test_init_async_client_returns_same_instance * fix redis cluster * update mypy_path * fix init_redis_cluster * remove stub * test redis commit * ClusterPipeline * fix import * RedisCluster import * fix redis cluster * Potential fix for code scanning alert no. 2129: Clear-text logging of sensitive information Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> * fix naming of redis cluster integration * test_redis_caching_ttl_pipeline * fix async_set_cache_pipeline --------- Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
1 parent b710407 commit 40e3af0

File tree

7 files changed

+112
-27
lines changed

7 files changed

+112
-27
lines changed

litellm/_redis.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
183183
)
184184

185185
verbose_logger.debug(
186-
"init_redis_cluster: startup nodes: ", redis_kwargs["startup_nodes"]
186+
"init_redis_cluster: startup nodes are being initialized."
187187
)
188188
from redis.cluster import ClusterNode
189189

@@ -266,7 +266,9 @@ def get_redis_client(**env_overrides):
266266
return redis.Redis(**redis_kwargs)
267267

268268

269-
def get_redis_async_client(**env_overrides) -> async_redis.Redis:
269+
def get_redis_async_client(
270+
**env_overrides,
271+
) -> async_redis.Redis:
270272
redis_kwargs = _get_redis_client_logic(**env_overrides)
271273
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
272274
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)

litellm/caching/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
from .in_memory_cache import InMemoryCache
55
from .qdrant_semantic_cache import QdrantSemanticCache
66
from .redis_cache import RedisCache
7+
from .redis_cluster_cache import RedisClusterCache
78
from .redis_semantic_cache import RedisSemanticCache
89
from .s3_cache import S3Cache

litellm/caching/caching.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from .in_memory_cache import InMemoryCache
4242
from .qdrant_semantic_cache import QdrantSemanticCache
4343
from .redis_cache import RedisCache
44+
from .redis_cluster_cache import RedisClusterCache
4445
from .redis_semantic_cache import RedisSemanticCache
4546
from .s3_cache import S3Cache
4647

@@ -158,14 +159,23 @@ def __init__(
158159
None. Cache is set as a litellm param
159160
"""
160161
if type == LiteLLMCacheType.REDIS:
161-
self.cache: BaseCache = RedisCache(
162-
host=host,
163-
port=port,
164-
password=password,
165-
redis_flush_size=redis_flush_size,
166-
startup_nodes=redis_startup_nodes,
167-
**kwargs,
168-
)
162+
if redis_startup_nodes:
163+
self.cache: BaseCache = RedisClusterCache(
164+
host=host,
165+
port=port,
166+
password=password,
167+
redis_flush_size=redis_flush_size,
168+
startup_nodes=redis_startup_nodes,
169+
**kwargs,
170+
)
171+
else:
172+
self.cache = RedisCache(
173+
host=host,
174+
port=port,
175+
password=password,
176+
redis_flush_size=redis_flush_size,
177+
**kwargs,
178+
)
169179
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
170180
self.cache = RedisSemanticCache(
171181
host=host,

litellm/caching/redis_cache.py

+35-14
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import json
1515
import time
1616
from datetime import timedelta
17-
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
17+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
1818

1919
import litellm
2020
from litellm._logging import print_verbose, verbose_logger
@@ -26,15 +26,20 @@
2626

2727
if TYPE_CHECKING:
2828
from opentelemetry.trace import Span as _Span
29-
from redis.asyncio import Redis
29+
from redis.asyncio import Redis, RedisCluster
3030
from redis.asyncio.client import Pipeline
31+
from redis.asyncio.cluster import ClusterPipeline
3132

3233
pipeline = Pipeline
34+
cluster_pipeline = ClusterPipeline
3335
async_redis_client = Redis
36+
async_redis_cluster_client = RedisCluster
3437
Span = _Span
3538
else:
3639
pipeline = Any
40+
cluster_pipeline = Any
3741
async_redis_client = Any
42+
async_redis_cluster_client = Any
3843
Span = Any
3944

4045

@@ -122,7 +127,9 @@ def __init__(
122127
else:
123128
super().__init__() # defaults to 60s
124129

125-
def init_async_client(self):
130+
def init_async_client(
131+
self,
132+
) -> Union[async_redis_client, async_redis_cluster_client]:
126133
from .._redis import get_redis_async_client
127134

128135
return get_redis_async_client(
@@ -345,8 +352,14 @@ async def async_set_cache(self, key, value, **kwargs):
345352
)
346353

347354
async def _pipeline_helper(
348-
self, pipe: pipeline, cache_list: List[Tuple[Any, Any]], ttl: Optional[float]
355+
self,
356+
pipe: Union[pipeline, cluster_pipeline],
357+
cache_list: List[Tuple[Any, Any]],
358+
ttl: Optional[float],
349359
) -> List:
360+
"""
361+
Helper function for executing a pipeline of set operations on Redis
362+
"""
350363
ttl = self.get_ttl(ttl=ttl)
351364
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
352365
for cache_key, cache_value in cache_list:
@@ -359,7 +372,11 @@ async def _pipeline_helper(
359372
_td: Optional[timedelta] = None
360373
if ttl is not None:
361374
_td = timedelta(seconds=ttl)
362-
pipe.set(cache_key, json_cache_value, ex=_td)
375+
pipe.set( # type: ignore
376+
name=cache_key,
377+
value=json_cache_value,
378+
ex=_td,
379+
)
363380
# Execute the pipeline and return the results.
364381
results = await pipe.execute()
365382
return results
@@ -373,9 +390,8 @@ async def async_set_cache_pipeline(
373390
# don't waste a network request if there's nothing to set
374391
if len(cache_list) == 0:
375392
return
376-
from redis.asyncio import Redis
377393

378-
_redis_client: Redis = self.init_async_client() # type: ignore
394+
_redis_client = self.init_async_client()
379395
start_time = time.time()
380396

381397
print_verbose(
@@ -384,7 +400,7 @@ async def async_set_cache_pipeline(
384400
cache_value: Any = None
385401
try:
386402
async with _redis_client as redis_client:
387-
async with redis_client.pipeline(transaction=True) as pipe:
403+
async with redis_client.pipeline(transaction=False) as pipe:
388404
results = await self._pipeline_helper(pipe, cache_list, ttl)
389405

390406
print_verbose(f"pipeline results: {results}")
@@ -730,7 +746,8 @@ async def async_batch_get_cache(
730746
"""
731747
Use Redis for bulk read operations
732748
"""
733-
_redis_client = await self.init_async_client()
749+
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget`
750+
_redis_client: Any = self.init_async_client()
734751
key_value_dict = {}
735752
start_time = time.time()
736753
try:
@@ -822,7 +839,8 @@ def sync_ping(self) -> bool:
822839
raise e
823840

824841
async def ping(self) -> bool:
825-
_redis_client = self.init_async_client()
842+
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ping`
843+
_redis_client: Any = self.init_async_client()
826844
start_time = time.time()
827845
async with _redis_client as redis_client:
828846
print_verbose("Pinging Async Redis Cache")
@@ -858,7 +876,8 @@ async def ping(self) -> bool:
858876
raise e
859877

860878
async def delete_cache_keys(self, keys):
861-
_redis_client = self.init_async_client()
879+
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
880+
_redis_client: Any = self.init_async_client()
862881
# keys is a list, unpack it so it gets passed as individual elements to delete
863882
async with _redis_client as redis_client:
864883
await redis_client.delete(*keys)
@@ -881,7 +900,8 @@ async def disconnect(self):
881900
await self.async_redis_conn_pool.disconnect(inuse_connections=True)
882901

883902
async def async_delete_cache(self, key: str):
884-
_redis_client = self.init_async_client()
903+
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
904+
_redis_client: Any = self.init_async_client()
885905
# keys is str
886906
async with _redis_client as redis_client:
887907
await redis_client.delete(key)
@@ -936,7 +956,7 @@ async def async_increment_pipeline(
936956

937957
try:
938958
async with _redis_client as redis_client:
939-
async with redis_client.pipeline(transaction=True) as pipe:
959+
async with redis_client.pipeline(transaction=False) as pipe:
940960
results = await self._pipeline_increment_helper(
941961
pipe, increment_list
942962
)
@@ -991,7 +1011,8 @@ async def async_get_ttl(self, key: str) -> Optional[int]:
9911011
Redis ref: https://redis.io/docs/latest/commands/ttl/
9921012
"""
9931013
try:
994-
_redis_client = await self.init_async_client()
1014+
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ttl`
1015+
_redis_client: Any = self.init_async_client()
9951016
async with _redis_client as redis_client:
9961017
ttl = await redis_client.ttl(key)
9971018
if ttl <= -1: # -1 means the key does not exist, -2 key does not exist
+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""
2+
Redis Cluster Cache implementation
3+
4+
Key differences:
5+
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
6+
"""
7+
8+
from typing import TYPE_CHECKING, Any, Optional
9+
10+
from litellm.caching.redis_cache import RedisCache
11+
12+
if TYPE_CHECKING:
13+
from opentelemetry.trace import Span as _Span
14+
from redis.asyncio import Redis, RedisCluster
15+
from redis.asyncio.client import Pipeline
16+
17+
pipeline = Pipeline
18+
async_redis_client = Redis
19+
Span = _Span
20+
else:
21+
pipeline = Any
22+
async_redis_client = Any
23+
Span = Any
24+
25+
26+
class RedisClusterCache(RedisCache):
27+
def __init__(self, *args, **kwargs):
28+
super().__init__(*args, **kwargs)
29+
self.redis_cluster_client: Optional[RedisCluster] = None
30+
31+
def init_async_client(self):
32+
from redis.asyncio import RedisCluster
33+
34+
from .._redis import get_redis_async_client
35+
36+
if self.redis_cluster_client:
37+
return self.redis_cluster_client
38+
39+
_redis_client = get_redis_async_client(
40+
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
41+
)
42+
if isinstance(_redis_client, RedisCluster):
43+
self.redis_cluster_client = _redis_client
44+
return _redis_client

mypy.ini

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[mypy]
22
warn_return_any = False
33
ignore_missing_imports = True
4+
mypy_path = litellm/stubs
45

56
[mypy-google.*]
67
ignore_missing_imports = True

tests/local_testing/test_caching.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
import litellm
2222
from litellm import aembedding, completion, embedding
2323
from litellm.caching.caching import Cache
24-
24+
from redis.asyncio import RedisCluster
25+
from litellm.caching.redis_cluster_cache import RedisClusterCache
2526
from unittest.mock import AsyncMock, patch, MagicMock, call
2627
import datetime
2728
from datetime import timedelta
@@ -2328,8 +2329,12 @@ async def test_redis_caching_ttl_pipeline():
23282329
# Verify that the set method was called on the mock Redis instance
23292330
mock_set.assert_has_calls(
23302331
[
2331-
call.set("test_key1", '"test_value1"', ex=expected_timedelta),
2332-
call.set("test_key2", '"test_value2"', ex=expected_timedelta),
2332+
call.set(
2333+
name="test_key1", value='"test_value1"', ex=expected_timedelta
2334+
),
2335+
call.set(
2336+
name="test_key2", value='"test_value2"', ex=expected_timedelta
2337+
),
23332338
]
23342339
)
23352340

@@ -2388,6 +2393,7 @@ async def test_redis_increment_pipeline():
23882393
from litellm.caching.redis_cache import RedisCache
23892394

23902395
litellm.set_verbose = True
2396+
litellm._turn_on_debug()
23912397
redis_cache = RedisCache(
23922398
host=os.environ["REDIS_HOST"],
23932399
port=os.environ["REDIS_PORT"],

0 commit comments

Comments
 (0)