Skip to content

Commit d158be4

Browse files
fix(JsonSchemaValidator): fix recursive loop and general LLM (claude, mistral...) compatibility (#7556)
* Feat: Fix recursive conversion in JsonSchemaValidator (autofix generated by ClaudeOpus). Modify the behaviour to build the error template in a single user_message instead of two separate. Modify the behaviour to only include latest message instead of full history (very costly if long looping pipeline) * Feat: Fix recursive conversion in JsonSchemaValidator (autofix generated by ClaudeOpus). Modify the behaviour to build the error template in a single user_message instead of two separate. Modify the behaviour to only include latest message instead of full history (very costly if long looping pipeline) * reno * fix test * Verify provided message contains JSON object to begin with * Minor detail --------- Co-authored-by: Vladimir Blagojevic <[email protected]>
1 parent 75ad76a commit d158be4

File tree

3 files changed

+62
-23
lines changed

3 files changed

+62
-23
lines changed

haystack/components/validators/json_schema.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,20 @@
1313
from jsonschema import ValidationError, validate
1414

1515

16+
def is_valid_json(s: str) -> bool:
17+
"""
18+
Check if the provided string is a valid JSON.
19+
20+
:param s: The string to be checked.
21+
:returns: `True` if the string is a valid JSON; otherwise, `False`.
22+
"""
23+
try:
24+
json.loads(s)
25+
except ValueError:
26+
return False
27+
return True
28+
29+
1630
@component
1731
class JsonSchemaValidator:
1832
"""
@@ -77,13 +91,15 @@ def run(self, messages: List[ChatMessage]) -> dict:
7791

7892
# Default error description template
7993
default_error_template = (
80-
"The JSON content in the next message does not conform to the provided schema.\n"
94+
"The following generated JSON does not conform to the provided schema.\n"
95+
"Generated JSON: {failing_json}\n"
8196
"Error details:\n- Message: {error_message}\n"
8297
"- Error Path in JSON: {error_path}\n"
8398
"- Schema Path: {error_schema_path}\n"
8499
"Please match the following schema:\n"
85100
"{json_schema}\n"
86-
"and provide the corrected JSON content ONLY."
101+
"and provide the corrected JSON content ONLY. Please do not output anything else than the raw corrected "
102+
"JSON string, this is the most important part of the task. Don't use any markdown and don't add any comment."
87103
)
88104

89105
def __init__(self, json_schema: Optional[Dict[str, Any]] = None, error_template: Optional[str] = None):
@@ -125,14 +141,23 @@ def run(
125141
dictionaries.
126142
"""
127143
last_message = messages[-1]
128-
last_message_content = json.loads(last_message.content)
144+
if not is_valid_json(last_message.content):
145+
return {
146+
"validation_error": [
147+
ChatMessage.from_user(
148+
f"The message '{last_message.content}' is not a valid JSON object. "
149+
f"Please provide only a valid JSON object in string format."
150+
f"Don't use any markdown and don't add any comment."
151+
)
152+
]
153+
}
129154

155+
last_message_content = json.loads(last_message.content)
130156
json_schema = json_schema or self.json_schema
131157
error_template = error_template or self.error_template or self.default_error_template
132158

133159
if not json_schema:
134160
raise ValueError("Provide a JSON schema for validation either in the run method or in the component init.")
135-
136161
# fc payload is json object but subtree `parameters` is string - we need to convert to json object
137162
# we need complete json to validate it against schema
138163
last_message_json = self._recursive_json_to_object(last_message_content)
@@ -149,18 +174,22 @@ def run(
149174
else:
150175
validate(instance=content, schema=validation_schema)
151176

152-
return {"validated": messages}
177+
return {"validated": [last_message]}
153178
except ValidationError as e:
154179
error_path = " -> ".join(map(str, e.absolute_path)) if e.absolute_path else "N/A"
155180
error_schema_path = " -> ".join(map(str, e.absolute_schema_path)) if e.absolute_schema_path else "N/A"
156181

157182
error_template = error_template or self.default_error_template
158183

159184
recovery_prompt = self._construct_error_recovery_message(
160-
error_template, str(e), error_path, error_schema_path, validation_schema
185+
error_template,
186+
str(e),
187+
error_path,
188+
error_schema_path,
189+
validation_schema,
190+
failing_json=last_message.content,
161191
)
162-
complete_message_list = [ChatMessage.from_user(recovery_prompt)] + messages
163-
return {"validation_error": complete_message_list}
192+
return {"validation_error": [ChatMessage.from_user(recovery_prompt)]}
164193

165194
def _construct_error_recovery_message(
166195
self,
@@ -169,6 +198,7 @@ def _construct_error_recovery_message(
169198
error_path: str,
170199
error_schema_path: str,
171200
json_schema: Dict[str, Any],
201+
failing_json: str,
172202
) -> str:
173203
"""
174204
Constructs an error recovery message using a specified template or the default one if none is provided.
@@ -178,6 +208,7 @@ def _construct_error_recovery_message(
178208
:param error_path: The path in the JSON content where the error occurred.
179209
:param error_schema_path: The path in the JSON schema where the error occurred.
180210
:param json_schema: The JSON schema against which the content is validated.
211+
:param failing_json: The generated invalid JSON string.
181212
"""
182213
error_template = error_template or self.default_error_template
183214

@@ -186,6 +217,7 @@ def _construct_error_recovery_message(
186217
error_path=error_path,
187218
error_schema_path=error_schema_path,
188219
json_schema=json_schema,
220+
failing_json=failing_json,
189221
)
190222

191223
def _is_openai_function_calling_schema(self, json_schema: Dict[str, Any]) -> bool:
@@ -215,11 +247,10 @@ def _recursive_json_to_object(self, data: Any) -> Any:
215247
if isinstance(value, str):
216248
try:
217249
json_value = json.loads(value)
218-
new_dict[key] = (
219-
self._recursive_json_to_object(json_value)
220-
if isinstance(json_value, (dict, list))
221-
else json_value
222-
)
250+
if isinstance(json_value, (dict, list)):
251+
new_dict[key] = self._recursive_json_to_object(json_value)
252+
else:
253+
new_dict[key] = value # Preserve the original string value
223254
except json.JSONDecodeError:
224255
new_dict[key] = value
225256
elif isinstance(value, dict):
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
enhancements:
3+
- |
4+
Made JSON schema validator compatible with all LLM by switching error template handling to a single user message.
5+
Also reduce cost by only including last error instead of full message history.
6+
fixes:
7+
- |
8+
Fix recursive JSON type conversion in the schema validator to be less aggressive (no infinite recursion).

test/components/validators/test_json_schema.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
import json
55
from typing import List
66

7-
from haystack import component, Pipeline
8-
from haystack.components.validators import JsonSchemaValidator
9-
107
import pytest
118

9+
from haystack import Pipeline, component
10+
from haystack.components.validators import JsonSchemaValidator
1211
from haystack.dataclasses import ChatMessage
1312

1413

@@ -110,10 +109,9 @@ def test_validates_multiple_messages_against_json_schema(self, json_schema_githu
110109
]
111110

112111
result = validator.run(messages, json_schema_github_compare)
113-
114112
assert "validated" in result
115-
assert len(result["validated"]) == 2
116-
assert result["validated"] == messages
113+
assert len(result["validated"]) == 1
114+
assert result["validated"][0] == messages[1]
117115

118116
# Validates a message against an OpenAI function calling schema successfully.
119117
def test_validates_message_against_openai_function_calling_schema(
@@ -142,8 +140,8 @@ def test_validates_multiple_messages_against_openai_function_calling_schema(
142140
result = validator.run(messages, json_schema_github_compare_openai)
143141

144142
assert "validated" in result
145-
assert len(result["validated"]) == 2
146-
assert result["validated"] == messages
143+
assert len(result["validated"]) == 1
144+
assert result["validated"][0] == messages[1]
147145

148146
# Constructs a custom error recovery message when validation fails.
149147
def test_construct_custom_error_recovery_message(self):
@@ -155,10 +153,11 @@ def test_construct_custom_error_recovery_message(self):
155153
"- Schema Path: {error_schema_path}\n"
156154
"Please match the following schema:\n"
157155
"{json_schema}\n"
156+
"Failing Json: {failing_json}\n"
158157
)
159158

160159
recovery_message = validator._construct_error_recovery_message(
161-
new_error_template, "Error message", "Error path", "Error schema path", {"type": "object"}
160+
new_error_template, "Error message", "Error path", "Error schema path", {"type": "object"}, "Failing Json"
162161
)
163162

164163
expected_recovery_message = (
@@ -167,6 +166,7 @@ def test_construct_custom_error_recovery_message(self):
167166
"- Schema Path: Error schema path\n"
168167
"Please match the following schema:\n"
169168
"{'type': 'object'}\n"
169+
"Failing Json: Failing Json\n"
170170
)
171171
assert recovery_message == expected_recovery_message
172172

@@ -201,5 +201,5 @@ def run(self):
201201
pipe.connect("message_producer", "schema_validator")
202202
result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}})
203203
assert "validation_error" in result["schema_validator"]
204-
assert len(result["schema_validator"]["validation_error"]) > 1
204+
assert len(result["schema_validator"]["validation_error"]) == 1
205205
assert "Error details" in result["schema_validator"]["validation_error"][0].content

0 commit comments

Comments
 (0)