Skip to content

Commit 466eb04

Browse files
committed
Implement dry_run
1 parent a25bbd5 commit 466eb04

File tree

7 files changed

+287
-5
lines changed

7 files changed

+287
-5
lines changed

pinecone/db_data/dataclasses/update_response.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import cast
2+
from typing import Optional, cast
33

44
from .utils import DictLike
55
from pinecone.utils.response_info import ResponseInfo
@@ -10,9 +10,13 @@ class UpdateResponse(DictLike):
1010
"""Response from an update operation.
1111
1212
Attributes:
13+
matched_records: The number of records that matched the filter (if a filter was provided).
14+
updated_records: The number of records that were actually updated.
1315
_response_info: Response metadata including LSN headers.
1416
"""
1517

18+
matched_records: Optional[int] = None
19+
updated_records: Optional[int] = None
1620
_response_info: ResponseInfo = field(
1721
default_factory=lambda: cast(ResponseInfo, {"raw_headers": {}}), repr=True, compare=False
1822
)

pinecone/db_data/index.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,8 @@ def update(
683683
)
684684
# Extract response info from result if it's an OpenAPI model with _response_info
685685
response_info = None
686+
matched_records = None
687+
updated_records = None
686688
if hasattr(result, "_response_info"):
687689
response_info = result._response_info
688690
else:
@@ -691,7 +693,30 @@ def update(
691693

692694
response_info = extract_response_info({})
693695

694-
return UpdateResponse(_response_info=response_info)
696+
# Extract matched_records and updated_records from OpenAPI model
697+
if hasattr(result, "matched_records"):
698+
matched_records = result.matched_records
699+
if hasattr(result, "updated_records"):
700+
updated_records = result.updated_records
701+
# Also check for camelCase in case it's in the raw response
702+
if updated_records is None and hasattr(result, "updatedRecords"):
703+
updated_records = result.updatedRecords
704+
# Check _data_store for fields not in the OpenAPI spec
705+
if hasattr(result, "_data_store"):
706+
if updated_records is None:
707+
updated_records = result._data_store.get(
708+
"updatedRecords"
709+
) or result._data_store.get("updated_records")
710+
if matched_records is None:
711+
matched_records = result._data_store.get(
712+
"matchedRecords"
713+
) or result._data_store.get("matched_records")
714+
715+
return UpdateResponse(
716+
matched_records=matched_records,
717+
updated_records=updated_records,
718+
_response_info=response_info,
719+
)
695720

696721
@validate_and_convert_errors
697722
def describe_index_stats(

pinecone/db_data/index_asyncio.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,8 @@ async def update(
654654
)
655655
# Extract response info from result if it's an OpenAPI model with _response_info
656656
response_info = None
657+
matched_records = None
658+
updated_records = None
657659
if hasattr(result, "_response_info"):
658660
response_info = result._response_info
659661
else:
@@ -662,7 +664,30 @@ async def update(
662664

663665
response_info = extract_response_info({})
664666

665-
return UpdateResponse(_response_info=response_info)
667+
# Extract matched_records and updated_records from OpenAPI model
668+
if hasattr(result, "matched_records"):
669+
matched_records = result.matched_records
670+
if hasattr(result, "updated_records"):
671+
updated_records = result.updated_records
672+
# Also check for camelCase in case it's in the raw response
673+
if updated_records is None and hasattr(result, "updatedRecords"):
674+
updated_records = result.updatedRecords
675+
# Check _data_store for fields not in the OpenAPI spec
676+
if hasattr(result, "_data_store"):
677+
if updated_records is None:
678+
updated_records = result._data_store.get(
679+
"updatedRecords"
680+
) or result._data_store.get("updated_records")
681+
if matched_records is None:
682+
matched_records = result._data_store.get(
683+
"matchedRecords"
684+
) or result._data_store.get("matched_records")
685+
686+
return UpdateResponse(
687+
matched_records=matched_records,
688+
updated_records=updated_records,
689+
_response_info=response_info,
690+
)
666691

667692
@validate_and_convert_errors
668693
async def describe_index_stats(

pinecone/grpc/utils.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,34 @@ def parse_update_response(
152152
):
153153
from pinecone.db_data.dataclasses import UpdateResponse
154154
from pinecone.utils.response_info import extract_response_info
155+
from google.protobuf import json_format
155156

156157
# Extract response info from initial metadata
157158
metadata = initial_metadata or {}
158159
response_info = extract_response_info(metadata)
159160

160-
return UpdateResponse(_response_info=response_info)
161+
# Extract matched_records and updated_records from response
162+
matched_records = None
163+
updated_records = None
164+
if isinstance(response, Message):
165+
# GRPC response - convert to dict to extract matched_records and updated_records
166+
json_response = json_format.MessageToDict(response)
167+
matched_records = json_response.get("matchedRecords") or json_response.get(
168+
"matched_records"
169+
)
170+
updated_records = json_response.get("updatedRecords") or json_response.get(
171+
"updated_records"
172+
)
173+
elif isinstance(response, dict):
174+
# Dict response - extract directly
175+
matched_records = response.get("matchedRecords") or response.get("matched_records")
176+
updated_records = response.get("updatedRecords") or response.get("updated_records")
177+
178+
return UpdateResponse(
179+
matched_records=matched_records,
180+
updated_records=updated_records,
181+
_response_info=response_info,
182+
)
161183

162184

163185
def parse_delete_response(

tests/integration/rest_asyncio/db/data/test_update.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,75 @@ async def test_update_metadata(self, index_host, dimension, target_namespace):
6969
fetched_vec = await asyncio_idx.fetch(ids=["2"], namespace=target_namespace)
7070
assert fetched_vec.vectors["2"].metadata == {"genre": "comedy"}
7171
await asyncio_idx.close()
72+
73+
async def test_update_with_filter_and_dry_run(self, index_host, dimension, target_namespace):
74+
"""Test update with filter and dry_run=True to verify matched_records and updated_records are returned."""
75+
asyncio_idx = build_asyncioindex_client(index_host)
76+
77+
# Upsert vectors with different genres
78+
upsert1 = await asyncio_idx.upsert(
79+
vectors=[
80+
Vector(
81+
id=str(i),
82+
values=embedding_values(dimension),
83+
metadata={"genre": "comedy" if i % 2 == 0 else "drama", "status": "active"},
84+
)
85+
for i in range(10)
86+
],
87+
namespace=target_namespace,
88+
batch_size=10,
89+
show_progress=False,
90+
)
91+
92+
await poll_until_lsn_reconciled_async(
93+
asyncio_idx, upsert1._response_info, namespace=target_namespace
94+
)
95+
96+
# Test dry_run=True - should return matched_records without updating
97+
dry_run_response = await asyncio_idx.update(
98+
filter={"genre": {"$eq": "comedy"}},
99+
set_metadata={"status": "updated"},
100+
dry_run=True,
101+
namespace=target_namespace,
102+
)
103+
104+
# Verify matched_records is returned and correct (5 comedy vectors)
105+
assert dry_run_response.matched_records is not None
106+
assert dry_run_response.matched_records == 5
107+
# In dry run, updated_records should be 0 or None since no records are actually updated
108+
assert dry_run_response.updated_records is None or dry_run_response.updated_records == 0
109+
110+
# Verify the vectors were NOT actually updated (dry run)
111+
fetched_before = await asyncio_idx.fetch(
112+
ids=["0", "2", "4", "6", "8"], namespace=target_namespace
113+
)
114+
for vec_id in ["0", "2", "4", "6", "8"]:
115+
assert fetched_before.vectors[vec_id].metadata.get("status") == "active"
116+
117+
# Now do the actual update
118+
update_response = await asyncio_idx.update(
119+
filter={"genre": {"$eq": "comedy"}},
120+
set_metadata={"status": "updated"},
121+
namespace=target_namespace,
122+
)
123+
124+
# Verify matched_records and updated_records are returned
125+
assert update_response.matched_records is not None
126+
assert update_response.matched_records == 5
127+
# updated_records should match the number of records actually updated (if returned by API)
128+
if update_response.updated_records is not None:
129+
assert update_response.updated_records == 5
130+
131+
await poll_until_lsn_reconciled_async(
132+
asyncio_idx, update_response._response_info, namespace=target_namespace
133+
)
134+
135+
# Verify the vectors were actually updated
136+
fetched_after = await asyncio_idx.fetch(
137+
ids=["0", "2", "4", "6", "8"], namespace=target_namespace
138+
)
139+
for vec_id in ["0", "2", "4", "6", "8"]:
140+
assert fetched_after.vectors[vec_id].metadata.get("status") == "updated"
141+
assert fetched_after.vectors[vec_id].metadata.get("genre") == "comedy"
142+
143+
await asyncio_idx.close()
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pytest
2+
from pinecone import Vector
3+
from tests.integration.helpers import poll_until_lsn_reconciled, embedding_values, random_string
4+
5+
6+
@pytest.fixture(scope="session")
7+
def update_namespace():
8+
return random_string(10)
9+
10+
11+
class TestUpdate:
12+
def test_update_with_filter_and_dry_run(self, idx, update_namespace):
13+
"""Test update with filter and dry_run=True to verify matched_records and updated_records are returned."""
14+
target_namespace = update_namespace
15+
16+
# Upsert vectors with different genres
17+
upsert1 = idx.upsert(
18+
vectors=[
19+
Vector(
20+
id=str(i),
21+
values=embedding_values(),
22+
metadata={"genre": "comedy" if i % 2 == 0 else "drama", "status": "active"},
23+
)
24+
for i in range(10)
25+
],
26+
namespace=target_namespace,
27+
)
28+
29+
poll_until_lsn_reconciled(idx, upsert1._response_info, namespace=target_namespace)
30+
31+
# Test dry_run=True - should return matched_records without updating
32+
dry_run_response = idx.update(
33+
filter={"genre": {"$eq": "comedy"}},
34+
set_metadata={"status": "updated"},
35+
dry_run=True,
36+
namespace=target_namespace,
37+
)
38+
39+
# Verify matched_records is returned and correct (5 comedy vectors)
40+
assert dry_run_response.matched_records is not None
41+
assert dry_run_response.matched_records == 5
42+
# In dry run, updated_records should be 0 or None since no records are actually updated
43+
assert dry_run_response.updated_records is None or dry_run_response.updated_records == 0
44+
45+
# Verify the vectors were NOT actually updated (dry run)
46+
fetched_before = idx.fetch(ids=["0", "2", "4", "6", "8"], namespace=target_namespace)
47+
for vec_id in ["0", "2", "4", "6", "8"]:
48+
assert fetched_before.vectors[vec_id].metadata.get("status") == "active"
49+
50+
# Now do the actual update
51+
update_response = idx.update(
52+
filter={"genre": {"$eq": "comedy"}},
53+
set_metadata={"status": "updated"},
54+
namespace=target_namespace,
55+
)
56+
57+
# Verify matched_records and updated_records are returned
58+
assert update_response.matched_records is not None
59+
assert update_response.matched_records == 5
60+
# updated_records should match the number of records actually updated (if returned by API)
61+
if update_response.updated_records is not None:
62+
assert update_response.updated_records == 5
63+
64+
poll_until_lsn_reconciled(idx, update_response._response_info, namespace=target_namespace)
65+
66+
# Verify the vectors were actually updated
67+
fetched_after = idx.fetch(ids=["0", "2", "4", "6", "8"], namespace=target_namespace)
68+
for vec_id in ["0", "2", "4", "6", "8"]:
69+
assert fetched_after.vectors[vec_id].metadata.get("status") == "updated"
70+
assert fetched_after.vectors[vec_id].metadata.get("genre") == "comedy"

tests/unit/test_index.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pandas as pd
22
import pytest
33

4-
from pinecone.db_data import _Index
4+
from pinecone.db_data import _Index, _IndexAsyncio
55
import pinecone.core.openapi.db_data.models as oai
66
from pinecone import QueryResponse, UpsertResponse, Vector
77

@@ -631,6 +631,70 @@ def test_update_withDryRunAndAllParams_updateWithDryRunAndAllParams(self, mocker
631631

632632
# endregion
633633

634+
# region: asyncio update tests
635+
636+
@pytest.mark.asyncio
637+
async def test_asyncio_update_withDryRun_updateWithDryRun(self, mocker):
638+
"""Test asyncio update with dry_run parameter."""
639+
asyncio_index = _IndexAsyncio(api_key="asdf", host="https://test.pinecone.io")
640+
mocker.patch.object(asyncio_index._vector_api, "update_vector", autospec=True)
641+
await asyncio_index.update(filter=self.filter1, dry_run=True, namespace="ns")
642+
asyncio_index._vector_api.update_vector.assert_called_once_with(
643+
oai.UpdateRequest(filter=self.filter1, dry_run=True, namespace="ns")
644+
)
645+
646+
@pytest.mark.asyncio
647+
async def test_asyncio_update_withDryRunAndSetMetadata_updateWithDryRunAndSetMetadata(
648+
self, mocker
649+
):
650+
"""Test asyncio update with dry_run and set_metadata."""
651+
asyncio_index = _IndexAsyncio(api_key="asdf", host="https://test.pinecone.io")
652+
mocker.patch.object(asyncio_index._vector_api, "update_vector", autospec=True)
653+
await asyncio_index.update(
654+
set_metadata=self.md1, filter=self.filter1, dry_run=True, namespace="ns"
655+
)
656+
asyncio_index._vector_api.update_vector.assert_called_once_with(
657+
oai.UpdateRequest(
658+
set_metadata=self.md1, filter=self.filter1, dry_run=True, namespace="ns"
659+
)
660+
)
661+
662+
@pytest.mark.asyncio
663+
async def test_asyncio_update_withDryRunFalse_updateWithDryRunFalse(self, mocker):
664+
"""Test asyncio update with dry_run=False."""
665+
asyncio_index = _IndexAsyncio(api_key="asdf", host="https://test.pinecone.io")
666+
mocker.patch.object(asyncio_index._vector_api, "update_vector", autospec=True)
667+
await asyncio_index.update(filter=self.filter1, dry_run=False, namespace="ns")
668+
asyncio_index._vector_api.update_vector.assert_called_once_with(
669+
oai.UpdateRequest(filter=self.filter1, dry_run=False, namespace="ns")
670+
)
671+
672+
@pytest.mark.asyncio
673+
async def test_asyncio_update_withDryRunAndAllParams_updateWithDryRunAndAllParams(self, mocker):
674+
"""Test asyncio update with dry_run and all parameters."""
675+
asyncio_index = _IndexAsyncio(api_key="asdf", host="https://test.pinecone.io")
676+
mocker.patch.object(asyncio_index._vector_api, "update_vector", autospec=True)
677+
await asyncio_index.update(
678+
values=self.vals1,
679+
set_metadata=self.md1,
680+
sparse_values=self.sv1,
681+
filter=self.filter1,
682+
dry_run=True,
683+
namespace="ns",
684+
)
685+
asyncio_index._vector_api.update_vector.assert_called_once_with(
686+
oai.UpdateRequest(
687+
values=self.vals1,
688+
set_metadata=self.md1,
689+
sparse_values=oai.SparseValues(indices=self.svi1, values=self.svv1),
690+
filter=self.filter1,
691+
dry_run=True,
692+
namespace="ns",
693+
)
694+
)
695+
696+
# endregion
697+
634698
# region: describe index tests
635699

636700
def test_describeIndexStats_callWithoutFilter_CalledWithoutFilter(self, mocker):

0 commit comments

Comments
 (0)