Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions requests_unixsocket/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import os
import requests
import sys

from .adapters import UnixAdapter

DEFAULT_SCHEME = 'http+unix://'
DEFAULT_SCHEMES = os.getenv(
'REQUESTS_UNIXSOCKET_URL_SCHEMES',
'http+unix://,http://sock.local/'
).split(',')


class Session(requests.Session):
def __init__(self, url_scheme=DEFAULT_SCHEME, *args, **kwargs):
def __init__(self, url_schemes=DEFAULT_SCHEMES, *args, **kwargs):
super(Session, self).__init__(*args, **kwargs)
self.mount(url_scheme, UnixAdapter())
for url_scheme in url_schemes:
self.mount(url_scheme, UnixAdapter())


class monkeypatch(object):
def __init__(self, url_scheme=DEFAULT_SCHEME):
self.session = Session()
def __init__(self, url_schemes=DEFAULT_SCHEMES):
self.session = Session(url_schemes=url_schemes)
requests = self._get_global_requests_module()

# Methods to replace
Expand Down
49 changes: 43 additions & 6 deletions requests_unixsocket/adapters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import socket

from requests.adapters import HTTPAdapter
Expand All @@ -14,6 +15,36 @@
import urllib3


def get_unix_socket(path_or_name, timeout=None, type=socket.SOCK_STREAM):
sock = socket.socket(family=socket.AF_UNIX, type=type)
if timeout:
sock.settimeout(timeout)
sock.connect(path_or_name)
return sock


def get_sock_path_and_req_path(path):
i = 1
while True:
try:
items = path.rsplit('/', i)
sock_path = items[0]
rest = items[1:]
except ValueError:
return None, None

if os.path.exists(sock_path):
return sock_path, '/' + '/'.join(rest)

# Detect abstract namespace socket, starting with `/%00`
if '/' not in sock_path[1:] and sock_path[1:4] == '%00':
return '\x00' + sock_path[4:], '/' + '/'.join(rest)

if sock_path == '':
return None, None
i += 1


# The following was adapted from some code from docker-py
# https://github.com/docker/docker-py/blob/master/docker/transport/unixconn.py
class UnixHTTPConnection(httplib.HTTPConnection, object):
Expand All @@ -35,11 +66,13 @@ def __del__(self): # base class does not have d'tor
self.sock.close()

def connect(self):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(self.timeout)
socket_path = unquote(urlparse(self.unix_socket_url).netloc)
sock.connect(socket_path)
self.sock = sock
path = urlparse(self.unix_socket_url).path
socket_path, req_path = get_sock_path_and_req_path(path)
if not socket_path:
socket_path = urlparse(self.unix_socket_url).path
if '\x00' not in socket_path and not os.path.exists(socket_path):
socket_path = unquote(urlparse(self.unix_socket_url).netloc)
self.sock = get_unix_socket(socket_path, timeout=self.timeout)


class UnixHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
Expand Down Expand Up @@ -83,7 +116,11 @@ def get_connection(self, url, proxies=None):
return pool

def request_url(self, request, proxies):
return request.path_url
sock_path, req_path = get_sock_path_and_req_path(request.path_url)
if req_path:
return req_path
else:
return request.path_url

def close(self):
self.pools.clear()
99 changes: 99 additions & 0 deletions requests_unixsocket/tests/test_requests_unixsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,31 @@ def test_unix_domain_adapter_ok():
assert r.text == 'Hello world!'


def test_unix_domain_adapter_ok_alt_scheme():
with UnixSocketServerThread() as usock_thread:
session = requests_unixsocket.Session('http+unix://')
url = 'http+unix://unix.socket%s/path/to/page' % usock_thread.usock

for method in ['get', 'post', 'head', 'patch', 'put', 'delete',
'options']:
logger.debug('Calling session.%s(%r) ...', method, url)
r = getattr(session, method)(url)
logger.debug(
'Received response: %r with text: %r and headers: %r',
r, r.text, r.headers)
assert r.status_code == 200
assert r.headers['server'] == 'waitress'
assert r.headers['X-Transport'] == 'unix domain socket'
assert r.headers['X-Requested-Path'] == '/path/to/page'
assert r.headers['X-Socket-Path'] == usock_thread.usock
assert isinstance(r.connection, requests_unixsocket.UnixAdapter)
assert r.url.lower() == url.lower()
if method == 'head':
assert r.text == ''
else:
assert r.text == 'Hello world!'


