Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from google.adk.tools.tool_context import ToolContext

from ap2.types.mandate import CartMandate
from ap2.types.payment_request import PAYMENT_METHOD_DATA_DATA_KEY
from common.a2a_message_builder import A2aMessageBuilder
from common import artifact_utils
Expand All @@ -41,7 +42,7 @@ async def get_payment_methods(
Returns:
A dictionary of the user's applicable payment methods.
"""
cart_mandate = tool_context.state["cart_mandate"]
cart_mandate = CartMandate.model_validate(tool_context.state["cart_mandate"])
message_builder = (
A2aMessageBuilder()
.set_context_id(tool_context.state["shopping_context_id"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def create_intent_mandate(
datetime.now(timezone.utc) + timedelta(days=1)
).isoformat(),
)
tool_context.state["intent_mandate"] = intent_mandate
tool_context.state["intent_mandate"] = intent_mandate.model_dump()
return intent_mandate


Expand All @@ -84,9 +84,10 @@ async def find_products(
Raises:
RuntimeError: If the merchant agent fails to provide products.
"""
intent_mandate = tool_context.state["intent_mandate"]
if not intent_mandate:
intent_mandate_data = tool_context.state["intent_mandate"]
if not intent_mandate_data:
raise RuntimeError("No IntentMandate found in tool context state.")
intent_mandate = IntentMandate.model_validate(intent_mandate_data)
risk_data = _collect_risk_data(tool_context)
if not risk_data:
raise RuntimeError("No risk data found in tool context state.")
Expand All @@ -106,7 +107,7 @@ async def find_products(

tool_context.state["shopping_context_id"] = task.context_id
cart_mandates = _parse_cart_mandates(task.artifacts)
tool_context.state["cart_mandates"] = cart_mandates
tool_context.state["cart_mandates"] = [cm.model_dump() for cm in cart_mandates]
return cart_mandates


Expand All @@ -117,7 +118,10 @@ def update_chosen_cart_mandate(cart_id: str, tool_context: ToolContext) -> str:
cart_id: The ID of the chosen cart.
tool_context: The ADK supplied tool context.
"""
cart_mandates: list[CartMandate] = tool_context.state.get("cart_mandates", [])
cart_mandates = [
CartMandate.model_validate(cm)
for cm in tool_context.state.get("cart_mandates", [])
]
for cart in cart_mandates:
print(
f"Checking cart with ID: {cart.contents.id} with chosen ID: {cart_id}"
Expand Down
24 changes: 15 additions & 9 deletions samples/python/src/roles/shopping_agent/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ async def update_cart(
_parse_cart_mandates(task.artifacts)
)

tool_context.state["cart_mandate"] = updated_cart_mandate
tool_context.state["shipping_address"] = shipping_address
tool_context.state["cart_mandate"] = updated_cart_mandate.model_dump()
tool_context.state["shipping_address"] = shipping_address.model_dump()

return updated_cart_mandate

Expand Down Expand Up @@ -163,7 +163,7 @@ def store_receipt_if_present(task, tool_context: ToolContext) -> None:
)
if payment_receipts:
payment_receipt = artifact_utils.only(payment_receipts)
tool_context.state["payment_receipt"] = payment_receipt
tool_context.state["payment_receipt"] = payment_receipt.model_dump()


def create_payment_mandate(
Expand All @@ -181,10 +181,12 @@ def create_payment_mandate(
Returns:
The payment mandate.
"""
cart_mandate = tool_context.state["cart_mandate"]
cart_mandate = CartMandate.model_validate(tool_context.state["cart_mandate"])

payment_request = cart_mandate.contents.payment_request
shipping_address = tool_context.state["shipping_address"]
shipping_address = ContactAddress.model_validate(
tool_context.state["shipping_address"]
)

payment_method = os.environ.get("PAYMENT_METHOD", "CARD")
if payment_method == "x402":
Expand Down Expand Up @@ -215,7 +217,7 @@ def create_payment_mandate(
),
)

tool_context.state["payment_mandate"] = payment_mandate
tool_context.state["payment_mandate"] = payment_mandate.model_dump()
return payment_mandate


Expand All @@ -238,8 +240,12 @@ def sign_mandates_on_user_device(tool_context: ToolContext) -> str:
Returns:
A string representing the simulated user authorization signature (JWT).
"""
payment_mandate: PaymentMandate = tool_context.state["payment_mandate"]
cart_mandate: CartMandate = tool_context.state["cart_mandate"]
payment_mandate = PaymentMandate.model_validate(
tool_context.state["payment_mandate"]
)
cart_mandate = CartMandate.model_validate(
tool_context.state["cart_mandate"]
)
cart_mandate_hash = _generate_cart_mandate_hash(cart_mandate)
payment_mandate_hash = _generate_payment_mandate_hash(
payment_mandate.payment_mandate_contents
Expand All @@ -250,7 +256,7 @@ def sign_mandates_on_user_device(tool_context: ToolContext) -> str:
payment_mandate.user_authorization = (
cart_mandate_hash + "_" + payment_mandate_hash
)
tool_context.state["signed_payment_mandate"] = payment_mandate
tool_context.state["signed_payment_mandate"] = payment_mandate.model_dump()
return payment_mandate.user_authorization


Expand Down