Skip to content

Commit

Permalink
Add timeout to PrivContext and entrypoint_with_timeout decorator
Browse files Browse the repository at this point in the history
entrypoint_with_timeout decorator can be used with a timeout parameter,
if the timeout is reached PrivsepTimeout is raised.
The PrivContext has timeout variable, which will be used for all
functions decorated with entrypoint, and PrivsepTimeout is raised if
timeout is reached.

Co-authored-by: Rodolfo Alonso <[email protected]>
Change-Id: Ie3b1fc255c0c05fd5403b90ef49b954fe397fb77
Related-Bug: #1930401
  • Loading branch information
elajkat and ralonsoh committed Jun 23, 2021
1 parent fa47d53 commit f7f3349
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 50 deletions.
55 changes: 55 additions & 0 deletions doc/source/user/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,31 @@ defines a sys_admin_pctxt with ``CAP_CHOWN``, ``CAP_DAC_OVERRIDE``,
capabilities.CAP_SYS_ADMIN],
)

Defining a context with timeout
-------------------------------

It is possible to initialize PrivContext with timeout::

from oslo_privsep import capabilities
from oslo_privsep import priv_context

dhcp_release_cmd = priv_context.PrivContext(
__name__,
cfg_section='privsep_dhcp_release',
pypath=__name__ + '.dhcp_release_cmd',
capabilities=[caps.CAP_SYS_ADMIN,
caps.CAP_NET_ADMIN],
timeout=5
)

``PrivsepTimeout`` is raised if timeout is reached.

.. warning::

The daemon (the root process) task won't stop when timeout
is reached. That means we'll have less available threads if the related
thread never finishes.

Defining a privileged function
==============================

Expand All @@ -51,6 +76,36 @@ generic ``update_file(filename, content)`` was created, it could be used to
overwrite any file in the filesystem, allowing easy escalation to root
rights. That would defeat the whole purpose of oslo.privsep.

Defining a privileged function with timeout
-------------------------------------------

It is possible to use ``entrypoint_with_timeout`` decorator::

from oslo_privsep import daemon

from neutron import privileged

@privileged.default.entrypoint_with_timeout(timeout=5)
def get_link_devices(namespace, **kwargs):
try:
with get_iproute(namespace) as ip:
return make_serializable(ip.get_links(**kwargs))
except OSError as e:
if e.errno == errno.ENOENT:
raise NetworkNamespaceNotFound(netns_name=namespace)
raise
except daemon.FailedToDropPrivileges:
raise
except daemon.PrivsepTimeout:
raise

``PrivsepTimeout`` is raised if timeout is reached.

.. warning::

The daemon (the root process) task won't stop when timeout
is reached. That means we'll have less available threads if the related
thread never finishes.

