From ad9faf334195bcb8204500d7ad3992e086a3319a Mon Sep 17 00:00:00 2001 From: Emil Madsen Date: Thu, 11 Feb 2021 22:28:37 +0100 Subject: [PATCH 1/3] Initial attempt at caching async functions --- beaker/cache.py | 119 ++++++++++++++++++++++++++++++++------------ beaker/container.py | 61 +++++++++++++++-------- 2 files changed, 128 insertions(+), 52 deletions(-) diff --git a/beaker/cache.py b/beaker/cache.py index 5a1ad6a4..950bfae5 100644 --- a/beaker/cache.py +++ b/beaker/cache.py @@ -6,6 +6,7 @@ :func:`.region_invalidate`. """ +import inspect import warnings from itertools import chain @@ -322,6 +323,11 @@ def get(self, key, **kw): return self._get_value(key, **kw).get_value() get_value = get + async def aget(self, key, **kw): + """Retrieve a cached value from the container""" + return await self._get_value(key, **kw).aget_value() + aget_value = aget + def remove_value(self, key, **kw): mycontainer = self._get_value(key, **kw) mycontainer.clear_value() @@ -547,28 +553,42 @@ def _cache_decorate(deco_args, manager, options, region): cache = [None] - def decorate(func): - namespace = util.func_namespace(func) - skip_self = util.has_self_arg(func) - signature = func_signature(func) + def _get_cache_region(region): + if region is None: + return None + if region not in cache_regions: + raise BeakerException( + 'Cache region not configured: %s' % region + ) + return cache_regions[region] + + def _short_circuit(cache, region): + if not cache and region is not None: + reg = _get_cache_region(region) + if not reg.get('enabled', True): + return True + return False + + def _find_cache(namespace, region, **options): + if region is not None: + reg = _get_cache_region(region) + return Cache._get_cache(namespace, reg) + elif manager: + return manager.get_cache(namespace, **options) + else: + raise Exception("'manager + kwargs' or 'region' " + "argument is required") - @wraps(func) - def cached(*args, **kwargs): - if not cache[0]: - if region is not None: - if region not in cache_regions: - raise BeakerException( - 'Cache region not configured: %s' % region) - reg = cache_regions[region] - if not reg.get('enabled', True): - return func(*args, **kwargs) - cache[0] = Cache._get_cache(namespace, reg) - elif manager: - cache[0] = manager.get_cache(namespace, **options) - else: - raise Exception("'manager + kwargs' or 'region' " - "argument is required") + def _determine_key_length(region, options): + if region: + cachereg = cache_regions[region] + key_length = cachereg.get('key_length', util.DEFAULT_CACHE_KEY_LENGTH) + else: + key_length = options.pop('key_length', util.DEFAULT_CACHE_KEY_LENGTH) + return key_length + def _cache_key_func(namespace, skip_self, signature): + def _inner(key_length, *args, **kwargs): cache_key_kwargs = [] if kwargs: # kwargs provided, merge them in positional args @@ -582,23 +602,58 @@ def cached(*args, **kwargs): cache_key = u_(" ").join(map(u_, chain(deco_args, cache_key_args, cache_key_kwargs))) - if region: - cachereg = cache_regions[region] - key_length = cachereg.get('key_length', util.DEFAULT_CACHE_KEY_LENGTH) - else: - key_length = options.pop('key_length', util.DEFAULT_CACHE_KEY_LENGTH) - - # TODO: This is probably a bug as length is checked before converting to UTF8 + # TODO: This is probably a bug as length is checked before converting to UTF-8 # which will cause cache_key to grow in size. if len(cache_key) + len(namespace) > int(key_length): cache_key = sha1(cache_key.encode('utf-8')).hexdigest() + return cache_key + return _inner + + def decorate(func): + namespace = util.func_namespace(func) + skip_self = util.has_self_arg(func) + signature = func_signature(func) + + _determine_cache_key = _cache_key_func(namespace, skip_self, signature) + + async_func = inspect.iscoroutinefunction(func) + if async_func: + @wraps(func) + async def cached(*args, **kwargs): + if _short_circuit(cache[0], region): + return await func(*args, **kwargs) + + if not cache[0]: + cache[0] = _find_cache(namespace, region, **options) + + key_length = _determine_key_length(region, options) + cache_key = _determine_cache_key(key_length, *args, **kwargs) + + async def go(): + return await func(*args, **kwargs) + # save org function name + go.__name__ = '_cached_%s' % (func.__name__,) + + return await cache[0].aget_value(cache_key, createfunc=go) + else: + @wraps(func) + def cached(*args, **kwargs): + if _short_circuit(cache[0], region): + return func(*args, **kwargs) + + if not cache[0]: + cache[0] = _find_cache(namespace, region, **options) + + key_length = _determine_key_length(region, options) + cache_key = _determine_cache_key(key_length, *args, **kwargs) + + def go(): + return func(*args, **kwargs) + # save org function name + go.__name__ = '_cached_%s' % (func.__name__,) - def go(): - return func(*args, **kwargs) - # save org function name - go.__name__ = '_cached_%s' % (func.__name__,) + return cache[0].get_value(cache_key, createfunc=go) - return cache[0].get_value(cache_key, createfunc=go) cached._arg_namespace = namespace if region is not None: cached._arg_region = region diff --git a/beaker/container.py b/beaker/container.py index f3f5b4f8..4796d000 100644 --- a/beaker/container.py +++ b/beaker/container.py @@ -328,7 +328,7 @@ def _is_expired(self, storedtime, expiretime): ) ) - def get_value(self): + def _check_cache(self): self.namespace.acquire_read_lock() try: has_value = self.has_value() @@ -336,7 +336,7 @@ def get_value(self): try: stored, expired, value = self._get_value() if not self._is_expired(stored, expired): - return value + return None, value except KeyError: # guard against un-mutexed backends raising KeyError has_value = False @@ -345,36 +345,35 @@ def get_value(self): raise KeyError(self.key) finally: self.namespace.release_read_lock() + return has_value, None - has_createlock = False + def _creation_lock_or_value(self, has_value): creation_lock = self.namespace.get_creation_lock(self.key) if has_value: if not creation_lock.acquire(wait=False): debug("get_value returning old value while new one is created") - return value + return None, value else: debug("lock_creatfunc (didnt wait)") - has_createlock = True - - if not has_createlock: + else: debug("lock_createfunc (waiting)") creation_lock.acquire() debug("lock_createfunc (waited)") + return creation_lock, None + + def get_value(self): + has_value, value = self._check_cache() + if has_value is None: + return value + + creation_lock, value = self._creation_lock_or_value(has_value) + if creation_lock is None: + return value try: - # see if someone created the value already - self.namespace.acquire_read_lock() - try: - if self.has_value(): - try: - stored, expired, value = self._get_value() - if not self._is_expired(stored, expired): - return value - except KeyError: - # guard against un-mutexed backends raising KeyError - pass - finally: - self.namespace.release_read_lock() + has_value, value = self._check_cache() + if has_value is None: + return value debug("get_value creating new value") v = self.createfunc() @@ -384,6 +383,28 @@ def get_value(self): creation_lock.release() debug("released create lock") + async def aget_value(self): + has_value, value = self._check_cache() + if has_value is None: + return value + + creation_lock, value = self._creation_lock_or_value(has_value) + if creation_lock is None: + return value + + try: + has_value, value = self._check_cache() + if has_value is None: + return value + + debug("get_value creating new value") + v = await self.createfunc() + self.set_value(v) + return v + finally: + creation_lock.release() + debug("released create lock") + def _get_value(self): value = self.namespace[self.key] try: From 4e185107e80dca9f4ae70b5135d7c60bb839980c Mon Sep 17 00:00:00 2001 From: Emil Madsen Date: Thu, 11 Feb 2021 22:35:18 +0100 Subject: [PATCH 2/3] Fix early options application --- beaker/cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/beaker/cache.py b/beaker/cache.py index 950bfae5..8d4cdc2b 100644 --- a/beaker/cache.py +++ b/beaker/cache.py @@ -569,7 +569,7 @@ def _short_circuit(cache, region): return True return False - def _find_cache(namespace, region, **options): + def _find_cache(namespace, region, options): if region is not None: reg = _get_cache_region(region) return Cache._get_cache(namespace, reg) @@ -624,7 +624,7 @@ async def cached(*args, **kwargs): return await func(*args, **kwargs) if not cache[0]: - cache[0] = _find_cache(namespace, region, **options) + cache[0] = _find_cache(namespace, region, options) key_length = _determine_key_length(region, options) cache_key = _determine_cache_key(key_length, *args, **kwargs) @@ -642,7 +642,7 @@ def cached(*args, **kwargs): return func(*args, **kwargs) if not cache[0]: - cache[0] = _find_cache(namespace, region, **options) + cache[0] = _find_cache(namespace, region, options) key_length = _determine_key_length(region, options) cache_key = _determine_cache_key(key_length, *args, **kwargs) From 788814806dbf9cad450530bf8ca30bed7cc5b015 Mon Sep 17 00:00:00 2001 From: Emil Madsen Date: Fri, 12 Feb 2021 11:58:17 +0100 Subject: [PATCH 3/3] Added python version checks for async --- beaker/cache.py | 10 ++++++---- beaker/container.py | 36 +++++++++++++++++++----------------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/beaker/cache.py b/beaker/cache.py index 8d4cdc2b..1cacfa19 100644 --- a/beaker/cache.py +++ b/beaker/cache.py @@ -8,6 +8,7 @@ """ import inspect import warnings +import sys from itertools import chain from beaker._compat import u_, unicode_text, func_signature, bindfuncargs @@ -323,10 +324,11 @@ def get(self, key, **kw): return self._get_value(key, **kw).get_value() get_value = get - async def aget(self, key, **kw): - """Retrieve a cached value from the container""" - return await self._get_value(key, **kw).aget_value() - aget_value = aget + if sys.version_info[0] == 3 and sys.version_info[1] > 4: + async def aget(self, key, **kw): + """Retrieve a cached value from the container""" + return await self._get_value(key, **kw).aget_value() + aget_value = aget def remove_value(self, key, **kw): mycontainer = self._get_value(key, **kw) diff --git a/beaker/container.py b/beaker/container.py index 4796d000..e39b5885 100644 --- a/beaker/container.py +++ b/beaker/container.py @@ -6,6 +6,7 @@ import beaker.util as util import logging import os +import sys import time from beaker.exceptions import CreationAbortedError, MissingCacheParameter @@ -383,27 +384,28 @@ def get_value(self): creation_lock.release() debug("released create lock") - async def aget_value(self): - has_value, value = self._check_cache() - if has_value is None: - return value - - creation_lock, value = self._creation_lock_or_value(has_value) - if creation_lock is None: - return value - - try: + if sys.version_info[0] == 3 and sys.version_info[1] > 4: + async def aget_value(self): has_value, value = self._check_cache() if has_value is None: return value - debug("get_value creating new value") - v = await self.createfunc() - self.set_value(v) - return v - finally: - creation_lock.release() - debug("released create lock") + creation_lock, value = self._creation_lock_or_value(has_value) + if creation_lock is None: + return value + + try: + has_value, value = self._check_cache() + if has_value is None: + return value + + debug("get_value creating new value") + v = await self.createfunc() + self.set_value(v) + return v + finally: + creation_lock.release() + debug("released create lock") def _get_value(self): value = self.namespace[self.key]