Skip to content

Commit 6295e9b

Browse files
committed
feat: fix inheritance test
1 parent 1e85d3b commit 6295e9b

File tree

3 files changed

+15
-17
lines changed

3 files changed

+15
-17
lines changed

aidial_sdk/chat_completion/form.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ def _create_class(cls: Type[_Model]) -> Type[_Model]:
169169
}
170170

171171
# Inject model config extensions
172+
model_config = ModelConfigWrapper.create(cls, namespace)
172173
if chat_message_input_disabled is not None:
173-
model_config = ModelConfigWrapper.create(cls, namespace)
174174
model_config["chat_message_input_disabled"] = (
175175
chat_message_input_disabled
176176
)

aidial_sdk/utils/_pydantic.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,10 @@ def create(
9090
cls, base_cls: Optional[Type[_Model]], namespace: Dict[str, Any]
9191
) -> ModelConfigBase:
9292
if (config_cls := namespace.get("Config")) is None:
93-
if base_cls:
94-
conf_base_cls = getattr(cls, "Config", None)
95-
else:
96-
conf_base_cls = None
93+
conf_base_cls = (
94+
None if base_cls is None else getattr(base_cls, "Config", None)
95+
)
9796

98-
# FIXME: add tests to confirm that the inheritance works
9997
config_cls = type("Config", (conf_base_cls or object,), {})
10098

10199
if module := namespace.get("__module__"):
@@ -128,8 +126,15 @@ def schema_extra_field(self) -> str:
128126
def create(
129127
cls, base_cls: Optional[Type[_Model]], namespace: Dict[str, Any]
130128
) -> ModelConfigBase:
131-
# FIXME: merge with the existing "base_cls.model_config"
132-
model_config = namespace["model_config"] = (
133-
namespace.get("model_config") or {}
129+
base_model_config = (
130+
{} if base_cls is None else getattr(base_cls, "model_config", {})
134131
)
132+
133+
curr_model_config = namespace.get("model_config") or {}
134+
135+
model_config = namespace["model_config"] = {
136+
**base_model_config,
137+
**curr_model_config,
138+
}
139+
135140
return cls(model_config)

tests/test_form.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,6 @@ def test_configuration_parsing_two_buttons_success():
239239

240240

241241
def test_dynamic_configuration_input_disabled_static():
242-
"""
243-
This test checks that DIAL specific features
244-
aren't inherited from the vanilla BaseModel class.
245-
"""
246-
247242
class Conf(BaseModel):
248243
if PYDANTIC_V2:
249244
model_config = ConfigDict(chat_message_input_disabled=True)
@@ -254,13 +249,12 @@ class Config:
254249

255250
conf = form()(Conf)
256251

257-
assert model_json_schema(conf).get("dial:chatMessageInputDisabled") is None
252+
assert model_json_schema(conf).get("dial:chatMessageInputDisabled") is True
258253

259254

260255
def test_dynamic_configuration_input_disabled_dynamic():
261256
class Conf(BaseModel):
262257
if PYDANTIC_V2:
263-
# FIXME: fix type ignore
264258
model_config = ConfigDict(extra="forbid") # type: ignore
265259
else:
266260

@@ -275,7 +269,6 @@ class Config:
275269
def test_dynamic_configuration_input_disabled_omitted():
276270
class Conf(BaseModel):
277271
if PYDANTIC_V2:
278-
# FIXME: fix type ignore
279272
model_config = ConfigDict(extra="forbid") # type: ignore
280273
else:
281274

0 commit comments

Comments
 (0)