Skip to content

Commit 91d5759

Browse files
authored
Merge pull request #622 from ably/feat/channel-options
feat: introduce ChannelOptions for enhanced channel configuration
2 parents f67fa40 + 204138f commit 91d5759

File tree

2 files changed

+180
-8
lines changed

2 files changed

+180
-8
lines changed

ably/realtime/realtime_channel.py

Lines changed: 120 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22
import asyncio
33
import logging
4-
from typing import Optional, TYPE_CHECKING
4+
from typing import Optional, TYPE_CHECKING, Dict, Any
55
from ably.realtime.connection import ConnectionState
66
from ably.transport.websockettransport import ProtocolMessageAction
77
from ably.rest.channel import Channel, Channels as RestChannels
@@ -14,10 +14,75 @@
1414

1515
if TYPE_CHECKING:
1616
from ably.realtime.realtime import AblyRealtime
17+
from ably.util.crypto import CipherParams
1718

1819
log = logging.getLogger(__name__)
1920

2021

22+
class ChannelOptions:
23+
"""Channel options for Ably Realtime channels
24+
25+
Attributes
26+
----------
27+
cipher : CipherParams, optional
28+
Requests encryption for this channel when not null, and specifies encryption-related parameters.
29+
params : Dict[str, str], optional
30+
Channel parameters that configure the behavior of the channel.
31+
"""
32+
33+
def __init__(self, cipher: Optional[CipherParams] = None, params: Optional[dict] = None):
34+
self.__cipher = cipher
35+
self.__params = params
36+
# Validate params
37+
if self.__params and not isinstance(self.__params, dict):
38+
raise AblyException("params must be a dictionary", 40000, 400)
39+
40+
@property
41+
def cipher(self):
42+
"""Get cipher configuration"""
43+
return self.__cipher
44+
45+
@property
46+
def params(self) -> Dict[str, str]:
47+
"""Get channel parameters"""
48+
return self.__params
49+
50+
def __eq__(self, other):
51+
"""Check equality with another ChannelOptions instance"""
52+
if not isinstance(other, ChannelOptions):
53+
return False
54+
55+
return (self.__cipher == other.__cipher and
56+
self.__params == other.__params)
57+
58+
def __hash__(self):
59+
"""Make ChannelOptions hashable"""
60+
return hash((
61+
self.__cipher,
62+
tuple(sorted(self.__params.items())) if self.__params else None,
63+
))
64+
65+
def to_dict(self) -> Dict[str, Any]:
66+
"""Convert to dictionary representation"""
67+
result = {}
68+
if self.__cipher is not None:
69+
result['cipher'] = self.__cipher
70+
if self.__params:
71+
result['params'] = self.__params
72+
return result
73+
74+
@classmethod
75+
def from_dict(cls, options_dict: Dict[str, Any]) -> 'ChannelOptions':
76+
"""Create ChannelOptions from dictionary"""
77+
if not isinstance(options_dict, dict):
78+
raise AblyException("options must be a dictionary", 40000, 400)
79+
80+
return cls(
81+
cipher=options_dict.get('cipher'),
82+
params=options_dict.get('params'),
83+
)
84+
85+
2186
class RealtimeChannel(EventEmitter, Channel):
2287
"""
2388
Ably Realtime Channel
@@ -43,23 +108,44 @@ class RealtimeChannel(EventEmitter, Channel):
43108
Unsubscribe to messages from a channel
44109
"""
45110

46-
def __init__(self, realtime: AblyRealtime, name: str):
111+
def __init__(self, realtime: AblyRealtime, name: str, channel_options: Optional[ChannelOptions] = None):
47112
EventEmitter.__init__(self)
48113
self.__name = name
49114
self.__realtime = realtime
50115
self.__state = ChannelState.INITIALIZED
51116
self.__message_emitter = EventEmitter()
52117
self.__state_timer: Optional[Timer] = None
53118
self.__attach_resume = False
119+
self.__attach_serial: Optional[str] = None
54120
self.__channel_serial: Optional[str] = None
55121
self.__retry_timer: Optional[Timer] = None
56122
self.__error_reason: Optional[AblyException] = None
123+
self.__channel_options = channel_options or ChannelOptions()
124+
self.__params: Optional[Dict[str, str]] = None
57125

