Skip to content

Commit 46016a1

Browse files
Fix Pydantic Models to support V1 and V2 of Pydantic and fix deprecation warnings (mlflow#14535)
Signed-off-by: Nathan Brake <[email protected]> Signed-off-by: Nathan Brake <[email protected]> Co-authored-by: Yair Bonastre <[email protected]>
1 parent 2f4afbb commit 46016a1

File tree

6 files changed

+93
-81
lines changed

6 files changed

+93
-81
lines changed

docs/docs/llms/deployments/index.mdx

+2-2
Original file line numberDiff line numberDiff line change
@@ -1469,7 +1469,7 @@ and a config class that inherits from `mlflow.gateway.base_models.ConfigModel`.
14691469
import os
14701470
from typing import AsyncIterable
14711471

1472-
from pydantic import validator
1472+
from mlflow.utils.pydantic_utils import field_validator
14731473
from mlflow.gateway.base_models import ConfigModel
14741474
from mlflow.gateway.config import RouteConfig
14751475
from mlflow.gateway.providers import BaseProvider
@@ -1480,7 +1480,7 @@ class MyLLMConfig(ConfigModel):
14801480
# This model defines the configuration for the provider such as API keys
14811481
my_llm_api_key: str
14821482

1483-
@validator("my_llm_api_key", pre=True)
1483+
@field_validator("my_llm_api_key", mode="before")
14841484
def validate_my_llm_api_key(cls, value):
14851485
return os.environ[value.lstrip("$")]
14861486

examples/gateway/plugin/my-llm/my_llm/config.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import os
22

3-
from pydantic import validator
4-
53
from mlflow.gateway.base_models import ConfigModel
4+
from mlflow.utils.pydantic_utils import field_validator
65

76

87
class MyLLMConfig(ConfigModel):
98
my_llm_api_key: str
109

11-
@validator("my_llm_api_key", pre=True)
10+
@field_validator("my_llm_api_key", mode="before")
1211
def validate_my_llm_api_key(cls, value):
1312
if value.startswith("$"):
1413
# This resolves the API key from an environment variable

mlflow/gateway/config.py

+38-46
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pydantic
1010
import yaml
1111
from packaging.version import Version
12-
from pydantic import ConfigDict, Field, ValidationError, root_validator, validator
12+
from pydantic import ConfigDict, Field, ValidationError
1313
from pydantic.json import pydantic_encoder
1414

1515
from mlflow.exceptions import MlflowException
@@ -26,7 +26,7 @@
2626
is_valid_endpoint_name,
2727
is_valid_mosiacml_chat_model,
2828
)
29-
from mlflow.utils import IS_PYDANTIC_V2_OR_NEWER
29+
from mlflow.utils.pydantic_utils import IS_PYDANTIC_V2_OR_NEWER, field_validator, model_validator
3030

3131
_logger = logging.getLogger(__name__)
3232

@@ -59,7 +59,7 @@ def values(cls):
5959
class TogetherAIConfig(ConfigModel):
6060
togetherai_api_key: str
6161

62-
@validator("togetherai_api_key", pre=True)
62+
@field_validator("togetherai_api_key", mode="before")
6363
def validate_togetherai_api_key(cls, value):
6464
return _resolve_api_key_from_input(value)
6565

@@ -73,15 +73,15 @@ class RouteType(str, Enum):
7373
class CohereConfig(ConfigModel):
7474
cohere_api_key: str
7575

76-
@validator("cohere_api_key", pre=True)
76+
@field_validator("cohere_api_key", mode="before")
7777
def validate_cohere_api_key(cls, value):
7878
return _resolve_api_key_from_input(value)
7979

8080

8181
class AI21LabsConfig(ConfigModel):
8282
ai21labs_api_key: str
8383

84-
@validator("ai21labs_api_key", pre=True)
84+
@field_validator("ai21labs_api_key", mode="before")
8585
def validate_ai21labs_api_key(cls, value):
8686
return _resolve_api_key_from_input(value)
8787

