Skip to content

Commit 16179e3

Browse files
authored
feat(flagd-rpc): add caching (#110)
add caching with tests --- Signed-off-by: Simon Schrottner <[email protected]>
1 parent b62d3d1 commit 16179e3

File tree

13 files changed

+244
-25
lines changed

13 files changed

+244
-25
lines changed

providers/openfeature-provider-flagd/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ dependencies = [
2424
"panzi-json-logic>=1.0.1",
2525
"semver>=3,<4",
2626
"pyyaml>=6.0.1",
27+
"cachebox"
2728
]
2829
requires-python = ">=3.8"
2930

@@ -59,6 +60,7 @@ cov = [
5960
"cov-report",
6061
]
6162

63+
6264
[tool.hatch.envs.mypy]
6365
dependencies = [
6466
"mypy[faster-cache]>=1.13.0",
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[pytest]
2+
markers =
3+
rpc: tests for rpc mode.
4+
in-process: tests for rpc mode.
5+
customCert: Supports custom certs.
6+
unixsocket: Supports unixsockets.
7+
events: Supports events.
8+
sync: Supports sync.
9+
caching: Supports caching.
10+
offline: Supports offline.

providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ class ResolverType(Enum):
88
IN_PROCESS = "in-process"
99

1010

11+
class CacheType(Enum):
12+
LRU = "lru"
13+
DISABLED = "disabled"
14+
15+
16+
DEFAULT_CACHE = CacheType.LRU
17+
DEFAULT_CACHE_SIZE = 1000
1118
DEFAULT_DEADLINE = 500
1219
DEFAULT_HOST = "localhost"
1320
DEFAULT_KEEP_ALIVE = 0
@@ -19,12 +26,14 @@ class ResolverType(Enum):
1926
DEFAULT_STREAM_DEADLINE = 600000
2027
DEFAULT_TLS = False
2128

29+
ENV_VAR_CACHE_SIZE = "FLAGD_MAX_CACHE_SIZE"
30+
ENV_VAR_CACHE_TYPE = "FLAGD_CACHE"
2231
ENV_VAR_DEADLINE_MS = "FLAGD_DEADLINE_MS"
2332
ENV_VAR_HOST = "FLAGD_HOST"
2433
ENV_VAR_KEEP_ALIVE_TIME_MS = "FLAGD_KEEP_ALIVE_TIME_MS"
2534
ENV_VAR_OFFLINE_FLAG_SOURCE_PATH = "FLAGD_OFFLINE_FLAG_SOURCE_PATH"
2635
ENV_VAR_PORT = "FLAGD_PORT"
27-
ENV_VAR_RESOLVER_TYPE = "FLAGD_RESOLVER_TYPE"
36+
ENV_VAR_RESOLVER_TYPE = "FLAGD_RESOLVER"
2837
ENV_VAR_RETRY_BACKOFF_MS = "FLAGD_RETRY_BACKOFF_MS"
2938
ENV_VAR_STREAM_DEADLINE_MS = "FLAGD_STREAM_DEADLINE_MS"
3039
ENV_VAR_TLS = "FLAGD_TLS"
@@ -36,6 +45,14 @@ def str_to_bool(val: str) -> bool:
3645
return val.lower() == "true"
3746

3847

48+
def convert_resolver_type(val: typing.Union[str, ResolverType]) -> ResolverType:
49+
if isinstance(val, str):
50+
v = val.lower()
51+
return ResolverType(v)
52+
else:
53+
return ResolverType(val)
54+
55+
3956
def env_or_default(
4057
env_var: str, default: T, cast: typing.Optional[typing.Callable[[str], T]] = None
4158
) -> typing.Union[str, T]:
@@ -56,7 +73,9 @@ def __init__( # noqa: PLR0913
5673
retry_backoff_ms: typing.Optional[int] = None,
5774
deadline: typing.Optional[int] = None,
5875
stream_deadline_ms: typing.Optional[int] = None,
59-
keep_alive_time: typing.Optional[int] = None,
76+
keep_alive: typing.Optional[int] = None,
77+
cache_type: typing.Optional[CacheType] = None,
78+
max_cache_size: typing.Optional[int] = None,
6079
):
6180
self.host = env_or_default(ENV_VAR_HOST, DEFAULT_HOST) if host is None else host
6281

@@ -77,7 +96,9 @@ def __init__( # noqa: PLR0913
7796
)
7897

7998
self.resolver_type = (
80-
ResolverType(env_or_default(ENV_VAR_RESOLVER_TYPE, DEFAULT_RESOLVER_TYPE))
99+
env_or_default(
100+
ENV_VAR_RESOLVER_TYPE, DEFAULT_RESOLVER_TYPE, cast=convert_resolver_type
101+
)
81102
if resolver_type is None
82103
else resolver_type
83104
)
@@ -118,10 +139,22 @@ def __init__( # noqa: PLR0913
118139
else stream_deadline_ms
119140
)
120141

121-
self.keep_alive_time: int = (
142+
self.keep_alive: int = (
122143
int(
123144
env_or_default(ENV_VAR_KEEP_ALIVE_TIME_MS, DEFAULT_KEEP_ALIVE, cast=int)
124145
)
125-
if keep_alive_time is None
126-
else keep_alive_time
146+
if keep_alive is None
147+
else keep_alive
148+
)
149+
150+
self.cache_type = (
151+
CacheType(env_or_default(ENV_VAR_CACHE_TYPE, DEFAULT_CACHE))
152+
if cache_type is None
153+
else cache_type
154+
)
155+
156+
self.max_cache_size: int = (
157+
int(env_or_default(ENV_VAR_CACHE_SIZE, DEFAULT_CACHE_SIZE, cast=int))
158+
if max_cache_size is None
159+
else max_cache_size
127160
)

providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from openfeature.provider.metadata import Metadata
3030
from openfeature.provider.provider import AbstractProvider
3131

32-
from .config import Config, ResolverType
32+
from .config import CacheType, Config, ResolverType
3333
from .resolvers import AbstractResolver, GrpcResolver, InProcessResolver
3434

3535
T = typing.TypeVar("T")
@@ -50,6 +50,8 @@ def __init__( # noqa: PLR0913
5050
offline_flag_source_path: typing.Optional[str] = None,
5151
stream_deadline_ms: typing.Optional[int] = None,
5252
keep_alive_time: typing.Optional[int] = None,
53+
cache_type: typing.Optional[CacheType] = None,
54+
max_cache_size: typing.Optional[int] = None,
5355
):
5456
"""
5557
Create an instance of the FlagdProvider
@@ -82,7 +84,9 @@ def __init__( # noqa: PLR0913
8284
resolver_type=resolver_type,
8385
offline_flag_source_path=offline_flag_source_path,
8486
stream_deadline_ms=stream_deadline_ms,
85-
keep_alive_time=keep_alive_time,
87+
keep_alive=keep_alive_time,
88+
cache_type=cache_type,
89+
max_cache_size=max_cache_size,
8690
)
8791

8892
self.resolver = self.setup_resolver()

providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/grpc.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import typing
55

66
import grpc
7+
from cachebox import BaseCacheImpl, LRUCache
78
from google.protobuf.json_format import MessageToDict
89
from google.protobuf.struct_pb2 import Struct
910

@@ -18,13 +19,13 @@
1819
ProviderNotReadyError,
1920
TypeMismatchError,
2021
)
21-
from openfeature.flag_evaluation import FlagResolutionDetails
22+
from openfeature.flag_evaluation import FlagResolutionDetails, Reason
2223
from openfeature.schemas.protobuf.flagd.evaluation.v1 import (
2324
evaluation_pb2,
2425
evaluation_pb2_grpc,
2526
)
2627

27-
from ..config import Config
28+
from ..config import CacheType, Config
2829
from ..flag_type import FlagType
2930

3031
if typing.TYPE_CHECKING:
@@ -51,6 +52,11 @@ def __init__(
5152
self.emit_provider_ready = emit_provider_ready
5253
self.emit_provider_error = emit_provider_error
5354
self.emit_provider_configuration_changed = emit_provider_configuration_changed
55+
self.cache: typing.Optional[BaseCacheImpl] = (
56+
LRUCache(maxsize=self.config.max_cache_size)
57+
if self.config.cache_type == CacheType.LRU
58+
else None
59+
)
5460
self.stub, self.channel = self._create_stub()
5561
self.retry_backoff_seconds = config.retry_backoff_ms * 0.001
5662
self.streamline_deadline_seconds = config.stream_deadline_ms * 0.001
@@ -64,9 +70,13 @@ def _create_stub(
6470
channel_factory = grpc.secure_channel if config.tls else grpc.insecure_channel
6571
channel = channel_factory(
6672
f"{config.host}:{config.port}",
67-
options=(("grpc.keepalive_time_ms", config.keep_alive_time),),
73+
options=(("grpc.keepalive_time_ms", config.keep_alive),),
6874
)
6975
stub = evaluation_pb2_grpc.ServiceStub(channel)
76+
77+
if self.cache:
78+
self.cache.clear()
79+
7080
return stub, channel
7181

7282
def initialize(self, evaluation_context: EvaluationContext) -> None:
@@ -75,6 +85,8 @@ def initialize(self, evaluation_context: EvaluationContext) -> None:
7585
def shutdown(self) -> None:
7686
self.active = False
7787
self.channel.close()
88+
if self.cache:
89+
self.cache.clear()
7890

7991
def connect(self) -> None:
8092
self.active = True
@@ -96,7 +108,6 @@ def connect(self) -> None:
96108

97109
def listen(self) -> None:
98110
retry_delay = self.retry_backoff_seconds
99-
100111
call_args = (
101112
{"timeout": self.streamline_deadline_seconds}
102113
if self.streamline_deadline_seconds > 0
@@ -148,6 +159,10 @@ def listen(self) -> None:
148159
def handle_changed_flags(self, data: typing.Any) -> None:
149160
changed_flags = list(data["flags"].keys())
150161

162+
if self.cache:
163+
for flag in changed_flags:
164+
self.cache.pop(flag)
165+
151166
self.emit_provider_configuration_changed(ProviderEventDetails(changed_flags))
152167

153168
def resolve_boolean_details(
@@ -190,13 +205,18 @@ def resolve_object_details(
190205
) -> FlagResolutionDetails[typing.Union[dict, list]]:
191206
return self._resolve(key, FlagType.OBJECT, default_value, evaluation_context)
192207

193-
def _resolve( # noqa: PLR0915
208+
def _resolve( # noqa: PLR0915 C901
194209
self,
195210
flag_key: str,
196211
flag_type: FlagType,
197212
default_value: T,
198213
evaluation_context: typing.Optional[EvaluationContext],
199214
) -> FlagResolutionDetails[T]:
215+
if self.cache is not None and flag_key in self.cache:
216+
cached_flag: FlagResolutionDetails[T] = self.cache[flag_key]
217+
cached_flag.reason = Reason.CACHED
218+
return cached_flag
219+
200220
context = self._convert_context(evaluation_context)
201221
call_args = {"timeout": self.deadline}
202222
try:
@@ -249,12 +269,17 @@ def _resolve( # noqa: PLR0915
249269
raise GeneralError(message) from e
250270

251271
# Got a valid flag and valid type. Return it.
252-
return FlagResolutionDetails(
272+
result = FlagResolutionDetails(
253273
value=value,
254274
reason=response.reason,
255275
variant=response.variant,
256276
)
257277

278+
if response.reason == Reason.STATIC and self.cache is not None:
279+
self.cache.insert(flag_key, result)
280+
281+
return result
282+
258283
def _convert_context(
259284
self, evaluation_context: typing.Optional[EvaluationContext]
260285
) -> Struct:

providers/openfeature-provider-flagd/tests/e2e/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111
SPEC_PATH = "../../openfeature/spec"
1212

1313

14+
# running all gherkin tests, except the ones, not implemented
15+
def pytest_collection_modifyitems(config):
16+
marker = "not customCert and not unixsocket and not sync"
17+
18+
# this seems to not work with python 3.8
19+
if hasattr(config.option, "markexpr") and config.option.markexpr == "":
20+
config.option.markexpr = marker
21+
22+
1423
@pytest.fixture(autouse=True, scope="module")
1524
def setup(request, port, image):
1625
container: DockerContainer = FlagdContainer(

providers/openfeature-provider-flagd/tests/e2e/flagd_container.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
class FlagdContainer(DockerContainer):
1212
def __init__(
1313
self,
14-
image: str = "ghcr.io/open-feature/flagd-testbed:v0.5.13",
14+
image: str = "ghcr.io/open-feature/flagd-testbed:v0.5.15",
1515
port: int = 8013,
1616
**kwargs,
1717
) -> None:

providers/openfeature-provider-flagd/tests/e2e/steps.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,16 @@ def setup_key_and_default(
8282
return (key, default)
8383

8484

85+
@when(
86+
parsers.cfparse(
87+
'a string flag with key "{key}" is evaluated with details',
88+
),
89+
target_fixture="key_and_default",
90+
)
91+
def setup_key_without_default(key: str) -> typing.Tuple[str, JsonPrimitive]:
92+
return setup_key_and_default(key, "")
93+
94+
8595
@when(
8696
parsers.cfparse(
8797
'an object flag with key "{key}" is evaluated with a null default value',

0 commit comments

Comments
 (0)