Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 200 additions & 69 deletions kuksa-client/kuksa_client/grpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
########################################################################
# Copyright (c) 2022 Robert Bosch GmbH
# Copyright (c) 2022-2025 Robert Bosch GmbH
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -690,13 +690,16 @@ def from_message(cls, message: val_v1.EntryUpdate):

@classmethod
def from_tuple(cls, path: str, dp: types_v2.Datapoint):
# we assume here that only one field of Value is set -> we use the first entry.
# This should always be the case.
# we assume here that at max one field of Value is set -> we use the first entry.
# If no field is set the value is currently unknown/not avaialable -> set to None.
data = dp.value.ListFields()
field_descriptor, value = data[0]
field_name = field_descriptor.name
value = getattr(dp.value, field_name)
if dp.timestamp.seconds == 0 and dp.timestamp.nanos == 0:
if data:
field_descriptor, value = data[0]
field_name = field_descriptor.name
value = getattr(dp.value, field_name)
else:
value = None
if dp.timestamp is None or (dp.timestamp.seconds == 0 and dp.timestamp.nanos == 0):
timestamp = None
else:
timestamp = dp.timestamp.ToDatetime(
Expand All @@ -709,6 +712,21 @@ def from_tuple(cls, path: str, dp: types_v2.Datapoint):
fields=[Field(value=types_v1.FIELD_VALUE)],
)

@classmethod
def from_actuate_value(cls, path: str, value: types_v2.Value):
# we assume here that exactly one field of Value is set -> we use the first entry.
# This should always be the case.
data = value.ListFields()
field_descriptor, target_value = data[0]
field_name = field_descriptor.name
target_value = getattr(value, field_name)
return cls(
entry=DataEntry(
path=path, actuator_target=Datapoint(value=target_value)
),
fields=[Field(value=types_v1.FIELD_ACTUATOR_TARGET)],
)

def to_message(self) -> val_v1.EntryUpdate:
message = val_v1.EntryUpdate(entry=self.entry.to_message())
message.fields.extend(field.value for field in self.fields)
Expand Down Expand Up @@ -750,7 +768,6 @@ def __init__(
connected: bool = False,
tls_server_name: Optional[str] = None,
):