@@ -90,7 +90,7 @@ class MosaicMLConfig(ConfigModel):
9090
mosaicml_api_key: str
9191
mosaicml_api_base: Optional[str] = None
9292

93-
@validator("mosaicml_api_key", pre=True)
93+
@field_validator("mosaicml_api_key", mode="before")
9494
def validate_mosaicml_api_key(cls, value):
9595
return _resolve_api_key_from_input(value)
9696

@@ -120,7 +120,7 @@ class OpenAIConfig(ConfigModel):
120120
openai_deployment_name: Optional[str] = None
121121
openai_organization: Optional[str] = None
122122

123-
@validator("openai_api_key", pre=True)
123+
@field_validator("openai_api_key", mode="before")
124124
def validate_openai_api_key(cls, value):
125125
return _resolve_api_key_from_input(value)
126126

@@ -158,34 +158,24 @@ def _validate_field_compatibility(cls, info: dict[str, Any]):
158158

159159
return info
160160

161-
if IS_PYDANTIC_V2_OR_NEWER:
162-
from pydantic import model_validator as _model_validator
163-
164-
@_model_validator(mode="before")
165-
def validate_field_compatibility(cls, info: dict[str, Any]):
166-
return cls._validate_field_compatibility(info)
167-
168-
else:
169-
from pydantic import root_validator as _root_validator
170-
171-
@_root_validator(pre=False)
172-
def validate_field_compatibility(cls, config: dict[str, Any]):
173-
return cls._validate_field_compatibility(config)
161+
@model_validator(mode="before")
162+
def validate_field_compatibility(cls, info: dict[str, Any]):
163+
return cls._validate_field_compatibility(info)
174164

175165

176166
class AnthropicConfig(ConfigModel):
177167
anthropic_api_key: str
178168
anthropic_version: str = "2023-06-01"
179169

180-
@validator("anthropic_api_key", pre=True)
170+
@field_validator("anthropic_api_key", mode="before")
181171
def validate_anthropic_api_key(cls, value):
182172
return _resolve_api_key_from_input(value)
183173

184174

185175
class PaLMConfig(ConfigModel):
186176
palm_api_key: str
187177

188-
@validator("palm_api_key", pre=True)
178+
@field_validator("palm_api_key", mode="before")
189179
def validate_palm_api_key(cls, value):
190180
return _resolve_api_key_from_input(value)
191181

@@ -225,7 +215,7 @@ class AmazonBedrockConfig(ConfigModel):
225215
class MistralConfig(ConfigModel):
226216
mistral_api_key: str
227217

228-
@validator("mistral_api_key", pre=True)
218+
@field_validator("mistral_api_key", mode="before")
229219
def validate_mistral_api_key(cls, value):
230220
return _resolve_api_key_from_input(value)
231221

@@ -284,7 +274,7 @@ class Model(ConfigModel):
284274
else:
285275
config: Optional[ConfigModel] = None
286276

287-
@validator("provider", pre=True)
277+
@field_validator("provider", mode="before")
288278
def validate_provider(cls, value):
289279
from mlflow.gateway.provider_registry import provider_registry
290280

@@ -298,28 +288,26 @@ def validate_provider(cls, value):
298288
raise MlflowException.invalid_parameter_value(f"The provider '{value}' is not supported.")
299289

300290
@classmethod
301-
def _validate_config(cls, info, values):
291+
def _validate_config(cls, val, context):
302292
from mlflow.gateway.provider_registry import provider_registry
303293

304-
if provider := values.get("provider"):
305-
config_type = provider_registry.get(provider).CONFIG_TYPE
306-
return config_type(**info)
294+
# For Pydantic v2: 'context' is a ValidationInfo object with a 'data' attribute.
295+
# For Pydantic v1: 'context' is dict-like 'values'.
296+
if IS_PYDANTIC_V2_OR_NEWER:
297+
provider = context.data.get("provider")
298+
else:
299+
provider = context.get("provider") if context else None
307300

