|
| 1 | +import asyncio |
| 2 | +from aiohttp import web |
| 3 | +import json |
| 4 | + |
1 | 5 | from ten import ( |
2 | | - Extension, |
3 | | - TenEnv, |
| 6 | + AsyncExtension, |
| 7 | + AsyncTenEnv, |
4 | 8 | Cmd, |
5 | 9 | StatusCode, |
6 | 10 | CmdResult, |
7 | 11 | ) |
8 | | -from http.server import HTTPServer, BaseHTTPRequestHandler |
9 | | -import threading |
10 | | -from functools import partial |
11 | | -import re |
12 | | - |
13 | | - |
14 | | -class HTTPHandler(BaseHTTPRequestHandler): |
15 | | - def __init__(self, ten: TenEnv, *args, directory=None, **kwargs): |
16 | | - ten.log_debug(f"new handler: {directory} {args} {kwargs}") |
17 | | - self.ten = ten |
18 | | - super().__init__(*args, **kwargs) |
19 | | - |
20 | | - def do_POST(self): |
21 | | - self.ten.log_debug(f"post request incoming {self.path}") |
22 | | - |
23 | | - # match path /cmd/<cmd_name> |
24 | | - match = re.match(r"^/cmd/([^/]+)$", self.path) |
25 | | - if match: |
26 | | - cmd_name = match.group(1) |
27 | | - try: |
28 | | - content_length = int(self.headers["Content-Length"]) |
29 | | - input = self.rfile.read(content_length).decode("utf-8") |
30 | | - self.ten.log_info(f"incoming request {self.path} {input}") |
31 | | - |
32 | | - # processing by send_cmd |
33 | | - cmd_result_event = threading.Event() |
34 | | - cmd_result: CmdResult |
35 | | - |
36 | | - def cmd_callback(_, result, ten_error): |
37 | | - nonlocal cmd_result_event |
38 | | - nonlocal cmd_result |
39 | | - cmd_result = result |
40 | | - self.ten.log_info( |
41 | | - "cmd callback result: {}".format( |
42 | | - cmd_result.get_property_to_json("") |
43 | | - ) |
44 | | - ) |
45 | | - cmd_result_event.set() |
46 | | - |
47 | | - cmd = Cmd.create(cmd_name) |
48 | | - cmd.set_property_from_json("", input) |
49 | | - self.ten.send_cmd(cmd, cmd_callback) |
50 | | - event_got = cmd_result_event.wait(timeout=5) |
51 | | - |
52 | | - # return response |
53 | | - if not event_got: # timeout |
54 | | - self.send_response_only(504) |
55 | | - self.end_headers() |
56 | | - return |
57 | | - self.send_response( |
58 | | - 200 if cmd_result.get_status_code() == StatusCode.OK else 502 |
59 | | - ) |
60 | | - self.send_header("Content-Type", "application/json") |
61 | | - self.end_headers() |
62 | | - self.wfile.write( |
63 | | - cmd_result.get_property_to_json("").encode(encoding="utf_8") |
64 | | - ) |
65 | | - except Exception as e: |
66 | | - self.ten.log_warn("failed to handle request, err {}".format(e)) |
67 | | - self.send_response_only(500) |
68 | | - self.end_headers() |
69 | | - else: |
70 | | - self.ten.log_warn(f"invalid path: {self.path}") |
71 | | - self.send_response_only(404) |
72 | | - self.end_headers() |
73 | | - |
74 | | - |
75 | | -class HTTPServerExtension(Extension): |
| 12 | + |
| 13 | + |
| 14 | +class HTTPServerExtension(AsyncExtension): |
76 | 15 | def __init__(self, name: str): |
77 | 16 | super().__init__(name) |
78 | | - self.listen_addr = "127.0.0.1" |
79 | | - self.listen_port = 8888 |
80 | | - self.cmd_white_list = None |
81 | | - self.server = None |
82 | | - self.thread = None |
83 | | - |
84 | | - def on_start(self, ten: TenEnv): |
85 | | - self.listen_addr = ten.get_property_string("listen_addr") |
86 | | - self.listen_port = ten.get_property_int("listen_port") |
87 | | - """ |
88 | | - white_list = ten.get_property_string("cmd_white_list") |
89 | | - if len(white_list) > 0: |
90 | | - self.cmd_white_list = white_list.split(",") |
91 | | - """ |
92 | | - |
93 | | - ten.log_info( |
94 | | - f"on_start {self.listen_addr}:{self.listen_port}, {self.cmd_white_list}" |
95 | | - ) |
96 | | - |
97 | | - self.server = HTTPServer( |
98 | | - (self.listen_addr, self.listen_port), partial(HTTPHandler, ten) |
99 | | - ) |
100 | | - self.thread = threading.Thread(target=self.server.serve_forever) |
101 | | - self.thread.start() |
102 | | - |
103 | | - ten.on_start_done() |
104 | | - |
105 | | - def on_stop(self, ten: TenEnv): |
106 | | - self.server.shutdown() |
107 | | - self.thread.join() |
108 | | - ten.on_stop_done() |
109 | | - |
110 | | - def on_cmd(self, ten: TenEnv, cmd: Cmd): |
| 17 | + self.listen_addr: str = "127.0.0.1" |
| 18 | + self.listen_port: int = 8888 |
| 19 | + |
| 20 | + self.ten_env: AsyncTenEnv = None |
| 21 | + |
| 22 | + # http server instances |
| 23 | + self.app = web.Application() |
| 24 | + self.runner = None |
| 25 | + |
| 26 | + # POST /cmd/{cmd_name} |
| 27 | + async def handle_post_cmd(self, request): |
| 28 | + ten_env = self.ten_env |
| 29 | + |
| 30 | + try: |
| 31 | + cmd_name = request.match_info.get('cmd_name') |
| 32 | + |
| 33 | + req_json = await request.json() |
| 34 | + input = json.dumps(req_json, ensure_ascii=False) |
| 35 | + |
| 36 | + ten_env.log_debug( |
| 37 | + f"process incoming request {request.method} {request.path} {input}") |
| 38 | + |
| 39 | + cmd = Cmd.create(cmd_name) |
| 40 | + cmd.set_property_from_json("", input) |
| 41 | + [cmd_result, _] = await asyncio.wait_for(ten_env.send_cmd(cmd), 5.0) |
| 42 | + |
| 43 | + # return response |
| 44 | + status = 200 if cmd_result.get_status_code() == StatusCode.OK else 502 |
| 45 | + return web.json_response( |
| 46 | + cmd_result.get_property_to_json(""), status=status |
| 47 | + ) |
| 48 | + except json.JSONDecodeError: |
| 49 | + return web.Response(status=400) |
| 50 | + except asyncio.TimeoutError: |
| 51 | + return web.Response(status=504) |
| 52 | + except Exception as e: |
| 53 | + ten_env.log_warn( |
| 54 | + "failed to handle request with unknown exception, err {}".format(e)) |
| 55 | + return web.Response(status=500) |
| 56 | + |
| 57 | + async def on_start(self, ten_env: AsyncTenEnv): |
| 58 | + if await ten_env.is_property_exist("listen_addr"): |
| 59 | + self.listen_addr = await ten_env.get_property_string("listen_addr") |
| 60 | + if await ten_env.is_property_exist("listen_port"): |
| 61 | + self.listen_port = await ten_env.get_property_int("listen_port") |
| 62 | + self.ten_env = ten_env |
| 63 | + |
| 64 | + ten_env.log_info( |
| 65 | + f"http server listening on {self.listen_addr}:{self.listen_port}") |
| 66 | + |
| 67 | + self.app.router.add_post("/cmd/{cmd_name}", self.handle_post_cmd) |
| 68 | + self.runner = web.AppRunner(self.app) |
| 69 | + await self.runner.setup() |
| 70 | + site = web.TCPSite(self.runner, self.listen_addr, self.listen_port) |
| 71 | + await site.start() |
| 72 | + |
| 73 | + async def on_stop(self, ten_env: AsyncTenEnv): |
| 74 | + await self.runner.cleanup() |
| 75 | + self.ten_env = None |
| 76 | + |
| 77 | + async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd): |
111 | 78 | cmd_name = cmd.get_name() |
112 | | - ten.log_info("on_cmd {cmd_name}") |
113 | | - cmd_result = CmdResult.create(StatusCode.OK) |
114 | | - ten.return_result(cmd_result, cmd) |
| 79 | + ten_env.log_debug(f"on_cmd {cmd_name}") |
| 80 | + ten_env.return_result(CmdResult.create(StatusCode.OK), cmd) |
0 commit comments