Skip to content

Commit

Permalink
fix(JsonSchemaValidator): fix recursive loop and general LLM (claude,…
Browse files Browse the repository at this point in the history
… mistral...) compatibility (#7556)

* Feat: Fix recursive conversion in JsonSchemaValidator (autofix generated by ClaudeOpus). Modify the behaviour to build the error template in a single user_message instead of two separate. Modify the behaviour to only include latest message instead of full history (very costly if long looping pipeline)

* Feat: Fix recursive conversion in JsonSchemaValidator (autofix generated by ClaudeOpus). Modify the behaviour to build the error template in a single user_message instead of two separate. Modify the behaviour to only include latest message instead of full history (very costly if long looping pipeline)

* reno

* fix test

* Verify provided message contains JSON object to begin with

* Minor detail

---------

Co-authored-by: Vladimir Blagojevic <[email protected]>
  • Loading branch information
lambda-science and vblagoje authored Jun 21, 2024
1 parent 75ad76a commit d158be4
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 23 deletions.
57 changes: 44 additions & 13 deletions haystack/components/validators/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@
from jsonschema import ValidationError, validate


def is_valid_json(s: str) -> bool:
"""
Check if the provided string is a valid JSON.
:param s: The string to be checked.
:returns: `True` if the string is a valid JSON; otherwise, `False`.
"""
try:
json.loads(s)
except ValueError:
return False
return True


@component
class JsonSchemaValidator:
"""
Expand Down Expand Up @@ -77,13 +91,15 @@ def run(self, messages: List[ChatMessage]) -> dict:

# Default error description template
default_error_template = (
"The JSON content in the next message does not conform to the provided schema.\n"
"The following generated JSON does not conform to the provided schema.\n"
"Generated JSON: {failing_json}\n"
"Error details:\n- Message: {error_message}\n"
"- Error Path in JSON: {error_path}\n"
"- Schema Path: {error_schema_path}\n"
"Please match the following schema:\n"
"{json_schema}\n"
"and provide the corrected JSON content ONLY."
"and provide the corrected JSON content ONLY. Please do not output anything else than the raw corrected "
"JSON string, this is the most important part of the task. Don't use any markdown and don't add any comment."
)

def __init__(self, json_schema: Optional[Dict[str, Any]] = None, error_template: Optional[str] = None):
Expand Down Expand Up @@ -125,14 +141,23 @@ def run(
dictionaries.
"""
last_message = messages[-1]
last_message_content = json.loads(last_message.content)
if not is_valid_json(last_message.content):
return {
"validation_error": [
ChatMessage.from_user(
f"The message '{last_message.content}' is not a valid JSON object. "
f"Please provide only a valid JSON object in string format."
f"Don't use any markdown and don't add any comment."
)
]
}

last_message_content = json.loads(last_message.content)
json_schema = json_schema or self.json_schema
error_template = error_template or self.error_template or self.default_error_template

if not json_schema:
raise ValueError("Provide a JSON schema for validation either in the run method or in the component init.")

# fc payload is json object but subtree `parameters` is string - we need to convert to json object
# we need complete json to validate it against schema
last_message_json = self._recursive_json_to_object(last_message_content)
Expand All @@ -149,18 +174,22 @@ def run(
else:
validate(instance=content, schema=validation_schema)

return {"validated": messages}
return {"validated": [last_message]}
except ValidationError as e:
error_path = " -> ".join(map(str, e.absolute_path)) if e.absolute_path else "N/A"
error_schema_path = " -> ".join(map(str, e.absolute_schema_path)) if e.absolute_schema_path else "N/A"

error_template = error_template or self.default_error_template

recovery_prompt = self._construct_error_recovery_message(
error_template, str(e), error_path, error_schema_path, validation_schema
error_template,
str(e),
error_path,
error_schema_path,
validation_schema,
failing_json=last_message.content,
)
complete_message_list = [ChatMessage.from_user(recovery_prompt)] + messages
return {"validation_error": complete_message_list}
return {"validation_error": [ChatMessage.from_user(recovery_prompt)]}

def _construct_error_recovery_message(
self,
Expand All @@ -169,6 +198,7 @@ def _construct_error_recovery_message(
error_path: str,
error_schema_path: str,
json_schema: Dict[str, Any],
failing_json: str,
) -> str:
"""
Constructs an error recovery message using a specified template or the default one if none is provided.
Expand All @@ -178,6 +208,7 @@ def _construct_error_recovery_message(
:param error_path: The path in the JSON content where the error occurred.
:param error_schema_path: The path in the JSON schema where the error occurred.
:param json_schema: The JSON schema against which the content is validated.
:param failing_json: The generated invalid JSON string.
"""
error_template = error_template or self.default_error_template

Expand All @@ -186,6 +217,7 @@ def _construct_error_recovery_message(
error_path=error_path,
error_schema_path=error_schema_path,
json_schema=json_schema,
failing_json=failing_json,
)

def _is_openai_function_calling_schema(self, json_schema: Dict[str, Any]) -> bool:
Expand Down Expand Up @@ -215,11 +247,10 @@ def _recursive_json_to_object(self, data: Any) -> Any:
if isinstance(value, str):
try:
json_value = json.loads(value)
new_dict[key] = (
self._recursive_json_to_object(json_value)
if isinstance(json_value, (dict, list))
else json_value
)
if isinstance(json_value, (dict, list)):
new_dict[key] = self._recursive_json_to_object(json_value)
else:
new_dict[key] = value # Preserve the original string value
except json.JSONDecodeError:
new_dict[key] = value
elif isinstance(value, dict):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
enhancements:
- |
Made JSON schema validator compatible with all LLM by switching error template handling to a single user message.
Also reduce cost by only including last error instead of full message history.
fixes:
- |
Fix recursive JSON type conversion in the schema validator to be less aggressive (no infinite recursion).
20 changes: 10 additions & 10 deletions test/components/validators/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
import json
from typing import List

from haystack import component, Pipeline
from haystack.components.validators import JsonSchemaValidator

import pytest

from haystack import Pipeline, component
from haystack.components.validators import JsonSchemaValidator
from haystack.dataclasses import ChatMessage


Expand Down Expand Up @@ -110,10 +109,9 @@ def test_validates_multiple_messages_against_json_schema(self, json_schema_githu
]

result = validator.run(messages, json_schema_github_compare)

assert "validated" in result
assert len(result["validated"]) == 2
assert result["validated"] == messages
assert len(result["validated"]) == 1
assert result["validated"][0] == messages[1]

# Validates a message against an OpenAI function calling schema successfully.
def test_validates_message_against_openai_function_calling_schema(
Expand Down Expand Up @@ -142,8 +140,8 @@ def test_validates_multiple_messages_against_openai_function_calling_schema(
result = validator.run(messages, json_schema_github_compare_openai)

assert "validated" in result
assert len(result["validated"]) == 2
assert result["validated"] == messages
assert len(result["validated"]) == 1
assert result["validated"][0] == messages[1]

# Constructs a custom error recovery message when validation fails.
def test_construct_custom_error_recovery_message(self):
Expand All @@ -155,10 +153,11 @@ def test_construct_custom_error_recovery_message(self):
"- Schema Path: {error_schema_path}\n"
"Please match the following schema:\n"
"{json_schema}\n"
"Failing Json: {failing_json}\n"
)

recovery_message = validator._construct_error_recovery_message(
new_error_template, "Error message", "Error path", "Error schema path", {"type": "object"}
new_error_template, "Error message", "Error path", "Error schema path", {"type": "object"}, "Failing Json"
)

expected_recovery_message = (
Expand All @@ -167,6 +166,7 @@ def test_construct_custom_error_recovery_message(self):
"- Schema Path: Error schema path\n"
"Please match the following schema:\n"
"{'type': 'object'}\n"
"Failing Json: Failing Json\n"
)
assert recovery_message == expected_recovery_message

Expand Down Expand Up @@ -201,5 +201,5 @@ def run(self):
pipe.connect("message_producer", "schema_validator")
result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}})
assert "validation_error" in result["schema_validator"]
assert len(result["schema_validator"]["validation_error"]) > 1
assert len(result["schema_validator"]["validation_error"]) == 1
assert "Error details" in result["schema_validator"]["validation_error"][0].content

0 comments on commit d158be4

Please sign in to comment.