9
9
import pydantic
10
10
import yaml
11
11
from packaging .version import Version
12
- from pydantic import ConfigDict , Field , ValidationError , root_validator , validator
12
+ from pydantic import ConfigDict , Field , ValidationError
13
13
from pydantic .json import pydantic_encoder
14
14
15
15
from mlflow .exceptions import MlflowException
26
26
is_valid_endpoint_name ,
27
27
is_valid_mosiacml_chat_model ,
28
28
)
29
- from mlflow .utils import IS_PYDANTIC_V2_OR_NEWER
29
+ from mlflow .utils . pydantic_utils import IS_PYDANTIC_V2_OR_NEWER , field_validator , model_validator
30
30
31
31
_logger = logging .getLogger (__name__ )
32
32
@@ -59,7 +59,7 @@ def values(cls):
59
59
class TogetherAIConfig (ConfigModel ):
60
60
togetherai_api_key : str
61
61
62
- @validator ("togetherai_api_key" , pre = True )
62
+ @field_validator ("togetherai_api_key" , mode = "before" )
63
63
def validate_togetherai_api_key (cls , value ):
64
64
return _resolve_api_key_from_input (value )
65
65
@@ -73,15 +73,15 @@ class RouteType(str, Enum):
73
73
class CohereConfig (ConfigModel ):
74
74
cohere_api_key : str
75
75
76
- @validator ("cohere_api_key" , pre = True )
76
+ @field_validator ("cohere_api_key" , mode = "before" )
77
77
def validate_cohere_api_key (cls , value ):
78
78
return _resolve_api_key_from_input (value )
79
79
80
80
81
81
class AI21LabsConfig (ConfigModel ):
82
82
ai21labs_api_key : str
83
83
84
- @validator ("ai21labs_api_key" , pre = True )
84
+ @field_validator ("ai21labs_api_key" , mode = "before" )
85
85
def validate_ai21labs_api_key (cls , value ):
86
86
return _resolve_api_key_from_input (value )
87
87
@@ -90,7 +90,7 @@ class MosaicMLConfig(ConfigModel):
90
90
mosaicml_api_key : str
91
91
mosaicml_api_base : Optional [str ] = None
92
92
93
- @validator ("mosaicml_api_key" , pre = True )
93
+ @field_validator ("mosaicml_api_key" , mode = "before" )
94
94
def validate_mosaicml_api_key (cls , value ):
95
95
return _resolve_api_key_from_input (value )
96
96
@@ -120,7 +120,7 @@ class OpenAIConfig(ConfigModel):
120
120
openai_deployment_name : Optional [str ] = None
121
121
openai_organization : Optional [str ] = None
122
122
123
- @validator ("openai_api_key" , pre = True )
123
+ @field_validator ("openai_api_key" , mode = "before" )
124
124
def validate_openai_api_key (cls , value ):
125
125
return _resolve_api_key_from_input (value )
126
126
@@ -158,34 +158,24 @@ def _validate_field_compatibility(cls, info: dict[str, Any]):
158
158
159
159
return info
160
160
161
- if IS_PYDANTIC_V2_OR_NEWER :
162
- from pydantic import model_validator as _model_validator
163
-
164
- @_model_validator (mode = "before" )
165
- def validate_field_compatibility (cls , info : dict [str , Any ]):
166
- return cls ._validate_field_compatibility (info )
167
-
168
- else :
169
- from pydantic import root_validator as _root_validator
170
-
171
- @_root_validator (pre = False )
172
- def validate_field_compatibility (cls , config : dict [str , Any ]):
173
- return cls ._validate_field_compatibility (config )
161
+ @model_validator (mode = "before" )
162
+ def validate_field_compatibility (cls , info : dict [str , Any ]):
163
+ return cls ._validate_field_compatibility (info )
174
164
175
165
176
166
class AnthropicConfig (ConfigModel ):
177
167
anthropic_api_key : str
178
168
anthropic_version : str = "2023-06-01"
179
169
180
- @validator ("anthropic_api_key" , pre = True )
170
+ @field_validator ("anthropic_api_key" , mode = "before" )
181
171
def validate_anthropic_api_key (cls , value ):
182
172
return _resolve_api_key_from_input (value )
183
173
184
174
185
175
class PaLMConfig (ConfigModel ):
186
176
palm_api_key : str
187
177
188
- @validator ("palm_api_key" , pre = True )
178
+ @field_validator ("palm_api_key" , mode = "before" )
189
179
def validate_palm_api_key (cls , value ):
190
180
return _resolve_api_key_from_input (value )
191
181
@@ -225,7 +215,7 @@ class AmazonBedrockConfig(ConfigModel):
225
215
class MistralConfig (ConfigModel ):
226
216
mistral_api_key : str
227
217
228
- @validator ("mistral_api_key" , pre = True )
218
+ @field_validator ("mistral_api_key" , mode = "before" )
229
219
def validate_mistral_api_key (cls , value ):
230
220
return _resolve_api_key_from_input (value )
231
221
@@ -284,7 +274,7 @@ class Model(ConfigModel):
284
274
else :
285
275
config : Optional [ConfigModel ] = None
286
276
287
- @validator ("provider" , pre = True )
277
+ @field_validator ("provider" , mode = "before" )
288
278
def validate_provider (cls , value ):
289
279
from mlflow .gateway .provider_registry import provider_registry
290
280
@@ -298,28 +288,26 @@ def validate_provider(cls, value):
298
288
raise MlflowException .invalid_parameter_value (f"The provider '{ value } ' is not supported." )
299
289
300
290
@classmethod
301
- def _validate_config (cls , info , values ):
291
+ def _validate_config (cls , val , context ):
302
292
from mlflow .gateway .provider_registry import provider_registry
303
293
304
- if provider := values .get ("provider" ):
305
- config_type = provider_registry .get (provider ).CONFIG_TYPE
306
- return config_type (** info )
294
+ # For Pydantic v2: 'context' is a ValidationInfo object with a 'data' attribute.
295
+ # For Pydantic v1: 'context' is dict-like 'values'.
296
+ if IS_PYDANTIC_V2_OR_NEWER :
297
+ provider = context .data .get ("provider" )
298
+ else :
299
+ provider = context .get ("provider" ) if context else None
307
300
301
+ if provider :
302
+ config_type = provider_registry .get (provider ).CONFIG_TYPE
303
+ return config_type (** val ) if isinstance (val , dict ) else val
308
304
raise MlflowException .invalid_parameter_value (
309
305
"A provider must be provided for each gateway route."
310
306
)
311
307
312
- if IS_PYDANTIC_V2_OR_NEWER :
313
-
314
- @validator ("config" , pre = True )
315
- def validate_config (cls , info , values ):
316
- return cls ._validate_config (info , values )
317
-
318
- else :
319
-
320
- @validator ("config" , pre = True )
321
- def validate_config (cls , config , values ):
322
- return cls ._validate_config (config , values )
308
+ @field_validator ("config" , mode = "before" )
309
+ def validate_config (cls , info , values ):
310
+ return cls ._validate_config (info , values )
323
311
324
312
325
313
class AliasedConfigModel (ConfigModel ):
@@ -351,7 +339,7 @@ class RouteConfig(AliasedConfigModel):
351
339
model : Model
352
340
limit : Optional [Limit ] = None
353
341
354
- @validator ("name" )
342
+ @field_validator ("name" )
355
343
def validate_endpoint_name (cls , route_name ):
356
344
if not is_valid_endpoint_name (route_name ):
357
345
raise MlflowException .invalid_parameter_value (
@@ -361,7 +349,7 @@ def validate_endpoint_name(cls, route_name):
361
349
)
362
350
return route_name
363
351
364
- @validator ("model" , pre = True )
352
+ @field_validator ("model" , mode = "before" )
365
353
def validate_model (cls , model ):
366
354
if model :
367
355
model_instance = Model (** model )
@@ -372,10 +360,14 @@ def validate_model(cls, model):
372
360
)
373
361
return model
374
362
375
- @root_validator ( skip_on_failure = True )
363
+ @model_validator ( mode = "after" , skip_on_failure = True )
376
364
def validate_route_type_and_model_name (cls , values ):
377
- route_type = values .get ("route_type" )
378
- model = values .get ("model" )
365
+ if IS_PYDANTIC_V2_OR_NEWER :
366
+ route_type = values .route_type
367
+ model = values .model
368
+ else :
369
+ route_type = values .get ("route_type" )
370
+ model = values .get ("model" )
379
371
if (
380
372
model
381
373
and model .provider == "mosaicml"
@@ -394,13 +386,13 @@ def validate_route_type_and_model_name(cls, values):
394
386
)
395
387
return values
396
388
397
- @validator ("route_type" , pre = True )
389
+ @field_validator ("route_type" , mode = "before" )
398
390
def validate_route_type (cls , value ):
399
391
if value in RouteType ._value2member_map_ :
400
392
return value
401
393
raise MlflowException .invalid_parameter_value (f"The route_type '{ value } ' is not supported." )
402
394
403
- @validator ("limit" , pre = True )
395
+ @field_validator ("limit" , mode = "before" )
404
396
def validate_limit (cls , value ):
405
397
from limits import parse
406
398
0 commit comments