Skip to content

Commit fac3b2e

Browse files
Add pyright to ci/cd + Fix remaining type-checking errors (BerriAI#6082)
* fix: fix type-checking errors * fix: fix additional type-checking errors * fix: additional type-checking error fixes * fix: fix additional type-checking errors * fix: additional type-check fixes * fix: fix all type-checking errors + add pyright to ci/cd * fix: fix incorrect import * ci(config.yml): use mypy on ci/cd * fix: fix type-checking errors in utils.py * fix: fix all type-checking errors on main.py * fix: fix mypy linting errors * fix(anthropic/cost_calculator.py): fix linting errors * fix: fix mypy linting errors * fix: fix linting errors
1 parent f7ce117 commit fac3b2e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+620
-523
lines changed

.circleci/config.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -315,10 +315,10 @@ jobs:
315315
python -m pip install --upgrade pip
316316
pip install ruff
317317
pip install pylint
318+
pip install pyright
318319
pip install .
319320
- run: python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1)
320321
- run: ruff check ./litellm
321-
322322

323323
build_and_test:
324324
machine:

enterprise/enterprise_callbacks/generic_api_callback.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from litellm.proxy._types import UserAPIKeyAuth
99
from litellm.caching import DualCache
1010

11-
from typing import Literal, Union
11+
from typing import Literal, Union, Optional
1212

1313
import traceback
1414

@@ -26,19 +26,25 @@
2626

2727
class GenericAPILogger:
2828
# Class variables or attributes
29-
def __init__(self, endpoint=None, headers=None):
29+
def __init__(self, endpoint: Optional[str] = None, headers: Optional[dict] = None):
3030
try:
31-
if endpoint == None:
31+
if endpoint is None:
3232
# check env for "GENERIC_LOGGER_ENDPOINT"
3333
if os.getenv("GENERIC_LOGGER_ENDPOINT"):
3434
# Do something with the endpoint
3535
endpoint = os.getenv("GENERIC_LOGGER_ENDPOINT")
3636
else:
3737
# Handle the case when the endpoint is not found in the environment variables
3838
raise ValueError(
39-
f"endpoint not set for GenericAPILogger, GENERIC_LOGGER_ENDPOINT not found in environment variables"
39+
"endpoint not set for GenericAPILogger, GENERIC_LOGGER_ENDPOINT not found in environment variables"
4040
)
4141
headers = headers or litellm.generic_logger_headers
42+
43+
if endpoint is None:
44+
raise ValueError("endpoint not set for GenericAPILogger")
45+
if headers is None:
46+
raise ValueError("headers not set for GenericAPILogger")
47+
4248
self.endpoint = endpoint
4349
self.headers = headers
4450

enterprise/enterprise_hooks/aporia_ai.py

-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ def __init__(
4848
)
4949
self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
5050
self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]
51-
self.event_hook: GuardrailEventHooks
52-
5351
super().__init__(**kwargs)
5452

5553
#### CALL HOOKS - proxy only ####

enterprise/enterprise_hooks/blocked_user_list.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ async def async_pre_call_hook(
8484
)
8585

