diff --git a/nemoguardrails/colang/v2_x/runtime/serialization.py b/nemoguardrails/colang/v2_x/runtime/serialization.py index 095bfbe0b..cf5bdf660 100644 --- a/nemoguardrails/colang/v2_x/runtime/serialization.py +++ b/nemoguardrails/colang/v2_x/runtime/serialization.py @@ -86,13 +86,21 @@ def encode_to_dict(obj: Any, refs: Dict[int, Any]): "value": {k: encode_to_dict(v, refs) for k, v in obj.items()}, } elif is_dataclass(obj): - value = { - "__type": type(obj).__name__, - "value": { - k: encode_to_dict(getattr(obj, k), refs) - for k in obj.__dataclass_fields__.keys() - }, + # Encode dataclasses. If it's a known class (present in name_to_class), + # keep its type tag so we can fully round-trip via json_to_state. + # Otherwise, fall back to a plain dict to avoid "Unknown d_type" on decode. + cls = type(obj) + encoded_fields = { + k: encode_to_dict(getattr(obj, k), refs) + for k in obj.__dataclass_fields__.keys() } + + if cls.__name__ in name_to_class and name_to_class[cls.__name__] is cls: + value = {"__type": cls.__name__, "value": encoded_fields} + else: + # Unknown dataclass → JSON-friendly dict + value = {"__type": "dict", "value": encoded_fields} + elif isinstance(obj, RailsConfig): value = { "__type": "RailsConfig", diff --git a/tests/v2_x/test_serialization_dataclass.py b/tests/v2_x/test_serialization_dataclass.py new file mode 100644 index 000000000..25ce80213 --- /dev/null +++ b/tests/v2_x/test_serialization_dataclass.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from dataclasses import dataclass + +from nemoguardrails.colang.v2_x.runtime.serialization import state_to_json + + +@dataclass +class Foo: + bar: str + baz: int + + +def test_state_to_json_unknown_dataclass_encodes_as_dict(): + js = state_to_json({"out": Foo("ok", 1)}) + d = json.loads(js) + # We expect unknown dataclasses to be encoded as plain dicts + assert d["__type"] == "dict" + out = d["value"]["out"] + assert out["__type"] == "dict" + assert out["value"] == {"bar": "ok", "baz": 1}