301+
if provider:
302+
config_type = provider_registry.get(provider).CONFIG_TYPE
303+
return config_type(**val) if isinstance(val, dict) else val
308304
raise MlflowException.invalid_parameter_value(
309305
"A provider must be provided for each gateway route."
310306
)
311307

312-
if IS_PYDANTIC_V2_OR_NEWER:
313-
314-
@validator("config", pre=True)
315-
def validate_config(cls, info, values):
316-
return cls._validate_config(info, values)
317-
318-
else:
319-
320-
@validator("config", pre=True)
321-
def validate_config(cls, config, values):
322-
return cls._validate_config(config, values)
308+
@field_validator("config", mode="before")
309+
def validate_config(cls, info, values):
310+
return cls._validate_config(info, values)
323311

324312

325313
class AliasedConfigModel(ConfigModel):
@@ -351,7 +339,7 @@ class RouteConfig(AliasedConfigModel):
351339
model: Model
352340
limit: Optional[Limit] = None
353341

354-
@validator("name")
342+
@field_validator("name")
355343
def validate_endpoint_name(cls, route_name):
356344
if not is_valid_endpoint_name(route_name):
357345
raise MlflowException.invalid_parameter_value(
@@ -361,7 +349,7 @@ def validate_endpoint_name(cls, route_name):
361349
)
362350
return route_name
363351

364-
@validator("model", pre=True)
352+
@field_validator("model", mode="before")
365353
def validate_model(cls, model):
366354
if model:
367355
model_instance = Model(**model)
@@ -372,10 +360,14 @@ def validate_model(cls, model):
372360
)
373361
return model
374362

375-
@root_validator(skip_on_failure=True)
363+
@model_validator(mode="after", skip_on_failure=True)
376364
def validate_route_type_and_model_name(cls, values):
377-
route_type = values.get("route_type")
378-
model = values.get("model")
365+
if IS_PYDANTIC_V2_OR_NEWER:
366+
route_type = values.route_type
367+
model = values.model
368+
else:
369+
route_type = values.get("route_type")
370+
model = values.get("model")
379371
if (
380372
model
381373
and model.provider == "mosaicml"
@@ -394,13 +386,13 @@ def validate_route_type_and_model_name(cls, values):
394386
)
395387
return values
396388

397-
@validator("route_type", pre=True)
389+
@field_validator("route_type", mode="before")
398390
def validate_route_type(cls, value):
399391
if value in RouteType._value2member_map_:
400392
return value
401393
raise MlflowException.invalid_parameter_value(f"The route_type '{value}' is not supported.")
402394

403-
@validator("limit", pre=True)
395+
@field_validator("limit", mode="before")
404396
def validate_limit(cls, value):
405397
from limits import parse
406398

mlflow/gateway/providers/mlflow.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
import time
22

3-
from pydantic import BaseModel, StrictFloat, StrictStr, ValidationError, validator
3+
from pydantic import BaseModel, StrictFloat, StrictStr, ValidationError
44

55
from mlflow.gateway.config import MlflowModelServingConfig, RouteConfig
66
from mlflow.gateway.constants import MLFLOW_SERVING_RESPONSE_KEY
77
from mlflow.gateway.exceptions import AIGatewayException
88
from mlflow.gateway.providers.base import BaseProvider
99
from mlflow.gateway.providers.utils import send_request
1010
from mlflow.gateway.schemas import chat, completions, embeddings
11+
from mlflow.utils.pydantic_utils import field_validator
1112

1213

1314
class ServingTextResponse(BaseModel):
1415
predictions: list[StrictStr]
1516

16-
@validator("predictions", pre=True)
17+
@field_validator("predictions", mode="before")
1718
def extract_choices(cls, predictions):
1819
if isinstance(predictions, list) and not predictions:
1920
raise ValueError("The input list is empty")
@@ -35,7 +36,7 @@ def extract_choices(cls, predictions):
3536
class EmbeddingsResponse(BaseModel):
3637
predictions: list[list[StrictFloat]]
3738

38-
@validator("predictions", pre=True)
39+
@field_validator("predictions", mode="before")
3940
def validate_predictions(cls, predictions):
4041
if isinstance(predictions, list) and not predictions:
4142
raise ValueError("The input list is empty")