Using a privileged function
===========================
Expand Down
54 changes: 35 additions & 19 deletions oslo_privsep/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
converted to tuples during serialization/deserialization.
"""

import datetime
import enum
import logging
import socket
import threading
Expand All @@ -28,22 +30,24 @@
import six

from oslo_privsep._i18n import _

from oslo_utils import uuidutils

LOG = logging.getLogger(__name__)


try:
import greenlet
@enum.unique
class Message(enum.IntEnum):
"""Types of messages sent across the communication channel"""
PING = 1
PONG = 2
CALL = 3
RET = 4
ERR = 5
LOG = 6

def _get_thread_ident():
# This returns something sensible, even if the current thread
# isn't a greenthread
return id(greenlet.getcurrent())

except ImportError:
def _get_thread_ident():
return threading.current_thread().ident
class PrivsepTimeout(Exception):
pass


class Serializer(object):
Expand Down Expand Up @@ -89,10 +93,11 @@ def __next__(self):
class Future(object):
"""A very simple object to track the return of a function call"""

def __init__(self, lock):
def __init__(self, lock, timeout=None):
self.condvar = threading.Condition(lock)
self.error = None
self.data = None
self.timeout = timeout

def set_result(self, data):
"""Must already be holding lock used in constructor"""
Expand All @@ -106,7 +111,16 @@ def set_exception(self, exc):

def result(self):
"""Must already be holding lock used in constructor"""
self.condvar.wait()
before = datetime.datetime.now()
if not self.condvar.wait(timeout=self.timeout):
now = datetime.datetime.now()
LOG.warning('Timeout while executing a command, timeout: %s, '
'time elapsed: %s', self.timeout,
(now - before).total_seconds())
return (Message.ERR.value,
'%s.%s' % (PrivsepTimeout.__module__,
PrivsepTimeout.__name__),
'')
if self.error is not None:
raise self.error
return self.data
Expand Down Expand Up @@ -138,8 +152,9 @@ def _reader_main(self, reader):
else:
with self.lock:
if msgid not in self.outstanding_msgs:
raise AssertionError("msgid should in "
"outstanding_msgs.")
LOG.warning("msgid should be in oustanding_msgs, it is"
"possible that timeout is reached!")
continue
self.outstanding_msgs[msgid].set_result(data)

# EOF. Perhaps the privileged process exited?
Expand All @@ -158,13 +173,14 @@ def out_of_band(self, msg):
"""Received OOB message. Subclasses might want to override this."""
pass

def send_recv(self, msg):
myid = _get_thread_ident()
future = Future(self.lock)
def send_recv(self, msg, timeout=None):
myid = uuidutils.generate_uuid()
while myid in self.outstanding_msgs:
LOG.warning("myid shoudn't be in outstanding_msgs.")
myid = uuidutils.generate_uuid()
future = Future(self.lock, timeout)

with self.lock:
if myid in self.outstanding_msgs:
raise AssertionError("myid shoudn't be in outstanding_msgs.")
self.outstanding_msgs[myid] = future
try:
self.writer.send((myid, msg))
Expand Down
40 changes: 15 additions & 25 deletions oslo_privsep/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,6 @@ class StdioFd(enum.IntEnum):
STDERR = 2


@enum.unique
class Message(enum.IntEnum):
"""Types of messages sent across the communication channel"""
PING = 1
PONG = 2
CALL = 3
RET = 4
ERR = 5
LOG = 6


class FailedToDropPrivileges(Exception):
pass

Expand Down Expand Up @@ -187,7 +176,7 @@ def emit(self, record):
data['msg'] = record.getMessage()
data['args'] = ()

self.channel.send((None, (Message.LOG, data)))
self.channel.send((None, (comm.Message.LOG, data)))


class _ClientChannel(comm.ClientChannel):
Expand All @@ -201,8 +190,8 @@ def __init__(self, sock, context):
def exchange_ping(self):
try:
# exchange "ready" messages
reply = self.send_recv((Message.PING.value,))
success = reply[0] == Message.PONG
reply = self.send_recv((comm.Message.PING.value,))
success = reply[0] == comm.Message.PONG
except Exception as e:
self.log.exception('Error while sending initial PING to privsep: '
'%s', e)
Expand All @@ -212,12 +201,13 @@ def exchange_ping(self):
self.log.critical(msg)
raise FailedToDropPrivileges(msg)

def remote_call(self, name, args, kwargs):
result = self.send_recv((Message.CALL.value, name, args, kwargs))
if result[0] == Message.RET:
def remote_call(self, name, args, kwargs, timeout):
result = self.send_recv((comm.Message.CALL.value, name, args, kwargs),
timeout)
if result[0] == comm.Message.RET:
# (RET, return value)
return result[1]
elif result[0] == Message.ERR:
elif result[0] == comm.Message.ERR:
# (ERR, exc_type, args)
#
# TODO(gus): see what can be done to preserve traceback
Expand All @@ -228,7 +218,7 @@ def remote_call(self, name, args, kwargs):
raise ProtocolError(_('Unexpected response: %r') % result)

def out_of_band(self, msg):
if msg[0] == Message.LOG:
if msg[0] == comm.Message.LOG:
# (LOG, LogRecord __dict__)
message = {encodeutils.safe_decode(k): v
for k, v in msg[1].items()}
Expand Down Expand Up @@ -470,11 +460,11 @@ def _process_cmd(self, msgid, cmd, *args):
:return: A tuple of the return status, optional call output, and
optional error information.
"""
if cmd == Message.PING:
return (Message.PONG.value,)
if cmd == comm.Message.PING:
return (comm.Message.PONG.value,)

