diff --git a/.gitignore b/.gitignore index bfccfbd0b..1b0cf3843 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,5 @@ tmp-KafkaCluster .venv venv_test venv_examples +.vscode/ +.dmypy.json \ No newline at end of file diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..94b52b05f --- /dev/null +++ b/mypy.ini @@ -0,0 +1,6 @@ +[mypy] +show_error_codes=true +disallow_untyped_defs=true +disallow_untyped_calls=true +warn_redundant_casts=true +strict_optional=true \ No newline at end of file diff --git a/setup.py b/setup.py index c5d3e1c27..fd63523a8 100755 --- a/setup.py +++ b/setup.py @@ -12,13 +12,20 @@ INSTALL_REQUIRES = [ 'futures;python_version<"3.2"', 'enum34;python_version<"3.4"', + 'six' ] TEST_REQUIRES = [ 'pytest==4.6.4;python_version<"3.0"', 'pytest;python_version>="3.0"', 'pytest-timeout', - 'flake8' + 'flake8', + # Cap the version to avoid issues with newer editions. Should be periodically updated! + 'mypy<=0.991', + 'types-protobuf', + 'types-jsonschema', + 'types-requests', + 'types-six' ] DOC_REQUIRES = ['sphinx', 'sphinx-rtd-theme'] @@ -27,7 +34,7 @@ AVRO_REQUIRES = ['fastavro>=0.23.0,<1.0;python_version<"3.0"', 'fastavro>=1.0;python_version>"3.0"', - 'avro>=1.11.1,<2', + 'avro>=1.11.2,<2', ] + SCHEMA_REGISTRY_REQUIRES JSON_REQUIRES = ['pyrsistent==0.16.1;python_version<"3.0"', @@ -81,6 +88,7 @@ def get_install_requirements(path): author_email='support@confluent.io', url='https://github.com/confluentinc/confluent-kafka-python', ext_modules=[module], + package_data={"confluent_kafka": ["py.typed"]}, packages=find_packages('src'), package_dir={'': 'src'}, data_files=[('', [os.path.join(work_dir, 'LICENSE.txt')])], diff --git a/src/confluent_kafka/__init__.py b/src/confluent_kafka/__init__.py index d477ba198..6b808bfc6 100644 --- a/src/confluent_kafka/__init__.py +++ b/src/confluent_kafka/__init__.py @@ -62,19 +62,19 @@ class ThrottleEvent(object): :ivar float throttle_time: The amount of time (in seconds) the broker throttled (delayed) the request """ - def __init__(self, broker_name, - broker_id, - throttle_time): + def __init__(self, broker_name: str, + broker_id: int, + throttle_time: float): self.broker_name = broker_name self.broker_id = broker_id self.throttle_time = throttle_time - def __str__(self): + def __str__(self) -> str: return "{}/{} throttled for {} ms".format(self.broker_name, self.broker_id, int(self.throttle_time * 1000)) -def _resolve_plugins(plugins): +def _resolve_plugins(plugins: str) -> str: """ Resolve embedded plugins from the wheel's library directory. For internal module use only. diff --git a/src/confluent_kafka/_model/__init__.py b/src/confluent_kafka/_model/__init__.py index 2bab6a1bd..dd24c0eb9 100644 --- a/src/confluent_kafka/_model/__init__.py +++ b/src/confluent_kafka/_model/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from enum import Enum +from typing import List, Optional from .. import cimpl @@ -34,7 +35,7 @@ class Node: rack: str The rack for this node. """ - def __init__(self, id, host, port, rack=None): + def __init__(self, id: int, host: str, port: int, rack: Optional[str]=None): self.id = id self.id_string = str(id) self.host = host @@ -55,7 +56,7 @@ class ConsumerGroupTopicPartitions: topic_partitions: list(TopicPartition) List of topic partitions information. """ - def __init__(self, group_id, topic_partitions=None): + def __init__(self, group_id: str, topic_partitions: Optional[List[cimpl.TopicPartition]]=None): self.group_id = group_id self.topic_partitions = topic_partitions @@ -85,7 +86,7 @@ class ConsumerGroupState(Enum): DEAD = cimpl.CONSUMER_GROUP_STATE_DEAD EMPTY = cimpl.CONSUMER_GROUP_STATE_EMPTY - def __lt__(self, other): - if self.__class__ != other.__class__: + def __lt__(self, other: object) -> bool: + if not isinstance(other, self.__class__): return NotImplemented return self.value < other.value diff --git a/src/confluent_kafka/_util/conversion_util.py b/src/confluent_kafka/_util/conversion_util.py index 82c9b7018..60dcb5741 100644 --- a/src/confluent_kafka/_util/conversion_util.py +++ b/src/confluent_kafka/_util/conversion_util.py @@ -13,11 +13,12 @@ # limitations under the License. from enum import Enum +from typing import Any, Type class ConversionUtil: @staticmethod - def convert_to_enum(val, enum_clazz): + def convert_to_enum(val: object, enum_clazz: Type) -> Any: if type(enum_clazz) is not type(Enum): raise TypeError("'enum_clazz' must be of type Enum") diff --git a/src/confluent_kafka/_util/validation_util.py b/src/confluent_kafka/_util/validation_util.py index ffe5785f2..cd4011f5d 100644 --- a/src/confluent_kafka/_util/validation_util.py +++ b/src/confluent_kafka/_util/validation_util.py @@ -12,38 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional from ..cimpl import KafkaError -try: - string_type = basestring -except NameError: - string_type = str - +import six class ValidationUtil: @staticmethod - def check_multiple_not_none(obj, vars_to_check): + def check_multiple_not_none(obj: object, vars_to_check: List[str]) -> None: for param in vars_to_check: ValidationUtil.check_not_none(obj, param) @staticmethod - def check_not_none(obj, param): + def check_not_none(obj: object, param: str) -> None: if getattr(obj, param) is None: raise ValueError("Expected %s to be not None" % (param,)) @staticmethod - def check_multiple_is_string(obj, vars_to_check): + def check_multiple_is_string(obj: object, vars_to_check: List[str]) -> None: for param in vars_to_check: ValidationUtil.check_is_string(obj, param) @staticmethod - def check_is_string(obj, param): + def check_is_string(obj: object, param: str) -> None: param_value = getattr(obj, param) - if param_value is not None and not isinstance(param_value, string_type): + if param_value is not None and not isinstance(param_value, six.string_types): raise TypeError("Expected %s to be a string" % (param,)) @staticmethod - def check_kafka_errors(errors): + def check_kafka_errors(errors: Optional[List]) -> None: if not isinstance(errors, list): raise TypeError("errors should be None or a list") for error in errors: @@ -51,6 +48,6 @@ def check_kafka_errors(errors): raise TypeError("Expected list of KafkaError") @staticmethod - def check_kafka_error(error): + def check_kafka_error(error: object) -> None: if not isinstance(error, KafkaError): raise TypeError("Expected error to be a KafkaError") diff --git a/src/confluent_kafka/admin/__init__.py b/src/confluent_kafka/admin/__init__.py index 324038744..5962abcfa 100644 --- a/src/confluent_kafka/admin/__init__.py +++ b/src/confluent_kafka/admin/__init__.py @@ -15,9 +15,13 @@ """ Kafka admin client: create, view, alter, and delete topics and resources. """ +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union, cast +from typing_extensions import TypeAlias import warnings import concurrent.futures +from concurrent.futures import Future + # Unused imports are keeped to be accessible using this public module from ._config import (ConfigSource, # noqa: F401 ConfigEntry, @@ -73,13 +77,10 @@ as _ConsumerGroupState -try: - string_type = basestring -except NameError: - string_type = str +import six -class AdminClient (_AdminClientImpl): +class AdminClient(_AdminClientImpl): """ AdminClient provides admin operations for Kafka brokers, topics, groups, and other resource types supported by the broker. @@ -101,7 +102,7 @@ class AdminClient (_AdminClientImpl): Requires broker version v0.11.0.0 or later. """ - def __init__(self, conf): + def __init__(self, conf: Dict): """ Create a new AdminClient using the provided configuration dictionary. @@ -114,13 +115,14 @@ def __init__(self, conf): super(AdminClient, self).__init__(conf) @staticmethod - def _make_topics_result(f, futmap): + def _make_topics_result(f: Future, futmap: Dict[str, Future]) -> None: """ Map per-topic results to per-topic futures in futmap. The result value of each (successful) future is None. """ try: result = f.result() + assert isinstance(result, Dict) for topic, error in result.items(): fut = futmap.get(topic, None) if fut is None: @@ -138,13 +140,14 @@ def _make_topics_result(f, futmap): fut.set_exception(e) @staticmethod - def _make_resource_result(f, futmap): + def _make_resource_result(f: Future, futmap: Dict[ConfigResource, Future]) -> None: """ Map per-resource results to per-resource futures in futmap. The result value of each (successful) future is a ConfigResource. """ try: result = f.result() + assert isinstance(result, Dict) for resource, configs in result.items(): fut = futmap.get(resource, None) if fut is None: @@ -163,11 +166,11 @@ def _make_resource_result(f, futmap): fut.set_exception(e) @staticmethod - def _make_list_consumer_groups_result(f, futmap): + def _make_list_consumer_groups_result(f: Future, futmap: Dict[str, Future]) -> None: pass @staticmethod - def _make_consumer_groups_result(f, futmap): + def _make_consumer_groups_result(f: Future, futmap: Dict[str, Future]) -> None: """ Map per-group results to per-group futures in futmap. """ @@ -192,7 +195,7 @@ def _make_consumer_groups_result(f, futmap): fut.set_exception(e) @staticmethod - def _make_consumer_group_offsets_result(f, futmap): + def _make_consumer_group_offsets_result(f: Future, futmap: Dict[str, Future]) -> None: """ Map per-group results to per-group futures in futmap. The result value of each (successful) future is ConsumerGroupTopicPartitions. @@ -218,7 +221,7 @@ def _make_consumer_group_offsets_result(f, futmap): fut.set_exception(e) @staticmethod - def _make_acls_result(f, futmap): + def _make_acls_result(f: Future, futmap: Dict[AclBinding, Future]) -> None: """ Map create ACL binding results to corresponding futures in futmap. For create_acls the result value of each (successful) future is None. @@ -244,7 +247,7 @@ def _make_acls_result(f, futmap): fut.set_exception(e) @staticmethod - def _make_user_scram_credentials_result(f, futmap): + def _make_user_scram_credentials_result(f: Future, futmap: Dict[str, Future]) -> None: try: results = f.result() len_results = len(results) @@ -266,14 +269,16 @@ def _make_user_scram_credentials_result(f, futmap): fut.set_exception(e) @staticmethod - def _create_future(): - f = concurrent.futures.Future() + def _create_future() -> Future: + f: Future = concurrent.futures.Future() if not f.set_running_or_notify_cancel(): raise RuntimeError("Future was cancelled prematurely") return f + _futures_map_key = TypeVar("_futures_map_key") + @staticmethod - def _make_futures(futmap_keys, class_check, make_result_fn): + def _make_futures(futmap_keys: List[_futures_map_key], class_check: Optional[Type], make_result_fn: Callable[[Future, Dict[_futures_map_key, Future]], None]) -> Tuple[Future, Dict[_futures_map_key, Future]]: """ Create futures and a futuremap for the keys in futmap_keys, and create a request-level future to be bassed to the C API. @@ -295,7 +300,7 @@ def _make_futures(futmap_keys, class_check, make_result_fn): return f, futmap @staticmethod - def _make_futures_v2(futmap_keys, class_check, make_result_fn): + def _make_futures_v2(futmap_keys: Iterable[_futures_map_key], class_check: Optional[Type], make_result_fn: Callable[[Future, Dict[_futures_map_key, Future]], None]) -> Tuple[Future, Dict[_futures_map_key, Future]]: """ Create futures and a futuremap for the keys in futmap_keys, and create a request-level future to be bassed to the C API. @@ -315,11 +320,11 @@ def _make_futures_v2(futmap_keys, class_check, make_result_fn): return f, futmap @staticmethod - def _has_duplicates(items): + def _has_duplicates(items: Sequence[object]) -> bool: return len(set(items)) != len(items) @staticmethod - def _check_list_consumer_group_offsets_request(request): + def _check_list_consumer_group_offsets_request(request: List[_ConsumerGroupTopicPartitions]) -> None: if request is None: raise TypeError("request cannot be None") if not isinstance(request, list): @@ -332,7 +337,7 @@ def _check_list_consumer_group_offsets_request(request): if req.group_id is None: raise TypeError("'group_id' cannot be None") - if not isinstance(req.group_id, string_type): + if not isinstance(req.group_id, six.string_types): raise TypeError("'group_id' must be a string") if not req.group_id: raise ValueError("'group_id' cannot be empty") @@ -358,7 +363,7 @@ def _check_list_consumer_group_offsets_request(request): raise ValueError("Element of 'topic_partitions' must not have 'offset' value") @staticmethod - def _check_alter_consumer_group_offsets_request(request): + def _check_alter_consumer_group_offsets_request(request: List[_ConsumerGroupTopicPartitions]) -> None: if request is None: raise TypeError("request cannot be None") if not isinstance(request, list): @@ -370,7 +375,7 @@ def _check_alter_consumer_group_offsets_request(request): raise TypeError("Expected list of 'ConsumerGroupTopicPartitions'") if req.group_id is None: raise TypeError("'group_id' cannot be None") - if not isinstance(req.group_id, string_type): + if not isinstance(req.group_id, six.string_types): raise TypeError("'group_id' must be a string") if not req.group_id: raise ValueError("'group_id' cannot be empty") @@ -397,17 +402,17 @@ def _check_alter_consumer_group_offsets_request(request): "Element of 'topic_partitions' must not have negative value for 'offset' field") @staticmethod - def _check_describe_user_scram_credentials_request(users): + def _check_describe_user_scram_credentials_request(users: List) -> None: if not isinstance(users, list): raise TypeError("Expected input to be list of String") for user in users: - if not isinstance(user, string_type): + if not isinstance(user, six.string_types): raise TypeError("Each value should be a string") if not user: raise ValueError("'user' cannot be empty") @staticmethod - def _check_alter_user_scram_credentials_request(alterations): + def _check_alter_user_scram_credentials_request(alterations: List[UserScramCredentialAlteration]) -> None: if not isinstance(alterations, list): raise TypeError("Expected input to be list") if len(alterations) == 0: @@ -417,7 +422,7 @@ def _check_alter_user_scram_credentials_request(alterations): raise TypeError("Expected each element of list to be subclass of UserScramCredentialAlteration") if alteration.user is None: raise TypeError("'user' cannot be None") - if not isinstance(alteration.user, string_type): + if not isinstance(alteration.user, six.string_types): raise TypeError("'user' must be a string") if not alteration.user: raise ValueError("'user' cannot be empty") @@ -449,7 +454,7 @@ def _check_alter_user_scram_credentials_request(alterations): "to be either a UserScramCredentialUpsertion or a " + "UserScramCredentialDeletion") - def create_topics(self, new_topics, **kwargs): + def create_topics(self, new_topics: List[NewTopic], **kwargs: object) -> Dict[str, Future]: # type: ignore[override] """ Create one or more new topics. @@ -483,7 +488,7 @@ def create_topics(self, new_topics, **kwargs): return futmap - def delete_topics(self, topics, **kwargs): + def delete_topics(self, topics: List[str], **kwargs: object) -> Dict[str, Future]: # type: ignore[override] """ Delete one or more topics. @@ -513,15 +518,15 @@ def delete_topics(self, topics, **kwargs): return futmap - def list_topics(self, *args, **kwargs): + def list_topics(self, *args: object, **kwargs: object) -> object: return super(AdminClient, self).list_topics(*args, **kwargs) - def list_groups(self, *args, **kwargs): + def list_groups(self, *args: object, **kwargs: object) -> object: return super(AdminClient, self).list_groups(*args, **kwargs) - def create_partitions(self, new_partitions, **kwargs): + def create_partitions(self, new_partitions: List[NewPartitions], **kwargs: object) -> Dict[str, Future]: # type: ignore[override] """ Create additional partitions for the given topics. @@ -554,7 +559,7 @@ def create_partitions(self, new_partitions, **kwargs): return futmap - def describe_configs(self, resources, **kwargs): + def describe_configs(self, resources: List[ConfigResource], **kwargs: object) -> Dict[ConfigResource, Future]: # type: ignore[override] """ Get the configuration of the specified resources. @@ -586,7 +591,7 @@ def describe_configs(self, resources, **kwargs): return futmap - def alter_configs(self, resources, **kwargs): + def alter_configs(self, resources: List[ConfigResource], **kwargs: object) -> Dict[ConfigResource, Future]: # type: ignore[override] """ .. deprecated:: 2.2.0 @@ -634,7 +639,7 @@ def alter_configs(self, resources, **kwargs): return futmap - def incremental_alter_configs(self, resources, **kwargs): + def incremental_alter_configs(self, resources: List[ConfigResource], **kwargs: object) -> Dict[ConfigResource, Future]: # type: ignore[override] """ Update configuration properties for the specified resources. Updates are incremental, i.e only the values mentioned are changed @@ -667,7 +672,7 @@ def incremental_alter_configs(self, resources, **kwargs): return futmap - def create_acls(self, acls, **kwargs): + def create_acls(self, acls: List[AclBinding], **kwargs: object) -> Dict[AclBinding, Future]: # type: ignore[override] """ Create one or more ACL bindings. @@ -696,7 +701,7 @@ def create_acls(self, acls, **kwargs): return futmap - def describe_acls(self, acl_binding_filter, **kwargs): + def describe_acls(self, acl_binding_filter: List[AclBindingFilter], **kwargs: object) -> Future: # type: ignore[override] """ Match ACL bindings by filter. @@ -731,7 +736,7 @@ def describe_acls(self, acl_binding_filter, **kwargs): return f - def delete_acls(self, acl_binding_filters, **kwargs): + def delete_acls(self, acl_binding_filters: List[AclBindingFilter], **kwargs: object) -> Dict[AclBindingFilter, Future]: # type: ignore[override] """ Delete ACL bindings matching one or more ACL binding filters. @@ -764,13 +769,13 @@ def delete_acls(self, acl_binding_filters, **kwargs): raise ValueError("duplicate ACL binding filters not allowed") f, futmap = AdminClient._make_futures(acl_binding_filters, AclBindingFilter, - AdminClient._make_acls_result) + cast(Callable[[Future, Dict[AclBindingFilter, Future]], None], AdminClient._make_acls_result)) super(AdminClient, self).delete_acls(acl_binding_filters, f, **kwargs) return futmap - def list_consumer_groups(self, **kwargs): + def list_consumer_groups(self, states:Optional[Set[_ConsumerGroupState]] = None, **kwargs: object) -> Future: # type: ignore[override] """ List consumer groups. @@ -788,24 +793,21 @@ def list_consumer_groups(self, **kwargs): :raises TypeException: Invalid input. :raises ValueException: Invalid input. """ - if "states" in kwargs: - states = kwargs["states"] - if states is not None: - if not isinstance(states, set): - raise TypeError("'states' must be a set") - for state in states: - if not isinstance(state, _ConsumerGroupState): - raise TypeError("All elements of states must be of type ConsumerGroupState") - kwargs["states_int"] = [state.value for state in states] - kwargs.pop("states") + if states is not None: + if not isinstance(states, set): + raise TypeError("'states' must be a set") + for state in states: + if not isinstance(state, _ConsumerGroupState): + raise TypeError("All elements of states must be of type ConsumerGroupState") + kwargs["states_int"] = [state.value for state in states] - f, _ = AdminClient._make_futures([], None, AdminClient._make_list_consumer_groups_result) + f, _ = AdminClient._make_futures(cast(List[str], []), None, AdminClient._make_list_consumer_groups_result) super(AdminClient, self).list_consumer_groups(f, **kwargs) return f - def describe_consumer_groups(self, group_ids, **kwargs): + def describe_consumer_groups(self, group_ids: List[str], **kwargs: object) -> Dict[str, Future]: # type: ignore[override] """ Describe consumer groups. @@ -837,7 +839,7 @@ def describe_consumer_groups(self, group_ids, **kwargs): return futmap - def delete_consumer_groups(self, group_ids, **kwargs): + def delete_consumer_groups(self, group_ids: List[str], **kwargs: object) -> Dict[str, Future]: # type: ignore[override] """ Delete the given consumer groups. @@ -861,13 +863,13 @@ def delete_consumer_groups(self, group_ids, **kwargs): if len(group_ids) == 0: raise ValueError("Expected at least one group to be deleted") - f, futmap = AdminClient._make_futures(group_ids, string_type, AdminClient._make_consumer_groups_result) + f, futmap = AdminClient._make_futures(group_ids, str, AdminClient._make_consumer_groups_result) super(AdminClient, self).delete_consumer_groups(group_ids, f, **kwargs) return futmap - def list_consumer_group_offsets(self, list_consumer_group_offsets_request, **kwargs): + def list_consumer_group_offsets(self, list_consumer_group_offsets_request: List[_ConsumerGroupTopicPartitions], **kwargs: object) -> Dict[str, Future]: # type: ignore[override] """ List offset information for the consumer group and (optional) topic partition provided in the request. @@ -896,14 +898,14 @@ def list_consumer_group_offsets(self, list_consumer_group_offsets_request, **kwa AdminClient._check_list_consumer_group_offsets_request(list_consumer_group_offsets_request) f, futmap = AdminClient._make_futures([request.group_id for request in list_consumer_group_offsets_request], - string_type, + str, AdminClient._make_consumer_group_offsets_result) super(AdminClient, self).list_consumer_group_offsets(list_consumer_group_offsets_request, f, **kwargs) return futmap - def alter_consumer_group_offsets(self, alter_consumer_group_offsets_request, **kwargs): + def alter_consumer_group_offsets(self, alter_consumer_group_offsets_request: List[_ConsumerGroupTopicPartitions], **kwargs: object) -> Dict[str, Future]: # type: ignore[override] """ Alter offset for the consumer group and topic partition provided in the request. @@ -929,14 +931,14 @@ def alter_consumer_group_offsets(self, alter_consumer_group_offsets_request, **k AdminClient._check_alter_consumer_group_offsets_request(alter_consumer_group_offsets_request) f, futmap = AdminClient._make_futures([request.group_id for request in alter_consumer_group_offsets_request], - string_type, + str, AdminClient._make_consumer_group_offsets_result) super(AdminClient, self).alter_consumer_group_offsets(alter_consumer_group_offsets_request, f, **kwargs) return futmap - def set_sasl_credentials(self, username, password): + def set_sasl_credentials(self, username: str, password: str) -> None: """ Sets the SASL credentials used for this client. These credentials will overwrite the old ones, and will be used the @@ -955,7 +957,7 @@ def set_sasl_credentials(self, username, password): """ super(AdminClient, self).set_sasl_credentials(username, password) - def describe_user_scram_credentials(self, users, **kwargs): + def describe_user_scram_credentials(self, users: List[str], **kwargs: object) -> Dict[str, Future]: # type: ignore[override] """ Describe user SASL/SCRAM credentials. @@ -984,7 +986,7 @@ def describe_user_scram_credentials(self, users, **kwargs): return futmap - def alter_user_scram_credentials(self, alterations, **kwargs): + def alter_user_scram_credentials(self, alterations: List[UserScramCredentialAlteration], **kwargs: object) -> Dict: # type: ignore[override] """ Alter user SASL/SCRAM credentials. diff --git a/src/confluent_kafka/admin/_acl.py b/src/confluent_kafka/admin/_acl.py index 3512a74ca..d37118fe3 100644 --- a/src/confluent_kafka/admin/_acl.py +++ b/src/confluent_kafka/admin/_acl.py @@ -14,14 +14,14 @@ from enum import Enum import functools +from typing import Dict, List, Tuple, Type, TypeVar, cast from .. import cimpl as _cimpl from ._resource import ResourceType, ResourcePatternType from .._util import ValidationUtil, ConversionUtil -try: - string_type = basestring -except NameError: - string_type = str +import six + +string_type = six.string_types[0] class AclOperation(Enum): @@ -42,9 +42,10 @@ class AclOperation(Enum): ALTER_CONFIGS = _cimpl.ACL_OPERATION_ALTER_CONFIGS #: ALTER_CONFIGS operation IDEMPOTENT_WRITE = _cimpl.ACL_OPERATION_IDEMPOTENT_WRITE #: IDEMPOTENT_WRITE operation - def __lt__(self, other): + def __lt__(self, other: object) -> bool: if self.__class__ != other.__class__: return NotImplemented + assert isinstance(other, self.__class__) return self.value < other.value @@ -57,9 +58,10 @@ class AclPermissionType(Enum): DENY = _cimpl.ACL_PERMISSION_TYPE_DENY #: Disallows access ALLOW = _cimpl.ACL_PERMISSION_TYPE_ALLOW #: Grants access - def __lt__(self, other): + def __lt__(self, other: object) -> bool: if self.__class__ != other.__class__: return NotImplemented + assert isinstance(other, self.__class__) return self.value < other.value @@ -89,9 +91,9 @@ class AclBinding(object): The permission type for the specified operation. """ - def __init__(self, restype, name, - resource_pattern_type, principal, host, - operation, permission_type): + def __init__(self, restype: ResourceType, name: str, + resource_pattern_type: ResourcePatternType, principal: str, host: str, + operation: AclOperation, permission_type: AclPermissionType): self.restype = restype self.name = name self.resource_pattern_type = resource_pattern_type @@ -106,7 +108,7 @@ def __init__(self, restype, name, self.operation_int = int(self.operation.value) self.permission_type_int = int(self.permission_type.value) - def _convert_enums(self): + def _convert_enums(self) -> None: self.restype = ConversionUtil.convert_to_enum(self.restype, ResourceType) self.resource_pattern_type = ConversionUtil.convert_to_enum( self.resource_pattern_type, ResourcePatternType) @@ -115,20 +117,20 @@ def _convert_enums(self): self.permission_type = ConversionUtil.convert_to_enum( self.permission_type, AclPermissionType) - def _check_forbidden_enums(self, forbidden_enums): + def _check_forbidden_enums(self, forbidden_enums: Dict[str, List]) -> None: for k, v in forbidden_enums.items(): enum_value = getattr(self, k) if enum_value in v: raise ValueError("Cannot use enum %s, value %s in this class" % (k, enum_value.name)) - def _not_none_args(self): + def _not_none_args(self) -> List[str]: return ["restype", "name", "resource_pattern_type", "principal", "host", "operation", "permission_type"] - def _string_args(self): + def _string_args(self) -> List[str]: return ["name", "principal", "host"] - def _forbidden_enums(self): + def _forbidden_enums(self) -> Dict[str, List[Enum]]: return { "restype": [ResourceType.ANY], "resource_pattern_type": [ResourcePatternType.ANY, @@ -137,7 +139,7 @@ def _forbidden_enums(self): "permission_type": [AclPermissionType.ANY] } - def _convert_args(self): + def _convert_args(self) -> None: not_none_args = self._not_none_args() string_args = self._string_args() forbidden_enums = self._forbidden_enums() @@ -146,26 +148,28 @@ def _convert_args(self): self._convert_enums() self._check_forbidden_enums(forbidden_enums) - def __repr__(self): + def __repr__(self) -> str: type_name = type(self).__name__ return "%s(%s,%s,%s,%s,%s,%s,%s)" % ((type_name,) + self._to_tuple()) - def _to_tuple(self): + def _to_tuple(self) -> Tuple: return (self.restype, self.name, self.resource_pattern_type, self.principal, self.host, self.operation, self.permission_type) - def __hash__(self): + def __hash__(self) -> int: return hash(self._to_tuple()) - def __lt__(self, other): + def __lt__(self, other: object) -> bool: if self.__class__ != other.__class__: return NotImplemented + assert isinstance(other, self.__class__) return self._to_tuple() < other._to_tuple() - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if self.__class__ != other.__class__: return NotImplemented + assert isinstance(other, self.__class__) return self._to_tuple() == other._to_tuple() @@ -194,11 +198,11 @@ class AclBindingFilter(AclBinding): The permission type to match or :attr:`AclPermissionType.ANY` to match any value. """ - def _not_none_args(self): + def _not_none_args(self) -> List[str]: return ["restype", "resource_pattern_type", "operation", "permission_type"] - def _forbidden_enums(self): + def _forbidden_enums(self) -> Dict[str, List[Enum]]: return { "restype": [ResourceType.UNKNOWN], "resource_pattern_type": [ResourcePatternType.UNKNOWN], diff --git a/src/confluent_kafka/admin/_config.py b/src/confluent_kafka/admin/_config.py index 008a18ee1..3d1586ab7 100644 --- a/src/confluent_kafka/admin/_config.py +++ b/src/confluent_kafka/admin/_config.py @@ -14,6 +14,7 @@ from enum import Enum import functools +from typing import Dict, List, Optional, Union from .. import cimpl as _cimpl from ._resource import ResourceType @@ -66,14 +67,14 @@ class ConfigEntry(object): This class is typically not user instantiated. """ - def __init__(self, name, value, - source=ConfigSource.UNKNOWN_CONFIG, - is_read_only=False, - is_default=False, - is_sensitive=False, - is_synonym=False, - synonyms=[], - incremental_operation=None): + def __init__(self, name: str, value: str, + source: ConfigSource=ConfigSource.UNKNOWN_CONFIG, + is_read_only: bool=False, + is_default: bool=False, + is_sensitive: bool=False, + is_synonym: bool=False, + synonyms: List[str]=[], + incremental_operation: Optional[AlterConfigOpType]=None): """ This class is typically not user instantiated. """ @@ -103,10 +104,10 @@ def __init__(self, name, value, self.incremental_operation = incremental_operation """The incremental operation (AlterConfigOpType) to use in incremental_alter_configs.""" - def __repr__(self): + def __repr__(self) -> str: return "ConfigEntry(%s=\"%s\")" % (self.name, self.value) - def __str__(self): + def __str__(self) -> str: return "%s=\"%s\"" % (self.name, self.value) @@ -129,9 +130,8 @@ class ConfigResource(object): Type = ResourceType - def __init__(self, restype, name, - set_config=None, described_configs=None, error=None, - incremental_configs=None): + def __init__(self, restype: Union[str, int, ResourceType], name: str, + set_config: Optional[Dict[str, str]]=None, described_configs: Optional[object]=None, error: Optional[object]=None, incremental_configs: Optional[List[ConfigEntry]]=None): """ :param ConfigResource.Type restype: Resource type. :param str name: The resource name, which depends on restype. @@ -146,18 +146,20 @@ def __init__(self, restype, name, if name is None: raise ValueError("Expected resource name to be a string") - if type(restype) == str: + if isinstance(restype, str): # Allow resource type to be specified as case-insensitive string, for convenience. try: - restype = ConfigResource.Type[restype.upper()] + self.restype = ConfigResource.Type[restype.upper()] except KeyError: raise ValueError("Unknown resource type \"%s\": should be a ConfigResource.Type" % restype) - elif type(restype) == int: + elif isinstance(restype, int): # The C-code passes restype as an int, convert to Type. - restype = ConfigResource.Type(restype) + self.restype = ConfigResource.Type(restype) + + else: + self.restype = restype - self.restype = restype self.restype_int = int(self.restype.value) # for the C code self.name = name @@ -166,36 +168,38 @@ def __init__(self, restype, name, else: self.set_config_dict = dict() - self.incremental_configs = list(incremental_configs or []) + self.incremental_configs = incremental_configs or [] self.configs = described_configs self.error = error - def __repr__(self): + def __repr__(self) -> str: if self.error is not None: return "ConfigResource(%s,%s,%r)" % (self.restype, self.name, self.error) else: return "ConfigResource(%s,%s)" % (self.restype, self.name) - def __hash__(self): + def __hash__(self) -> int: return hash((self.restype, self.name)) - def __lt__(self, other): + def __lt__(self, other: object) -> bool: + assert isinstance(other, ConfigResource) if self.restype < other.restype: return True return self.name.__lt__(other.name) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + assert isinstance(other, ConfigResource) return self.restype == other.restype and self.name == other.name - def __len__(self): + def __len__(self) -> int: """ :rtype: int :returns: number of configuration entries/operations """ return len(self.set_config_dict) - def set_config(self, name, value, overwrite=True): + def set_config(self, name: str, value: str, overwrite: bool=True) -> None: """ Set/overwrite a configuration value. @@ -213,7 +217,7 @@ def set_config(self, name, value, overwrite=True): return self.set_config_dict[name] = value - def add_incremental_config(self, config_entry): + def add_incremental_config(self, config_entry: ConfigEntry) -> None: """ Add a ConfigEntry for incremental alter configs, using the configured incremental_operation. diff --git a/src/confluent_kafka/admin/_group.py b/src/confluent_kafka/admin/_group.py index 1c8d5e6fe..5ac117775 100644 --- a/src/confluent_kafka/admin/_group.py +++ b/src/confluent_kafka/admin/_group.py @@ -13,8 +13,11 @@ # limitations under the License. +from typing import List, Optional + +from confluent_kafka.cimpl import KafkaException, TopicPartition from .._util import ConversionUtil -from .._model import ConsumerGroupState +from .._model import ConsumerGroupState, Node class ConsumerGroupListing: @@ -31,7 +34,7 @@ class ConsumerGroupListing: state : ConsumerGroupState Current state of the consumer group. """ - def __init__(self, group_id, is_simple_consumer_group, state=None): + def __init__(self, group_id: str, is_simple_consumer_group: bool, state: Optional[ConsumerGroupState]=None): self.group_id = group_id self.is_simple_consumer_group = is_simple_consumer_group if state is not None: @@ -50,7 +53,7 @@ class ListConsumerGroupsResult: errors : list(KafkaException) List of errors encountered during the operation, if any. """ - def __init__(self, valid=None, errors=None): + def __init__(self, valid: Optional[List[ConsumerGroupListing]]=None, errors: Optional[List[KafkaException]]=None): self.valid = valid self.errors = errors @@ -65,7 +68,7 @@ class MemberAssignment: topic_partitions : list(TopicPartition) The topic partitions assigned to a group member. """ - def __init__(self, topic_partitions=[]): + def __init__(self, topic_partitions: Optional[List[TopicPartition]]=None): self.topic_partitions = topic_partitions if self.topic_partitions is None: self.topic_partitions = [] @@ -89,7 +92,7 @@ class MemberDescription: group_instance_id : str The instance id of the group member. """ - def __init__(self, member_id, client_id, host, assignment, group_instance_id=None): + def __init__(self, member_id: str, client_id: str, host: str, assignment: MemberAssignment, group_instance_id: Optional[str]=None): self.member_id = member_id self.client_id = client_id self.host = host @@ -117,8 +120,8 @@ class ConsumerGroupDescription: coordinator: Node Consumer group coordinator. """ - def __init__(self, group_id, is_simple_consumer_group, members, partition_assignor, state, - coordinator): + def __init__(self, group_id:str , is_simple_consumer_group: bool, members: List[MemberDescription], partition_assignor: str, state: ConsumerGroupState, + coordinator: Node): self.group_id = group_id self.is_simple_consumer_group = is_simple_consumer_group self.members = members diff --git a/src/confluent_kafka/admin/_metadata.py b/src/confluent_kafka/admin/_metadata.py index 201e4534b..0d68c0862 100644 --- a/src/confluent_kafka/admin/_metadata.py +++ b/src/confluent_kafka/admin/_metadata.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, List, Optional + + class ClusterMetadata (object): """ Provides information about the Kafka cluster, brokers, and topics. @@ -20,24 +23,24 @@ class ClusterMetadata (object): This class is typically not user instantiated. """ - def __init__(self): - self.cluster_id = None + def __init__(self) -> None: + self.cluster_id: Optional[str] = None """Cluster id string, if supported by the broker, else None.""" self.controller_id = -1 """Current controller broker id, or -1.""" - self.brokers = {} + self.brokers: Dict[int, BrokerMetadata] = {} """Map of brokers indexed by the broker id (int). Value is a BrokerMetadata object.""" - self.topics = {} + self.topics: Dict[str, TopicMetadata] = {} """Map of topics indexed by the topic name. Value is a TopicMetadata object.""" self.orig_broker_id = -1 """The broker this metadata originated from.""" self.orig_broker_name = None """The broker name/address this metadata originated from.""" - def __repr__(self): + def __repr__(self) -> str: return "ClusterMetadata({})".format(self.cluster_id) - def __str__(self): + def __str__(self) -> str: return str(self.cluster_id) @@ -48,7 +51,7 @@ class BrokerMetadata (object): This class is typically not user instantiated. """ - def __init__(self): + def __init__(self) -> None: self.id = -1 """Broker id""" self.host = None @@ -56,10 +59,10 @@ def __init__(self): self.port = -1 """Broker port""" - def __repr__(self): + def __repr__(self) -> str: return "BrokerMetadata({}, {}:{})".format(self.id, self.host, self.port) - def __str__(self): + def __str__(self) -> str: return "{}:{}/{}".format(self.host, self.port, self.id) @@ -73,22 +76,22 @@ class TopicMetadata (object): # Sphinx issue where it tries to reference the same instance variable # on other classes which raises a warning/error. - def __init__(self): - self.topic = None + def __init__(self) -> None: + self.topic: Optional[str] = None """Topic name""" - self.partitions = {} + self.partitions: Dict[int, PartitionMetadata] = {} """Map of partitions indexed by partition id. Value is a PartitionMetadata object.""" self.error = None """Topic error, or None. Value is a KafkaError object.""" - def __repr__(self): + def __repr__(self) -> str: if self.error is not None: return "TopicMetadata({}, {} partitions, {})".format(self.topic, len(self.partitions), self.error) else: return "TopicMetadata({}, {} partitions)".format(self.topic, len(self.partitions)) - def __str__(self): - return self.topic + def __str__(self) -> str: + return str(self.topic) class PartitionMetadata (object): @@ -103,25 +106,25 @@ class PartitionMetadata (object): of a broker id in the brokers dict. """ - def __init__(self): + def __init__(self) -> None: self.id = -1 """Partition id.""" self.leader = -1 """Current leader broker for this partition, or -1.""" - self.replicas = [] + self.replicas: List[int] = [] """List of replica broker ids for this partition.""" - self.isrs = [] + self.isrs: List[int] = [] """List of in-sync-replica broker ids for this partition.""" self.error = None """Partition error, or None. Value is a KafkaError object.""" - def __repr__(self): + def __repr__(self) -> str: if self.error is not None: return "PartitionMetadata({}, {})".format(self.id, self.error) else: return "PartitionMetadata({})".format(self.id) - def __str__(self): + def __str__(self) -> str: return "{}".format(self.id) @@ -134,7 +137,7 @@ class GroupMember(object): This class is typically not user instantiated. """ # noqa: E501 - def __init__(self,): + def __init__(self) -> None: self.id = None """Member id (generated by broker).""" self.client_id = None @@ -153,10 +156,10 @@ class GroupMetadata(object): This class is typically not user instantiated. """ - def __init__(self): + def __init__(self) -> None: self.broker = None """Originating broker metadata.""" - self.id = None + self.id: Optional[str] = None """Group name.""" self.error = None """Broker-originated error, or None. Value is a KafkaError object.""" @@ -166,14 +169,14 @@ def __init__(self): """Group protocol type.""" self.protocol = None """Group protocol.""" - self.members = [] + self.members: List = [] """Group members.""" - def __repr__(self): + def __repr__(self) -> str: if self.error is not None: return "GroupMetadata({}, {})".format(self.id, self.error) else: return "GroupMetadata({})".format(self.id) - def __str__(self): - return self.id + def __str__(self) -> str: + return str(self.id) diff --git a/src/confluent_kafka/admin/_resource.py b/src/confluent_kafka/admin/_resource.py index b786f3a9a..5920fd2a4 100644 --- a/src/confluent_kafka/admin/_resource.py +++ b/src/confluent_kafka/admin/_resource.py @@ -26,9 +26,10 @@ class ResourceType(Enum): GROUP = _cimpl.RESOURCE_GROUP #: Group resource. Resource name is group.id. BROKER = _cimpl.RESOURCE_BROKER #: Broker resource. Resource name is broker id. - def __lt__(self, other): + def __lt__(self, other: object) -> bool: if self.__class__ != other.__class__: return NotImplemented + assert isinstance(other, self.__class__) return self.value < other.value @@ -42,7 +43,8 @@ class ResourcePatternType(Enum): LITERAL = _cimpl.RESOURCE_PATTERN_LITERAL #: Literal: A literal resource name PREFIXED = _cimpl.RESOURCE_PATTERN_PREFIXED #: Prefixed: A prefixed resource name - def __lt__(self, other): + def __lt__(self, other: object) -> bool: if self.__class__ != other.__class__: return NotImplemented + assert isinstance(other, self.__class__) return self.value < other.value diff --git a/src/confluent_kafka/admin/_scram.py b/src/confluent_kafka/admin/_scram.py index c20f55bbc..a6a34cf74 100644 --- a/src/confluent_kafka/admin/_scram.py +++ b/src/confluent_kafka/admin/_scram.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional from .. import cimpl from enum import Enum @@ -25,8 +26,8 @@ class ScramMechanism(Enum): SCRAM_SHA_256 = cimpl.SCRAM_MECHANISM_SHA_256 #: SCRAM-SHA-256 mechanism SCRAM_SHA_512 = cimpl.SCRAM_MECHANISM_SHA_512 #: SCRAM-SHA-512 mechanism - def __lt__(self, other): - if self.__class__ != other.__class__: + def __lt__(self, other: object) -> bool: + if not isinstance(other, self.__class__): return NotImplemented return self.value < other.value @@ -43,7 +44,7 @@ class ScramCredentialInfo: iterations: int Positive number of iterations used when creating the credential. """ - def __init__(self, mechanism, iterations): + def __init__(self, mechanism: ScramMechanism, iterations: int): self.mechanism = mechanism self.iterations = iterations @@ -60,7 +61,7 @@ class UserScramCredentialsDescription: scram_credential_infos: list(ScramCredentialInfo) SASL/SCRAM credential representations for the user. """ - def __init__(self, user, scram_credential_infos): + def __init__(self, user: str, scram_credential_infos: List[ScramCredentialInfo]): self.user = user self.scram_credential_infos = scram_credential_infos @@ -93,7 +94,7 @@ class UserScramCredentialUpsertion(UserScramCredentialAlteration): salt: bytes Salt to use. Will be generated randomly if None. (optional) """ - def __init__(self, user, scram_credential_info, password, salt=None): + def __init__(self, user: str, scram_credential_info: ScramCredentialInfo, password: bytes, salt: Optional[bytes]=None): super(UserScramCredentialUpsertion, self).__init__(user) self.scram_credential_info = scram_credential_info self.password = password @@ -111,6 +112,6 @@ class UserScramCredentialDeletion(UserScramCredentialAlteration): mechanism: ScramMechanism SASL/SCRAM mechanism. """ - def __init__(self, user, mechanism): + def __init__(self, user: str, mechanism: ScramMechanism): super(UserScramCredentialDeletion, self).__init__(user) self.mechanism = mechanism diff --git a/src/confluent_kafka/avro/__init__.py b/src/confluent_kafka/avro/__init__.py index c5475c9ed..f668da7de 100644 --- a/src/confluent_kafka/avro/__init__.py +++ b/src/confluent_kafka/avro/__init__.py @@ -19,16 +19,18 @@ Avro schema registry module: Deals with encoding and decoding of messages with avro schemas """ +from typing import Dict, Optional, cast import warnings from confluent_kafka import Producer, Consumer from confluent_kafka.avro.error import ClientError -from confluent_kafka.avro.load import load, loads # noqa +from confluent_kafka.avro.load import load, loads, schema # noqa from confluent_kafka.avro.cached_schema_registry_client import CachedSchemaRegistryClient from confluent_kafka.avro.serializer import (SerializerError, # noqa KeySerializerError, ValueSerializerError) from confluent_kafka.avro.serializer.message_serializer import MessageSerializer +from confluent_kafka.cimpl import Message class AvroProducer(Producer): @@ -48,8 +50,8 @@ class AvroProducer(Producer): :param str default_value_schema: Optional default avro schema for value """ - def __init__(self, config, default_key_schema=None, - default_value_schema=None, schema_registry=None, **kwargs): + def __init__(self, config: Dict, default_key_schema: Optional[schema.Schema]=None, + default_value_schema: Optional[schema.Schema]=None, schema_registry: Optional[CachedSchemaRegistryClient]=None, **kwargs: object): warnings.warn( "AvroProducer has been deprecated. Use AvroSerializer instead.", category=DeprecationWarning, stacklevel=2) @@ -77,7 +79,7 @@ def __init__(self, config, default_key_schema=None, self._key_schema = default_key_schema self._value_schema = default_value_schema - def produce(self, **kwargs): + def produce(self, topic: str, **kwargs: object) -> None: """ Asynchronously sends message to Kafka by encoding with specified or default avro schema. @@ -93,12 +95,9 @@ def produce(self, **kwargs): :raises BufferError: If producer queue is full. :raises KafkaException: For other produce failures. """ - # get schemas from kwargs if defined - key_schema = kwargs.pop('key_schema', self._key_schema) - value_schema = kwargs.pop('value_schema', self._value_schema) - topic = kwargs.pop('topic', None) - if not topic: - raise ClientError("Topic name not specified.") + # get schemas from kwargs if defined + key_schema = cast(Optional[schema.Schema], kwargs.pop('key_schema', self._key_schema)) + value_schema = cast(Optional[schema.Schema], kwargs.pop('value_schema', self._value_schema)) value = kwargs.pop('value', None) key = kwargs.pop('key', None) @@ -114,7 +113,7 @@ def produce(self, **kwargs): else: raise KeySerializerError("Avro schema required for key") - super(AvroProducer, self).produce(topic, value, key, **kwargs) + super(AvroProducer, self).produce(topic, value=value, key=key, **kwargs) class AvroConsumer(Consumer): @@ -135,7 +134,7 @@ class AvroConsumer(Consumer): :raises ValueError: For invalid configurations """ - def __init__(self, config, schema_registry=None, reader_key_schema=None, reader_value_schema=None, **kwargs): + def __init__(self, config: Dict, schema_registry: Optional[CachedSchemaRegistryClient]=None, reader_key_schema: Optional[schema.Schema]=None, reader_value_schema: Optional[schema.Schema]=None, **kwargs: object): warnings.warn( "AvroConsumer has been deprecated. Use AvroDeserializer instead.", category=DeprecationWarning, stacklevel=2) @@ -160,7 +159,7 @@ def __init__(self, config, schema_registry=None, reader_key_schema=None, reader_ super(AvroConsumer, self).__init__(ap_conf, **kwargs) self._serializer = MessageSerializer(schema_registry, reader_key_schema, reader_value_schema) - def poll(self, timeout=None): + def poll(self, timeout: float=-1) -> Optional[Message]: """ This is an overriden method from confluent_kafka.Consumer class. This handles message deserialization using avro schema @@ -169,8 +168,6 @@ def poll(self, timeout=None): :returns: message object with deserialized key and value as dict objects :rtype: Message """ - if timeout is None: - timeout = -1 message = super(AvroConsumer, self).poll(timeout) if message is None: return None diff --git a/src/confluent_kafka/avro/cached_schema_registry_client.py b/src/confluent_kafka/avro/cached_schema_registry_client.py index b0ea6c388..0733f53ef 100644 --- a/src/confluent_kafka/avro/cached_schema_registry_client.py +++ b/src/confluent_kafka/avro/cached_schema_registry_client.py @@ -20,21 +20,23 @@ # derived from https://github.com/verisign/python-confluent-schemaregistry.git # import logging +from turtle import pos +from typing import Any, Dict, Optional, Sized, Tuple, TypeVar, Union, cast import warnings import urllib3 import json from collections import defaultdict -from requests import Session, utils +from requests import Session, utils, Response + +from confluent_kafka.schema_registry.schema_registry_client import Schema from .error import ClientError -from . import loads +from . import loads, schema + +import six -# Python 2 considers int an instance of str -try: - string_type = basestring # noqa -except NameError: - string_type = str +string_type = six.string_types[0] VALID_LEVELS = ['NONE', 'FULL', 'FORWARD', 'BACKWARD'] VALID_METHODS = ['GET', 'POST', 'PUT', 'DELETE'] @@ -66,10 +68,11 @@ class CachedSchemaRegistryClient(object): :param str key_location: Path to client's private key used for authentication. """ - def __init__(self, url, max_schemas_per_subject=1000, ca_location=None, cert_location=None, key_location=None): - # In order to maintain compatibility the url(conf in future versions) param has been preserved for now. - conf = url - if not isinstance(url, dict): + def __init__(self, url: Union[str, Dict[str, object]], max_schemas_per_subject: int=1000, ca_location: Optional[str]=None, cert_location: Optional[str]=None, key_location: Optional[str]=None): + # In order to maintain compatibility the url(conf in future versions) param has been preserved for now. + if isinstance(url, dict): + conf = url + else: conf = { 'url': url, 'ssl.ca.location': ca_location, @@ -87,7 +90,7 @@ def __init__(self, url, max_schemas_per_subject=1000, ca_location=None, cert_loc """Construct a Schema Registry client""" # Ensure URL valid scheme is included; http[s] - url = conf.pop('url', '') + url = cast(str, conf.pop('url', '')) if not isinstance(url, string_type): raise TypeError("URL must be of type str") @@ -97,102 +100,113 @@ def __init__(self, url, max_schemas_per_subject=1000, ca_location=None, cert_loc self.url = url.rstrip('/') # subj => { schema => id } - self.subject_to_schema_ids = defaultdict(dict) + self.subject_to_schema_ids: Dict[str, Dict[schema.Schema, int]] = defaultdict(dict) # id => avro_schema - self.id_to_schema = defaultdict(dict) + self.id_to_schema: Dict[int, schema.Schema] = {} # subj => { schema => version } - self.subject_to_schema_versions = defaultdict(dict) + self.subject_to_schema_versions: Dict[str, Dict[schema.Schema, int]] = defaultdict(dict) s = Session() - ca_path = conf.pop('ssl.ca.location', None) + ca_path = cast(Optional[str], conf.pop('ssl.ca.location', None)) if ca_path is not None: s.verify = ca_path - s.cert = self._configure_client_tls(conf) + _conf_cert = self._configure_client_tls(conf) + + # Logic in _configure_client_tls promises both of the output variables are of the same type + if _conf_cert == [None, None]: + s.cert = None + else: + s.cert = _conf_cert + s.auth = self._configure_basic_auth(self.url, conf) self.url = utils.urldefragauth(self.url) self._session = s key_password = conf.pop('ssl.key.password', None) self._is_key_password_provided = not key_password - self._https_session = self._make_https_session(s.cert[0], s.cert[1], ca_path, s.auth, key_password) + self._https_session = self._make_https_session(s.cert[0] if s.cert is not None else None, s.cert[1] if s.cert is not None else None, ca_path, s.auth, key_password) self.auto_register_schemas = conf.pop("auto.register.schemas", True) if len(conf) > 0: raise ValueError("Unrecognized configuration properties: {}".format(conf.keys())) - def __del__(self): + def __del__(self) -> None: self.close() - def __enter__(self): + def __enter__(self) -> "CachedSchemaRegistryClient": return self - def __exit__(self, *args): + def __exit__(self, *args: object) -> None: self.close() - def close(self): + def close(self) -> None: # Constructor exceptions may occur prior to _session being set. if hasattr(self, '_session'): self._session.close() if hasattr(self, '_https_session'): - self._https_session.clear() + self._https_session.clear() # type: ignore[no-untyped-call] @staticmethod - def _make_https_session(cert_location, key_location, ca_certs_path, auth, key_password): + def _make_https_session(cert_location: Optional[str], key_location: Optional[str], ca_certs_path: Optional[str], auth: Tuple[str, str], key_password: object) -> urllib3.PoolManager: https_session = urllib3.PoolManager(cert_reqs='CERT_REQUIRED', ca_certs=ca_certs_path, cert_file=cert_location, key_file=key_location, key_password=key_password) - https_session.auth = auth + https_session.auth = auth # type: ignore[attr-defined] return https_session - def _send_https_session_request(self, url, method, headers, body): + def _send_https_session_request(self, url: str, method: str, headers: Dict, body: Any) -> urllib3.response.HTTPResponse: request_headers = {'Accept': ACCEPT_HDR} - auth = self._https_session.auth + auth = self._https_session.auth # type: ignore[attr-defined] if body: body = json.dumps(body).encode('UTF-8') request_headers["Content-Length"] = str(len(body)) request_headers["Content-Type"] = "application/vnd.schemaregistry.v1+json" if auth[0] != '' and auth[1] != '': - request_headers.update(urllib3.make_headers(basic_auth=auth[0] + ":" + - auth[1])) + request_headers.update(urllib3.make_headers(basic_auth=auth[0] + ":" + auth[1])) # type:ignore[no-untyped-call] request_headers.update(headers) - response = self._https_session.request(method, url, headers=request_headers, body=body) + response = self._https_session.request(method, url, headers=request_headers, body=body) # type:ignore[no-untyped-call] return response @staticmethod - def _configure_basic_auth(url, conf): + def _configure_basic_auth(url: str, conf: Dict) -> Tuple[str, str]: auth_provider = conf.pop('basic.auth.credentials.source', 'URL').upper() if auth_provider not in VALID_AUTH_PROVIDERS: raise ValueError("schema.registry.basic.auth.credentials.source must be one of {}" .format(VALID_AUTH_PROVIDERS)) + auth: Tuple[str, str] if auth_provider == 'SASL_INHERIT': if conf.pop('sasl.mechanism', '').upper() == 'GSSAPI': raise ValueError("SASL_INHERIT does not support SASL mechanism GSSAPI") - auth = (conf.pop('sasl.username', ''), conf.pop('sasl.password', '')) + auth = (cast(str, conf.pop('sasl.username', '')), cast(str, conf.pop('sasl.password', ''))) elif auth_provider == 'USER_INFO': - auth = tuple(conf.pop('basic.auth.user.info', '').split(':')) + possible_auth = tuple(cast(str, conf.pop('basic.auth.user.info', ':')).split(':')) + assert len(possible_auth) == 2, possible_auth + auth = cast(Tuple[str, str], possible_auth) else: auth = utils.get_auth_from_url(url) return auth + _client_tls_ret = TypeVar("_client_tls_ret", str, None) + @staticmethod - def _configure_client_tls(conf): - cert = conf.pop('ssl.certificate.location', None), conf.pop('ssl.key.location', None) + def _configure_client_tls(conf: Dict) -> Tuple[_client_tls_ret, _client_tls_ret]: + cert = cast(Optional[str], conf.pop('ssl.certificate.location', None)), cast(Optional[str], conf.pop('ssl.key.location', None)) # Both values can be None or no values can be None - if bool(cert[0]) != bool(cert[1]): + if (cert[0] is None) != (cert[1] is None): raise ValueError( "Both schema.registry.ssl.certificate.location and schema.registry.ssl.key.location must be set") - return cert + return cert # type: ignore[return-value] - def _send_request(self, url, method='GET', body=None, headers={}): + def _send_request(self, url: str, method: str='GET', body: Optional[Sized]=None, headers: Dict={}) -> Tuple[object, int]: if method not in VALID_METHODS: raise ClientError("Method {} is invalid; valid methods include {}".format(method, VALID_METHODS)) if url.startswith('https') and self._is_key_password_provided: - response = self._send_https_session_request(url, method, headers, body) + http_response = self._send_https_session_request(url, method, headers, body) try: - return json.loads(response.data), response.status + return json.loads(http_response.data), http_response.status except ValueError: - return response.content, response.status + return http_response.data, http_response.status _headers = {'Accept': ACCEPT_HDR} if body: @@ -207,12 +221,16 @@ def _send_request(self, url, method='GET', body=None, headers={}): except ValueError: return response.content, response.status_code + CacheKey = TypeVar("CacheKey") + SubCacheKey = TypeVar("SubCacheKey") + SubCacheValue = TypeVar("SubCacheValue") + @staticmethod - def _add_to_cache(cache, subject, schema, value): + def _add_to_cache(cache: Dict[CacheKey, Dict[SubCacheKey, SubCacheValue]], subject: CacheKey, schema: SubCacheKey, value: SubCacheValue) -> None: sub_cache = cache[subject] sub_cache[schema] = value - def _cache_schema(self, schema, schema_id, subject=None, version=None): + def _cache_schema(self, schema: schema.Schema, schema_id: int, subject: Optional[str]=None, version: Optional[int]=None) -> None: # don't overwrite anything if schema_id in self.id_to_schema: schema = self.id_to_schema[schema_id] @@ -226,7 +244,7 @@ def _cache_schema(self, schema, schema_id, subject=None, version=None): self._add_to_cache(self.subject_to_schema_versions, subject, schema, version) - def register(self, subject, avro_schema): + def register(self, subject: str, avro_schema: schema.Schema) -> int: """ POST /subjects/(string: subject)/versions Register a schema with the registry under the given subject @@ -243,7 +261,7 @@ def register(self, subject, avro_schema): """ schemas_to_id = self.subject_to_schema_ids[subject] - schema_id = schemas_to_id.get(avro_schema, None) + schema_id: Optional[int] = schemas_to_id.get(avro_schema, None) if schema_id is not None: return schema_id # send it up @@ -265,12 +283,12 @@ def register(self, subject, avro_schema): raise ClientError("Unable to register schema. Error code:" + str(code) + " message:" + str(result)) # result is a dict - schema_id = result['id'] + schema_id = cast(Dict, result)['id'] # cache it self._cache_schema(avro_schema, schema_id, subject) return schema_id - def check_registration(self, subject, avro_schema): + def check_registration(self, subject: str, avro_schema: schema.Schema) -> int: """ POST /subjects/(string: subject) Check if a schema has already been registered under the specified subject. @@ -303,12 +321,12 @@ def check_registration(self, subject, avro_schema): elif not 200 <= code <= 299: raise ClientError("Unable to check schema registration. Error code:" + str(code)) # result is a dict - schema_id = result['id'] + schema_id = cast(Dict, result)['id'] # cache it self._cache_schema(avro_schema, schema_id, subject) return schema_id - def delete_subject(self, subject): + def delete_subject(self, subject: str) -> int: """ DELETE /subjects/(string: subject) Deletes the specified subject and its associated compatibility level if registered. @@ -323,9 +341,10 @@ def delete_subject(self, subject): result, code = self._send_request(url, method="DELETE") if not (code >= 200 and code <= 299): raise ClientError('Unable to delete subject: {}'.format(result)) + assert isinstance(result, int) return result - def get_by_id(self, schema_id): + def get_by_id(self, schema_id: int) -> Optional[schema.Schema]: """ GET /schemas/ids/{int: id} Retrieve a parsed avro schema by id or None if not found @@ -347,7 +366,9 @@ def get_by_id(self, schema_id): return None else: # need to parse the schema - schema_str = result.get("schema") + assert isinstance(result, Dict) + schema_str = result["schema"] + assert isinstance(schema_str, str) try: result = loads(schema_str) # cache it @@ -357,7 +378,7 @@ def get_by_id(self, schema_id): # bad schema - should not happen raise ClientError("Received bad schema (id %s) from registry: %s" % (schema_id, e)) - def get_latest_schema(self, subject): + def get_latest_schema(self, subject: str) -> Tuple[Optional[int], Optional[schema.Schema], Optional[int]]: """ GET /subjects/(string: subject)/versions/latest @@ -374,7 +395,7 @@ def get_latest_schema(self, subject): """ return self.get_by_version(subject, 'latest') - def get_by_version(self, subject, version): + def get_by_version(self, subject: str, version: object) -> Tuple[Optional[int], Optional[schema.Schema], Optional[int]]: """ GET /subjects/(string: subject)/versions/(versionId: version) @@ -401,8 +422,11 @@ def get_by_version(self, subject, version): return (None, None, None) elif not (code >= 200 and code <= 299): return (None, None, None) + assert isinstance(result, Dict) schema_id = result['id'] + assert isinstance(schema_id, int) version = result['version'] + assert isinstance(version, int) if schema_id in self.id_to_schema: schema = self.id_to_schema[schema_id] else: @@ -415,7 +439,7 @@ def get_by_version(self, subject, version): self._cache_schema(schema, schema_id, subject, version) return (schema_id, schema, version) - def get_version(self, subject, avro_schema): + def get_version(self, subject: str, avro_schema: schema.Schema) -> Optional[int]: """ POST /subjects/(string: subject) @@ -442,12 +466,13 @@ def get_version(self, subject, avro_schema): elif not (code >= 200 and code <= 299): log.error("Unable to get version of a schema:" + str(code)) return None + assert isinstance(result, Dict) schema_id = result['id'] version = result['version'] self._cache_schema(avro_schema, schema_id, subject, version) return version - def test_compatibility(self, subject, avro_schema, version='latest'): + def test_compatibility(self, subject: str, avro_schema: schema.Schema, version: str='latest') -> Optional[bool]: """ POST /compatibility/subjects/(string: subject)/versions/(versionId: version) @@ -471,6 +496,7 @@ def test_compatibility(self, subject, avro_schema, version='latest'): log.error(("Invalid subject or schema:" + str(code))) return False elif code >= 200 and code <= 299: + assert isinstance(result, Dict) return result.get('is_compatible') else: log.error("Unable to check the compatibility: " + str(code)) @@ -479,7 +505,7 @@ def test_compatibility(self, subject, avro_schema, version='latest'): log.error("_send_request() failed: %s", e) return False - def update_compatibility(self, level, subject=None): + def update_compatibility(self, level: str, subject: Optional[str]=None) -> str: """ PUT /config/(string: subject) @@ -497,11 +523,12 @@ def update_compatibility(self, level, subject=None): body = {"compatibility": level} result, code = self._send_request(url, method='PUT', body=body) if code >= 200 and code <= 299: + assert isinstance(result, Dict) return result['compatibility'] else: raise ClientError("Unable to update level: %s. Error code: %d" % (str(level), code)) - def get_compatibility(self, subject=None): + def get_compatibility(self, subject: Optional[str]=None) -> bool: """ GET /config Get the current compatibility level for a subject. Result will be one of: @@ -520,6 +547,7 @@ def get_compatibility(self, subject=None): if not is_successful_request: raise ClientError('Unable to fetch compatibility level. Error code: %d' % code) + assert isinstance(result, Dict) compatibility = result.get('compatibilityLevel', None) if compatibility not in VALID_LEVELS: if compatibility is None: @@ -528,4 +556,5 @@ def get_compatibility(self, subject=None): error_msg_suffix = str(compatibility) raise ClientError('Invalid compatibility level received: %s' % error_msg_suffix) - return compatibility + # Can't be None, as that's not in VALID_LEVELS + return cast(bool, compatibility) diff --git a/src/confluent_kafka/avro/error.py b/src/confluent_kafka/avro/error.py index b879c4b7a..2f8516ad8 100644 --- a/src/confluent_kafka/avro/error.py +++ b/src/confluent_kafka/avro/error.py @@ -16,16 +16,19 @@ # +from typing import Optional + + class ClientError(Exception): """ Error thrown by Schema Registry clients """ - def __init__(self, message, http_code=None): + def __init__(self, message: str, http_code: Optional[int]=None): self.message = message self.http_code = http_code super(ClientError, self).__init__(self.__str__()) - def __repr__(self): + def __repr__(self) -> str: return "ClientError(error={error})".format(error=self.message) - def __str__(self): + def __str__(self) -> str: return self.message diff --git a/src/confluent_kafka/avro/load.py b/src/confluent_kafka/avro/load.py index 9db8660e1..c22b2795b 100644 --- a/src/confluent_kafka/avro/load.py +++ b/src/confluent_kafka/avro/load.py @@ -16,30 +16,12 @@ # +from typing import TYPE_CHECKING from confluent_kafka.avro.error import ClientError - -def loads(schema_str): - """ Parse a schema given a schema string """ - try: - return schema.parse(schema_str) - except SchemaParseException as e: - raise ClientError("Schema parse failed: %s" % (str(e))) - - -def load(fp): - """ Parse a schema from a file path """ - with open(fp) as f: - return loads(f.read()) - - -# avro.schema.RecordSchema and avro.schema.PrimitiveSchema classes are not hashable. Hence defining them explicitly as -# a quick fix -def _hash_func(self): - return hash(str(self)) - - try: + # FIXME: Needs https://github.com/apache/avro/pull/1952 + # pip install git+https://github.com/apache/avro#subdirectory=lang/py from avro import schema try: @@ -47,11 +29,38 @@ def _hash_func(self): from avro.errors import SchemaParseException except ImportError: # avro < 1.11.0 - from avro.schema import SchemaParseException + from avro.schema import SchemaParseException # type:ignore[attr-defined, no-redef] + + # avro.schema.RecordSchema and avro.schema.PrimitiveSchema classes are not hashable. Hence defining them explicitly as + # a quick fix + def _hash_func(self: object) -> int: + return hash(str(self)) + + schema.RecordSchema.__hash__ = _hash_func # type:ignore[assignment] + schema.PrimitiveSchema.__hash__ = _hash_func # type:ignore[assignment] + schema.UnionSchema.__hash__ = _hash_func # type:ignore[assignment] + + def loads(schema_str: str) -> schema.Schema: + """ Parse a schema given a schema string """ + try: + return schema.parse(schema_str) + except SchemaParseException as e: + raise ClientError("Schema parse failed: %s" % (str(e))) + + + def load(fp: str) -> schema.Schema: + """ Parse a schema from a file path """ + with open(fp) as f: + return loads(f.read()) - schema.RecordSchema.__hash__ = _hash_func - schema.PrimitiveSchema.__hash__ = _hash_func - schema.UnionSchema.__hash__ = _hash_func except ImportError: - schema = None + if TYPE_CHECKING: + # Workaround hack so the type checking for Schema objects still works + class Schema: + pass + class TopLevelSchema: + Schema = Schema + schema = TopLevelSchema # type:ignore[assignment] + else: + schema = None diff --git a/src/confluent_kafka/avro/requirements.txt b/src/confluent_kafka/avro/requirements.txt index e34a65dd8..3380a200c 100644 --- a/src/confluent_kafka/avro/requirements.txt +++ b/src/confluent_kafka/avro/requirements.txt @@ -1,3 +1,3 @@ fastavro>=0.23.0 requests -avro>=1.11.1,<2 +avro>=1.11.2,<2 diff --git a/src/confluent_kafka/avro/serializer/__init__.py b/src/confluent_kafka/avro/serializer/__init__.py index 845f58a84..65579719a 100644 --- a/src/confluent_kafka/avro/serializer/__init__.py +++ b/src/confluent_kafka/avro/serializer/__init__.py @@ -19,16 +19,16 @@ class SerializerError(Exception): """Generic error from serializer package""" - def __init__(self, message): + def __init__(self, message: str): self.message = message - def __repr__(self): + def __repr__(self) -> str: return '{klass}(error={error})'.format( klass=self.__class__.__name__, error=self.message ) - def __str__(self): + def __str__(self) -> str: return self.message diff --git a/src/confluent_kafka/avro/serializer/message_serializer.py b/src/confluent_kafka/avro/serializer/message_serializer.py index d92763e2c..9af7f2896 100644 --- a/src/confluent_kafka/avro/serializer/message_serializer.py +++ b/src/confluent_kafka/avro/serializer/message_serializer.py @@ -25,11 +25,14 @@ import struct import sys import traceback +from typing import Any, Callable, Dict, Optional, Union import avro import avro.io +from confluent_kafka.avro import schema from confluent_kafka.avro import ClientError +from confluent_kafka.avro.cached_schema_registry_client import CachedSchemaRegistryClient from confluent_kafka.avro.serializer import (SerializerError, KeySerializerError, ValueSerializerError) @@ -53,12 +56,11 @@ class ContextStringIO(io.BytesIO): Wrapper to allow use of StringIO via 'with' constructs. """ - def __enter__(self): + def __enter__(self) -> "ContextStringIO": return self - def __exit__(self, *args): + def __exit__(self, *args: Any) -> None: self.close() - return False class MessageSerializer(object): @@ -70,15 +72,15 @@ class MessageSerializer(object): All decode_* methods expect a buffer received from kafka. """ - def __init__(self, registry_client, reader_key_schema=None, reader_value_schema=None): + def __init__(self, registry_client: CachedSchemaRegistryClient, reader_key_schema: Optional[schema.Schema]=None, reader_value_schema: Optional[schema.Schema]=None): self.registry_client = registry_client - self.id_to_decoder_func = {} - self.id_to_writers = {} + self.id_to_decoder_func: Dict[int, Callable] = {} + self.id_to_writers: Dict[int, Callable] = {} self.reader_key_schema = reader_key_schema self.reader_value_schema = reader_value_schema # Encoder support - def _get_encoder_func(self, writer_schema): + def _get_encoder_func(self, writer_schema: Optional[schema.Schema]) -> Callable: if HAS_FAST: schema = json.loads(str(writer_schema)) parsed_schema = parse_schema(schema) @@ -86,7 +88,7 @@ def _get_encoder_func(self, writer_schema): writer = avro.io.DatumWriter(writer_schema) return lambda record, fp: writer.write(record, avro.io.BinaryEncoder(fp)) - def encode_record_with_schema(self, topic, schema, record, is_key=False): + def encode_record_with_schema(self, topic: str, schema: schema.Schema, record: object, is_key: bool=False) -> bytes: """ Given a parsed avro schema, encode a record for the given topic. The record is expected to be a dictionary. @@ -119,7 +121,7 @@ def encode_record_with_schema(self, topic, schema, record, is_key=False): return self.encode_record_with_schema_id(schema_id, record, is_key=is_key) - def encode_record_with_schema_id(self, schema_id, record, is_key=False): + def encode_record_with_schema_id(self, schema_id: int, record: object, is_key: bool=False) -> bytes: """ Encode a record with a given schema id. The record must be a python dictionary. @@ -156,7 +158,7 @@ def encode_record_with_schema_id(self, schema_id, record, is_key=False): return outf.getvalue() # Decoder support - def _get_decoder_func(self, schema_id, payload, is_key=False): + def _get_decoder_func(self, schema_id: int, payload: ContextStringIO, is_key: bool=False) -> Callable: if schema_id in self.id_to_decoder_func: return self.id_to_decoder_func[schema_id] @@ -206,14 +208,14 @@ def _get_decoder_func(self, schema_id, payload, is_key=False): # def __init__(self, writer_schema=None, reader_schema=None) avro_reader = avro.io.DatumReader(writer_schema_obj, reader_schema_obj) - def decoder(p): + def decoder(p: io.BytesIO) -> object: bin_decoder = avro.io.BinaryDecoder(p) return avro_reader.read(bin_decoder) self.id_to_decoder_func[schema_id] = decoder return self.id_to_decoder_func[schema_id] - def decode_message(self, message, is_key=False): + def decode_message(self, message: Optional[bytes], is_key: bool=False) -> Optional[Dict]: """ Decode a message from kafka that has been encoded for use with the schema registry. diff --git a/src/confluent_kafka/cimpl.pyi b/src/confluent_kafka/cimpl.pyi new file mode 100644 index 000000000..285018cfc --- /dev/null +++ b/src/confluent_kafka/cimpl.pyi @@ -0,0 +1,361 @@ +from concurrent.futures import Future +from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple + +from typing import overload +ACL_OPERATION_ALL: int +ACL_OPERATION_ALTER: int +ACL_OPERATION_ALTER_CONFIGS: int +ACL_OPERATION_ANY: int +ACL_OPERATION_CLUSTER_ACTION: int +ACL_OPERATION_CREATE: int +ACL_OPERATION_DELETE: int +ACL_OPERATION_DESCRIBE: int +ACL_OPERATION_DESCRIBE_CONFIGS: int +ACL_OPERATION_IDEMPOTENT_WRITE: int +ACL_OPERATION_READ: int +ACL_OPERATION_UNKNOWN: int +ACL_OPERATION_WRITE: int +ACL_PERMISSION_TYPE_ALLOW: int +ACL_PERMISSION_TYPE_ANY: int +ACL_PERMISSION_TYPE_DENY: int +ACL_PERMISSION_TYPE_UNKNOWN: int +ALTER_CONFIG_OP_TYPE_APPEND: int +ALTER_CONFIG_OP_TYPE_DELETE: int +ALTER_CONFIG_OP_TYPE_SET: int +ALTER_CONFIG_OP_TYPE_SUBTRACT: int +CONFIG_SOURCE_DEFAULT_CONFIG: int +CONFIG_SOURCE_DYNAMIC_BROKER_CONFIG: int +CONFIG_SOURCE_DYNAMIC_DEFAULT_BROKER_CONFIG: int +CONFIG_SOURCE_DYNAMIC_TOPIC_CONFIG: int +CONFIG_SOURCE_STATIC_BROKER_CONFIG: int +CONFIG_SOURCE_UNKNOWN_CONFIG: int +CONSUMER_GROUP_STATE_COMPLETING_REBALANCE: int +CONSUMER_GROUP_STATE_DEAD: int +CONSUMER_GROUP_STATE_EMPTY: int +CONSUMER_GROUP_STATE_PREPARING_REBALANCE: int +CONSUMER_GROUP_STATE_STABLE: int +CONSUMER_GROUP_STATE_UNKNOWN: int +OFFSET_BEGINNING: int +OFFSET_END: int +OFFSET_INVALID: int +OFFSET_STORED: int +RESOURCE_ANY: int +RESOURCE_BROKER: int +RESOURCE_GROUP: int +RESOURCE_PATTERN_ANY: int +RESOURCE_PATTERN_LITERAL: int +RESOURCE_PATTERN_MATCH: int +RESOURCE_PATTERN_PREFIXED: int +RESOURCE_PATTERN_UNKNOWN: int +RESOURCE_TOPIC: int +RESOURCE_UNKNOWN: int +SCRAM_MECHANISM_SHA_256: int +SCRAM_MECHANISM_SHA_512: int +SCRAM_MECHANISM_UNKNOWN: int +TIMESTAMP_CREATE_TIME: int +TIMESTAMP_LOG_APPEND_TIME: int +TIMESTAMP_NOT_AVAILABLE: int + +class TopicPartition: + error: Any + metadata: Any + offset: Any + partition: int + topic: str + def __init__(self, *args: object, **kwargs: object) -> None: ... + def __eq__(self, other: object) -> Any: ... + def __ge__(self, other: object) -> Any: ... + def __gt__(self, other: object) -> Any: ... + def __hash__(self) -> Any: ... + def __le__(self, other: object) -> Any: ... + def __lt__(self, other: object) -> Any: ... + def __ne__(self, other: object) -> Any: ... + +class Consumer: + def __init__(self, *args: object, **kwargs: object) -> None: ... + def assign(self, partitions: object) -> Any: ... + def assignment(self, *args: object, **kwargs: object) -> Any: ... + def close(self, *args: object, **kwargs: object) -> Any: ... + def commit(self, message: Optional[object] = None, offsets: Optional[object] = None, asynchronous: Optional[bool] = None) -> Any: ... + def committed(self, *args: object, **kwargs: object) -> Any: ... + def consume(self, num_messages: int, timeout: float) -> Any: ... + def consumer_group_metadata(self) -> Any: ... + def get_watermark_offsets(self, *args: object, **kwargs: object) -> Any: ... + def incremental_assign(self, partitions: object) -> Any: ... + def incremental_unassign(self, partitions: object) -> Any: ... + def list_topics(self, *args: object, **kwargs: object) -> Any: ... + def memberid(self) -> Any: ... + def offsets_for_times(self, *args: object, **kwargs: object) -> Any: ... + def pause(self, partitions: object) -> Any: ... + def poll(self, timeout: float) -> Optional[Message]: ... + def position(self, partitions: object) -> Any: ... + def resume(self, partitions: object) -> Any: ... + @overload + def seek(self, partition: object) -> Any: ... + @overload + def seek(self) -> Any: ... + def store_offsets(self, *args: object, **kwargs: object) -> Any: ... + def subscribe(self, topic: str, on_assign: Optional[Callable[[Consumer, List[TopicPartition]], None]] = None, on_revoke:Optional[Callable] = None) -> Any: ... + def unassign(self, *args: object, **kwargs: object) -> Any: ... + def unsubscribe(self, *args: object, **kwargs: object) -> Any: ... + +class KafkaError: + BROKER_NOT_AVAILABLE: ClassVar[int] = ... + CLUSTER_AUTHORIZATION_FAILED: ClassVar[int] = ... + CONCURRENT_TRANSACTIONS: ClassVar[int] = ... + COORDINATOR_LOAD_IN_PROGRESS: ClassVar[int] = ... + COORDINATOR_NOT_AVAILABLE: ClassVar[int] = ... + DELEGATION_TOKEN_AUTHORIZATION_FAILED: ClassVar[int] = ... + DELEGATION_TOKEN_AUTH_DISABLED: ClassVar[int] = ... + DELEGATION_TOKEN_EXPIRED: ClassVar[int] = ... + DELEGATION_TOKEN_NOT_FOUND: ClassVar[int] = ... + DELEGATION_TOKEN_OWNER_MISMATCH: ClassVar[int] = ... + DELEGATION_TOKEN_REQUEST_NOT_ALLOWED: ClassVar[int] = ... + DUPLICATE_RESOURCE: ClassVar[int] = ... + DUPLICATE_SEQUENCE_NUMBER: ClassVar[int] = ... + ELECTION_NOT_NEEDED: ClassVar[int] = ... + ELIGIBLE_LEADERS_NOT_AVAILABLE: ClassVar[int] = ... + FEATURE_UPDATE_FAILED: ClassVar[int] = ... + FENCED_INSTANCE_ID: ClassVar[int] = ... + FENCED_LEADER_EPOCH: ClassVar[int] = ... + FETCH_SESSION_ID_NOT_FOUND: ClassVar[int] = ... + GROUP_AUTHORIZATION_FAILED: ClassVar[int] = ... + GROUP_ID_NOT_FOUND: ClassVar[int] = ... + GROUP_MAX_SIZE_REACHED: ClassVar[int] = ... + GROUP_SUBSCRIBED_TO_TOPIC: ClassVar[int] = ... + ILLEGAL_GENERATION: ClassVar[int] = ... + ILLEGAL_SASL_STATE: ClassVar[int] = ... + INCONSISTENT_GROUP_PROTOCOL: ClassVar[int] = ... + INCONSISTENT_VOTER_SET: ClassVar[int] = ... + INVALID_COMMIT_OFFSET_SIZE: ClassVar[int] = ... + INVALID_CONFIG: ClassVar[int] = ... + INVALID_FETCH_SESSION_EPOCH: ClassVar[int] = ... + INVALID_GROUP_ID: ClassVar[int] = ... + INVALID_MSG: ClassVar[int] = ... + INVALID_MSG_SIZE: ClassVar[int] = ... + INVALID_PARTITIONS: ClassVar[int] = ... + INVALID_PRINCIPAL_TYPE: ClassVar[int] = ... + INVALID_PRODUCER_EPOCH: ClassVar[int] = ... + INVALID_PRODUCER_ID_MAPPING: ClassVar[int] = ... + INVALID_RECORD: ClassVar[int] = ... + INVALID_REPLICATION_FACTOR: ClassVar[int] = ... + INVALID_REPLICA_ASSIGNMENT: ClassVar[int] = ... + INVALID_REQUEST: ClassVar[int] = ... + INVALID_REQUIRED_ACKS: ClassVar[int] = ... + INVALID_SESSION_TIMEOUT: ClassVar[int] = ... + INVALID_TIMESTAMP: ClassVar[int] = ... + INVALID_TRANSACTION_TIMEOUT: ClassVar[int] = ... + INVALID_TXN_STATE: ClassVar[int] = ... + INVALID_UPDATE_VERSION: ClassVar[int] = ... + KAFKA_STORAGE_ERROR: ClassVar[int] = ... + LEADER_NOT_AVAILABLE: ClassVar[int] = ... + LISTENER_NOT_FOUND: ClassVar[int] = ... + LOG_DIR_NOT_FOUND: ClassVar[int] = ... + MEMBER_ID_REQUIRED: ClassVar[int] = ... + MSG_SIZE_TOO_LARGE: ClassVar[int] = ... + NETWORK_EXCEPTION: ClassVar[int] = ... + NON_EMPTY_GROUP: ClassVar[int] = ... + NOT_CONTROLLER: ClassVar[int] = ... + NOT_COORDINATOR: ClassVar[int] = ... + NOT_ENOUGH_REPLICAS: ClassVar[int] = ... + NOT_ENOUGH_REPLICAS_AFTER_APPEND: ClassVar[int] = ... + NOT_LEADER_FOR_PARTITION: ClassVar[int] = ... + NO_ERROR: ClassVar[int] = ... + NO_REASSIGNMENT_IN_PROGRESS: ClassVar[int] = ... + OFFSET_METADATA_TOO_LARGE: ClassVar[int] = ... + OFFSET_NOT_AVAILABLE: ClassVar[int] = ... + OFFSET_OUT_OF_RANGE: ClassVar[int] = ... + OPERATION_NOT_ATTEMPTED: ClassVar[int] = ... + OUT_OF_ORDER_SEQUENCE_NUMBER: ClassVar[int] = ... + POLICY_VIOLATION: ClassVar[int] = ... + PREFERRED_LEADER_NOT_AVAILABLE: ClassVar[int] = ... + PRINCIPAL_DESERIALIZATION_FAILURE: ClassVar[int] = ... + PRODUCER_FENCED: ClassVar[int] = ... + REASSIGNMENT_IN_PROGRESS: ClassVar[int] = ... + REBALANCE_IN_PROGRESS: ClassVar[int] = ... + RECORD_LIST_TOO_LARGE: ClassVar[int] = ... + REPLICA_NOT_AVAILABLE: ClassVar[int] = ... + REQUEST_TIMED_OUT: ClassVar[int] = ... + RESOURCE_NOT_FOUND: ClassVar[int] = ... + SASL_AUTHENTICATION_FAILED: ClassVar[int] = ... + SECURITY_DISABLED: ClassVar[int] = ... + STALE_BROKER_EPOCH: ClassVar[int] = ... + STALE_CTRL_EPOCH: ClassVar[int] = ... + THROTTLING_QUOTA_EXCEEDED: ClassVar[int] = ... + TOPIC_ALREADY_EXISTS: ClassVar[int] = ... + TOPIC_AUTHORIZATION_FAILED: ClassVar[int] = ... + TOPIC_DELETION_DISABLED: ClassVar[int] = ... + TOPIC_EXCEPTION: ClassVar[int] = ... + TRANSACTIONAL_ID_AUTHORIZATION_FAILED: ClassVar[int] = ... + TRANSACTION_COORDINATOR_FENCED: ClassVar[int] = ... + UNACCEPTABLE_CREDENTIAL: ClassVar[int] = ... + UNKNOWN: ClassVar[int] = ... + UNKNOWN_LEADER_EPOCH: ClassVar[int] = ... + UNKNOWN_MEMBER_ID: ClassVar[int] = ... + UNKNOWN_PRODUCER_ID: ClassVar[int] = ... + UNKNOWN_TOPIC_OR_PART: ClassVar[int] = ... + UNSTABLE_OFFSET_COMMIT: ClassVar[int] = ... + UNSUPPORTED_COMPRESSION_TYPE: ClassVar[int] = ... + UNSUPPORTED_FOR_MESSAGE_FORMAT: ClassVar[int] = ... + UNSUPPORTED_SASL_MECHANISM: ClassVar[int] = ... + UNSUPPORTED_VERSION: ClassVar[int] = ... + _ALL_BROKERS_DOWN: ClassVar[int] = ... + _APPLICATION: ClassVar[int] = ... + _ASSIGNMENT_LOST: ClassVar[int] = ... + _ASSIGN_PARTITIONS: ClassVar[int] = ... + _AUTHENTICATION: ClassVar[int] = ... + _AUTO_OFFSET_RESET: ClassVar[int] = ... + _BAD_COMPRESSION: ClassVar[int] = ... + _BAD_MSG: ClassVar[int] = ... + _CONFLICT: ClassVar[int] = ... + _CRIT_SYS_RESOURCE: ClassVar[int] = ... + _DESTROY: ClassVar[int] = ... + _EXISTING_SUBSCRIPTION: ClassVar[int] = ... + _FAIL: ClassVar[int] = ... + _FATAL: ClassVar[int] = ... + _FENCED: ClassVar[int] = ... + _FS: ClassVar[int] = ... + _GAPLESS_GUARANTEE: ClassVar[int] = ... + _INCONSISTENT: ClassVar[int] = ... + _INTR: ClassVar[int] = ... + _INVALID_ARG: ClassVar[int] = ... + _INVALID_TYPE: ClassVar[int] = ... + _IN_PROGRESS: ClassVar[int] = ... + _ISR_INSUFF: ClassVar[int] = ... + _KEY_DESERIALIZATION: ClassVar[int] = ... + _KEY_SERIALIZATION: ClassVar[int] = ... + _MAX_POLL_EXCEEDED: ClassVar[int] = ... + _MSG_TIMED_OUT: ClassVar[int] = ... + _NODE_UPDATE: ClassVar[int] = ... + _NOENT: ClassVar[int] = ... + _NOOP: ClassVar[int] = ... + _NOT_CONFIGURED: ClassVar[int] = ... + _NOT_IMPLEMENTED: ClassVar[int] = ... + _NO_OFFSET: ClassVar[int] = ... + _OUTDATED: ClassVar[int] = ... + _PARTIAL: ClassVar[int] = ... + _PARTITION_EOF: ClassVar[int] = ... + _PREV_IN_PROGRESS: ClassVar[int] = ... + _PURGE_INFLIGHT: ClassVar[int] = ... + _PURGE_QUEUE: ClassVar[int] = ... + _QUEUE_FULL: ClassVar[int] = ... + _READ_ONLY: ClassVar[int] = ... + _RESOLVE: ClassVar[int] = ... + _RETRY: ClassVar[int] = ... + _REVOKE_PARTITIONS: ClassVar[int] = ... + _SSL: ClassVar[int] = ... + _STATE: ClassVar[int] = ... + _TIMED_OUT: ClassVar[int] = ... + _TIMED_OUT_QUEUE: ClassVar[int] = ... + _TRANSPORT: ClassVar[int] = ... + _UNDERFLOW: ClassVar[int] = ... + _UNKNOWN_BROKER: ClassVar[int] = ... + _UNKNOWN_GROUP: ClassVar[int] = ... + _UNKNOWN_PARTITION: ClassVar[int] = ... + _UNKNOWN_PROTOCOL: ClassVar[int] = ... + _UNKNOWN_TOPIC: ClassVar[int] = ... + _UNSUPPORTED_FEATURE: ClassVar[int] = ... + _VALUE_DESERIALIZATION: ClassVar[int] = ... + _VALUE_SERIALIZATION: ClassVar[int] = ... + _WAIT_CACHE: ClassVar[int] = ... + _WAIT_COORD: ClassVar[int] = ... + def __init__(self, *args: object, **kwargs: object) -> None: ... + def code(self, *args: object, **kwargs: object) -> Any: ... + def fatal(self, *args: object, **kwargs: object) -> Any: ... + def name(self, *args: object, **kwargs: object) -> Any: ... + def retriable(self, *args: object, **kwargs: object) -> Any: ... + def str(self, *args: object, **kwargs: object) -> Any: ... + def txn_requires_abort(self, *args: object, **kwargs: object) -> Any: ... + def __eq__(self, other: object) -> Any: ... + def __ge__(self, other: object) -> Any: ... + def __gt__(self, other: object) -> Any: ... + def __hash__(self) -> Any: ... + def __le__(self, other: object) -> Any: ... + def __lt__(self, other: object) -> Any: ... + def __ne__(self, other: object) -> Any: ... + +class KafkaException(Exception): ... + +class Message: + def error(self) -> Any: ... + def headers(self, *args: object, **kwargs: object) -> Any: ... + def key(self, *args: object, **kwargs: object) -> Any: ... + def latency(self, *args: object, **kwargs: object) -> Any: ... + def offset(self, *args: object, **kwargs: object) -> Any: ... + def partition(self, *args: object, **kwargs: object) -> Any: ... + def set_headers(self, *args: object, **kwargs: object) -> Any: ... + def set_key(self, *args: object, **kwargs: object) -> Any: ... + def set_value(self, *args: object, **kwargs: object) -> Any: ... + def timestamp(self, *args: object, **kwargs: object) -> Any: ... + def topic(self, *args: object, **kwargs: object) -> Any: ... + def value(self) -> Any: ... + def __len__(self) -> Any: ... + +class NewPartitions: + new_total_count: Any + replica_assignment: Any + topic: Any + def __init__(self, *args: object, **kwargs: object) -> None: ... + def __eq__(self, other: object) -> Any: ... + def __ge__(self, other: object) -> Any: ... + def __gt__(self, other: object) -> Any: ... + def __hash__(self) -> Any: ... + def __le__(self, other: object) -> Any: ... + def __lt__(self, other: object) -> Any: ... + def __ne__(self, other: object) -> Any: ... + +class NewTopic: + config: Any + num_partitions: Any + replica_assignment: Any + replication_factor: Any + topic: str + def __init__(self, *args: object, **kwargs: object) -> None: ... + def __eq__(self, other: object) -> Any: ... + def __ge__(self, other: object) -> Any: ... + def __gt__(self, other: object) -> Any: ... + def __hash__(self) -> Any: ... + def __le__(self, other: object) -> Any: ... + def __lt__(self, other: object) -> Any: ... + def __ne__(self, other: object) -> Any: ... + +class Producer: + def __init__(self, *args: object, **kwargs: object) -> None: ... + def abort_transaction(self, *args: object, **kwargs: object) -> Any: ... + def begin_transaction(self) -> Any: ... + def commit_transaction(self, *args: object, **kwargs: object) -> Any: ... + def flush(self, *args: object, **kwargs: object) -> Any: ... + def init_transactions(self, *args: object, **kwargs: object) -> Any: ... + def list_topics(self, *args: object, **kwargs: object) -> Any: ... + def poll(self, *args: object, **kwargs: object) -> Any: ... + def produce(self, topic: str, **kwargs: object) -> None: ... + def purge(self, *args: object, **kwargs: object) -> Any: ... + def send_offsets_to_transaction(self, *args: object, **kwargs: object) -> Any: ... + def __len__(self) -> Any: ... + +class _AdminClientImpl: + def __init__(self, *args: object, **kwargs: object) -> None: ... + def alter_configs(self, resources: List, f: Future, **kwargs: object) -> Any: ... + def create_acls(self, acls: List, f: Future, **kwargs: object) -> Any: ... + def create_partitions(self, new_partitions: List[NewPartitions], f: Future) -> Any: ... + def create_topics(self, new_topics: List[NewTopic], f: Future) -> Dict[str, Future]: ... + def delete_acls(self, acl_binding_filters: List, f: Future) -> Any: ... + def delete_topics(self, topics: List[str], f: Future, **kwargs: object) -> Any: ... + def describe_acls(self, acl_binding_filter: List, f: Future, **kwargs: object) -> Any: ... + def describe_configs(self, resources: List, f: Future, **kwargs: object) -> Any: ... + def list_groups(self, *args: object, **kwargs: object) -> Any: ... + def list_topics(self, *args: object, **kwargs: object) -> Any: ... + def poll(self) -> Any: ... + def alter_user_scram_credentials(self, alterations: List, f: Future) -> Any: ... + def describe_consumer_groups(self, group_ids: List[str], f: Future, **kwargs: object) -> Any: ... + def list_consumer_groups(self, f: Future, **kwargs: object) -> Any: ... + def list_consumer_group_offsets(self, offsets: List, f: Future, **kwargs: object) -> Any: ... + def incremental_alter_configs(self, resources: List, f: Future, **kwargs: object) -> Any: ... + def delete_consumer_groups(self, group_ids: List, f: Future, **kwargs: object) -> Any: ... + def set_sasl_credentials(self, username: str, password: str) -> Any: ... + def describe_user_scram_credentials(self, users: List[str], f: Future, **kwargs: object) -> Any: ... + def alter_consumer_group_offsets(self, request: List, f: Future) -> Any: ... + def __len__(self) -> Any: ... + +def libversion(*args: object, **kwargs: object) -> Any: ... +def version(*args: object, **kwargs: object) -> Any: ... diff --git a/src/confluent_kafka/deserializing_consumer.py b/src/confluent_kafka/deserializing_consumer.py index 39f80943d..ad5416297 100644 --- a/src/confluent_kafka/deserializing_consumer.py +++ b/src/confluent_kafka/deserializing_consumer.py @@ -16,7 +16,8 @@ # limitations under the License. # -from confluent_kafka.cimpl import Consumer as _ConsumerImpl +from typing import Dict, Optional +from confluent_kafka.cimpl import Consumer as _ConsumerImpl, Message from .error import (ConsumeError, KeyDeserializationError, ValueDeserializationError) @@ -70,14 +71,14 @@ class DeserializingConsumer(_ConsumerImpl): ValueError: if configuration validation fails """ # noqa: E501 - def __init__(self, conf): + def __init__(self, conf: Dict): conf_copy = conf.copy() self._key_deserializer = conf_copy.pop('key.deserializer', None) self._value_deserializer = conf_copy.pop('value.deserializer', None) super(DeserializingConsumer, self).__init__(conf_copy) - def poll(self, timeout=-1): + def poll(self, timeout: float=-1) -> Optional[Message]: """ Consume messages and calls callbacks. @@ -123,7 +124,7 @@ def poll(self, timeout=-1): msg.set_value(value) return msg - def consume(self, num_messages=1, timeout=-1): + def consume(self, num_messages: int =1, timeout: float=-1) -> None: """ :py:func:`Consumer.consume` not implemented, use :py:func:`DeserializingConsumer.poll` instead diff --git a/src/confluent_kafka/error.py b/src/confluent_kafka/error.py index 07c733c23..8aae4de4b 100644 --- a/src/confluent_kafka/error.py +++ b/src/confluent_kafka/error.py @@ -15,7 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from confluent_kafka.cimpl import KafkaException, KafkaError +from typing import Any, Optional, cast +from confluent_kafka.cimpl import KafkaException, KafkaError, Message from confluent_kafka.serialization import SerializationError @@ -32,18 +33,18 @@ class _KafkaClientError(KafkaException): by the broker. """ - def __init__(self, kafka_error, exception=None, kafka_message=None): + def __init__(self, kafka_error: KafkaError, exception: Optional[Exception]=None, kafka_message: Optional[Message]=None): super(_KafkaClientError, self).__init__(kafka_error) self.exception = exception self.kafka_message = kafka_message @property - def code(self): - return self.args[0].code() + def code(self) -> Any: + return cast(KafkaError, self.args[0]).code() @property - def name(self): - return self.args[0].name() + def name(self) -> Any: + return cast(KafkaError, self.args[0]).name() class ConsumeError(_KafkaClientError): @@ -64,7 +65,7 @@ class ConsumeError(_KafkaClientError): """ - def __init__(self, kafka_error, exception=None, kafka_message=None): + def __init__(self, kafka_error: KafkaError, exception: Optional[Exception]=None, kafka_message: Optional[Message]=None): super(ConsumeError, self).__init__(kafka_error, exception, kafka_message) @@ -81,7 +82,7 @@ class KeyDeserializationError(ConsumeError, SerializationError): """ - def __init__(self, exception=None, kafka_message=None): + def __init__(self, exception: Optional[Exception]=None, kafka_message: Optional[Message]=None): super(KeyDeserializationError, self).__init__( KafkaError(KafkaError._KEY_DESERIALIZATION, str(exception)), exception=exception, kafka_message=kafka_message) @@ -100,7 +101,7 @@ class ValueDeserializationError(ConsumeError, SerializationError): """ - def __init__(self, exception=None, kafka_message=None): + def __init__(self, exception: Optional[Exception]=None, kafka_message: Optional[Message]=None): super(ValueDeserializationError, self).__init__( KafkaError(KafkaError._VALUE_DESERIALIZATION, str(exception)), exception=exception, kafka_message=kafka_message) @@ -116,7 +117,7 @@ class ProduceError(_KafkaClientError): exception(Exception, optional): The original exception. """ - def __init__(self, kafka_error, exception=None): + def __init__(self, kafka_error: KafkaError, exception: Optional[Exception]=None): super(ProduceError, self).__init__(kafka_error, exception, None) @@ -128,7 +129,7 @@ class KeySerializationError(ProduceError, SerializationError): exception (Exception): The exception that occurred during serialization. """ - def __init__(self, exception=None): + def __init__(self, exception: Optional[Exception]=None): super(KeySerializationError, self).__init__( KafkaError(KafkaError._KEY_SERIALIZATION, str(exception)), exception=exception) @@ -142,7 +143,7 @@ class ValueSerializationError(ProduceError, SerializationError): exception (Exception): The exception that occurred during serialization. """ - def __init__(self, exception=None): + def __init__(self, exception: Optional[Exception]=None): super(ValueSerializationError, self).__init__( KafkaError(KafkaError._VALUE_SERIALIZATION, str(exception)), exception=exception) diff --git a/src/confluent_kafka/kafkatest/verifiable_client.py b/src/confluent_kafka/kafkatest/verifiable_client.py index 56d4383e3..989560734 100644 --- a/src/confluent_kafka/kafkatest/verifiable_client.py +++ b/src/confluent_kafka/kafkatest/verifiable_client.py @@ -21,6 +21,7 @@ import socket import sys import time +from typing import Dict class VerifiableClient(object): @@ -28,8 +29,7 @@ class VerifiableClient(object): Generic base class for a kafkatest verifiable client. Implements the common kafkatest protocol and semantics. """ - - def __init__(self, conf): + def __init__(self, conf: Dict): """ """ super(VerifiableClient, self).__init__() @@ -39,26 +39,26 @@ def __init__(self, conf): signal.signal(signal.SIGTERM, self.sig_term) self.dbg('Pid is %d' % os.getpid()) - def sig_term(self, sig, frame): + def sig_term(self, sig: int, frame: object) -> None: self.dbg('SIGTERM') self.run = False @staticmethod - def _timestamp(): + def _timestamp() -> str: return time.strftime('%H:%M:%S', time.localtime()) - def dbg(self, s): + def dbg(self, s: str) -> None: """ Debugging printout """ sys.stderr.write('%% %s DEBUG: %s\n' % (self._timestamp(), s)) - def err(self, s, term=False): + def err(self, s: str, term: bool =False) -> None: """ Error printout, if term=True the process will terminate immediately. """ sys.stderr.write('%% %s ERROR: %s\n' % (self._timestamp(), s)) if term: sys.stderr.write('%% FATAL ERROR ^\n') sys.exit(1) - def send(self, d): + def send(self, d: Dict) -> None: """ Send dict as JSON to stdout for consumtion by kafkatest handler """ d['_time'] = str(datetime.datetime.now()) self.dbg('SEND: %s' % json.dumps(d)) @@ -66,9 +66,9 @@ def send(self, d): sys.stdout.flush() @staticmethod - def set_config(conf, args): + def set_config(conf: Dict, args: Dict) -> None: """ Set client config properties using args dict. """ - for n, v in args.iteritems(): + for n, v in args.items(): if v is None: continue @@ -95,9 +95,9 @@ def set_config(conf, args): conf[n] = v @staticmethod - def read_config_file(path): + def read_config_file(path: str) -> Dict[str, str]: """Read (java client) config file and return dict with properties""" - conf = {} + conf: Dict[str, str] = {} with open(path, 'r') as f: for line in f: diff --git a/src/confluent_kafka/kafkatest/verifiable_consumer.py b/src/confluent_kafka/kafkatest/verifiable_consumer.py index 94aa48ee2..2ac79d219 100755 --- a/src/confluent_kafka/kafkatest/verifiable_consumer.py +++ b/src/confluent_kafka/kafkatest/verifiable_consumer.py @@ -18,8 +18,11 @@ import argparse import os import time +from typing import Any, Dict, List, Optional, cast +from typing_extensions import Literal, TypedDict from confluent_kafka import Consumer, KafkaError, KafkaException -from verifiable_client import VerifiableClient +from confluent_kafka.cimpl import Message, TopicPartition +from .verifiable_client import VerifiableClient class VerifiableConsumer(VerifiableClient): @@ -27,8 +30,7 @@ class VerifiableConsumer(VerifiableClient): confluent-kafka-python backed VerifiableConsumer class for use with Kafka's kafkatests client tests. """ - - def __init__(self, conf): + def __init__(self, conf: Dict): """ conf is a config dict passed to confluent_kafka.Consumer() """ @@ -41,16 +43,17 @@ def __init__(self, conf): self.use_auto_commit = False self.use_async_commit = False self.max_msgs = -1 - self.assignment = [] - self.assignment_dict = dict() + self.assignment: List[AssignedPartition] = [] + self.assignment_dict: Dict[str, AssignedPartition] = dict() + self.verbose: bool = False - def find_assignment(self, topic, partition): + def find_assignment(self, topic: str, partition: int) -> Optional["AssignedPartition"]: """ Find and return existing assignment based on topic and partition, or None on miss. """ skey = '%s %d' % (topic, partition) return self.assignment_dict.get(skey) - def send_records_consumed(self, immediate=False): + def send_records_consumed(self, immediate: bool=False) -> None: """ Send records_consumed, every 100 messages, on timeout, or if immediate is set. """ if self.consumed_msgs <= self.consumed_msgs_last_reported + (0 if immediate else 100): @@ -59,7 +62,8 @@ def send_records_consumed(self, immediate=False): if len(self.assignment) == 0: return - d = {'name': 'records_consumed', + SendDict = TypedDict("SendDict", {"name": Literal['records_consumed'], 'count': int, 'partitions': List[Dict] }) + d: SendDict = {'name': 'records_consumed', 'count': self.consumed_msgs - self.consumed_msgs_last_reported, 'partitions': []} @@ -71,16 +75,16 @@ def send_records_consumed(self, immediate=False): d['partitions'].append(a.to_dict()) a.min_offset = -1 - self.send(d) + self.send(cast(Dict, d)) self.consumed_msgs_last_reported = self.consumed_msgs - def send_assignment(self, evtype, partitions): + def send_assignment(self, evtype: str, partitions: List[TopicPartition]) -> None: """ Send assignment update, evtype is either 'assigned' or 'revoked' """ d = {'name': 'partitions_' + evtype, 'partitions': [{'topic': x.topic, 'partition': x.partition} for x in partitions]} self.send(d) - def on_assign(self, consumer, partitions): + def on_assign(self, consumer: Consumer, partitions: List[TopicPartition]) -> None: """ Rebalance on_assign callback """ old_assignment = self.assignment self.assignment = [AssignedPartition(p.topic, p.partition) for p in partitions] @@ -88,12 +92,13 @@ def on_assign(self, consumer, partitions): # minOffset even after a rebalance loop. for a in old_assignment: b = self.find_assignment(a.topic, a.partition) + assert b is not None b.min_offset = a.min_offset self.assignment_dict = {a.skey: a for a in self.assignment} self.send_assignment('assigned', partitions) - def on_revoke(self, consumer, partitions): + def on_revoke(self, consumer: Consumer, partitions: List[TopicPartition]) -> None: """ Rebalance on_revoke callback """ # Send final consumed records prior to rebalancing to make sure # latest consumed is in par with what is going to be committed. @@ -103,7 +108,7 @@ def on_revoke(self, consumer, partitions): self.assignment_dict = dict() self.send_assignment('revoked', partitions) - def on_commit(self, err, partitions): + def on_commit(self, err: Optional[KafkaError], partitions: List[TopicPartition]) -> None: """ Offsets Committed callback """ if err is not None and err.code() == KafkaError._NO_OFFSET: self.dbg('on_commit(): no offsets to commit') @@ -112,7 +117,7 @@ def on_commit(self, err, partitions): # Report consumed messages to make sure consumed position >= committed position self.send_records_consumed(immediate=True) - d = {'name': 'offsets_committed', + d: Dict[str, Any] = {'name': 'offsets_committed', 'offsets': []} if err is not None: @@ -134,7 +139,7 @@ def on_commit(self, err, partitions): self.send(d) - def do_commit(self, immediate=False, asynchronous=None): + def do_commit(self, immediate: bool=False, asynchronous: Optional[bool]=None) -> None: """ Commit every 1000 messages or whenever there is a consume timeout or immediate. """ if (self.use_auto_commit @@ -186,7 +191,7 @@ def do_commit(self, immediate=False, asynchronous=None): self.consumed_msgs_at_last_commit = self.consumed_msgs - def msg_consume(self, msg): + def msg_consume(self, msg: Message) -> None: """ Handle consumed message (or error event) """ if msg.error(): self.err('Consume failed: %s' % msg.error(), term=False) @@ -209,11 +214,12 @@ def msg_consume(self, msg): self.err('Received message on unassigned partition %s [%d] @ %d' % (msg.topic(), msg.partition(), msg.offset()), term=True) - a.consumed_msgs += 1 - if a.min_offset == -1: - a.min_offset = msg.offset() - if a.max_offset < msg.offset(): - a.max_offset = msg.offset() + else: + a.consumed_msgs += 1 + if a.min_offset == -1: + a.min_offset = msg.offset() + if a.max_offset < msg.offset(): + a.max_offset = msg.offset() self.consumed_msgs += 1 @@ -224,8 +230,7 @@ def msg_consume(self, msg): class AssignedPartition(object): """ Local state container for assigned partition. """ - - def __init__(self, topic, partition): + def __init__(self, topic: str, partition: int): super(AssignedPartition, self).__init__() self.topic = topic self.partition = partition @@ -234,7 +239,7 @@ def __init__(self, topic, partition): self.min_offset = -1 self.max_offset = 0 - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: """ Return a dict of this partition's state """ return {'topic': self.topic, 'partition': self.partition, 'minOffset': self.min_offset, 'maxOffset': self.max_offset} diff --git a/src/confluent_kafka/kafkatest/verifiable_producer.py b/src/confluent_kafka/kafkatest/verifiable_producer.py index a543e1d93..8decf1c71 100755 --- a/src/confluent_kafka/kafkatest/verifiable_producer.py +++ b/src/confluent_kafka/kafkatest/verifiable_producer.py @@ -17,8 +17,10 @@ import argparse import time +from typing import Dict, Optional from confluent_kafka import Producer, KafkaException -from verifiable_client import VerifiableClient +from confluent_kafka import KafkaError, Message +from .verifiable_client import VerifiableClient class VerifiableProducer(VerifiableClient): @@ -26,8 +28,7 @@ class VerifiableProducer(VerifiableClient): confluent-kafka-python backed VerifiableProducer class for use with Kafka's kafkatests client tests. """ - - def __init__(self, conf): + def __init__(self, conf: Dict): """ conf is a config dict passed to confluent_kafka.Producer() """ @@ -35,10 +36,11 @@ def __init__(self, conf): self.conf['on_delivery'] = self.dr_cb self.producer = Producer(**self.conf) self.num_acked = 0 - self.num_sent = 0 - self.num_err = 0 + self.num_sent: int = 0 + self.num_err: int = 0 + self.max_msgs: int = 0 - def dr_cb(self, err, msg): + def dr_cb(self, err: KafkaError, msg: Message) -> None: """ Per-message Delivery report callback. Called from poll() """ if err: self.num_err += 1 @@ -96,7 +98,7 @@ def dr_cb(self, err, msg): value_fmt = '%d' repeating_keys = args['repeating_keys'] - key_counter = 0 + key_counter: int = 0 if throughput > 0: delay = 1.0/throughput @@ -112,6 +114,7 @@ def dr_cb(self, err, msg): t_end = time.time() + delay while vp.run: + key: Optional[str] if repeating_keys != 0: key = '%d' % key_counter key_counter = (key_counter + 1) % repeating_keys @@ -119,7 +122,7 @@ def dr_cb(self, err, msg): key = None try: - vp.producer.produce(topic, value=(value_fmt % i), key=key, + vp.producer.produce(topic=topic, value=(value_fmt % i), key=key, timestamp=args.get('create_time', 0)) vp.num_sent += 1 except KafkaException as e: diff --git a/src/confluent_kafka/py.typed b/src/confluent_kafka/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/src/confluent_kafka/schema_registry/__init__.py b/src/confluent_kafka/schema_registry/__init__.py index e9a5a17d4..7fa9d0add 100644 --- a/src/confluent_kafka/schema_registry/__init__.py +++ b/src/confluent_kafka/schema_registry/__init__.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from confluent_kafka.serialization import SerializationContext from .schema_registry_client import (RegisteredSchema, Schema, SchemaRegistryClient, @@ -33,7 +34,7 @@ "record_subject_name_strategy"] -def topic_subject_name_strategy(ctx, record_name): +def topic_subject_name_strategy(ctx: SerializationContext, record_name: str) -> str: """ Constructs a subject name in the form of {topic}-key|value. @@ -44,10 +45,10 @@ def topic_subject_name_strategy(ctx, record_name): record_name (str): Record name. """ - return ctx.topic + "-" + ctx.field + return ctx.topic + "-" + str(ctx.field) -def topic_record_subject_name_strategy(ctx, record_name): +def topic_record_subject_name_strategy(ctx: SerializationContext, record_name: str) -> str: """ Constructs a subject name in the form of {topic}-{record_name}. @@ -61,7 +62,7 @@ def topic_record_subject_name_strategy(ctx, record_name): return ctx.topic + "-" + record_name -def record_subject_name_strategy(ctx, record_name): +def record_subject_name_strategy(ctx: SerializationContext, record_name: str) -> str: """ Constructs a subject name in the form of {record_name}. @@ -75,7 +76,7 @@ def record_subject_name_strategy(ctx, record_name): return record_name -def reference_subject_name_strategy(ctx, schema_ref): +def reference_subject_name_strategy(ctx: SerializationContext, schema_ref: SchemaReference) -> str: """ Constructs a subject reference name in the form of {reference name}. diff --git a/src/confluent_kafka/schema_registry/avro.py b/src/confluent_kafka/schema_registry/avro.py index 38ab25d56..b234ddb63 100644 --- a/src/confluent_kafka/schema_registry/avro.py +++ b/src/confluent_kafka/schema_registry/avro.py @@ -18,15 +18,19 @@ from io import BytesIO from json import loads from struct import pack, unpack +from typing import Any, Callable, Dict, Optional, Set, Tuple, Union, cast +from typing_extensions import Literal from fastavro import (parse_schema, schemaless_reader, schemaless_writer) +from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient + from . import (_MAGIC_BYTE, Schema, topic_subject_name_strategy) -from confluent_kafka.serialization import (Deserializer, +from confluent_kafka.serialization import (Deserializer, SerializationContext, SerializationError, Serializer) @@ -36,15 +40,14 @@ class _ContextStringIO(BytesIO): Wrapper to allow use of StringIO via 'with' constructs. """ - def __enter__(self): + def __enter__(self) -> "_ContextStringIO": return self - def __exit__(self, *args): + def __exit__(self, *args: Any) -> None: self.close() - return False -def _schema_loads(schema_str): +def _schema_loads(schema_str: str) -> Schema: """ Instantiate a Schema instance from a declaration string. @@ -67,7 +70,7 @@ def _schema_loads(schema_str): return Schema(schema_str, schema_type='AVRO') -def _resolve_named_schema(schema, schema_registry_client, named_schemas=None): +def _resolve_named_schema(schema: Schema, schema_registry_client: SchemaRegistryClient, named_schemas: Optional[Dict]=None) -> Dict: """ Resolves named schemas referenced by the provided schema recursively. :param schema: Schema to resolve named schemas for. @@ -180,7 +183,8 @@ class AvroSerializer(Serializer): 'use.latest.version': False, 'subject.name.strategy': topic_subject_name_strategy} - def __init__(self, schema_registry_client, schema_str, to_dict=None, conf=None): + def __init__(self, schema_registry_client: SchemaRegistryClient, schema_str: Union[str, Schema], + to_dict: Optional[Callable[[object, SerializationContext], Dict]]=None, conf: Optional[Dict]=None): if isinstance(schema_str, str): schema = _schema_loads(schema_str) elif isinstance(schema_str, Schema): @@ -189,8 +193,8 @@ def __init__(self, schema_registry_client, schema_str, to_dict=None, conf=None): raise TypeError('You must pass either schema string or schema object') self._registry = schema_registry_client - self._schema_id = None - self._known_subjects = set() + self._schema_id: Optional[int] = None + self._known_subjects: Set[str] = set() if to_dict is not None and not callable(to_dict): raise ValueError("to_dict must be callable with the signature " @@ -246,7 +250,7 @@ def __init__(self, schema_registry_client, schema_str, to_dict=None, conf=None): self._schema_name = schema_name self._parsed_schema = parsed_schema - def __call__(self, obj, ctx): + def __call__(self, obj: Any, ctx: SerializationContext) -> Optional[bytes]: """ Serializes an object to Avro binary format, prepending it with Confluent Schema Registry framing. @@ -269,6 +273,7 @@ def __call__(self, obj, ctx): if obj is None: return None + assert callable(self._subject_name_func) subject = self._subject_name_func(ctx, self._schema_name) if subject not in self._known_subjects: @@ -278,6 +283,7 @@ def __call__(self, obj, ctx): else: # Check to ensure this schema has been registered under subject_name. + assert isinstance(self._normalize_schemas, bool) if self._auto_register: # The schema name will always be the same. We can't however register # a schema without a subject so we set the schema_id here to handle @@ -343,7 +349,7 @@ class AvroDeserializer(Deserializer): __slots__ = ['_reader_schema', '_registry', '_from_dict', '_writer_schemas', '_return_record_name', '_schema', '_named_schemas'] - def __init__(self, schema_registry_client, schema_str=None, from_dict=None, return_record_name=False): + def __init__(self, schema_registry_client: SchemaRegistryClient, schema_str: Optional[str]=None, from_dict:Optional[Callable]=None, return_record_name: bool=False): schema = None if schema_str is not None: if isinstance(schema_str, str): @@ -355,9 +361,11 @@ def __init__(self, schema_registry_client, schema_str=None, from_dict=None, retu self._schema = schema self._registry = schema_registry_client - self._writer_schemas = {} + self._writer_schemas: Dict[int, Schema] = {} + self._named_schemas: Optional[Dict] + self._reader_schema: Optional[Dict] - if schema: + if self._schema: schema_dict = loads(self._schema.schema_str) self._named_schemas = _resolve_named_schema(self._schema, schema_registry_client) self._reader_schema = parse_schema(schema_dict, @@ -375,7 +383,7 @@ def __init__(self, schema_registry_client, schema_str=None, from_dict=None, retu if not isinstance(self._return_record_name, bool): raise ValueError("return_record_name must be a boolean value") - def __call__(self, data, ctx): + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext]=None) -> Any: """ Deserialize Avro binary encoded data with Confluent Schema Registry framing to a dict, or object instance according to from_dict, if specified. @@ -404,9 +412,9 @@ def __call__(self, data, ctx): "Schema Registry serializer".format(len(data))) with _ContextStringIO(data) as payload: - magic, schema_id = unpack('>bI', payload.read(5)) + magic, schema_id = cast(Tuple[bytes, int], unpack('>bI', payload.read(5))) if magic != _MAGIC_BYTE: - raise SerializationError("Unexpected magic byte {}. This message " + raise SerializationError("Unexpected magic byte {!r}. This message " "was not produced with a Confluent " "Schema Registry serializer".format(magic)) diff --git a/src/confluent_kafka/schema_registry/error.py b/src/confluent_kafka/schema_registry/error.py index 77feaeee6..181dd0b9e 100644 --- a/src/confluent_kafka/schema_registry/error.py +++ b/src/confluent_kafka/schema_registry/error.py @@ -41,15 +41,15 @@ class SchemaRegistryError(Exception): """ # noqa: E501 UNKNOWN = -1 - def __init__(self, http_status_code, error_code, error_message): + def __init__(self, http_status_code: int, error_code: int, error_message: str): self.http_status_code = http_status_code self.error_code = error_code self.error_message = error_message - def __repr__(self): + def __repr__(self) -> str: return str(self) - def __str__(self): + def __str__(self) -> str: return "{} (HTTP status code {}, SR code {})".format(self.error_message, self.http_status_code, self.error_code) diff --git a/src/confluent_kafka/schema_registry/json_schema.py b/src/confluent_kafka/schema_registry/json_schema.py index 92fefc6f3..8d9baeb19 100644 --- a/src/confluent_kafka/schema_registry/json_schema.py +++ b/src/confluent_kafka/schema_registry/json_schema.py @@ -19,13 +19,15 @@ import json import struct +from typing import Any, Callable, Dict, Optional, Set, Union, cast from jsonschema import validate, ValidationError, RefResolver from confluent_kafka.schema_registry import (_MAGIC_BYTE, Schema, topic_subject_name_strategy) -from confluent_kafka.serialization import (SerializationError, +from .schema_registry_client import SchemaRegistryClient +from confluent_kafka.serialization import (SerializationContext, SerializationError, Deserializer, Serializer) @@ -35,15 +37,14 @@ class _ContextStringIO(BytesIO): Wrapper to allow use of StringIO via 'with' constructs. """ - def __enter__(self): + def __enter__(self) -> "_ContextStringIO": return self - def __exit__(self, *args): + def __exit__(self, *args: Any) -> None: self.close() - return False -def _resolve_named_schema(schema, schema_registry_client, named_schemas=None): +def _resolve_named_schema(schema: Schema, schema_registry_client: SchemaRegistryClient, named_schemas: Optional[Dict]=None) -> Dict: """ Resolves named schemas referenced by the provided schema recursively. :param schema: Schema to resolve named schemas for. @@ -160,7 +161,7 @@ class JSONSerializer(Serializer): 'use.latest.version': False, 'subject.name.strategy': topic_subject_name_strategy} - def __init__(self, schema_str, schema_registry_client, to_dict=None, conf=None): + def __init__(self, schema_str: str, schema_registry_client: SchemaRegistryClient, to_dict: Optional[Callable[[object, SerializationContext], Dict]]=None, conf: Optional[Dict]=None): self._are_references_provided = False if isinstance(schema_str, str): self._schema = Schema(schema_str, schema_type="JSON") @@ -171,8 +172,8 @@ def __init__(self, schema_str, schema_registry_client, to_dict=None, conf=None): raise TypeError('You must pass either str or Schema') self._registry = schema_registry_client - self._schema_id = None - self._known_subjects = set() + self._schema_id: Optional[int] = None + self._known_subjects: Set[str] = set() if to_dict is not None and not callable(to_dict): raise ValueError("to_dict must be callable with the signature " @@ -198,7 +199,7 @@ def __init__(self, schema_str, schema_registry_client, to_dict=None, conf=None): if self._use_latest_version and self._auto_register: raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - self._subject_name_func = conf_copy.pop('subject.name.strategy') + self._subject_name_func: Callable = cast(Callable, conf_copy.pop('subject.name.strategy')) if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") @@ -214,7 +215,7 @@ def __init__(self, schema_str, schema_registry_client, to_dict=None, conf=None): self._schema_name = schema_name self._parsed_schema = schema_dict - def __call__(self, obj, ctx): + def __call__(self, obj: Any, ctx: SerializationContext) -> Optional[bytes]: """ Serializes an object to JSON, prepending it with Confluent Schema Registry framing. @@ -245,6 +246,7 @@ def __call__(self, obj, ctx): else: # Check to ensure this schema has been registered under subject_name. + assert isinstance(self._normalize_schemas, bool) if self._auto_register: # The schema name will always be the same. We can't however register # a schema without a subject so we set the schema_id here to handle @@ -302,7 +304,7 @@ class JSONDeserializer(Deserializer): __slots__ = ['_parsed_schema', '_from_dict', '_registry', '_are_references_provided', '_schema'] - def __init__(self, schema_str, from_dict=None, schema_registry_client=None): + def __init__(self, schema_str: Union[Schema, str], from_dict: Optional[Callable]=None, schema_registry_client: Optional[SchemaRegistryClient]=None): self._are_references_provided = False if isinstance(schema_str, str): schema = Schema(schema_str, schema_type="JSON") @@ -325,7 +327,7 @@ def __init__(self, schema_str, from_dict=None, schema_registry_client=None): self._from_dict = from_dict - def __call__(self, data, ctx): + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext]=None) -> Any: """ Deserialize a JSON encoded record with Confluent Schema Registry framing to a dict, or object instance according to from_dict if from_dict is specified. @@ -364,7 +366,7 @@ def __call__(self, data, ctx): try: if self._are_references_provided: - named_schemas = _resolve_named_schema(self._schema, self._registry) + named_schemas = _resolve_named_schema(self._schema, cast(SchemaRegistryClient, self._registry)) validate(instance=obj_dict, schema=self._parsed_schema, resolver=RefResolver(self._parsed_schema.get('$id'), self._parsed_schema, diff --git a/src/confluent_kafka/schema_registry/protobuf.py b/src/confluent_kafka/schema_registry/protobuf.py index b1de06799..d1ad234e1 100644 --- a/src/confluent_kafka/schema_registry/protobuf.py +++ b/src/confluent_kafka/schema_registry/protobuf.py @@ -16,28 +16,33 @@ # limitations under the License. import io +from subprocess import call import sys import base64 import struct +from typing import Any, Callable, Deque, Dict, List, Optional, Set, cast import warnings from collections import deque +import six -from google.protobuf.message import DecodeError +from google.protobuf.message import DecodeError, Message from google.protobuf.message_factory import MessageFactory +from google.protobuf.descriptor import Descriptor, FileDescriptor +from google.protobuf.reflection import GeneratedProtocolMessageType from . import (_MAGIC_BYTE, reference_subject_name_strategy, topic_subject_name_strategy,) from .schema_registry_client import (Schema, - SchemaReference) -from confluent_kafka.serialization import SerializationError + SchemaReference, SchemaRegistryClient) +from confluent_kafka.serialization import SerializationContext, SerializationError # Convert an int to bytes (inverse of ord()) # Python3.chr() -> Unicode # Python2.chr() -> str(alias for bytes) -if sys.version > '3': - def _bytes(v): +if six.PY3: + def _bytes(v: int) -> bytes: """ Convert int to bytes @@ -46,7 +51,7 @@ def _bytes(v): """ return bytes((v,)) else: - def _bytes(v): + def _bytes(v: int) -> str: """ Convert int to bytes @@ -61,15 +66,14 @@ class _ContextStringIO(io.BytesIO): Wrapper to allow use of StringIO via 'with' constructs. """ - def __enter__(self): + def __enter__(self) -> "_ContextStringIO": return self - def __exit__(self, *args): + def __exit__(self, *args: Any) -> None: self.close() - return False -def _create_index_array(msg_desc): +def _create_index_array(msg_desc: Descriptor) -> List[int]: """ Creates an index array specifying the location of msg_desc in the referenced FileDescriptor. @@ -84,14 +88,14 @@ def _create_index_array(msg_desc): ValueError: If the message descriptor is malformed. """ - msg_idx = deque() + msg_idx: Deque[int] = deque() # Walk the nested MessageDescriptor tree up to the root. current = msg_desc found = False while current.containing_type is not None: previous = current - current = previous.containing_type + current = cast(Descriptor, previous.containing_type) # find child's position for idx, node in enumerate(current.nested_types): if node == previous: @@ -114,7 +118,7 @@ def _create_index_array(msg_desc): return list(msg_idx) -def _schema_to_str(file_descriptor): +def _schema_to_str(file_descriptor: FileDescriptor) -> str: """ Base64 encode a FileDescriptor @@ -245,7 +249,7 @@ class ProtobufSerializer(object): 'use.deprecated.format': False, } - def __init__(self, msg_type, schema_registry_client, conf=None): + def __init__(self, msg_type: GeneratedProtocolMessageType, schema_registry_client: SchemaRegistryClient, conf: Optional[Dict]=None): if conf is None or 'use.deprecated.format' not in conf: raise RuntimeError( @@ -303,17 +307,17 @@ def __init__(self, msg_type, schema_registry_client, conf=None): .format(", ".join(conf_copy.keys()))) self._registry = schema_registry_client - self._schema_id = None - self._known_subjects = set() + self._schema_id: Optional[int] = None + self._known_subjects: Set[str] = set() self._msg_class = msg_type - descriptor = msg_type.DESCRIPTOR + descriptor = msg_type.DESCRIPTOR # type:ignore[attr-defined] self._index_array = _create_index_array(descriptor) self._schema = Schema(_schema_to_str(descriptor.file), schema_type='PROTOBUF') @staticmethod - def _write_varint(buf, val, zigzag=True): + def _write_varint(buf: io.BytesIO, val: int, zigzag: bool=True) -> None: """ Writes val to buf, either using zigzag or uvarint encoding. @@ -332,7 +336,7 @@ def _write_varint(buf, val, zigzag=True): buf.write(_bytes(val)) @staticmethod - def _encode_varints(buf, ints, zigzag=True): + def _encode_varints(buf: io.BytesIO, ints: List[int], zigzag: bool=True) -> None: """ Encodes each int as a uvarint onto buf @@ -353,7 +357,7 @@ def _encode_varints(buf, ints, zigzag=True): for value in ints: ProtobufSerializer._write_varint(buf, value, zigzag=zigzag) - def _resolve_dependencies(self, ctx, file_desc): + def _resolve_dependencies(self, ctx: SerializationContext, file_desc: FileDescriptor) -> List[SchemaReference]: """ Resolves and optionally registers schema references recursively. @@ -363,11 +367,12 @@ def _resolve_dependencies(self, ctx, file_desc): file_desc (FileDescriptor): file descriptor to traverse. """ - schema_refs = [] + schema_refs: List[SchemaReference] = [] for dep in file_desc.dependencies: if self._skip_known_types and dep.name.startswith("google/protobuf/"): continue dep_refs = self._resolve_dependencies(ctx, dep) + assert callable(self._ref_reference_subject_func) subject = self._ref_reference_subject_func(ctx, dep) schema = Schema(_schema_to_str(dep), references=dep_refs, @@ -382,7 +387,7 @@ def _resolve_dependencies(self, ctx, file_desc): reference.version)) return schema_refs - def __call__(self, message, ctx): + def __call__(self, message: Message, ctx: SerializationContext) -> Optional[bytes]: """ Serializes an instance of a class derived from Protobuf Message, and prepends it with Confluent Schema Registry framing. @@ -408,6 +413,7 @@ def __call__(self, message, ctx): raise ValueError("message must be of type {} not {}" .format(self._msg_class, type(message))) + assert callable(self._subject_name_func) subject = self._subject_name_func(ctx, message.DESCRIPTOR.full_name) @@ -420,6 +426,7 @@ def __call__(self, message, ctx): self._schema.references = self._resolve_dependencies( ctx, message.DESCRIPTOR.file) + assert isinstance(self._normalize_schemas, bool) if self._auto_register: self._schema_id = self._registry.register_schema(subject, self._schema, @@ -479,7 +486,7 @@ class ProtobufDeserializer(object): 'use.deprecated.format': False, } - def __init__(self, message_type, conf=None): + def __init__(self, message_type: Any, conf: Optional[Dict]=None): # Require use.deprecated.format to be explicitly configured # during a transitionary period since old/new format are @@ -513,7 +520,7 @@ def __init__(self, message_type, conf=None): self._msg_class = MessageFactory().GetPrototype(descriptor) @staticmethod - def _decode_varint(buf, zigzag=True): + def _decode_varint(buf: io.BytesIO, zigzag: bool=True) -> int: """ Decodes a single varint from a buffer. @@ -548,7 +555,7 @@ def _decode_varint(buf, zigzag=True): raise EOFError("Unexpected EOF while reading index") @staticmethod - def _read_byte(buf): + def _read_byte(buf: io.BytesIO) -> int: """ Read one byte from buf as an int. @@ -565,7 +572,7 @@ def _read_byte(buf): return ord(i) @staticmethod - def _read_index_array(buf, zigzag=True): + def _read_index_array(buf: io.BytesIO, zigzag: bool=True) -> List[int]: """ Read an index array from buf that specifies the message descriptor of interest in the file descriptor. @@ -591,7 +598,7 @@ def _read_index_array(buf, zigzag=True): return msg_index - def __call__(self, data, ctx): + def __call__(self, data: bytes, ctx: Optional[SerializationContext]) -> Any: """ Deserialize a serialized protobuf message with Confluent Schema Registry framing. diff --git a/src/confluent_kafka/schema_registry/schema_registry_client.py b/src/confluent_kafka/schema_registry/schema_registry_client.py index c414c7ee1..1ececfea0 100644 --- a/src/confluent_kafka/schema_registry/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/schema_registry_client.py @@ -17,34 +17,22 @@ # import json import logging -import urllib +from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast from collections import defaultdict from threading import Lock +from typing_extensions import TypedDict from requests import (Session, utils) from .error import SchemaRegistryError +import six +import six.moves.urllib.parse as urllib +string_type = six.string_types[0] -# TODO: consider adding `six` dependency or employing a compat file -# Python 2.7 is officially EOL so compatibility issue will be come more the norm. -# We need a better way to handle these issues. -# Six is one possibility but the compat file pattern used by requests -# is also quite nice. -# -# six: https://pypi.org/project/six/ -# compat file : https://github.com/psf/requests/blob/master/requests/compat.py -try: - string_type = basestring # noqa - - def _urlencode(value): - return urllib.quote(value, safe='') -except NameError: - string_type = str - - def _urlencode(value): - return urllib.parse.quote(value, safe='') +def _urlencode(value: str) -> str: + return urllib.quote(value, safe='') log = logging.getLogger(__name__) VALID_AUTH_PROVIDERS = ['URL', 'USER_INFO'] @@ -60,7 +48,7 @@ class _RestClient(object): conf (dict): Dictionary containing _RestClient configuration """ - def __init__(self, conf): + def __init__(self, conf: Dict): self.session = Session() # copy dict to avoid mutating the original @@ -105,7 +93,7 @@ def __init__(self, conf): " remove basic.auth.user.info from the" " configuration") - userinfo = tuple(conf_copy.pop('basic.auth.user.info', '').split(':')) + userinfo = cast(Tuple[str, str], tuple(conf_copy.pop('basic.auth.user.info', '').split(':'))) if len(userinfo) != 2: raise ValueError("basic.auth.user.info must be in the form" @@ -118,22 +106,22 @@ def __init__(self, conf): raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) - def _close(self): + def _close(self) -> None: self.session.close() - def get(self, url, query=None): + def get(self, url: str, query: Optional[Dict]=None) -> Any: return self.send_request(url, method='GET', query=query) - def post(self, url, body, **kwargs): + def post(self, url: str, body: Any) -> Any: return self.send_request(url, method='POST', body=body) - def delete(self, url): + def delete(self, url: str) -> Any: return self.send_request(url, method='DELETE') - def put(self, url, body=None): + def put(self, url: str, body: Any=None) -> Any: return self.send_request(url, method='PUT', body=body) - def send_request(self, url, method, body=None, query=None): + def send_request(self, url: str, method: str, body: Any=None, query: Optional[Dict]=None) -> Any: """ Sends HTTP request to the SchemaRegistry. @@ -148,7 +136,7 @@ def send_request(self, url, method, body=None, query=None): method (str): HTTP method - body (str): Request content + body (object): Request content query (dict): Query params to attach to the URL @@ -191,13 +179,13 @@ class _SchemaCache(object): known subject membership. """ - def __init__(self): + def __init__(self) -> None: self.lock = Lock() - self.schema_id_index = {} - self.schema_index = {} - self.subject_schemas = defaultdict(set) + self.schema_id_index: Dict[int, Schema] = {} + self.schema_index: Dict[Schema, int] = {} + self.subject_schemas: Dict[str, Set[Schema]] = defaultdict(set) - def set(self, schema_id, schema, subject_name=None): + def set(self, schema_id: int, schema: "Schema", subject_name: Optional[str]=None) -> None: """ Add a Schema identified by schema_id to the cache. @@ -218,7 +206,7 @@ def set(self, schema_id, schema, subject_name=None): if subject_name is not None: self.subject_schemas[subject_name].add(schema) - def get_schema(self, schema_id): + def get_schema(self, schema_id: int) -> Optional["Schema"]: """ Get the schema instance associated with schema_id from the cache. @@ -231,7 +219,7 @@ def get_schema(self, schema_id): return self.schema_id_index.get(schema_id, None) - def get_schema_id_by_subject(self, subject, schema): + def get_schema_id_by_subject(self, subject: str, schema: "Schema") -> Optional[int]: """ Get the schema_id associated with this schema registered under subject. @@ -248,6 +236,8 @@ def get_schema_id_by_subject(self, subject, schema): if schema in self.subject_schemas[subject]: return self.schema_index.get(schema, None) + return None + class SchemaRegistryClient(object): """ @@ -289,18 +279,18 @@ class SchemaRegistryClient(object): `Confluent Schema Registry documentation `_ """ # noqa: E501 - def __init__(self, conf): + def __init__(self, conf: Dict): self._rest_client = _RestClient(conf) self._cache = _SchemaCache() - def __enter__(self): + def __enter__(self) -> "SchemaRegistryClient": return self - def __exit__(self, *args): + def __exit__(self, *args: Any) -> None: if self._rest_client is not None: self._rest_client._close() - def register_schema(self, subject_name, schema, normalize_schemas=False): + def register_schema(self, subject_name: str, schema: "Schema", normalize_schemas: bool=False) -> int: """ Registers a schema under ``subject_name``. @@ -324,7 +314,7 @@ def register_schema(self, subject_name, schema, normalize_schemas=False): if schema_id is not None: return schema_id - request = {'schema': schema.schema_str} + request: Dict[str, Any] = {'schema': schema.schema_str} # CP 5.5 adds new fields (for JSON and Protobuf). if len(schema.references) > 0 or schema.schema_type != 'AVRO': @@ -338,12 +328,12 @@ def register_schema(self, subject_name, schema, normalize_schemas=False): 'subjects/{}/versions?normalize={}'.format(_urlencode(subject_name), normalize_schemas), body=request) - schema_id = response['id'] + schema_id = cast(int, response['id']) self._cache.set(schema_id, schema, subject_name) return schema_id - def get_schema(self, schema_id): + def get_schema(self, schema_id: int) -> "Schema": """ Fetches the schema associated with ``schema_id`` from the Schema Registry. The result is cached so subsequent attempts will not @@ -379,7 +369,7 @@ def get_schema(self, schema_id): return schema - def lookup_schema(self, subject_name, schema, normalize_schemas=False): + def lookup_schema(self, subject_name: str, schema: "Schema", normalize_schemas: bool=False) -> "RegisteredSchema": """ Returns ``schema`` registration information for ``subject``. @@ -398,7 +388,7 @@ def lookup_schema(self, subject_name, schema, normalize_schemas=False): `POST Subject API Reference `_ """ # noqa: E501 - request = {'schema': schema.schema_str} + request: Dict[str, Any] = {'schema': schema.schema_str} # CP 5.5 adds new fields (for JSON and Protobuf). if len(schema.references) > 0 or schema.schema_type != 'AVRO': @@ -426,7 +416,7 @@ def lookup_schema(self, subject_name, schema, normalize_schemas=False): subject=response['subject'], version=response['version']) - def get_subjects(self): + def get_subjects(self) -> List[str]: """ List all subjects registered with the Schema Registry @@ -442,7 +432,7 @@ def get_subjects(self): return self._rest_client.get('subjects') - def delete_subject(self, subject_name, permanent=False): + def delete_subject(self, subject_name: str, permanent: bool=False) -> List[int]: """ Deletes the specified subject and its associated compatibility level if registered. It is recommended to use this API only when a topic needs @@ -471,7 +461,7 @@ def delete_subject(self, subject_name, permanent=False): return list - def get_latest_version(self, subject_name): + def get_latest_version(self, subject_name: str) -> "RegisteredSchema": """ Retrieves latest registered version for subject @@ -505,7 +495,7 @@ def get_latest_version(self, subject_name): subject=response['subject'], version=response['version']) - def get_version(self, subject_name, version): + def get_version(self, subject_name: str, version: int) -> "RegisteredSchema": """ Retrieves a specific schema registered under ``subject_name``. @@ -541,7 +531,7 @@ def get_version(self, subject_name, version): subject=response['subject'], version=response['version']) - def get_versions(self, subject_name): + def get_versions(self, subject_name: str) -> List[int]: """ Get a list of all versions registered with this subject. @@ -560,7 +550,7 @@ def get_versions(self, subject_name): return self._rest_client.get('subjects/{}/versions'.format(_urlencode(subject_name))) - def delete_version(self, subject_name, version): + def delete_version(self, subject_name: str, version: int) -> int: """ Deletes a specific version registered to ``subject_name``. @@ -584,7 +574,7 @@ def delete_version(self, subject_name, version): version)) return response - def set_compatibility(self, subject_name=None, level=None): + def set_compatibility(self, subject_name: Optional[str]=None, level: Optional[str]=None) -> str: """ Update global or subject level compatibility level. @@ -616,7 +606,7 @@ def set_compatibility(self, subject_name=None, level=None): .format(_urlencode(subject_name)), body={'compatibility': level.upper()}) - def get_compatibility(self, subject_name=None): + def get_compatibility(self, subject_name: Optional[str]=None) -> str: """ Get the current compatibility level. @@ -642,7 +632,7 @@ def get_compatibility(self, subject_name=None): result = self._rest_client.get(url) return result['compatibilityLevel'] - def test_compatibility(self, subject_name, schema, version="latest"): + def test_compatibility(self, subject_name: str, schema: "Schema", version: Union[int, str]="latest") -> bool: """Test the compatibility of a candidate schema for a given subject and version Args: @@ -662,7 +652,9 @@ def test_compatibility(self, subject_name, schema, version="latest"): `POST Test Compatibility API Reference `_ """ # noqa: E501 - request = {"schema": schema.schema_str} + Reference = TypedDict("Reference", {'name': str, 'subject': str, 'version': int}) + Request = TypedDict('Request', {'schema': str, 'schemaType': str, 'references': List[Reference]}, total=False) + request: Request = {"schema": schema.schema_str} if schema.schema_type != "AVRO": request['schemaType'] = schema.schema_type @@ -692,7 +684,7 @@ class Schema(object): """ __slots__ = ['schema_str', 'schema_type', 'references', '_hash'] - def __init__(self, schema_str, schema_type, references=[]): + def __init__(self, schema_str: str, schema_type: str, references: List["SchemaReference"] =[]): super(Schema, self).__init__() self.schema_str = schema_str @@ -700,11 +692,12 @@ def __init__(self, schema_str, schema_type, references=[]): self.references = references self._hash = hash(schema_str) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + assert isinstance(other, Schema) return all([self.schema_str == other.schema_str, self.schema_type == other.schema_type]) - def __hash__(self): + def __hash__(self) -> int: return self._hash @@ -725,7 +718,7 @@ class RegisteredSchema(object): version (int): Version of this subject this schema is registered to """ - def __init__(self, schema_id, schema, subject, version): + def __init__(self, schema_id: int, schema: Schema, subject: str, version: int): self.schema_id = schema_id self.schema = schema self.subject = subject @@ -748,7 +741,7 @@ class SchemaReference(object): version (int): This Schema's version """ - def __init__(self, name, subject, version): + def __init__(self, name: str, subject: str, version: int): super(SchemaReference, self).__init__() self.name = name self.subject = subject diff --git a/src/confluent_kafka/serialization/__init__.py b/src/confluent_kafka/serialization/__init__.py index 13cfc1dd6..daead5abe 100644 --- a/src/confluent_kafka/serialization/__init__.py +++ b/src/confluent_kafka/serialization/__init__.py @@ -16,6 +16,7 @@ # limitations under the License. # import struct as _struct +from typing import Any, List, Optional, Tuple from confluent_kafka.error import KafkaException __all__ = ['Deserializer', @@ -60,7 +61,7 @@ class SerializationContext(object): headers (list): List of message header tuples. Defaults to None. """ - def __init__(self, topic, field, headers=None): + def __init__(self, topic: str, field: str, headers: Optional[List[Tuple[str, str]]]=None): self.topic = topic self.field = field self.headers = headers @@ -107,9 +108,9 @@ class Serializer(object): - unicode(encoding) """ - __slots__ = [] + __slots__: List[str] = [] - def __call__(self, obj, ctx=None): + def __call__(self, obj: Any, ctx: SerializationContext) -> Optional[bytes]: """ Converts obj to bytes. @@ -164,9 +165,9 @@ class Deserializer(object): - unicode(encoding) """ - __slots__ = [] + __slots__: List[str] = [] - def __call__(self, value, ctx=None): + def __call__(self, value: bytes, ctx: Optional[SerializationContext]=None) -> Any: """ Convert bytes to object @@ -194,8 +195,7 @@ class DoubleSerializer(Serializer): `DoubleSerializer Javadoc `_ """ # noqa: E501 - - def __call__(self, obj, ctx=None): + def __call__(self, obj: Optional[float], ctx: Optional[SerializationContext]=None) -> Any: """ Args: obj (object): object to be serialized @@ -230,7 +230,7 @@ class DoubleDeserializer(Deserializer): `DoubleDeserializer Javadoc `_ """ # noqa: E501 - def __call__(self, value, ctx=None): + def __call__(self, value: bytes, ctx: Optional[SerializationContext]=None) -> Optional[float]: """ Deserializes float from IEEE 764 binary64 bytes. @@ -264,7 +264,7 @@ class IntegerSerializer(Serializer): `IntegerSerializer Javadoc `_ """ # noqa: E501 - def __call__(self, obj, ctx=None): + def __call__(self, obj: Any, ctx: Optional[SerializationContext]=None) -> Optional[bytes]: """ Serializes int as int32 bytes. @@ -301,7 +301,7 @@ class IntegerDeserializer(Deserializer): `IntegerDeserializer Javadoc `_ """ # noqa: E501 - def __call__(self, value, ctx=None): + def __call__(self, value: bytes, ctx: Optional[SerializationContext]=None) -> Optional[int]: """ Deserializes int from int32 bytes. @@ -343,10 +343,10 @@ class StringSerializer(Serializer): `StringSerializer Javadoc `_ """ # noqa: E501 - def __init__(self, codec='utf_8'): + def __init__(self, codec: str='utf_8'): self.codec = codec - def __call__(self, obj, ctx=None): + def __call__(self, obj: Any, ctx: Optional[SerializationContext]=None) -> Optional[bytes]: """ Serializes a str(py2:unicode) to bytes. @@ -389,10 +389,10 @@ class StringDeserializer(Deserializer): `StringDeserializer Javadoc `_ """ # noqa: E501 - def __init__(self, codec='utf_8'): + def __init__(self, codec: str='utf_8'): self.codec = codec - def __call__(self, value, ctx=None): + def __call__(self, value: bytes, ctx: Optional[SerializationContext]=None) -> Optional[str]: """ Serializes unicode to bytes per the configured codec. Defaults to ``utf_8``. diff --git a/src/confluent_kafka/serializing_producer.py b/src/confluent_kafka/serializing_producer.py index 3b3ff82b0..a34da0cc5 100644 --- a/src/confluent_kafka/serializing_producer.py +++ b/src/confluent_kafka/serializing_producer.py @@ -16,7 +16,9 @@ # limitations under the License. # +from typing import Callable, Dict, List, Optional, Tuple from confluent_kafka.cimpl import Producer as _ProducerImpl + from .serialization import (MessageField, SerializationContext) from .error import (KeySerializationError, @@ -66,7 +68,7 @@ class SerializingProducer(_ProducerImpl): conf (producer): SerializingProducer configuration. """ # noqa E501 - def __init__(self, conf): + def __init__(self, conf: Dict): conf_copy = conf.copy() self._key_serializer = conf_copy.pop('key.serializer', None) @@ -74,8 +76,8 @@ def __init__(self, conf): super(SerializingProducer, self).__init__(conf_copy) - def produce(self, topic, key=None, value=None, partition=-1, - on_delivery=None, timestamp=0, headers=None): + def produce(self, topic: str, key: Optional[object]=None, value: Optional[object]=None, partition: int=-1, # type:ignore[override] + on_delivery: Optional[Callable]=None, timestamp: float=0, headers: Optional[List[Tuple[str, str]]]=None) -> None: # type:ignore[override] """ Produce a message. @@ -139,7 +141,7 @@ def produce(self, topic, key=None, value=None, partition=-1, except Exception as se: raise ValueSerializationError(se) - super(SerializingProducer, self).produce(topic, value, key, + super(SerializingProducer, self).produce(topic, value=value, key=key, headers=headers, partition=partition, timestamp=timestamp, diff --git a/tests/avro/mock_registry.py b/tests/avro/mock_registry.py index b7e0ab9a2..dd1e87fc8 100644 --- a/tests/avro/mock_registry.py +++ b/tests/avro/mock_registry.py @@ -51,7 +51,7 @@ def log_message(self, format, *args): class MockServer(HTTPSERVER.HTTPServer, object): - def __init__(self, *args, **kwargs): + def __init__(self, *args: object, **kwargs: object): super(MockServer, self).__init__(*args, **kwargs) self.counts = {} self.registry = MockSchemaRegistryClient() diff --git a/tests/requirements.txt b/tests/requirements.txt index 120daf4ac..af97ea976 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -5,6 +5,13 @@ pytest-timeout requests-mock trivup>=0.8.3 fastavro -avro>=1.11.1,<2 +avro>=1.11.2,<2 jsonschema protobuf + +# Cap the version to avoid issues with newer editions. Should be periodically updated! +mypy<=0.991 +types-protobuf +types-jsonschema +types-requests +types-six \ No newline at end of file diff --git a/tools/smoketest.sh b/tools/smoketest.sh index acfb4ac9a..ea1afa840 100755 --- a/tools/smoketest.sh +++ b/tools/smoketest.sh @@ -93,6 +93,9 @@ for py in 3.8 ; do echo "$0: Running unit tests" pytest + echo "$0: Running type checks" + mypy src/confluent_kafka + fails="" echo "$0: Verifying OpenSSL"