def test_unix_domain_adapter_url_with_query_params():
with UnixSocketServerThread() as usock_thread:
session = requests_unixsocket.Session('http+unix://')
Expand Down Expand Up @@ -69,6 +94,33 @@ def test_unix_domain_adapter_url_with_query_params():
assert r.text == 'Hello world!'


def test_unix_domain_adapter_url_with_query_params_alt_scheme():
with UnixSocketServerThread() as usock_thread:
session = requests_unixsocket.Session('http+unix://')
url = ('http+unix://unix.socket%s'
'/containers/nginx/logs?timestamp=true' % usock_thread.usock)

for method in ['get', 'post', 'head', 'patch', 'put', 'delete',
'options']:
logger.debug('Calling session.%s(%r) ...', method, url)
r = getattr(session, method)(url)
logger.debug(
'Received response: %r with text: %r and headers: %r',
r, r.text, r.headers)
assert r.status_code == 200
assert r.headers['server'] == 'waitress'
assert r.headers['X-Transport'] == 'unix domain socket'
assert r.headers['X-Requested-Path'] == '/containers/nginx/logs'
assert r.headers['X-Requested-Query-String'] == 'timestamp=true'
assert r.headers['X-Socket-Path'] == usock_thread.usock
assert isinstance(r.connection, requests_unixsocket.UnixAdapter)
assert r.url.lower() == url.lower()
if method == 'head':
assert r.text == ''
else:
assert r.text == 'Hello world!'


def test_unix_domain_adapter_connection_error():
session = requests_unixsocket.Session('http+unix://')

Expand All @@ -78,6 +130,15 @@ def test_unix_domain_adapter_connection_error():
'http+unix://socket_does_not_exist/path/to/page')


def test_unix_domain_adapter_connection_error_alt_scheme():
session = requests_unixsocket.Session('http+unix://')

for method in ['get', 'post', 'head', 'patch', 'put', 'delete', 'options']:
with pytest.raises(requests.ConnectionError):
getattr(session, method)(
'http+unix://unix.socket/socket_does_not_exist/path/to/page')


def test_unix_domain_adapter_connection_proxies_error():
session = requests_unixsocket.Session('http+unix://')

Expand All @@ -90,6 +151,18 @@ def test_unix_domain_adapter_connection_proxies_error():
in str(excinfo.value))


def test_unix_domain_adapter_connection_proxies_error_alt_scheme():
session = requests_unixsocket.Session('http+unix://')

for method in ['get', 'post', 'head', 'patch', 'put', 'delete', 'options']:
with pytest.raises(ValueError) as excinfo:
getattr(session, method)(
'http+unix://unix.socket/socket_does_not_exist/path/to/page',
proxies={"http+unix": "http://10.10.1.10:1080"})
assert ('UnixAdapter does not support specifying proxies'
in str(excinfo.value))


def test_unix_domain_adapter_monkeypatch():
with UnixSocketServerThread() as usock_thread:
with requests_unixsocket.monkeypatch('http+unix://'):
Expand Down Expand Up @@ -119,3 +192,29 @@ def test_unix_domain_adapter_monkeypatch():
for method in ['get', 'post', 'head', 'patch', 'put', 'delete', 'options']:
with pytest.raises(requests.exceptions.InvalidSchema):
getattr(requests, method)(url)


def test_unix_domain_adapter_monkeypatch_alt_scheme():
with UnixSocketServerThread() as usock_thread:
with requests_unixsocket.monkeypatch():
url = 'http://sock.local/%s/path/to/page' % usock_thread.usock

for method in ['get', 'post', 'head', 'patch', 'put', 'delete',
'options']:
logger.debug('Calling session.%s(%r) ...', method, url)
r = getattr(requests, method)(url)
logger.debug(
'Received response: %r with text: %r and headers: %r',
r, r.text, r.headers)
assert r.status_code == 200
assert r.headers['server'] == 'waitress'
assert r.headers['X-Transport'] == 'unix domain socket'
assert r.headers['X-Requested-Path'] == '/path/to/page'
assert r.headers['X-Socket-Path'] == usock_thread.usock
assert isinstance(r.connection,
requests_unixsocket.UnixAdapter)
assert r.url.lower() == url.lower()
if method == 'head':
assert r.text == ''
else:
assert r.text == 'Hello world!'