try:
if cmd != Message.CALL:
if cmd != comm.Message.CALL:
raise ProtocolError(_('Unknown privsep cmd: %s') % cmd)

# Extract the callable and arguments
Expand All @@ -485,14 +475,14 @@ def _process_cmd(self, msgid, cmd, *args):
raise NameError(msg)

ret = func(*f_args, **f_kwargs)
return (Message.RET.value, ret)
return (comm.Message.RET.value, ret)
except Exception as e:
LOG.debug(
'privsep: Exception during request[%(msgid)s]: '
'%(err)s', {'msgid': msgid, 'err': e}, exc_info=True)
cls = e.__class__
cls_name = '%s.%s' % (cls.__module__, cls.__name__)
return (Message.ERR.value, cls_name, e.args)
return (comm.Message.ERR.value, cls_name, e.args)

def _create_done_callback(self, msgid):
"""Creates a future callback to receive command execution results.
Expand Down Expand Up @@ -520,7 +510,7 @@ def _call_back(result):
'%(err)s', {'msgid': msgid, 'err': e}, exc_info=True)
cls = e.__class__
cls_name = '%s.%s' % (cls.__module__, cls.__name__)
reply = (Message.ERR.value, cls_name, e.args)
reply = (comm.Message.ERR.value, cls_name, e.args)
try:
channel.send((msgid, reply))
except IOError:
Expand Down
43 changes: 43 additions & 0 deletions oslo_privsep/functional/test_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from oslo_config import fixture as config_fixture
from oslotest import base

from oslo_privsep import comm
from oslo_privsep import priv_context


Expand All @@ -30,13 +31,33 @@
capabilities=[],
)

test_context_with_timeout = priv_context.PrivContext(
__name__,
cfg_section='privsep',
pypath=__name__ + '.test_context_with_timeout',
capabilities=[],
timeout=0.03
)


@test_context.entrypoint
def sleep():
# We don't want the daemon to be able to handle these calls too fast.
time.sleep(.001)


@test_context.entrypoint_with_timeout(0.03)
def sleep_with_timeout(long_timeout=0.04):
time.sleep(long_timeout)
return 42


@test_context_with_timeout.entrypoint
def sleep_with_t_context(long_timeout=0.04):
time.sleep(long_timeout)
return 42


@test_context.entrypoint
def one():
return 1
Expand Down Expand Up @@ -65,6 +86,28 @@ def test_concurrency(self):
# Make sure the daemon is still working
self.assertEqual(1, one())

def test_entrypoint_with_timeout(self):
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
for _ in range(thread_pool_size + 1):
self.assertRaises(comm.PrivsepTimeout, sleep_with_timeout)

def test_entrypoint_with_timeout_pass(self):
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
for _ in range(thread_pool_size + 1):
res = sleep_with_timeout(0.01)
self.assertEqual(42, res)

def test_context_with_timeout(self):
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
for _ in range(thread_pool_size + 1):
self.assertRaises(comm.PrivsepTimeout, sleep_with_t_context)

def test_context_with_timeout_pass(self):
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
for _ in range(thread_pool_size + 1):
res = sleep_with_t_context(0.01)
self.assertEqual(42, res)

def test_logging(self):
logs()
self.assertIn('foo', self.log_fixture.logger.output)
Expand Down
Loading

0 comments on commit f7f3349

Please sign in to comment.