|  | 
|  | 1 | +#!/usr/bin/env python | 
|  | 2 | +# -*- coding: utf-8 -*- | 
|  | 3 | +# | 
|  | 4 | +# Copyright 2024 Confluent Inc. | 
|  | 5 | +# | 
|  | 6 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 7 | +# you may not use this file except in compliance with the License. | 
|  | 8 | +# You may obtain a copy of the License at | 
|  | 9 | +# | 
|  | 10 | +# http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 11 | +# | 
|  | 12 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 13 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 14 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 15 | +# See the License for the specific language governing permissions and | 
|  | 16 | +# limitations under the License. | 
|  | 17 | +# | 
|  | 18 | +import uuid | 
|  | 19 | +from collections import defaultdict | 
|  | 20 | +from threading import Lock | 
|  | 21 | +from typing import List, Dict, Optional | 
|  | 22 | + | 
|  | 23 | +from .schema_registry_client import AsyncSchemaRegistryClient | 
|  | 24 | +from ..common.schema_registry_client import RegisteredSchema, Schema, ServerConfig | 
|  | 25 | +from ..error import SchemaRegistryError | 
|  | 26 | + | 
|  | 27 | + | 
|  | 28 | +class _SchemaStore(object): | 
|  | 29 | + | 
|  | 30 | +    def __init__(self): | 
|  | 31 | +        self.lock = Lock() | 
|  | 32 | +        self.max_id = 0 | 
|  | 33 | +        self.schema_id_index = {} | 
|  | 34 | +        self.schema_guid_index = {} | 
|  | 35 | +        self.schema_index = {} | 
|  | 36 | +        self.subject_schemas = defaultdict(set) | 
|  | 37 | + | 
|  | 38 | +    def set(self, registered_schema: RegisteredSchema) -> RegisteredSchema: | 
|  | 39 | +        with self.lock: | 
|  | 40 | +            self.max_id += 1 | 
|  | 41 | +            rs = RegisteredSchema( | 
|  | 42 | +                schema_id=self.max_id, | 
|  | 43 | +                guid=registered_schema.guid, | 
|  | 44 | +                schema=registered_schema.schema, | 
|  | 45 | +                subject=registered_schema.subject, | 
|  | 46 | +                version=registered_schema.version | 
|  | 47 | +            ) | 
|  | 48 | +            self.schema_id_index[rs.schema_id] = rs | 
|  | 49 | +            self.schema_guid_index[rs.guid] = rs | 
|  | 50 | +            self.schema_index[rs.schema] = rs.schema_id | 
|  | 51 | +            self.subject_schemas[rs.subject].add(rs) | 
|  | 52 | +            return rs | 
|  | 53 | + | 
|  | 54 | +    def get_schema(self, schema_id: int) -> Optional[Schema]: | 
|  | 55 | +        with self.lock: | 
|  | 56 | +            rs = self.schema_id_index.get(schema_id, None) | 
|  | 57 | +            return rs.schema if rs else None | 
|  | 58 | + | 
|  | 59 | +    def get_schema_by_guid(self, guid: str) -> Optional[Schema]: | 
|  | 60 | +        with self.lock: | 
|  | 61 | +            rs = self.schema_guid_index.get(guid, None) | 
|  | 62 | +            return rs.schema if rs else None | 
|  | 63 | + | 
|  | 64 | +    def get_registered_schema_by_schema( | 
|  | 65 | +        self, | 
|  | 66 | +        subject_name: str, | 
|  | 67 | +        schema: Schema | 
|  | 68 | +    ) -> Optional[RegisteredSchema]: | 
|  | 69 | +        with self.lock: | 
|  | 70 | +            if subject_name in self.subject_schemas: | 
|  | 71 | +                for rs in self.subject_schemas[subject_name]: | 
|  | 72 | +                    if rs.schema == schema: | 
|  | 73 | +                        return rs | 
|  | 74 | +            return None | 
|  | 75 | + | 
|  | 76 | +    def get_version(self, subject_name: str, version: int) -> Optional[RegisteredSchema]: | 
|  | 77 | +        with self.lock: | 
|  | 78 | +            if subject_name in self.subject_schemas: | 
|  | 79 | +                for rs in self.subject_schemas[subject_name]: | 
|  | 80 | +                    if rs.version == version: | 
|  | 81 | +                        return rs | 
|  | 82 | +            return None | 
|  | 83 | + | 
|  | 84 | +    def get_latest_version(self, subject_name: str) -> Optional[RegisteredSchema]: | 
|  | 85 | +        with self.lock: | 
|  | 86 | +            if subject_name in self.subject_schemas: | 
|  | 87 | +                latest_version = 0 | 
|  | 88 | +                latest_schema = None | 
|  | 89 | +                for rs in self.subject_schemas[subject_name]: | 
|  | 90 | +                    if rs.version > latest_version: | 
|  | 91 | +                        latest_version = rs.version | 
|  | 92 | +                        latest_schema = rs | 
|  | 93 | +                return latest_schema | 
|  | 94 | +            return None | 
|  | 95 | + | 
|  | 96 | +    def get_latest_with_metadata( | 
|  | 97 | +        self, subject_name: str, | 
|  | 98 | +        metadata: Dict[str, str] | 
|  | 99 | +    ) -> Optional[RegisteredSchema]: | 
|  | 100 | +        with self.lock: | 
|  | 101 | +            if subject_name in self.subject_schemas: | 
|  | 102 | +                rs: RegisteredSchema | 
|  | 103 | +                for rs in self.subject_schemas[subject_name]: | 
|  | 104 | +                    if (rs.schema | 
|  | 105 | +                            and rs.schema.metadata | 
|  | 106 | +                            and rs.schema.metadata.properties | 
|  | 107 | +                            and metadata.items() <= rs.schema.metadata.properties.properties.items()): | 
|  | 108 | +                        return rs | 
|  | 109 | +            return None | 
|  | 110 | + | 
|  | 111 | +    def get_subjects(self) -> List[str]: | 
|  | 112 | +        with self.lock: | 
|  | 113 | +            return list(self.subject_schemas.keys()) | 
|  | 114 | + | 
|  | 115 | +    def get_versions(self, subject_name: str) -> List[int]: | 
|  | 116 | +        with self.lock: | 
|  | 117 | +            if subject_name in self.subject_schemas: | 
|  | 118 | +                return [rs.version for rs in self.subject_schemas[subject_name]] | 
|  | 119 | +            return [] | 
|  | 120 | + | 
|  | 121 | +    def remove_by_schema(self, registered_schema: RegisteredSchema): | 
|  | 122 | +        with self.lock: | 
|  | 123 | +            subject_name = registered_schema.subject | 
|  | 124 | +            if subject_name in self.subject_schemas: | 
|  | 125 | +                self.subject_schemas[subject_name].remove(registered_schema) | 
|  | 126 | + | 
|  | 127 | +    def remove_by_subject(self, subject_name: str) -> List[int]: | 
|  | 128 | +        with self.lock: | 
|  | 129 | +            versions = [] | 
|  | 130 | +            if subject_name in self.subject_schemas: | 
|  | 131 | +                for rs in self.subject_schemas[subject_name]: | 
|  | 132 | +                    versions.append(rs.version) | 
|  | 133 | +                    schema_id = self.schema_index.pop(rs.schema, None) | 
|  | 134 | +                    if schema_id is not None: | 
|  | 135 | +                        self.schema_id_index.pop(schema_id, None) | 
|  | 136 | + | 
|  | 137 | +                del self.subject_schemas[subject_name] | 
|  | 138 | +            return versions | 
|  | 139 | + | 
|  | 140 | +    def clear(self): | 
|  | 141 | +        with self.lock: | 
|  | 142 | +            self.schema_id_index.clear() | 
|  | 143 | +            self.schema_guid_index.clear() | 
|  | 144 | +            self.schema_index.clear() | 
|  | 145 | +            self.subject_schemas.clear() | 
|  | 146 | + | 
|  | 147 | + | 
|  | 148 | +class AsyncMockSchemaRegistryClient(AsyncSchemaRegistryClient): | 
|  | 149 | + | 
|  | 150 | +    def __init__(self, conf: dict): | 
|  | 151 | +        super().__init__(conf) | 
|  | 152 | +        self._store = _SchemaStore() | 
|  | 153 | + | 
|  | 154 | +    async def register_schema( | 
|  | 155 | +        self, subject_name: str, schema: 'Schema', | 
|  | 156 | +        normalize_schemas: bool = False | 
|  | 157 | +    ) -> int: | 
|  | 158 | +        registered_schema = await self.register_schema_full_response(subject_name, schema, normalize_schemas) | 
|  | 159 | +        return registered_schema.schema_id | 
|  | 160 | + | 
|  | 161 | +    async def register_schema_full_response( | 
|  | 162 | +        self, subject_name: str, schema: 'Schema', | 
|  | 163 | +        normalize_schemas: bool = False | 
|  | 164 | +    ) -> 'RegisteredSchema': | 
|  | 165 | +        registered_schema = self._store.get_registered_schema_by_schema(subject_name, schema) | 
|  | 166 | +        if registered_schema is not None: | 
|  | 167 | +            return registered_schema | 
|  | 168 | + | 
|  | 169 | +        latest_schema = self._store.get_latest_version(subject_name) | 
|  | 170 | +        latest_version = 1 if latest_schema is None else latest_schema.version + 1 | 
|  | 171 | + | 
|  | 172 | +        registered_schema = RegisteredSchema( | 
|  | 173 | +            schema_id=1, | 
|  | 174 | +            guid=str(uuid.uuid4()), | 
|  | 175 | +            schema=schema, | 
|  | 176 | +            subject=subject_name, | 
|  | 177 | +            version=latest_version | 
|  | 178 | +        ) | 
|  | 179 | + | 
|  | 180 | +        registered_schema = self._store.set(registered_schema) | 
|  | 181 | + | 
|  | 182 | +        return registered_schema | 
|  | 183 | + | 
|  | 184 | +    async def get_schema( | 
|  | 185 | +        self, schema_id: int, subject_name: Optional[str] = None, | 
|  | 186 | +        fmt: Optional[str] = None | 
|  | 187 | +    ) -> 'Schema': | 
|  | 188 | +        schema = self._store.get_schema(schema_id) | 
|  | 189 | +        if schema is not None: | 
|  | 190 | +            return schema | 
|  | 191 | + | 
|  | 192 | +        raise SchemaRegistryError(404, 40400, "Schema Not Found") | 
|  | 193 | + | 
|  | 194 | +    async def get_schema_by_guid( | 
|  | 195 | +        self, guid: str, fmt: Optional[str] = None | 
|  | 196 | +    ) -> 'Schema': | 
|  | 197 | +        schema = self._store.get_schema_by_guid(guid) | 
|  | 198 | +        if schema is not None: | 
|  | 199 | +            return schema | 
|  | 200 | + | 
|  | 201 | +        raise SchemaRegistryError(404, 40400, "Schema Not Found") | 
|  | 202 | + | 
|  | 203 | +    async def lookup_schema( | 
|  | 204 | +        self, subject_name: str, schema: 'Schema', | 
|  | 205 | +        normalize_schemas: bool = False, deleted: bool = False | 
|  | 206 | +    ) -> 'RegisteredSchema': | 
|  | 207 | + | 
|  | 208 | +        registered_schema = self._store.get_registered_schema_by_schema(subject_name, schema) | 
|  | 209 | +        if registered_schema is not None: | 
|  | 210 | +            return registered_schema | 
|  | 211 | + | 
|  | 212 | +        raise SchemaRegistryError(404, 40400, "Schema Not Found") | 
|  | 213 | + | 
|  | 214 | +    async def get_subjects(self) -> List[str]: | 
|  | 215 | +        return self._store.get_subjects() | 
|  | 216 | + | 
|  | 217 | +    async def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int]: | 
|  | 218 | +        return self._store.remove_by_subject(subject_name) | 
|  | 219 | + | 
|  | 220 | +    async def get_latest_version(self, subject_name: str, fmt: Optional[str] = None) -> 'RegisteredSchema': | 
|  | 221 | +        registered_schema = self._store.get_latest_version(subject_name) | 
|  | 222 | +        if registered_schema is not None: | 
|  | 223 | +            return registered_schema | 
|  | 224 | + | 
|  | 225 | +        raise SchemaRegistryError(404, 40400, "Schema Not Found") | 
|  | 226 | + | 
|  | 227 | +    async def get_latest_with_metadata( | 
|  | 228 | +        self, subject_name: str, metadata: Dict[str, str], | 
|  | 229 | +        deleted: bool = False, fmt: Optional[str] = None | 
|  | 230 | +    ) -> 'RegisteredSchema': | 
|  | 231 | +        registered_schema = self._store.get_latest_with_metadata(subject_name, metadata) | 
|  | 232 | +        if registered_schema is not None: | 
|  | 233 | +            return registered_schema | 
|  | 234 | + | 
|  | 235 | +        raise SchemaRegistryError(404, 40400, "Schema Not Found") | 
|  | 236 | + | 
|  | 237 | +    async def get_version( | 
|  | 238 | +        self, subject_name: str, version: int, | 
|  | 239 | +        deleted: bool = False, fmt: Optional[str] = None | 
|  | 240 | +    ) -> 'RegisteredSchema': | 
|  | 241 | +        registered_schema = self._store.get_version(subject_name, version) | 
|  | 242 | +        if registered_schema is not None: | 
|  | 243 | +            return registered_schema | 
|  | 244 | + | 
|  | 245 | +        raise SchemaRegistryError(404, 40400, "Schema Not Found") | 
|  | 246 | + | 
|  | 247 | +    async def get_versions(self, subject_name: str) -> List[int]: | 
|  | 248 | +        return self._store.get_versions(subject_name) | 
|  | 249 | + | 
|  | 250 | +    async def delete_version(self, subject_name: str, version: int, permanent: bool = False) -> int: | 
|  | 251 | +        registered_schema = self._store.get_version(subject_name, version) | 
|  | 252 | +        if registered_schema is not None: | 
|  | 253 | +            self._store.remove_by_schema(registered_schema) | 
|  | 254 | +            return registered_schema.schema_id | 
|  | 255 | + | 
|  | 256 | +        raise SchemaRegistryError(404, 40400, "Schema Not Found") | 
|  | 257 | + | 
|  | 258 | +    async def set_config( | 
|  | 259 | +        self, subject_name: Optional[str] = None, config: 'ServerConfig' = None  # noqa F821 | 
|  | 260 | +    ) -> 'ServerConfig':  # noqa F821 | 
|  | 261 | +        return None | 
|  | 262 | + | 
|  | 263 | +    async def get_config(self, subject_name: Optional[str] = None) -> 'ServerConfig':  # noqa F821 | 
|  | 264 | +        return None | 
0 commit comments