Skip to content

Commit 73d77ac

Browse files
committed
feat: fix inheritance test
1 parent 1e85d3b commit 73d77ac

File tree

4 files changed

+21
-24
lines changed

4 files changed

+21
-24
lines changed

aidial_sdk/chat_completion/form.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ def _create_class(cls: Type[_Model]) -> Type[_Model]:
169169
}
170170

171171
# Inject model config extensions
172+
model_config = ModelConfigWrapper.create(cls, namespace)
172173
if chat_message_input_disabled is not None:
173-
model_config = ModelConfigWrapper.create(cls, namespace)
174174
model_config["chat_message_input_disabled"] = (
175175
chat_message_input_disabled
176176
)

aidial_sdk/utils/_pydantic.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,10 @@ def create(
9090
cls, base_cls: Optional[Type[_Model]], namespace: Dict[str, Any]
9191
) -> ModelConfigBase:
9292
if (config_cls := namespace.get("Config")) is None:
93-
if base_cls:
94-
conf_base_cls = getattr(cls, "Config", None)
95-
else:
96-
conf_base_cls = None
93+
conf_base_cls = (
94+
None if base_cls is None else getattr(base_cls, "Config", None)
95+
)
9796

98-
# FIXME: add tests to confirm that the inheritance works
9997
config_cls = type("Config", (conf_base_cls or object,), {})
10098

10199
if module := namespace.get("__module__"):
@@ -128,8 +126,15 @@ def schema_extra_field(self) -> str:
128126
def create(
129127
cls, base_cls: Optional[Type[_Model]], namespace: Dict[str, Any]
130128
) -> ModelConfigBase:
131-
# FIXME: merge with the existing "base_cls.model_config"
132-
model_config = namespace["model_config"] = (
133-
namespace.get("model_config") or {}
129+
base_model_config = (
130+
{} if base_cls is None else getattr(base_cls, "model_config", {})
134131
)
132+
133+
curr_model_config = namespace.get("model_config") or {}
134+
135+
model_config = namespace["model_config"] = {
136+
**base_model_config,
137+
**curr_model_config,
138+
}
139+
135140
return cls(model_config)

examples/tic_tac_toe/app/main.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,12 @@ class InitConfiguration(BaseModel, metaclass=FormMetaclass):
6767
# The form defines the move the user in making during the tic-tac-toe game.
6868
# The bot suggest a list of available moves to the user.
6969
# The user pick one of the move by its index in the list and returns this data structure to the application.
70-
# Note that the form doesn't have any buttons, since the moves are determined dynamically. Nor does it have a DIAL specific configuration, since it inherits
71-
# from vanilla Pydantic BaseModel that it's aware of any DIAL specific features.
72-
# The buttons and configuration will be added to the model dynamically
73-
# in the `chat_completion`` handler via the class decorator "form".
70+
# Note that the form doesn't have any buttons, since the moves are determined dynamically.
71+
# The buttons will be added to the model dynamically
72+
# in the `chat_completion` handler via the class decorator "form".
7473
class MoveForm(BaseModel):
74+
model_config = ConfigDict(chat_message_input_disabled=True)
75+
7576
move: int
7677

7778

@@ -142,9 +143,7 @@ async def chat_completion(
142143
)
143144

144145
# Use the form decorator to add buttons to the form.
145-
_MoveForm = form(
146-
chat_message_input_disabled=True, move=move_selector
147-
)(MoveForm)
146+
_MoveForm = form(move=move_selector)(MoveForm)
148147

149148
# Save the form schema in the bot message
150149
choice.set_form_schema(_MoveForm.model_json_schema())

tests/test_form.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,6 @@ def test_configuration_parsing_two_buttons_success():
239239

240240

241241
def test_dynamic_configuration_input_disabled_static():
242-
"""
243-
This test checks that DIAL specific features
244-
aren't inherited from the vanilla BaseModel class.
245-
"""
246-
247242
class Conf(BaseModel):
248243
if PYDANTIC_V2:
249244
model_config = ConfigDict(chat_message_input_disabled=True)
@@ -254,13 +249,12 @@ class Config:
254249

255250
conf = form()(Conf)
256251

257-
assert model_json_schema(conf).get("dial:chatMessageInputDisabled") is None
252+
assert model_json_schema(conf).get("dial:chatMessageInputDisabled") is True
258253

259254

260255
def test_dynamic_configuration_input_disabled_dynamic():
261256
class Conf(BaseModel):
262257
if PYDANTIC_V2:
263-
# FIXME: fix type ignore
264258
model_config = ConfigDict(extra="forbid") # type: ignore
265259
else:
266260

@@ -275,7 +269,6 @@ class Config:
275269
def test_dynamic_configuration_input_disabled_omitted():
276270
class Conf(BaseModel):
277271
if PYDANTIC_V2:
278-
# FIXME: fix type ignore
279272
model_config = ConfigDict(extra="forbid") # type: ignore
280273
else:
281274

0 commit comments

Comments
 (0)