From fca4bc620f2924528b3a3d2bd5a9b64ca6c562c8 Mon Sep 17 00:00:00 2001 From: originaljayeshsharma Date: Tue, 12 May 2026 00:21:39 +0530 Subject: [PATCH] fix: idempotency guard (#402) and pagination support (#388) --- tools/python/stripe_agent_toolkit/api.py | 411 ++++++++++++++++++ .../stripe_agent_toolkit/idempotency.py | 100 +++++ tools/python/tests/test_fixes.py | 276 ++++++++++++ 3 files changed, 787 insertions(+) create mode 100644 tools/python/stripe_agent_toolkit/api.py create mode 100644 tools/python/stripe_agent_toolkit/idempotency.py create mode 100644 tools/python/tests/test_fixes.py diff --git a/tools/python/stripe_agent_toolkit/api.py b/tools/python/stripe_agent_toolkit/api.py new file mode 100644 index 00000000..136f4ac5 --- /dev/null +++ b/tools/python/stripe_agent_toolkit/api.py @@ -0,0 +1,411 @@ +""" +stripe_agent_toolkit/api.py +--------------------------- +Thin wrappers around stripe-python that agent tools call directly. + +Changes in this version +~~~~~~~~~~~~~~~~~~~~~~~ +* Fix #402 — Every mutating call now passes a *stable* idempotency_key + derived from the tool name + arguments, preventing duplicate charges + when an agent framework retries a tool invocation as a new session. + +* Fix #388 — ``list_subscriptions``, ``list_products``, ``list_prices``, + and ``search_stripe_resources`` now accept optional ``starting_after`` + and ``ending_before`` cursor parameters so callers can page through more + than 100 records. +""" + +from __future__ import annotations + +import json +from typing import Any + +import stripe as _stripe + +from .idempotency import with_idempotency + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _stripe_obj_to_dict(obj: Any) -> dict: + """Convert a Stripe API object to a plain dict for easy serialisation.""" + if hasattr(obj, "to_dict_recursive"): + return obj.to_dict_recursive() + if hasattr(obj, "__dict__"): + return {k: v for k, v in obj.__dict__.items() if not k.startswith("_")} + return dict(obj) + + +def _ok(obj: Any) -> str: + return json.dumps(_stripe_obj_to_dict(obj), default=str) + + +# --------------------------------------------------------------------------- +# Customer +# --------------------------------------------------------------------------- + +def create_customer(client: _stripe.StripeClient, args: dict) -> str: + params = with_idempotency("create_customer", args) + customer = client.customers.create( + params={k: v for k, v in params.items() if k != "idempotency_key"}, + options={"idempotency_key": params.get("idempotency_key")}, + ) + return _ok(customer) + + +def list_customers(client: _stripe.StripeClient, args: dict) -> str: + params: dict[str, Any] = {} + if args.get("limit"): + params["limit"] = args["limit"] + if args.get("email"): + params["email"] = args["email"] + # Fix #388 — pagination cursors + if args.get("starting_after"): + params["starting_after"] = args["starting_after"] + if args.get("ending_before"): + params["ending_before"] = args["ending_before"] + customers = client.customers.list(params=params) + return _ok(customers) + + +def retrieve_customer(client: _stripe.StripeClient, args: dict) -> str: + customer = client.customers.retrieve(args["customer_id"]) + return _ok(customer) + + +# --------------------------------------------------------------------------- +# Products +# --------------------------------------------------------------------------- + +def create_product(client: _stripe.StripeClient, args: dict) -> str: + params = with_idempotency("create_product", args) + product = client.products.create( + params={k: v for k, v in params.items() if k != "idempotency_key"}, + options={"idempotency_key": params.get("idempotency_key")}, + ) + return _ok(product) + + +def list_products(client: _stripe.StripeClient, args: dict) -> str: + """List products. + + Fix #388: now accepts ``starting_after`` and ``ending_before`` for cursor- + based pagination, making it possible to retrieve more than 100 products. + """ + params: dict[str, Any] = {} + if args.get("limit"): + params["limit"] = args["limit"] + if args.get("active") is not None: + params["active"] = args["active"] + # Fix #388 — pagination cursors + if args.get("starting_after"): + params["starting_after"] = args["starting_after"] + if args.get("ending_before"): + params["ending_before"] = args["ending_before"] + products = client.products.list(params=params) + return _ok(products) + + +def retrieve_product(client: _stripe.StripeClient, args: dict) -> str: + product = client.products.retrieve(args["product_id"]) + return _ok(product) + + +# --------------------------------------------------------------------------- +# Prices +# --------------------------------------------------------------------------- + +def create_price(client: _stripe.StripeClient, args: dict) -> str: + params = with_idempotency("create_price", args) + price = client.prices.create( + params={k: v for k, v in params.items() if k != "idempotency_key"}, + options={"idempotency_key": params.get("idempotency_key")}, + ) + return _ok(price) + + +def list_prices(client: _stripe.StripeClient, args: dict) -> str: + """List prices. + + Fix #388: now accepts ``starting_after`` and ``ending_before`` cursors. + """ + params: dict[str, Any] = {} + if args.get("limit"): + params["limit"] = args["limit"] + if args.get("product"): + params["product"] = args["product"] + if args.get("active") is not None: + params["active"] = args["active"] + # Fix #388 — pagination cursors + if args.get("starting_after"): + params["starting_after"] = args["starting_after"] + if args.get("ending_before"): + params["ending_before"] = args["ending_before"] + prices = client.prices.list(params=params) + return _ok(prices) + + +# --------------------------------------------------------------------------- +# Payment links +# --------------------------------------------------------------------------- + +def create_payment_link(client: _stripe.StripeClient, args: dict) -> str: + params = with_idempotency("create_payment_link", args) + link = client.payment_links.create( + params={k: v for k, v in params.items() if k != "idempotency_key"}, + options={"idempotency_key": params.get("idempotency_key")}, + ) + return _ok(link) + + +def list_payment_links(client: _stripe.StripeClient, args: dict) -> str: + params: dict[str, Any] = {} + if args.get("limit"): + params["limit"] = args["limit"] + if args.get("active") is not None: + params["active"] = args["active"] + if args.get("starting_after"): + params["starting_after"] = args["starting_after"] + if args.get("ending_before"): + params["ending_before"] = args["ending_before"] + links = client.payment_links.list(params=params) + return _ok(links) + + +def retrieve_payment_link(client: _stripe.StripeClient, args: dict) -> str: + link = client.payment_links.retrieve(args["payment_link"]) + return _ok(link) + + +# --------------------------------------------------------------------------- +# Payment intents +# --------------------------------------------------------------------------- + +def create_payment_intent(client: _stripe.StripeClient, args: dict) -> str: + """Create a PaymentIntent. + + Fix #402: a deterministic idempotency_key is derived from the tool args so + that any agent-level retry of the exact same call returns the original + PaymentIntent instead of creating a second charge. + """ + params = with_idempotency("create_payment_intent", args) + intent = client.payment_intents.create( + params={k: v for k, v in params.items() if k != "idempotency_key"}, + options={"idempotency_key": params.get("idempotency_key")}, + ) + return _ok(intent) + + +def retrieve_payment_intent(client: _stripe.StripeClient, args: dict) -> str: + intent = client.payment_intents.retrieve(args["payment_intent"]) + return _ok(intent) + + +# --------------------------------------------------------------------------- +# Refunds +# --------------------------------------------------------------------------- + +def create_refund(client: _stripe.StripeClient, args: dict) -> str: + """Create a Refund. + + Fix #402: idempotency key prevents duplicate refunds on agent retry. + """ + params = with_idempotency("create_refund", args) + refund = client.refunds.create( + params={k: v for k, v in params.items() if k != "idempotency_key"}, + options={"idempotency_key": params.get("idempotency_key")}, + ) + return _ok(refund) + + +# --------------------------------------------------------------------------- +# Invoices +# --------------------------------------------------------------------------- + +def create_invoice(client: _stripe.StripeClient, args: dict) -> str: + params = with_idempotency("create_invoice", args) + invoice = client.invoices.create( + params={k: v for k, v in params.items() if k != "idempotency_key"}, + options={"idempotency_key": params.get("idempotency_key")}, + ) + return _ok(invoice) + + +def create_invoice_item(client: _stripe.StripeClient, args: dict) -> str: + params = with_idempotency("create_invoice_item", args) + item = client.invoice_items.create( + params={k: v for k, v in params.items() if k != "idempotency_key"}, + options={"idempotency_key": params.get("idempotency_key")}, + ) + return _ok(item) + + +def finalize_invoice(client: _stripe.StripeClient, args: dict) -> str: + params = with_idempotency("finalize_invoice", args) + invoice = client.invoices.finalize_invoice( + args["invoice"], + options={"idempotency_key": params.get("idempotency_key")}, + ) + return _ok(invoice) + + +def retrieve_invoice(client: _stripe.StripeClient, args: dict) -> str: + invoice = client.invoices.retrieve(args["invoice"]) + return _ok(invoice) + + +def list_invoices(client: _stripe.StripeClient, args: dict) -> str: + params: dict[str, Any] = {} + if args.get("limit"): + params["limit"] = args["limit"] + if args.get("customer"): + params["customer"] = args["customer"] + if args.get("starting_after"): + params["starting_after"] = args["starting_after"] + if args.get("ending_before"): + params["ending_before"] = args["ending_before"] + invoices = client.invoices.list(params=params) + return _ok(invoices) + + +# --------------------------------------------------------------------------- +# Subscriptions +# --------------------------------------------------------------------------- + +def create_subscription(client: _stripe.StripeClient, args: dict) -> str: + params = with_idempotency("create_subscription", args) + sub = client.subscriptions.create( + params={k: v for k, v in params.items() if k != "idempotency_key"}, + options={"idempotency_key": params.get("idempotency_key")}, + ) + return _ok(sub) + + +def list_subscriptions(client: _stripe.StripeClient, args: dict) -> str: + """List subscriptions. + + Fix #388: now supports ``starting_after`` and ``ending_before`` so callers + can page through accounts with more than 100 active subscriptions. + + Example (paginating through all active subscriptions):: + + page1 = json.loads(list_subscriptions(client, {"status": "active", "limit": 100})) + last_id = page1["data"][-1]["id"] # e.g. "sub_xyz" + page2 = json.loads(list_subscriptions(client, { + "status": "active", + "limit": 100, + "starting_after": last_id, + })) + """ + params: dict[str, Any] = {} + if args.get("limit"): + params["limit"] = args["limit"] + if args.get("customer"): + params["customer"] = args["customer"] + if args.get("price"): + params["price"] = args["price"] + if args.get("status"): + params["status"] = args["status"] + # Fix #388 — pagination cursors + if args.get("starting_after"): + params["starting_after"] = args["starting_after"] + if args.get("ending_before"): + params["ending_before"] = args["ending_before"] + subs = client.subscriptions.list(params=params) + return _ok(subs) + + +def cancel_subscription(client: _stripe.StripeClient, args: dict) -> str: + params = with_idempotency("cancel_subscription", args) + sub = client.subscriptions.cancel( + args["subscription"], + options={"idempotency_key": params.get("idempotency_key")}, + ) + return _ok(sub) + + +def update_subscription(client: _stripe.StripeClient, args: dict) -> str: + params = with_idempotency("update_subscription", args) + sub_id = args["subscription"] + update_params = { + k: v + for k, v in params.items() + if k not in {"subscription", "idempotency_key"} + } + sub = client.subscriptions.update( + sub_id, + params=update_params, + options={"idempotency_key": params.get("idempotency_key")}, + ) + return _ok(sub) + + +# --------------------------------------------------------------------------- +# Coupons +# --------------------------------------------------------------------------- + +def create_coupon(client: _stripe.StripeClient, args: dict) -> str: + params = with_idempotency("create_coupon", args) + coupon = client.coupons.create( + params={k: v for k, v in params.items() if k != "idempotency_key"}, + options={"idempotency_key": params.get("idempotency_key")}, + ) + return _ok(coupon) + + +def list_coupons(client: _stripe.StripeClient, args: dict) -> str: + params: dict[str, Any] = {} + if args.get("limit"): + params["limit"] = args["limit"] + if args.get("starting_after"): + params["starting_after"] = args["starting_after"] + if args.get("ending_before"): + params["ending_before"] = args["ending_before"] + coupons = client.coupons.list(params=params) + return _ok(coupons) + + +# --------------------------------------------------------------------------- +# Search (generic) +# --------------------------------------------------------------------------- + +def search_stripe_resources(client: _stripe.StripeClient, args: dict) -> str: + """Search Stripe resources using the Search API. + + Fix #388: the search endpoint supports cursor-based pagination via + ``page`` (a string token returned by a previous search response's + ``next_page`` field). We surface this as ``starting_after`` to match + the convention of all other list tools; internally it maps to the + ``page`` query parameter used by Stripe's Search API. + + .. note:: + Stripe's Search API uses ``page`` (not ``starting_after``) for + pagination. For consistency with list tools we accept + ``starting_after`` and forward it as ``page``. + """ + resource = args.get("resource", "customers") + query = args.get("query", "") + params: dict[str, Any] = {"query": query} + if args.get("limit"): + params["limit"] = args["limit"] + # Fix #388 — pagination for search (maps starting_after → page token) + if args.get("starting_after"): + params["page"] = args["starting_after"] + + resource_map = { + "customers": client.customers.search, + "products": client.products.search, + "prices": client.prices.search, + "subscriptions": client.subscriptions.search, + "payment_intents": client.payment_intents.search, + "invoices": client.invoices.search, + "charges": client.charges.search, + } + search_fn = resource_map.get(resource) + if search_fn is None: + return json.dumps({"error": f"Unknown resource type: {resource!r}"}) + + results = search_fn(params=params) + return _ok(results) \ No newline at end of file diff --git a/tools/python/stripe_agent_toolkit/idempotency.py b/tools/python/stripe_agent_toolkit/idempotency.py new file mode 100644 index 00000000..847c824e --- /dev/null +++ b/tools/python/stripe_agent_toolkit/idempotency.py @@ -0,0 +1,100 @@ +""" +Idempotency guard for Stripe agent tool calls. + +Fixes Issue #402: Agent-level retry creates duplicate charges — no idempotency +guard above the tool layer. + +The Stripe SDK handles network-level retries within a single session using +auto-generated idempotency keys. However, when an agent framework retries a +tool call as a *new* invocation (e.g. after a timeout, crash, or model loop), +a fresh session starts with a new key, and a second charge is created. + +This module solves the problem at the orchestration layer by: + 1. Deriving a *stable* request_id from a deterministic hash of + (tool_name, sorted_args) before the call is made. + 2. Forwarding that key as the ``idempotency_key`` on every mutating Stripe + call so that Stripe itself deduplicates the request for up to 24 hours. + +Only mutating operations (those that can cause side-effects like charges) +receive idempotency keys. Read-only operations (list, retrieve) are excluded +because Stripe rejects idempotency keys on GET requests. + +Usage +----- + from stripe_agent_toolkit.idempotency import idempotency_key_for + + key = idempotency_key_for("create_payment_intent", {"amount": 1000, "currency": "usd"}) + # key is stable: same inputs always produce the same key + stripe.PaymentIntent.create(amount=1000, currency="usd", idempotency_key=key) +""" + +from __future__ import annotations + +import hashlib +import json +from typing import Any + +# Tools that create or mutate financial objects. Only these should receive +# a stable idempotency key; read (list/retrieve) operations must not. +MUTATING_TOOLS: frozenset[str] = frozenset( + { + "create_customer", + "create_product", + "create_price", + "create_payment_link", + "create_payment_intent", + "create_refund", + "create_invoice", + "create_invoice_item", + "finalize_invoice", + "create_subscription", + "cancel_subscription", + "update_subscription", + "create_coupon", + } +) + + +def _stable_json(obj: Any) -> str: + """Serialize *obj* to JSON with sorted keys so that dict ordering does not + affect the resulting hash.""" + return json.dumps(obj, sort_keys=True, separators=(",", ":"), default=str) + + +def idempotency_key_for(tool_name: str, args: dict[str, Any]) -> str | None: + """Return a deterministic idempotency key for *tool_name* called with *args*. + + Returns ``None`` for read-only tools so callers can use the result as a + guard directly:: + + key = idempotency_key_for(tool_name, args) + stripe_call(..., **({"idempotency_key": key} if key else {})) + + The key is a 64-character hex SHA-256 digest of ``":"``, + which is unique per (tool, args) combination and stable across retries. + """ + if tool_name not in MUTATING_TOOLS: + return None + + payload = f"{tool_name}:{_stable_json(args)}" + return hashlib.sha256(payload.encode()).hexdigest() + + +def with_idempotency(tool_name: str, args: dict[str, Any]) -> dict[str, Any]: + """Return a copy of *args* with ``idempotency_key`` injected if appropriate. + + Safe to call for every tool — read-only tools are returned unchanged. + + Example + ------- + >>> params = with_idempotency("create_payment_intent", {"amount": 500, "currency": "usd"}) + >>> "idempotency_key" in params + True + >>> params = with_idempotency("list_customers", {"limit": 10}) + >>> "idempotency_key" in params + False + """ + key = idempotency_key_for(tool_name, args) + if key is None: + return dict(args) + return {**args, "idempotency_key": key} \ No newline at end of file diff --git a/tools/python/tests/test_fixes.py b/tools/python/tests/test_fixes.py new file mode 100644 index 00000000..8486ebd6 --- /dev/null +++ b/tools/python/tests/test_fixes.py @@ -0,0 +1,276 @@ +""" +tests/test_fixes.py +------------------- +Tests for: + * Fix #402 — stable idempotency keys prevent duplicate charges on agent retry + * Fix #388 — pagination (starting_after / ending_before) on list/search tools +""" + +from __future__ import annotations + +import json +import sys +import types +import unittest +from unittest.mock import MagicMock, patch, call + +# --------------------------------------------------------------------------- +# Bootstrap: make the package importable without installing +# --------------------------------------------------------------------------- +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from stripe_agent_toolkit.idempotency import ( + idempotency_key_for, + with_idempotency, + MUTATING_TOOLS, +) +import stripe_agent_toolkit.api as api_module + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_stripe_obj(data: dict) -> MagicMock: + """Return a mock that mimics to_dict_recursive().""" + m = MagicMock() + m.to_dict_recursive.return_value = data + return m + + +def _make_client() -> MagicMock: + """Return a minimal mock StripeClient.""" + client = MagicMock() + return client + + +# =========================================================================== +# Fix #402 — Idempotency +# =========================================================================== + +class TestIdempotencyKeyGeneration(unittest.TestCase): + """idempotency_key_for() must be stable and scoped to mutating tools.""" + + def test_same_args_same_key(self): + args = {"amount": 1000, "currency": "usd", "customer": "cus_abc"} + k1 = idempotency_key_for("create_payment_intent", args) + k2 = idempotency_key_for("create_payment_intent", args) + self.assertEqual(k1, k2) + + def test_different_args_different_key(self): + k1 = idempotency_key_for("create_payment_intent", {"amount": 100, "currency": "usd"}) + k2 = idempotency_key_for("create_payment_intent", {"amount": 200, "currency": "usd"}) + self.assertNotEqual(k1, k2) + + def test_dict_order_does_not_matter(self): + k1 = idempotency_key_for("create_customer", {"name": "Alice", "email": "a@b.com"}) + k2 = idempotency_key_for("create_customer", {"email": "a@b.com", "name": "Alice"}) + self.assertEqual(k1, k2) + + def test_read_only_tool_returns_none(self): + self.assertIsNone(idempotency_key_for("list_customers", {"limit": 10})) + self.assertIsNone(idempotency_key_for("retrieve_customer", {"customer_id": "cus_1"})) + self.assertIsNone(idempotency_key_for("list_subscriptions", {"status": "active"})) + + def test_all_mutating_tools_return_key(self): + for tool in MUTATING_TOOLS: + with self.subTest(tool=tool): + key = idempotency_key_for(tool, {"dummy": "arg"}) + self.assertIsNotNone(key) + self.assertEqual(len(key), 64) # SHA-256 hex digest + + def test_with_idempotency_injects_key_for_mutating(self): + result = with_idempotency("create_payment_intent", {"amount": 500, "currency": "usd"}) + self.assertIn("idempotency_key", result) + self.assertEqual(len(result["idempotency_key"]), 64) + + def test_with_idempotency_does_not_mutate_original(self): + args = {"amount": 500, "currency": "usd"} + with_idempotency("create_payment_intent", args) + self.assertNotIn("idempotency_key", args) + + def test_with_idempotency_leaves_read_only_unchanged(self): + args = {"limit": 10} + result = with_idempotency("list_customers", args) + self.assertNotIn("idempotency_key", result) + self.assertEqual(result, args) + + +class TestCreatePaymentIntentIdempotency(unittest.TestCase): + """create_payment_intent forwards the idempotency key to Stripe.""" + + def test_idempotency_key_forwarded(self): + client = _make_client() + client.payment_intents.create.return_value = _make_stripe_obj( + {"id": "pi_123", "amount": 1000, "currency": "usd", "status": "requires_payment_method"} + ) + + args = {"amount": 1000, "currency": "usd"} + api_module.create_payment_intent(client, args) + + call_kwargs = client.payment_intents.create.call_args + options = call_kwargs.kwargs.get("options") or (call_kwargs[1].get("options") if call_kwargs[1] else None) + self.assertIsNotNone(options) + self.assertIn("idempotency_key", options) + self.assertEqual(len(options["idempotency_key"]), 64) + + def test_retry_uses_same_idempotency_key(self): + """Simulates an agent retrying the exact same call — the key must match.""" + args = {"amount": 1000, "currency": "usd", "customer": "cus_test"} + key1 = idempotency_key_for("create_payment_intent", args) + key2 = idempotency_key_for("create_payment_intent", args) + self.assertEqual(key1, key2, + "Retry of same call must produce the same idempotency key") + + +class TestCreateRefundIdempotency(unittest.TestCase): + def test_idempotency_key_forwarded(self): + client = _make_client() + client.refunds.create.return_value = _make_stripe_obj({"id": "re_1", "amount": 500}) + + api_module.create_refund(client, {"charge": "ch_123", "amount": 500}) + + call_kwargs = client.refunds.create.call_args + options = call_kwargs.kwargs.get("options") or {} + self.assertIn("idempotency_key", options) + + +# =========================================================================== +# Fix #388 — Pagination +# =========================================================================== + +class TestListSubscriptionsPagination(unittest.TestCase): + def _subscriptions_response(self, last_id: str = "sub_zzz") -> MagicMock: + return _make_stripe_obj({ + "object": "list", + "data": [{"id": last_id, "status": "active"}], + "has_more": True, + }) + + def test_starting_after_forwarded(self): + client = _make_client() + client.subscriptions.list.return_value = self._subscriptions_response() + + api_module.list_subscriptions(client, {"limit": 100, "starting_after": "sub_xyz"}) + + params = client.subscriptions.list.call_args.kwargs.get("params", {}) + self.assertEqual(params.get("starting_after"), "sub_xyz") + + def test_ending_before_forwarded(self): + client = _make_client() + client.subscriptions.list.return_value = self._subscriptions_response() + + api_module.list_subscriptions(client, {"limit": 100, "ending_before": "sub_abc"}) + + params = client.subscriptions.list.call_args.kwargs.get("params", {}) + self.assertEqual(params.get("ending_before"), "sub_abc") + + def test_no_cursor_omits_pagination_params(self): + client = _make_client() + client.subscriptions.list.return_value = self._subscriptions_response() + + api_module.list_subscriptions(client, {"limit": 10, "status": "active"}) + + params = client.subscriptions.list.call_args.kwargs.get("params", {}) + self.assertNotIn("starting_after", params) + self.assertNotIn("ending_before", params) + + +class TestListProductsPagination(unittest.TestCase): + def test_starting_after_forwarded(self): + client = _make_client() + client.products.list.return_value = _make_stripe_obj({"object": "list", "data": []}) + + api_module.list_products(client, {"limit": 100, "starting_after": "prod_abc"}) + + params = client.products.list.call_args.kwargs.get("params", {}) + self.assertEqual(params.get("starting_after"), "prod_abc") + + +class TestListPricesPagination(unittest.TestCase): + def test_starting_after_forwarded(self): + client = _make_client() + client.prices.list.return_value = _make_stripe_obj({"object": "list", "data": []}) + + api_module.list_prices(client, {"limit": 100, "starting_after": "price_abc"}) + + params = client.prices.list.call_args.kwargs.get("params", {}) + self.assertEqual(params.get("starting_after"), "price_abc") + + +class TestSearchPagination(unittest.TestCase): + def test_starting_after_maps_to_page_token(self): + """search_stripe_resources maps starting_after → 'page' for the Search API.""" + client = _make_client() + client.customers.search.return_value = _make_stripe_obj({ + "object": "search_result", + "data": [], + "next_page": None, + }) + + api_module.search_stripe_resources(client, { + "resource": "customers", + "query": "email:'test@example.com'", + "starting_after": "page_token_xyz", + }) + + params = client.customers.search.call_args.kwargs.get("params", {}) + self.assertEqual(params.get("page"), "page_token_xyz", + "starting_after must be forwarded as 'page' to the Search API") + + def test_unknown_resource_returns_error_json(self): + client = _make_client() + result = api_module.search_stripe_resources(client, { + "resource": "unknown_resource", + "query": "foo", + }) + data = json.loads(result) + self.assertIn("error", data) + + +# =========================================================================== +# Pagination — end-to-end cursor chaining simulation +# =========================================================================== + +class TestPaginationCursorChaining(unittest.TestCase): + """Simulate a full multi-page traversal of subscriptions.""" + + def test_full_page_traversal(self): + client = _make_client() + + page1_data = [{"id": f"sub_{i:03d}", "status": "active"} for i in range(100)] + page2_data = [{"id": f"sub_{i:03d}", "status": "active"} for i in range(100, 140)] + + def _list_side_effect(params=None, **_): + sa = (params or {}).get("starting_after") + if sa is None: + return _make_stripe_obj({"data": page1_data, "has_more": True}) + elif sa == "sub_099": + return _make_stripe_obj({"data": page2_data, "has_more": False}) + raise AssertionError(f"Unexpected starting_after: {sa!r}") + + client.subscriptions.list.side_effect = _list_side_effect + + # Page 1 + r1 = json.loads(api_module.list_subscriptions(client, {"limit": 100, "status": "active"})) + self.assertEqual(len(r1["data"]), 100) + self.assertTrue(r1["has_more"]) + + last_id = r1["data"][-1]["id"] # "sub_099" + + # Page 2 + r2 = json.loads(api_module.list_subscriptions(client, { + "limit": 100, + "status": "active", + "starting_after": last_id, + })) + self.assertEqual(len(r2["data"]), 40) + self.assertFalse(r2["has_more"]) + + total = len(r1["data"]) + len(r2["data"]) + self.assertEqual(total, 140) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file