1
1
import asyncio
2
2
import socket
3
3
from functools import partial
4
+ from logging import getLogger
4
5
5
6
from oscpy .server import OSCBaseServer
6
7
7
8
9
+ logger = getLogger (__name__ )
10
+
11
+
8
12
class OSCAsyncioServer (OSCBaseServer ):
9
13
def __init__ (self , * args , ** kwargs ):
10
14
super ().__init__ (* args , ** kwargs )
@@ -13,17 +17,20 @@ def __init__(self, *args, **kwargs):
13
17
14
18
def listen (self , address = 'localhost' , port = 0 , default = False , family = 'inet' , ** kwargs ):
15
19
loop = asyncio .get_event_loop ()
16
- addr = (address , port )
20
+ if family == 'unix' :
21
+ addr = address
22
+ else :
23
+ addr = (address , port )
17
24
sock = self .get_socket (
18
25
family = socket .AF_UNIX if family == 'unix' else socket .AF_INET ,
19
26
addr = addr ,
20
27
)
21
- self .listeners [addr ] = awaitable = loop .create_datagram_endpoint (
28
+ self .listeners [( address , port or sock . getsockname ()[ 1 ])] = loop .create_datagram_endpoint (
22
29
partial (OSCProtocol , self .handle_message , sock ),
23
30
sock = sock ,
24
31
)
25
32
self .add_socket (sock , default )
26
- return awaitable
33
+ return sock
27
34
28
35
async def process (self ):
29
36
return await asyncio .gather (
@@ -36,7 +43,6 @@ async def handle_message(self, data, sender, sender_socket):
36
43
await self ._execute_callbacks (callbacks , address , values )
37
44
38
45
async def _execute_callbacks (self , callbacks_list , address , values ):
39
- print (locals ())
40
46
for cb , get_address in callbacks_list :
41
47
try :
42
48
if get_address :
@@ -45,10 +51,29 @@ async def _execute_callbacks(self, callbacks_list, address, values):
45
51
await cb (* values )
46
52
except Exception :
47
53
if self .intercept_errors :
48
- logger .error ("Unhandled exception caught in oscpy server" , exc_info = True )
54
+ logger .exception ("Unhandled exception caught in oscpy server" )
49
55
else :
50
56
raise
51
57
58
+ def stop (self , sock = None ):
59
+ """Close and remove a socket from the server's sockets.
60
+
61
+ If `sock` is None, uses the default socket for the server.
62
+
63
+ """
64
+ if not sock and self .default_socket :
65
+ sock = self .default_socket
66
+
67
+ if sock in self .sockets :
68
+ sock .close ()
69
+ self .sockets .remove (sock )
70
+ else :
71
+ raise RuntimeError ('{} is not one of my sockets!' .format (sock ))
72
+
73
+ def stop_all (self ):
74
+ for sock in self .sockets [:]:
75
+ self .stop (sock )
76
+
52
77
async def join_server (self , timeout = None ):
53
78
"""Wait for the server to exit (`terminate_server()` must have been called before).
54
79
@@ -60,13 +85,16 @@ class OSCProtocol(asyncio.DatagramProtocol):
60
85
def __init__ (self , message_handler , sock , ** kwargs ):
61
86
super ().__init__ (** kwargs )
62
87
self .message_handler = message_handler
63
- self ._socket = sock
88
+ self .socket = sock
64
89
self .loop = asyncio .get_event_loop ()
65
90
66
91
def connection_made (self , transport ):
67
92
self .transport = transport
68
93
69
94
def datagram_received (self , data , addr ):
70
95
self .loop .call_soon (
71
- lambda : asyncio .ensure_future (self .message_handler (data , addr , self ._socket ))
96
+ lambda : asyncio .ensure_future (self .message_handler (data , addr , self .socket ))
72
97
)
98
+
99
+ def getsockname (self ):
100
+ return self .socket .getsockname ()
0 commit comments