58126
# Used to listen to state changes internally, if we use the public event emitter interface then internals
59127
# will be disrupted if the user called .off() to remove all listeners
60128
self.__internal_state_emitter = EventEmitter()
61129

62-
Channel.__init__(self, realtime, name, {})
130+
# Pass channel options as dictionary to parent Channel class
131+
Channel.__init__(self, realtime, name, self.__channel_options.to_dict())
132+
133+
async def set_options(self, channel_options: ChannelOptions) -> None:
134+
"""Set channel options"""
135+
should_reattach = self.should_reattach_to_set_options(channel_options)
136+
self.set_options_without_reattach(channel_options)
137+
138+
if should_reattach:
139+
self._attach_impl()
140+
state_change = await self.__internal_state_emitter.once_async()
141+
if state_change.current in (ChannelState.SUSPENDED, ChannelState.FAILED):
142+
raise state_change.reason
143+
144+
def set_options_without_reattach(self, channel_options: ChannelOptions) -> None:
145+
"""Internal method"""
146+
self.__channel_options = channel_options
147+
# Update parent class options
148+
self.options = channel_options.to_dict()
63149

64150
# RTL4
65151
async def attach(self) -> None:
@@ -108,6 +194,7 @@ def _attach_impl(self):
108194
# RTL4c
109195
attach_msg = {
110196
"action": ProtocolMessageAction.ATTACH,
197+
"params": self.__channel_options.params,
111198
"channel": self.name,
112199
}
113200

@@ -292,8 +379,6 @@ def _on_message(self, proto_msg: dict) -> None:
292379
action = proto_msg.get('action')
293380
# RTL4c1
294381
channel_serial = proto_msg.get('channelSerial')
295-
if channel_serial:
296-
self.__channel_serial = channel_serial
297382
# TM2a, TM2c, TM2f
298383
Message.update_inner_message_fields(proto_msg)
299384

@@ -303,6 +388,10 @@ def _on_message(self, proto_msg: dict) -> None:
303388
exception = None
304389
resumed = False
305390

391+
self.__attach_serial = channel_serial
392+
self.__channel_serial = channel_serial
393+
self.__params = proto_msg.get('params')
394+
306395
if error:
307396
exception = AblyException.from_dict(error)
308397

@@ -327,6 +416,7 @@ def _on_message(self, proto_msg: dict) -> None:
327416
self._request_state(ChannelState.ATTACHING)
328417
elif action == ProtocolMessageAction.MESSAGE:
329418
messages = Message.from_encoded_array(proto_msg.get('messages'))
419+
self.__channel_serial = channel_serial
330420
for message in messages:
331421
self.__message_emitter._emit(message.name, message)
332422
elif action == ProtocolMessageAction.ERROR:
@@ -431,6 +521,12 @@ def __on_retry_timer_expire(self) -> None:
431521
log.info("RealtimeChannel retry timer expired, attempting a new attach")
432522
self._request_state(ChannelState.ATTACHING)
433523

524+
def should_reattach_to_set_options(self, new_options: ChannelOptions) -> bool:
525+
"""Internal method"""
526+
if self.state != ChannelState.ATTACHING and self.state != ChannelState.ATTACHED:
527+
return False
528+
return self.__channel_options != new_options
529+
434530
# RTL23
435531
@property
436532
def name(self) -> str:
@@ -453,6 +549,11 @@ def error_reason(self) -> Optional[AblyException]:
453549
"""An AblyException instance describing the last error which occurred on the channel, if any."""
454550
return self.__error_reason
455551

552+
@property
553+
def params(self) -> Dict[str, str]:
554+
"""Get channel parameters"""
555+
return self.__params
556+
456557

457558
class Channels(RestChannels):
458559
"""Creates and destroys RealtimeChannel objects.
@@ -466,19 +567,31 @@ class Channels(RestChannels):
466567
"""
467568