8686
cache_key = f"litellm:end_user_id:{user}"
87-
end_user_cache_obj: LiteLLM_EndUserTable = cache.get_cache(
87+
end_user_cache_obj: Optional[LiteLLM_EndUserTable] = cache.get_cache( # type: ignore
8888
key=cache_key
8989
)
9090
if end_user_cache_obj is None and self.prisma_client is not None:

enterprise/enterprise_hooks/google_text_moderation.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
4848
# Class variables or attributes
4949
def __init__(self):
5050
try:
51-
from google.cloud import language_v1
51+
from google.cloud import language_v1 # type: ignore
5252
except Exception:
5353
raise Exception(
5454
"Missing google.cloud package. Run `pip install --upgrade google-cloud-language`"
@@ -57,8 +57,8 @@ def __init__(self):
5757
# Instantiates a client
5858
self.client = language_v1.LanguageServiceClient()
5959
self.moderate_text_request = language_v1.ModerateTextRequest
60-
self.language_document = language_v1.types.Document
61-
self.document_type = language_v1.types.Document.Type.PLAIN_TEXT
60+
self.language_document = language_v1.types.Document # type: ignore
61+
self.document_type = language_v1.types.Document.Type.PLAIN_TEXT # type: ignore
6262

6363
default_confidence_threshold = (
6464
litellm.google_moderation_confidence_threshold or 0.8

enterprise/enterprise_hooks/llama_guard.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# Thank you users! We ❤️ you! - Krrish & Ishaan
99

1010
import sys, os
11+
from collections.abc import Iterable
1112

1213
sys.path.insert(
1314
0, os.path.abspath("../..")
@@ -19,11 +20,12 @@
1920
from litellm.integrations.custom_logger import CustomLogger
2021
from fastapi import HTTPException
2122
from litellm._logging import verbose_proxy_logger
22-
from litellm.utils import (
23+
from litellm.types.utils import (
2324
ModelResponse,
2425
EmbeddingResponse,
2526
ImageResponse,
2627
StreamingChoices,
28+
Choices,
2729
)
2830
from datetime import datetime
2931
import aiohttp, asyncio
@@ -34,7 +36,10 @@
3436
class _ENTERPRISE_LlamaGuard(CustomLogger):
3537
# Class variables or attributes
3638
def __init__(self, model_name: Optional[str] = None):
37-
self.model = model_name or litellm.llamaguard_model_name
39+
_model = model_name or litellm.llamaguard_model_name
40+
if _model is None:
41+
raise ValueError("model_name not set for LlamaGuard")
42+
self.model = _model
3843
file_path = litellm.llamaguard_unsafe_content_categories
3944
data = None
4045

@@ -124,7 +129,13 @@ async def async_moderation_hook(
124129
hf_model_name="meta-llama/LlamaGuard-7b",
125130
)
126131

127-
if "unsafe" in response.choices[0].message.content:
132+
if (
133+
isinstance(response, ModelResponse)
134+
and isinstance(response.choices[0], Choices)
135+
and response.choices[0].message.content is not None
136+
and isinstance(response.choices[0].message.content, Iterable)
137+
and "unsafe" in response.choices[0].message.content
138+
):
128139
raise HTTPException(
129140
status_code=400, detail={"error": "Violated content safety policy"}
130141
)

enterprise/enterprise_hooks/llm_guard.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
## This provides an LLM Guard Integration for content moderation on the proxy
99

1010
from typing import Optional, Literal, Union
11-
import litellm, traceback, sys, uuid, os
11+
import litellm
12+
import traceback
13+
import sys
14+
import uuid
15+
import os
1216
from litellm.caching import DualCache
1317
from litellm.proxy._types import UserAPIKeyAuth
1418
from litellm.integrations.custom_logger import CustomLogger
@@ -21,8 +25,10 @@
2125
StreamingChoices,
2226
)
2327
from datetime import datetime
24-
import aiohttp, asyncio
28+
import aiohttp
29+
import asyncio
2530
from litellm.utils import get_formatted_prompt
31+
from litellm.secret_managers.main import get_secret_str
2632

2733
litellm.set_verbose = True
2834

@@ -38,7 +44,7 @@ def __init__(
3844
self.llm_guard_mode = litellm.llm_guard_mode
3945
if mock_testing == True: # for testing purposes only
4046
return
41-
self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None)
47+
self.llm_guard_api_base = get_secret_str("LLM_GUARD_API_BASE", None)
4248
if self.llm_guard_api_base is None:
4349
raise Exception("Missing `LLM_GUARD_API_BASE` from environment")
4450
elif not self.llm_guard_api_base.endswith("/"):

enterprise/enterprise_hooks/openai_moderation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ async def async_moderation_hook( ### 👈 KEY CHANGE ###
5151
"audio_transcription",
5252
],
5353
):
54+
text = ""
5455
if "messages" in data and isinstance(data["messages"], list):
55-
text = ""
5656
for m in data["messages"]: # assume messages is a list
5757
if "content" in m and isinstance(m["content"], str):
5858
text += m["content"]
@@ -67,7 +67,7 @@ async def async_moderation_hook( ### 👈 KEY CHANGE ###
6767
)
6868

6969
verbose_proxy_logger.debug("Moderation response: %s", moderation_response)
70-
if moderation_response.results[0].flagged == True:
70+
if moderation_response.results[0].flagged is True:
7171
raise HTTPException(
7272
status_code=403, detail={"error": "Violated content safety policy"}
7373
)

enterprise/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from datetime import datetime
77

88

9-
async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
9+
async def get_spend_by_tags(
10+
prisma_client: PrismaClient, start_date=None, end_date=None
11+
):
1012
response = await prisma_client.db.query_raw(
1113
"""
1214
SELECT

litellm/_redis.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
191191
new_startup_nodes.append(ClusterNode(**item))
192192

193193
redis_kwargs.pop("startup_nodes")
194-
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
194+
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) # type: ignore
195195

196196

197197
def _init_redis_sentinel(redis_kwargs) -> redis.Redis:

litellm/assistants/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
import litellm
21
from typing import Optional, Union
2+
3+
import litellm
4+
5+
from ..exceptions import UnsupportedParamsError
36
from ..types.llms.openai import *
47

58

litellm/caching.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import ast
1111
import asyncio
1212
import hashlib
13+
import inspect
1314
import io
1415
import json
1516
import logging
@@ -245,7 +246,8 @@ def __init__(
245246
self.redis_flush_size = redis_flush_size
246247
self.redis_version = "Unknown"
247248
try:
248-
self.redis_version = self.redis_client.info()["redis_version"]
249+
if not inspect.iscoroutinefunction(self.redis_client):
250+
self.redis_version = self.redis_client.info()["redis_version"] # type: ignore
249251
except Exception:
250252
pass
251253

@@ -266,7 +268,8 @@ def __init__(
266268

267269
### SYNC HEALTH PING ###
268270
try:
269-
self.redis_client.ping()
271+
if hasattr(self.redis_client, "ping"):
272+
self.redis_client.ping() # type: ignore
270273
except Exception as e:
271274
verbose_logger.error(
272275
"Error connecting to Sync Redis client", extra={"error": str(e)}
@@ -308,7 +311,7 @@ def increment_cache(
308311
_redis_client = self.redis_client
309312
start_time = time.time()
310313
try:
311-
result = _redis_client.incr(name=key, amount=value)
314+
result: int = _redis_client.incr(name=key, amount=value) # type: ignore
312315

313316
if ttl is not None:
314317
# check if key already has ttl, if not -> set ttl
@@ -561,7 +564,7 @@ async def async_set_cache_sadd(
561564
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
562565
)
563566
try:
564-
await redis_client.sadd(key, *value)
567+
await redis_client.sadd(key, *value) # type: ignore
565568
if ttl is not None:
566569
_td = timedelta(seconds=ttl)
567570
await redis_client.expire(key, _td)
@@ -712,7 +715,7 @@ def batch_get_cache(self, key_list) -> dict:
712715
for cache_key in key_list:
713716
cache_key = self.check_and_fix_namespace(key=cache_key)
714717
_keys.append(cache_key)
715-
results = self.redis_client.mget(keys=_keys)
718+
results: List = self.redis_client.mget(keys=_keys) # type: ignore
716719

717720
# Associate the results back with their keys.
718721
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
@@ -842,7 +845,7 @@ def sync_ping(self) -> bool:
842845
print_verbose("Pinging Sync Redis Cache")
843846
start_time = time.time()
844847
try:
845-
response = self.redis_client.ping()
848+
response: bool = self.redis_client.ping() # type: ignore
846849
print_verbose(f"Redis Cache PING: {response}")
847850
## LOGGING ##
848851
end_time = time.time()
@@ -911,8 +914,8 @@ async def delete_cache_keys(self, keys):
911914
async with _redis_client as redis_client:
912915
await redis_client.delete(*keys)
913916

914-
def client_list(self):
915-
client_list = self.redis_client.client_list()
917+
def client_list(self) -> List:
918+
client_list: List = self.redis_client.client_list() # type: ignore
916919
return client_list
917920

918921
def info(self):

litellm/cost_calculator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@
3939
)
4040
from litellm.llms.OpenAI.cost_calculation import cost_per_token as openai_cost_per_token
4141
from litellm.llms.together_ai.cost_calculator import get_model_params_and_category
42-
from litellm.rerank_api.types import RerankResponse
4342
from litellm.types.llms.openai import HttpxBinaryResponseContent
43+
from litellm.types.rerank import RerankResponse
4444
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
4545
from litellm.types.utils import PassthroughCallTypes, Usage
4646
from litellm.utils import (

litellm/integrations/SlackAlerting/slack_alerting.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@
3939
VirtualKeyEvent,
4040
WebhookEvent,
4141
)
42+
from litellm.types.integrations.slack_alerting import *
4243
from litellm.types.router import LiteLLM_Params
4344

4445
from ..email_templates.templates import *
4546
from .batching_handler import send_to_webhook, squash_payloads
46-
from .types import *
4747
from .utils import _add_langfuse_trace_id_to_alert, process_slack_alerting_variables
4848

4949

litellm/integrations/custom_logger.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,14 @@ async def async_moderation_hook(
172172
"moderation",
173173
"audio_transcription",
174174
],
175-
):
175+
) -> Any:
176176
pass
177177

178178
async def async_post_call_streaming_hook(
179179
self,
180180
user_api_key_dict: UserAPIKeyAuth,
181181
response: str,
182-
):
182+
) -> Any:
183183
pass
184184

185185
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function

litellm/integrations/email_alerting.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@ async def get_all_team_member_emails(team_id: Optional[str] = None) -> list:
4242
)
4343
_team_member_user_ids: List[str] = []
4444
for member in _team_members:
45-
if member and isinstance(member, dict) and member.get("user_id") is not None:
46-
_team_member_user_ids.append(member.get("user_id"))
45+
if member and isinstance(member, dict):
46+
_user_id = member.get("user_id")
47+
if _user_id and isinstance(_user_id, str):
48+
_team_member_user_ids.append(_user_id)
4749

4850
sql_query = """
4951
SELECT user_email

litellm/integrations/lunary.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def log_event(
149149
else:
150150
error_obj = None
151151

152-
self.lunary_client.track_event(
152+
self.lunary_client.track_event( # type: ignore
153153
type,
154154
"start",
155155
run_id,
@@ -164,7 +164,7 @@ def log_event(
164164
params=extra,
165165
)
166166

167-
self.lunary_client.track_event(
167+
self.lunary_client.track_event( # type: ignore
168168
type,
169169
event,
170170
run_id,

0 commit comments

Comments
 (0)