Skip to content
This repository was archived by the owner on Jan 17, 2025. It is now read-only.

Commit 6a3016d

Browse files
authored
fix concurrent access to self.cache in nosqldict (#121)
* fix concurrent access to self.cache in nosqldict * bump version to 0.4.2 * add -d to docker run in CI
1 parent 654aff0 commit 6a3016d

File tree

8 files changed

+182
-49
lines changed

8 files changed

+182
-49
lines changed

.github/workflows/pythonapp.yml

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ jobs:
2929
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
3030
- name: pytest
3131
run: |
32+
docker run -d -p 27017:27017 mongo
3233
pip install -r dev-requirements.txt
3334
make test
3435
- name: lint

Makefile

+5-1
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,9 @@ format:
1818

1919
test:
2020
pip3 install .
21-
coverage run -m pytest
21+
ifdef GITHUB_ACTIONS
22+
coverage run -m pytest -v --with_nosqldict
23+
else
24+
coverage run -m pytest -v
25+
endif
2226
coverage report -m

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
setuptools.setup(
1111
name="enochecker",
12-
version="0.4.1",
12+
version="0.4.2",
1313
author="domenukk",
1414
author_email="[email protected]",
1515
description="Library to build checker scripts for EnoEngine A/D CTF Framework in Python",

src/enochecker/enochecker.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def __init__(
178178
self.storage_dir = storage_dir
179179

180180
self._setup_logger()
181-
if use_db_cache:
181+
if use_db_cache and not os.getenv("MONGO_ENABLED"):
182182
self._active_dbs: Dict[str, Union[NoSqlDict, StoredDict]] = global_db_cache
183183
else:
184184
self._active_dbs = {}

src/enochecker/nosqldict.py

+59-46
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
from collections.abc import MutableMapping
66
from functools import wraps
7-
from threading import RLock, current_thread
7+
from threading import Lock, RLock, current_thread
88
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Union
99

1010
from . import utils
@@ -135,16 +135,16 @@ def __init__(
135135
self.checker_name = checker_name
136136
self.cache: Dict[Any, Any] = {}
137137
self.hash_cache: Dict[Any, Any] = {}
138+
self._lock: Lock = Lock()
138139
host_: str = host or DB_DEFAULT_HOST
139140
if isinstance(port, int):
140141
port_: int = port
141142
else:
142143
port_ = int(port or DB_DEFAULT_PORT)
143144
username_: Optional[str] = username or DB_DEFAULT_USER
144145
password_: Optional[str] = password or DB_DEFAULT_PASS
145-
self.db = self.get_client(host_, port_, username_, password_, self.logger)[
146-
checker_name
147-
][self.dict_name]
146+
self.client = self.get_client(host_, port_, username_, password_, self.logger)
147+
self.db = self.client[checker_name][self.dict_name]
148148
try:
149149
self.db.index_information()["checker_key"]
150150
except KeyError:
@@ -162,14 +162,17 @@ def __setitem__(self, key: str, value: Any) -> None:
162162
:param key: key in the dictionary
163163
:param value: value in the dictionary
164164
"""
165-
key = str(key)
165+
with self._lock:
166+
key = str(key)
166167

167-
self.cache[key] = value
168-
hash_ = value_to_hash(value)
169-
if hash_:
170-
self.hash_cache[key] = hash_
168+
self.cache[key] = value
169+
hash_ = value_to_hash(value)
170+
if hash_:
171+
self.hash_cache[key] = hash_
172+
elif key in self.hash_cache:
173+
del self.hash_cache[key]
171174

172-
self._upsert(key, value)
175+
self._upsert(key, value)
173176

174177
def _upsert(self, key: Any, value: Any) -> None:
175178
query_dict = {
@@ -198,28 +201,31 @@ def __getitem__(self, key: str, print_result: bool = False) -> Any:
198201
:param print_result: TODO
199202
:return: retrieved value
200203
"""
201-
key = str(key)
202-
if key in self.cache.items():
203-
return self.cache[key]
204+
with self._lock:
205+
key = str(key)
204206

205-
to_extract = {
206-
"key": key,
207-
"checker": self.checker_name,
208-
"name": self.dict_name,
209-
}
207+
if key in self.cache:
208+
return self.cache[key]
210209

211-
result = self.db.find_one(to_extract)
210+
to_extract = {
211+
"key": key,
212+
"checker": self.checker_name,
213+
"name": self.dict_name,
214+
}
212215

213-
if print_result:
214-
self.logger.debug(result)
216+
result = self.db.find_one(to_extract)
215217

216-
if result:
217-
self.cache[key] = result["value"]
218-
hash_ = value_to_hash(result)
219-
if hash_:
220-
self.hash_cache[key] = hash_
221-
return result["value"]
222-
raise KeyError("Could not find {} in {}".format(key, self))
218+
if print_result:
219+
self.logger.debug(result)
220+
221+
if result:
222+
val = result["value"]
223+
self.cache[key] = val
224+
hash_ = value_to_hash(val)
225+
if hash_:
226+
self.hash_cache[key] = hash_
227+
return val
228+
raise KeyError("Could not find {} in {}".format(key, self))
223229

224230
@_try_n_times
225231
def __delitem__(self, key: str) -> None:
@@ -230,16 +236,19 @@ def __delitem__(self, key: str) -> None:
230236
231237
:param key: key to delete
232238
"""
233-
key = str(key)
234-
if key in self.cache:
235-
del self.cache[key]
236-
237-
to_extract = {
238-
"key": key,
239-
"checker": self.checker_name,
240-
"name": self.dict_name,
241-
}
242-
self.db.delete_one(to_extract)
239+
with self._lock:
240+
key = str(key)
241+
if key in self.cache:
242+
del self.cache[key]
243+
if key in self.hash_cache:
244+
del self.hash_cache[key]
245+
246+
to_extract = {
247+
"key": key,
248+
"checker": self.checker_name,
249+
"name": self.dict_name,
250+
}
251+
self.db.delete_one(to_extract)
243252

244253
@_try_n_times
245254
def __len__(self) -> int:
@@ -267,14 +276,18 @@ def persist(self) -> None:
267276
"""
268277
Persist the changes in the backend.
269278
"""
270-
for (key, value) in self.cache.items():
271-
hash_ = value_to_hash(value)
272-
if (
273-
(not hash_)
274-
or (key not in self.hash_cache)
275-
or (self.hash_cache[key] != hash_)
276-
):
277-
self._upsert(key, value)
279+
with self._lock:
280+
for (key, value) in list(self.cache.items()):
281+
hash_ = value_to_hash(value)
282+
if (
283+
(not hash_)
284+
or (key not in self.hash_cache)
285+
or (self.hash_cache[key] != hash_)
286+
):
287+
self._upsert(key, value)
288+
del self.cache[key]
289+
if key in self.hash_cache:
290+
del self.hash_cache[key]
278291

279292
def __del__(self) -> None:
280293
"""

tests/conftest.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import pytest
2+
3+
4+
def pytest_addoption(parser):
5+
parser.addoption(
6+
"--with_nosqldict", action="store_true", help="Run the tests with the nosqldict"
7+
)
8+
9+
10+
def pytest_configure(config):
11+
config.addinivalue_line("markers", "nosqldict: mark test as requiring MongoDB")
12+
13+
14+
def pytest_collection_modifyitems(config, items):
15+
if config.getoption("--with_nosqldict"):
16+
return
17+
skip_nosqldict = pytest.mark.skip(reason="need --with_nosqldict option to run")
18+
for item in items:
19+
if "nosqldict" in item.keywords:
20+
item.add_marker(skip_nosqldict)

tests/test_enochecker.py

+40
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#!/usr/bin/env python3
22
import functools
33
import hashlib
4+
import secrets
45
import sys
56
import tempfile
67
from logging import DEBUG
8+
from unittest import mock
79

810
import pytest
911
from enochecker_core import CheckerMethod, CheckerTaskMessage, CheckerTaskResult
@@ -310,6 +312,44 @@ def putflagfn(self: CheckerExampleImpl):
310312
assert result.attack_info == attack_info
311313

312314

315+
@pytest.mark.nosqldict
316+
def test_nested_change_enochecker():
317+
import os
318+
319+
with mock.patch.dict(
320+
os.environ,
321+
{
322+
"MONGO_ENABLED": "1",
323+
},
324+
):
325+
dict_name = secrets.token_hex(8)
326+
327+
def putflagfn(self: CheckerExampleImpl):
328+
db = self.db(dict_name)
329+
x = {
330+
"asd": 123,
331+
}
332+
db["test"] = x
333+
334+
x["asd"] = 456
335+
336+
def getflagfn(self: CheckerExampleImpl):
337+
db = self.db(dict_name)
338+
assert db["test"]["asd"] == 456
339+
340+
setattr(CheckerExampleImpl, "putflag", putflagfn)
341+
checker = CheckerExampleImpl(method="putflag")
342+
343+
result = checker.run()
344+
assert result.result == CheckerTaskResult.OK
345+
346+
setattr(CheckerExampleImpl, "getflag", getflagfn)
347+
checker = CheckerExampleImpl(method="getflag")
348+
349+
result = checker.run()
350+
assert result.result == CheckerTaskResult.OK
351+
352+
313353
def main():
314354
pytest.main(sys.argv)
315355

tests/test_nosqldict.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import secrets
2+
3+
import pytest
4+
5+
from enochecker.nosqldict import NoSqlDict
6+
7+
8+
@pytest.fixture
9+
def nosqldict():
10+
dict_name = secrets.token_hex(8)
11+
checker_name = secrets.token_hex(8)
12+
return NoSqlDict(dict_name, checker_name)
13+
14+
15+
@pytest.mark.nosqldict
16+
def test_basic(nosqldict):
17+
nosqldict["abc"] = "xyz"
18+
assert nosqldict["abc"] == "xyz"
19+
20+
with pytest.raises(KeyError):
21+
_ = nosqldict["xyz"]
22+
23+
nosqldict["abc"] = {"stuff": b"asd"}
24+
assert nosqldict["abc"] == {"stuff": b"asd"}
25+
26+
del nosqldict["abc"]
27+
with pytest.raises(KeyError):
28+
_ = nosqldict["abc"]
29+
30+
31+
@pytest.mark.nosqldict
32+
def test_nested_change():
33+
dict_name = secrets.token_hex(8)
34+
checker_name = secrets.token_hex(8)
35+
36+
def scoped_access(dict_name, checker_name):
37+
nosqldict = NoSqlDict(dict_name, checker_name)
38+
39+
x = {
40+
"asd": 123,
41+
}
42+
nosqldict["test"] = x
43+
x["asd"] = 456
44+
45+
assert nosqldict["test"] == {
46+
"asd": 456,
47+
}
48+
49+
scoped_access(dict_name, checker_name)
50+
51+
nosqldict_new = NoSqlDict(dict_name, checker_name)
52+
53+
assert nosqldict_new["test"] == {
54+
"asd": 456,
55+
}

0 commit comments

Comments
 (0)