468569
# RTS3
469-
def get(self, name: str) -> RealtimeChannel:
570+
def get(self, name: str, options: Optional[ChannelOptions] = None) -> RealtimeChannel:
470571
"""Creates a new RealtimeChannel object, or returns the existing channel object.
471572
472573
Parameters
473574
----------
474575
475576
name: str
476577
Channel name
578+
options: ChannelOptions or dict, optional
579+
Channel options for the channel
477580
"""
478581
if name not in self.__all:
479-
channel = self.__all[name] = RealtimeChannel(self.__ably, name)
582+
channel = self.__all[name] = RealtimeChannel(self.__ably, name, options)
480583
else:
481584
channel = self.__all[name]
585+
# Update options if channel is not attached or currently attaching
586+
if options and channel.should_reattach_to_set_options(options):
587+
raise AblyException(
588+
'Channels.get() cannot be used to set channel options that would cause the channel to '
589+
'reattach. Please, use RealtimeChannel.setOptions() instead.',
590+
400,
591+
40000
592+
)
593+
elif options:
594+
channel.set_options_without_reattach(options)
482595
return channel
483596

484597
# RTS4

test/ably/realtime/realtimechannel_test.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import pytest
3-
from ably.realtime.realtime_channel import ChannelState, RealtimeChannel
3+
from ably.realtime.realtime_channel import ChannelState, RealtimeChannel, ChannelOptions
44
from ably.transport.websockettransport import ProtocolMessageAction
55
from ably.types.message import Message
66
from test.ably.testapp import TestApp
@@ -468,3 +468,62 @@ async def test_channel_error_cleared_upon_connect_from_terminal_state(self):
468468
assert channel.error_reason is None
469469

470470
await ably.close()
471+
472+
async def test_channel_params_received_by_relatime(self):
473+
ably = await TestApp.get_ably_realtime()
474+
channel_name = random_string(5)
475+
channel = ably.channels.get(channel_name, ChannelOptions(params={
476+
"rewind": "1"
477+
}))
478+
await channel.attach()
479+
assert channel.params["rewind"] == "1"
480+
481+
await ably.close()
482+
483+
async def test_channel_params_unknown_params_skipped_by_relatime(self):
484+
ably = await TestApp.get_ably_realtime()
485+
channel_name = random_string(5)
486+
channel = ably.channels.get(channel_name, ChannelOptions(params={
487+
"rewind": "1",
488+
"foo": "bar"
489+
}))
490+
await channel.attach()
491+
assert channel.params["rewind"] == "1"
492+
assert channel.params.get("foo") is None
493+
494+
await ably.close()
495+
496+
async def test_channel_params_as_dict(self):
497+
ably = await TestApp.get_ably_realtime()
498+
channel_name = random_string(5)
499+
channel = ably.channels.get(channel_name, ChannelOptions(params={"delta": "vcdiff"}))
500+
await channel.attach()
501+
assert channel.params["delta"] == "vcdiff"
502+
503+
await ably.close()
504+
505+
async def test_channel_get_channel_with_same_params(self):
506+
ably = await TestApp.get_ably_realtime()
507+
channel_name = random_string(5)
508+
channel = ably.channels.get(channel_name, ChannelOptions(params={"rewind": "1"}))
509+
await channel.attach()
510+
same_channel = ably.channels.get(channel_name, ChannelOptions(params={"rewind": "1"}))
511+
assert channel == same_channel
512+
513+
await ably.close()
514+
515+
async def test_channel_get_channel_with_different_params(self):
516+
ably = await TestApp.get_ably_realtime()
517+
channel_name = random_string(5)
518+
channel = ably.channels.get(channel_name, ChannelOptions(params={"rewind": "1"}))
519+
await channel.attach()
520+
521+
with pytest.raises(AblyException) as exception:
522+
ably.channels.get(channel_name, ChannelOptions(params={"delta": "vcdiff"}))
523+
524+
assert exception.value.code == 40000
525+
assert exception.value.status_code == 400
526+
527+
assert channel.params == {"rewind": "1"}
528+
529+
await ably.close()

0 commit comments

Comments
 (0)