self.authorization_header = self.get_authorization_header(token)
self.target_host = f"{host}:{port}"
self.root_certificates = root_certificates
Expand Down Expand Up @@ -868,26 +885,29 @@ def _prepare_subscribe_request(
logger.debug("%s: %s", type(req).__name__, req)
return req

def _prepare_subscribev2_request(
self,
entries: Iterable[SubscribeEntry],
def _prepare_v2_subscribe_request(
self, paths: Iterable[str]
) -> val_v2.SubscribeRequest:
paths = []
for entry in entries:
paths.append(entry.path)
req = val_v2.SubscribeRequest(signal_paths=paths)
logger.debug("%s: %s", type(req).__name__, req)
return req

for field in entry.fields:
if field != Field.VALUE:
raise VSSClientError(
error={
"code": grpc.StatusCode.INVALID_ARGUMENT.value[0],
"reason": grpc.StatusCode.INVALID_ARGUMENT.value[1],
"message": "Cannot use v2 if specifiying fields other than value",
},
errors=[],
)
def _prepare_v2_provide_actuation_request(
self,
paths: Iterable[str],
) -> List[val_v2.OpenProviderStreamRequest]:
signals = []
for path in paths:
signals.append(types_v2.SignalID(path=path))
provide_req = val_v2.ProvideActuationRequest(actuator_identifiers=signals)
req = val_v2.OpenProviderStreamRequest(provide_actuation_request=provide_req)
logger.debug("%s: %s", type(req).__name__, req)
return [req]

req = val_v2.SubscribeRequest(signal_paths=paths)
def _prepare_v2_list_metadata_request(
self, path: str
) -> val_v2.ListMetadataRequest:
req = val_v2.ListMetadataRequest(root=path)
logger.debug("%s: %s", type(req).__name__, req)
return req

Expand Down Expand Up @@ -947,6 +967,8 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.channel = None
self.exit_stack = contextlib.ExitStack()
self.path_to_id_mapping: Dict[str, int] = dict()
self.id_to_path_mapping: Dict[int, str] = dict()

def __enter__(self):
self.connect()
Expand All @@ -969,6 +991,12 @@ def wrapper(self, *args, **kwargs):
return wrapper

def connect(self, target_host=None):
# Reset the id mapping on each new connection to the data broker because the broker
# could have been restarted and assigned new ids to paths in between.
# Furthermore, the specified target host could have changed.
self.path_to_id_mapping.clear()
self.id_to_path_mapping.clear()

creds = self._load_creds()
if target_host is None:
target_host = self.target_host
Expand Down Expand Up @@ -1161,15 +1189,25 @@ def subscribe_current_values(
for path, dp in updates.items():
print(f"Current value for {path} is now: {dp.value}")
"""
for updates in self.subscribe(
entries=(
SubscribeEntry(path, View.CURRENT_VALUE, (Field.VALUE,))
for path in paths
),
try_v2=True,
**rpc_kwargs,
):
yield {update.entry.path: update.entry.value for update in updates}
try:
logger.debug("Try to subscribe current values via v2")
for updates in self.v2_subscribe(paths, **rpc_kwargs):
yield {
update.entry.path: update.entry.value for update in updates
}
except VSSClientError as exc:
if exc.error["code"] != grpc.StatusCode.UNIMPLEMENTED.value[0]:
raise

logger.debug("v2 not available - falling back to v1 subscribe current values")
for updates in self.subscribe(
entries=(
SubscribeEntry(path, View.CURRENT_VALUE, (Field.VALUE,))
for path in paths
),
**rpc_kwargs,
):
yield {update.entry.path: update.entry.value for update in updates}

@check_connected
def subscribe_target_values(
Expand All @@ -1186,16 +1224,27 @@ def subscribe_target_values(
for path, dp in updates.items():
print(f"Target value for {path} is now: {dp.value}")
"""
for updates in self.subscribe(
entries=(
SubscribeEntry(path, View.TARGET_VALUE, (Field.ACTUATOR_TARGET,))
for path in paths
),
**rpc_kwargs,
):
yield {
update.entry.path: update.entry.actuator_target for update in updates
}
try:
logger.debug("Try to subscribe actuation requests via v2")
for updates in self.v2_subscribe_actuation_requests(paths, **rpc_kwargs):
yield {
update.entry.path: update.entry.actuator_target for update in updates
}
except VSSClientError as exc:
if exc.error["code"] != grpc.StatusCode.UNIMPLEMENTED.value[0]:
raise

logger.debug("v2 not available - falling back to v1 subscribe target values")
for updates in self.subscribe(
entries=(
SubscribeEntry(path, View.TARGET_VALUE, (Field.ACTUATOR_TARGET,))
for path in paths
),
**rpc_kwargs,
):
yield {
update.entry.path: update.entry.actuator_target for update in updates
}

@check_connected
def subscribe_metadata(
Expand Down Expand Up @@ -1295,6 +1344,41 @@ def set(
raise VSSClientError.from_grpc_error(exc) from exc
self._process_set_response(resp)

def get_path(self, signal_id: types_v2.SignalID) -> str:
if signal_id.HasField("path"):
return signal_id.path
elif signal_id.HasField("id") and signal_id.id in self.id_to_path_mapping:
return self.id_to_path_mapping[signal_id.id]
return "<unknown signal>"

def ensure_id_mapping(self, paths: Iterable[str], **rpc_kwargs):
for path in paths:
if path not in self.path_to_id_mapping:
# Prevent duplicate requests for the same path
self.path_to_id_mapping[path] = None
req = self._prepare_v2_list_metadata_request(path)
try:
resp = self.client_stub_v2.ListMetadata(req, **rpc_kwargs)
logger.debug("%s: %s", type(resp).__name__, resp)
if len(resp.metadata) == 1:
self.path_to_id_mapping[path] = resp.metadata[0].id
self.id_to_path_mapping[resp.metadata[0].id] = path
else:
del self.path_to_id_mapping[path]
raise VSSClientError(
error={
"code": grpc.StatusCode.NOT_FOUND.value[0],
"reason": grpc.StatusCode.NOT_FOUND.value[1],
"message": f"Path {path} not found on server",
},
errors=[],
)
except RpcError as exc:
if exc.code() == grpc.StatusCode.UNIMPLEMENTED:
logger.debug("v2 not available - skip querying ids")
return
raise VSSClientError.from_grpc_error(exc) from exc

@check_connected
def subscribe(
self, entries: Iterable[SubscribeEntry], try_v2: bool = False, **rpc_kwargs
Expand All @@ -1305,36 +1389,83 @@ def subscribe(
grpc.*MultiCallable kwargs e.g. timeout, metadata, credentials.
"""

if try_v2:
raise VSSClientError(
error={
"code": grpc.StatusCode.INVALID_ARGUMENT.value[0],
"reason": grpc.StatusCode.INVALID_ARGUMENT.value[1],
"message": ("Method subscribe supports v1, only. "
"Use v2_subscribe or v2_subscribe_actuation_requests instead."),
},
errors=[],
)

logger.debug("Try subscribing via v1")
rpc_kwargs["metadata"] = self.generate_metadata_header(
rpc_kwargs.get("metadata")
)
if try_v2:
logger.debug("Trying v2")
req = self._prepare_subscribev2_request(entries)
resp_stream = self.client_stub_v2.Subscribe(req, **rpc_kwargs)
try:
for resp in resp_stream:
logger.debug("%s: %s", type(resp).__name__, resp)
req = self._prepare_subscribe_request(entries)
resp_stream = self.client_stub_v1.Subscribe(req, **rpc_kwargs)
try:
for resp in resp_stream:
logger.debug("%s: %s", type(resp).__name__, resp)
yield [EntryUpdate.from_message(update) for update in resp.updates]
except RpcError as exc:
raise VSSClientError.from_grpc_error(exc) from exc

@check_connected
def v2_subscribe(
self, paths: Iterable[str], **rpc_kwargs
) -> Iterator[List[EntryUpdate]]:
"""
Parameters:
rpc_kwargs
grpc.*MultiCallable kwargs e.g. timeout, metadata, credentials.
"""

logger.debug("Subscribe current values via v2")
rpc_kwargs["metadata"] = self.generate_metadata_header(
rpc_kwargs.get("metadata")
)
req = self._prepare_v2_subscribe_request(paths)
resp_stream = self.client_stub_v2.Subscribe(req, **rpc_kwargs)
try:
for resp in resp_stream:
logger.debug("%s: %s", type(resp).__name__, resp)
yield [
EntryUpdate.from_tuple(path, dp)
for path, dp in resp.entries.items()
]
except RpcError as exc:
raise VSSClientError.from_grpc_error(exc) from exc

@check_connected
def v2_subscribe_actuation_requests(
self, paths: Iterable[str], **rpc_kwargs
) -> Iterator[List[EntryUpdate]]:
"""
Parameters:
rpc_kwargs
grpc.*MultiCallable kwargs e.g. timeout, metadata, credentials.
"""

logger.debug("Subscribe actuation requests via v2")
rpc_kwargs["metadata"] = self.generate_metadata_header(
rpc_kwargs.get("metadata")
)
self.ensure_id_mapping(paths, **rpc_kwargs)
req = self._prepare_v2_provide_actuation_request(paths)
resp_stream = self.client_stub_v2.OpenProviderStream(iter(req), **rpc_kwargs)
try:
for resp in resp_stream:
logger.debug("batch %s: %s", type(resp).__name__, resp)
if resp.HasField("batch_actuate_stream_request"):
yield [
EntryUpdate.from_tuple(path, dp)
for path, dp in resp.entries.items()
EntryUpdate.from_actuate_value(self.get_path(actuate_req.signal_id), actuate_req.value)
for actuate_req in resp.batch_actuate_stream_request.actuate_requests
]
except RpcError as exc:
if exc.code() == grpc.StatusCode.UNIMPLEMENTED:
logger.debug("v2 not available fall back to v1 instead")
self.subscribe(entries)
else:
raise VSSClientError.from_grpc_error(exc) from exc
else:
logger.debug("Trying v1")
req = self._prepare_subscribe_request(entries)
resp_stream = self.client_stub_v1.Subscribe(req, **rpc_kwargs)
try:
for resp in resp_stream:
logger.debug("%s: %s", type(resp).__name__, resp)
yield [EntryUpdate.from_message(update) for update in resp.updates]
except RpcError as exc:
raise VSSClientError.from_grpc_error(exc) from exc
except RpcError as exc:
raise VSSClientError.from_grpc_error(exc) from exc

@check_connected
def authorize(self, token: str, **rpc_kwargs) -> str:
Expand Down
Loading