Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from google.adk.tools.tool_context import ToolContext

from ap2.types.mandate import CartMandate
from ap2.types.payment_request import PAYMENT_METHOD_DATA_DATA_KEY
from common.a2a_message_builder import A2aMessageBuilder
from common import artifact_utils
Expand All @@ -41,7 +42,7 @@ async def get_payment_methods(
Returns:
A dictionary of the user's applicable payment methods.
"""
cart_mandate = tool_context.state["cart_mandate"]
cart_mandate = CartMandate(**tool_context.state["cart_mandate"])
message_builder = (
A2aMessageBuilder()
.set_context_id(tool_context.state["shopping_context_id"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def create_intent_mandate(
datetime.now(timezone.utc) + timedelta(days=1)
).isoformat(),
)
tool_context.state["intent_mandate"] = intent_mandate
tool_context.state["intent_mandate"] = intent_mandate.model_dump()
return intent_mandate


Expand Down Expand Up @@ -93,7 +93,7 @@ async def find_products(
message = (
A2aMessageBuilder()
.add_text("Find products that match the user's IntentMandate.")
.add_data(INTENT_MANDATE_DATA_KEY, intent_mandate.model_dump())
.add_data(INTENT_MANDATE_DATA_KEY, intent_mandate)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

This change is correct since intent_mandate is now a dictionary. For better clarity and consistency with other parts of this PR (e.g., update_chosen_cart_mandate where cart_mandate_dicts is used), consider renaming the intent_mandate variable to intent_mandate_dict on line 87. This would make it explicit that it's a dictionary and not a Pydantic model instance, improving code readability.

.add_data("risk_data", risk_data)
.add_data("debug_mode", debug_mode)
.add_data("shopping_agent_id", "trusted_shopping_agent")
Expand All @@ -106,7 +106,9 @@ async def find_products(

tool_context.state["shopping_context_id"] = task.context_id
cart_mandates = _parse_cart_mandates(task.artifacts)
tool_context.state["cart_mandates"] = cart_mandates
tool_context.state["cart_mandates"] = [
cm.model_dump() for cm in cart_mandates
]
return cart_mandates


Expand All @@ -117,7 +119,8 @@ def update_chosen_cart_mandate(cart_id: str, tool_context: ToolContext) -> str:
cart_id: The ID of the chosen cart.
tool_context: The ADK supplied tool context.
"""
cart_mandates: list[CartMandate] = tool_context.state.get("cart_mandates", [])
cart_mandate_dicts = tool_context.state.get("cart_mandates", [])
cart_mandates = [CartMandate(**cm) for cm in cart_mandate_dicts]
for cart in cart_mandates:
print(
f"Checking cart with ID: {cart.contents.id} with chosen ID: {cart_id}"
Expand Down
18 changes: 9 additions & 9 deletions samples/python/src/roles/shopping_agent/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ async def update_cart(
_parse_cart_mandates(task.artifacts)
)

tool_context.state["cart_mandate"] = updated_cart_mandate
tool_context.state["shipping_address"] = shipping_address
tool_context.state["cart_mandate"] = updated_cart_mandate.model_dump()
tool_context.state["shipping_address"] = shipping_address.model_dump()

return updated_cart_mandate

Expand Down Expand Up @@ -163,7 +163,7 @@ def store_receipt_if_present(task, tool_context: ToolContext) -> None:
)
if payment_receipts:
payment_receipt = artifact_utils.only(payment_receipts)
tool_context.state["payment_receipt"] = payment_receipt
tool_context.state["payment_receipt"] = payment_receipt.model_dump()


def create_payment_mandate(
Expand All @@ -181,10 +181,10 @@ def create_payment_mandate(
Returns:
The payment mandate.
"""
cart_mandate = tool_context.state["cart_mandate"]
cart_mandate = CartMandate(**tool_context.state["cart_mandate"])

payment_request = cart_mandate.contents.payment_request
shipping_address = tool_context.state["shipping_address"]
shipping_address = ContactAddress(**tool_context.state["shipping_address"])

payment_method = os.environ.get("PAYMENT_METHOD", "CARD")
if payment_method == "x402":
Expand Down Expand Up @@ -215,7 +215,7 @@ def create_payment_mandate(
),
)

tool_context.state["payment_mandate"] = payment_mandate
tool_context.state["payment_mandate"] = payment_mandate.model_dump()
return payment_mandate


Expand All @@ -238,8 +238,8 @@ def sign_mandates_on_user_device(tool_context: ToolContext) -> str:
Returns:
A string representing the simulated user authorization signature (JWT).
"""
payment_mandate: PaymentMandate = tool_context.state["payment_mandate"]
cart_mandate: CartMandate = tool_context.state["cart_mandate"]
payment_mandate = PaymentMandate(**tool_context.state["payment_mandate"])
cart_mandate = CartMandate(**tool_context.state["cart_mandate"])
cart_mandate_hash = _generate_cart_mandate_hash(cart_mandate)
payment_mandate_hash = _generate_payment_mandate_hash(
payment_mandate.payment_mandate_contents
Expand All @@ -250,7 +250,7 @@ def sign_mandates_on_user_device(tool_context: ToolContext) -> str:
payment_mandate.user_authorization = (
cart_mandate_hash + "_" + payment_mandate_hash
)
tool_context.state["signed_payment_mandate"] = payment_mandate
tool_context.state["signed_payment_mandate"] = payment_mandate.model_dump()
return payment_mandate.user_authorization


Expand Down
Loading