Skip to content

Commit 5dd225d

Browse files
authored
[SB] fix to capture properties when using batches (#42598)
* fix to capture properties when using batches * return none for partition key * clean up comments * fix some spacing issues * enable partitions for queues and add tests * disable other tests * fix sleep time and get async cred * fix test * reset tests * reset tests for mgmt * vendor in to EH * check partition key * fix pylint * update changelog
1 parent 087d6f9 commit 5dd225d

File tree

7 files changed

+111
-0
lines changed

7 files changed

+111
-0
lines changed

sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# license information.
55
# --------------------------------------------------------------------------
66
import datetime
7+
from typing import TYPE_CHECKING
78
from base64 import b64encode
89
from hashlib import sha256
910
from hmac import HMAC
@@ -12,6 +13,10 @@
1213
from datetime import timezone
1314
from .types import TYPE, VALUE, AMQPTypes
1415
from ._encode import encode_payload
16+
from .message import Properties
17+
18+
if TYPE_CHECKING:
19+
from .message import Message
1520

1621
TZ_UTC: timezone = timezone.utc
1722
# Number of seconds between the Unix epoch (1/1/1970) and year 1 CE.
@@ -75,6 +80,13 @@ def add_batch(batch, message):
7580
encode_payload(output, message)
7681
batch[5].append(output)
7782

83+
def set_message_properties(message, properties: list):
84+
if not message[3]:
85+
message[3] = Properties(*properties)
86+
87+
def set_message_annotations(message, annotations: dict):
88+
if not message[2]:
89+
message[2] = annotations
7890

7991
def encode_str(data, encoding="utf-8"):
8092
try:

sdk/servicebus/azure-servicebus/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
### Bugs Fixed
1010

11+
- Fixed a bug where batched messages couldn't be sent to a queue that had session & partitions enabled. ([#42598](https://github.com/Azure/azure-sdk-for-python/pull/42598))
12+
1113
### Other Changes
1214

1315
## 7.14.2 (2025-04-09)

sdk/servicebus/azure-servicebus/azure/servicebus/_common/message.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from .._pyamqp._message_backcompat import LegacyMessage, LegacyBatchMessage
1616
from .._pyamqp.message import Message as pyamqp_Message
17+
from .._pyamqp.utils import set_message_properties, set_message_annotations
1718
from .._transport._pyamqp_transport import PyamqpTransport
1819

1920
from .constants import (
@@ -309,6 +310,7 @@ def partition_key(self) -> Optional[str]:
309310
310311
:rtype: str or None
311312
"""
313+
opt_p_key = None
312314
try:
313315
opt_p_key = self._raw_amqp_message.annotations.get(_X_OPT_PARTITION_KEY) # type: ignore
314316
if opt_p_key is not None:
@@ -675,6 +677,16 @@ def _add(self, add_message: Union[ServiceBusMessage, Mapping[str, Any], AmqpAnno
675677
self._count += 1
676678
self._messages.append(outgoing_sb_message)
677679

680+
if self._count == 1: # Populate properties on the batch envelope from the first message
681+
if outgoing_sb_message.message_id or outgoing_sb_message.session_id:
682+
properties: List[Optional[str]] = [None] * 13
683+
properties[0] = outgoing_sb_message.message_id
684+
properties[10] = outgoing_sb_message.session_id
685+
set_message_properties(self._message, properties)
686+
687+
if outgoing_sb_message.partition_key:
688+
set_message_annotations(self._message, {_X_OPT_PARTITION_KEY: outgoing_sb_message.partition_key})
689+
678690
@property
679691
def message(self) -> Union["BatchMessage", LegacyBatchMessage]:
680692
"""DEPRECATED: Get the underlying uamqp.BatchMessage or LegacyBatchMessage.

sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# license information.
55
# --------------------------------------------------------------------------
66
import datetime
7+
from typing import TYPE_CHECKING
78
from base64 import b64encode
89
from hashlib import sha256
910
from hmac import HMAC
@@ -12,6 +13,10 @@
1213
from datetime import timezone
1314
from .types import TYPE, VALUE, AMQPTypes
1415
from ._encode import encode_payload
16+
from .message import Properties
17+
18+
if TYPE_CHECKING:
19+
from .message import Message
1520

1621
TZ_UTC: timezone = timezone.utc
1722
# Number of seconds between the Unix epoch (1/1/1970) and year 1 CE.
@@ -74,6 +79,13 @@ def add_batch(batch, message):
7479
encode_payload(output, message)
7580
batch[5].append(output)
7681

82+
def set_message_properties(message, properties: list):
83+
if not message[3]:
84+
message[3] = Properties(*properties)
85+
86+
def set_message_annotations(message, annotations: dict):
87+
if not message[2]:
88+
message[2] = annotations
7789

7890
def encode_str(data, encoding="utf-8"):
7991
try:

sdk/servicebus/azure-servicebus/tests/async_tests/test_sessions_async.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,3 +1533,38 @@ async def test_async_next_available_session_timeout_value(
15331533
async for msg in receiver:
15341534
pass
15351535
assert time.time() - start_time2 > 65 # Default service timeout value is 65 seconds
1536+
1537+
@pytest.mark.asyncio
1538+
@pytest.mark.liveTest
1539+
@pytest.mark.live_test_only
1540+
@CachedServiceBusResourceGroupPreparer(name_prefix="servicebustest")
1541+
@CachedServiceBusNamespacePreparer(name_prefix="servicebustest")
1542+
@ServiceBusQueuePreparer(name_prefix="servicebustest", requires_session=True, enable_partitioning=True)
1543+
@pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids)
1544+
@ArgPasserAsync()
1545+
async def test_async_session_partition_batch(self, uamqp_transport, *, servicebus_namespace=None, servicebus_queue=None, **kwargs):
1546+
1547+
fully_qualified_namespace = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}"
1548+
credential = get_credential(is_async=True)
1549+
messages = [
1550+
ServiceBusMessage("Message 1", session_id="mySessionId", message_id=uuid.uuid4(), partition_key="mySessionId"),
1551+
ServiceBusMessage("Message 2", session_id="mySessionId", message_id=uuid.uuid4(), partition_key="mySessionId")
1552+
]
1553+
async with ServiceBusClient(
1554+
fully_qualified_namespace=fully_qualified_namespace,
1555+
credential=credential,
1556+
logging_enable=False,
1557+
uamqp_transport=uamqp_transport,
1558+
) as sb_client:
1559+
1560+
sender = sb_client.get_queue_sender(servicebus_queue.name)
1561+
async with sender:
1562+
await sender.send_messages(messages)
1563+
1564+
received_messages = []
1565+
async with sb_client.get_queue_receiver(servicebus_queue.name, session_id="mySessionId", max_wait_time=10) as receiver:
1566+
async for message in receiver:
1567+
received_messages.append(message)
1568+
assert len(received_messages) == 2
1569+
assert all(msg.session_id == "mySessionId" for msg in received_messages)
1570+
assert all(msg.partition_key == "mySessionId" for msg in received_messages)

sdk/servicebus/azure-servicebus/tests/servicebus_preparer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ def __init__(
427427
requires_duplicate_detection=False,
428428
dead_lettering_on_message_expiration=False,
429429
requires_session=False,
430+
enable_partitioning=False,
430431
lock_duration="PT30S",
431432
parameter_name=SERVICEBUS_QUEUE_PARAM,
432433
resource_group_parameter_name=RESOURCE_GROUP_PARAM,
@@ -452,12 +453,14 @@ def __init__(
452453
dead_lettering_on_message_expiration,
453454
requires_session,
454455
lock_duration,
456+
enable_partitioning,
455457
)
456458

457459
# Queue parameters
458460
self.requires_duplicate_detection = requires_duplicate_detection
459461
self.dead_lettering_on_message_expiration = dead_lettering_on_message_expiration
460462
self.requires_session = requires_session
463+
self.enable_partitioning = enable_partitioning
461464
self.lock_duration = lock_duration
462465
if random_name_enabled:
463466
self.resource_moniker = self.name_prefix + "sbqueue"
@@ -481,6 +484,7 @@ def create_resource(self, name, **kwargs):
481484
requires_duplicate_detection=self.requires_duplicate_detection,
482485
dead_lettering_on_message_expiration=self.dead_lettering_on_message_expiration,
483486
requires_session=self.requires_session,
487+
enable_partitioning=self.enable_partitioning,
484488
),
485489
)
486490
break