mlflow/types/agent.py

+14-26
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,7 @@
1818
Property,
1919
Schema,
2020
)
21-
from mlflow.utils import IS_PYDANTIC_V2_OR_NEWER
22-
23-
if IS_PYDANTIC_V2_OR_NEWER:
24-
from pydantic import model_validator
25-
else:
26-
from pydantic import root_validator
21+
from mlflow.utils.pydantic_utils import IS_PYDANTIC_V2_OR_NEWER, model_validator
2722

2823

2924
class ChatAgentMessage(BaseModel):
@@ -53,28 +48,21 @@ class ChatAgentMessage(BaseModel):
5348
# TODO make this a pydantic class with subtypes once we have more details on usage
5449
attachments: Optional[dict[str, str]] = None
5550

56-
if IS_PYDANTIC_V2_OR_NEWER:
57-
58-
@model_validator(mode="after")
59-
def check_content_and_tool_calls(cls, chat_agent_msg):
60-
"""
61-
Ensure at least one of 'content' or 'tool_calls' is set.
62-
"""
63-
if not chat_agent_msg.content and not chat_agent_msg.tool_calls:
64-
raise ValueError("Either 'content' or 'tool_calls' must be provided.")
65-
return chat_agent_msg
66-
else:
67-
68-
@root_validator
69-
def check_content_and_tool_calls(cls, values):
70-
"""
71-
Ensure at least one of 'content' or 'tool_calls' is set.
72-
"""
51+
@model_validator(mode="after")
52+
def check_content_and_tool_calls(cls, values):
53+
"""
54+
Ensure at least one of 'content' or 'tool_calls' is set.
55+
"""
56+
if IS_PYDANTIC_V2_OR_NEWER:
57+
content = values.content
58+
tool_calls = values.tool_calls
59+
else:
7360
content = values.get("content")
7461
tool_calls = values.get("tool_calls")
75-
if not content and not tool_calls:
76-
raise ValueError("Either 'content' or 'tool_calls' must be provided.")
77-
return values
62+
63+
if not content and not tool_calls:
64+
raise ValueError("Either 'content' or 'tool_calls' must be provided.")
65+
return values
7866

7967

8068
class ChatContext(BaseModel):

mlflow/utils/pydantic_utils.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, Callable
22

33
import pydantic
44
from packaging.version import Version
@@ -7,6 +7,38 @@
77
IS_PYDANTIC_V2_OR_NEWER = Version(pydantic.VERSION).major >= 2
88

99

10+
def field_validator(field: str, mode: str = "before"):
11+
def decorator(func: Callable) -> Callable:
12+
if IS_PYDANTIC_V2_OR_NEWER:
13+
from pydantic import field_validator as pydantic_field_validator
14+
15+
return pydantic_field_validator(field, mode=mode)(func)
16+
else:
17+
from pydantic import validator as pydantic_field_validator
18+
19+
return pydantic_field_validator(field, pre=mode == "before")(func)
20+
21+
return decorator
22+
23+
24+
def model_validator(mode: str, skip_on_failure: bool = False):
25+
"""A wrapper for Pydantic model validator that is compatible with Pydantic v1 and v2.
26+
Note that the `skip_on_failure` argument is only available in Pydantic v1.
27+
"""
28+
29+
def decorator(func: Callable) -> Callable:
30+
if IS_PYDANTIC_V2_OR_NEWER:
31+
from pydantic import model_validator as pydantic_model_validator
32+
33+
return pydantic_model_validator(mode=mode)(func)
34+
else:
35+
from pydantic import root_validator
36+
37+
return root_validator(pre=mode == "before", skip_on_failure=skip_on_failure)(func)
38+
39+
return decorator
40+
41+
1042
def model_dump_compat(pydantic_model: BaseModel, **kwargs: Any) -> dict[str, Any]:
1143
"""
1244
Dump the Pydantic model to dictionary, in a compatible way for Pydantic v1 and v2.

0 commit comments

Comments
 (0)