sdk/servicebus/azure-servicebus/tests/test_sessions.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,3 +1676,37 @@ def test_next_available_session_timeout_value(
16761676
pass
16771677

16781678
assert time.time() - start_time2 > 65 # Default service operation timeout is 65 seconds
1679+
1680+
@pytest.mark.liveTest
1681+
@pytest.mark.live_test_only
1682+
@CachedServiceBusResourceGroupPreparer(name_prefix="servicebustest")
1683+
@CachedServiceBusNamespacePreparer(name_prefix="servicebustest")
1684+
@ServiceBusQueuePreparer(name_prefix="servicebustest", requires_session=True, enable_partitioning=True)
1685+
@pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids)
1686+
@ArgPasser()
1687+
def test_session_partition_batch(self, uamqp_transport, *, servicebus_namespace=None, servicebus_queue=None, **kwargs):
1688+
1689+
fully_qualified_namespace = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}"
1690+
credential = get_credential()
1691+
messages = [
1692+
ServiceBusMessage("Message 1", session_id="mySessionId", message_id=uuid.uuid4(), partition_key="mySessionId"),
1693+
ServiceBusMessage("Message 2", session_id="mySessionId", message_id=uuid.uuid4(), partition_key="mySessionId")
1694+
]
1695+
with ServiceBusClient(
1696+
fully_qualified_namespace=fully_qualified_namespace,
1697+
credential=credential,
1698+
logging_enable=False,
1699+
uamqp_transport=uamqp_transport,
1700+
) as sb_client:
1701+
1702+
with sb_client.get_queue_sender(servicebus_queue.name) as sender:
1703+
sender.send_messages(messages)
1704+
1705+
received_messages = []
1706+
with sb_client.get_queue_receiver(servicebus_queue.name, session_id="mySessionId", max_wait_time=10) as receiver:
1707+
for message in receiver:
1708+
received_messages.append(message)
1709+
1710+
assert len(received_messages) == 2
1711+
assert all(msg.session_id == "mySessionId" for msg in received_messages)
1712+
assert all(msg.partition_key == "mySessionId" for msg in received_messages)

0 commit comments

Comments
 (0)