Skip to content

Commit 686e2ff

Browse files
refactor: use daemon for read
1 parent 6e97a07 commit 686e2ff

File tree

2 files changed

+191
-75
lines changed

2 files changed

+191
-75
lines changed

src/websocket/moon.pkg.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
"moonbitlang/async/io",
44
"moonbitlang/async/socket",
55
"moonbitlang/async/internal/time",
6-
"moonbitlang/x/crypto"
6+
"moonbitlang/x/crypto",
7+
"moonbitlang/async",
8+
"moonbitlang/async/aqueue",
9+
"moonbitlang/async/semaphore"
710
]
811
}

src/websocket/server.mbt

Lines changed: 187 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
struct ServerConnection {
1818
conn : @socket.Tcp
1919
mut closed : CloseCode?
20+
out : @async.Queue[Result[Message, Error]]
21+
semaphore : @semaphore.Semaphore
2022
}
2123

2224
///|
@@ -94,7 +96,176 @@ async fn ServerConnection::handshake(conn : @socket.Tcp) -> ServerConnection {
9496
$|\r
9597
$|
9698
conn.write(@encoding/utf8.encode(response))
97-
{ conn, closed: None }
99+
{
100+
conn,
101+
closed: None,
102+
out: @aqueue.new(),
103+
semaphore: @semaphore.Semaphore::new(1),
104+
}
105+
}
106+
107+
///|
108+
/// The main read loop for the WebSocket connection
109+
///
110+
/// This does not raise any errors. Errors are communicated via the out queue.
111+
async fn ServerConnection::serve_read(self : ServerConnection) -> Unit noraise {
112+
let frames : Array[Frame] = []
113+
let mut first_opcode : OpCode? = None
114+
while self.closed is None {
115+
let frame = read_frame(self.conn) catch {
116+
e => {
117+
// On read error, close the connection and communicate the error
118+
if self.closed is None {
119+
self.closed = Some(Abnormal)
120+
}
121+
self.out.put(Err(e))
122+
return
123+
}
124+
}
125+
126+
// Handle control frames immediately
127+
match frame.opcode {
128+
Close => {
129+
// Parse close code and reason
130+
// Ref: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1
131+
let mut close_code = Normal
132+
if frame.payload.length() >= 2 {
133+
let payload_arr = frame.payload.to_fixedarray()
134+
let code_int = (payload_arr[0].to_int() << 8) |
135+
payload_arr[1].to_int()
136+
close_code = CloseCode::from_int(code_int).unwrap_or(Normal)
137+
}
138+
// If we didn't send close first, respond with close
139+
if self.closed is None {
140+
// Echo the close frame back and close
141+
self.closed = Some(close_code)
142+
self.send_close(code=close_code) catch {
143+
_ => ()
144+
}
145+
self.out.put(Err(ConnectionClosed(close_code)))
146+
}
147+
return
148+
}
149+
Ping => {
150+
// Auto-respond to ping with pong
151+
self.pong(data=frame.payload) catch {
152+
e => {
153+
if self.closed is None {
154+
self.closed = Some(Abnormal)
155+
}
156+
self.out.put(Err(e))
157+
return
158+
}
159+
}
160+
continue
161+
}
162+
Pong =>
163+
// Ignore pong frames
164+
// TODO : track pong responses for ping timeouts
165+
continue
166+
Text =>
167+
if first_opcode is Some(_) {
168+
// Ref https://datatracker.ietf.org/doc/html/rfc6455#section-5.4
169+
// We don't have extensions, so fragments MUST NOT be interleaved
170+
self.closed = Some(ProtocolError)
171+
self.send_close(code=ProtocolError) catch {
172+
_ => ()
173+
}
174+
self.out.put(Err(ConnectionClosed(ProtocolError)))
175+
return
176+
} else if frame.fin {
177+
// Single-frame text message
178+
let text = @encoding/utf8.decode(frame.payload) catch {
179+
_ => {
180+
// Ref https://datatracker.ietf.org/doc/html/rfc6455#section-8.1
181+
// We MUST Fail the WebSocket Connection if the payload is not
182+
// valid UTF-8
183+
self.closed = Some(InvalidFramePayload)
184+
self.send_close(code=InvalidFramePayload) catch {
185+
_ => ()
186+
}
187+
self.out.put(Err(ConnectionClosed(InvalidFramePayload)))
188+
return
189+
}
190+
}
191+
let message = Message::Text(text)
192+
// Handle the complete message
193+
self.out.put(Ok(message))
194+
} else {
195+
first_opcode = Some(Text)
196+
// Start of fragmented text message
197+
frames.push(frame)
198+
}
199+
Binary =>
200+
if first_opcode is Some(_) {
201+
// Ref https://datatracker.ietf.org/doc/html/rfc6455#section-5.4
202+
// We don't have extensions, so fragments MUST NOT be interleaved
203+
self.closed = Some(ProtocolError)
204+
self.send_close(code=ProtocolError) catch {
205+
_ => ()
206+
}
207+
self.out.put(Err(ConnectionClosed(ProtocolError)))
208+
return
209+
} else if frame.fin {
210+
// Single-frame binary message
211+
let message = Message::Binary(frame.payload)
212+
// Handle the complete message
213+
self.out.put(Ok(message))
214+
} else {
215+
first_opcode = Some(Binary)
216+
// Start of fragmented binary message
217+
frames.push(frame)
218+
}
219+
Continuation => {
220+
if first_opcode is None {
221+
// Continuation frame without a starting frame
222+
self.closed = Some(ProtocolError)
223+
self.send_close(code=ProtocolError) catch {
224+
_ => ()
225+
}
226+
self.out.put(Err(ConnectionClosed(ProtocolError)))
227+
}
228+
frames.push(frame)
229+
if frame.fin {
230+
// Final fragment received, assemble message
231+
let total_size = frames.fold(init=0, fn(acc, f) {
232+
acc + f.payload.length()
233+
})
234+
let data = FixedArray::make(total_size, b'\x00')
235+
let mut offset = 0
236+
for f in frames {
237+
data.blit_from_bytes(offset, f.payload, 0, f.payload.length())
238+
offset += f.payload.length()
239+
}
240+
let message_data = data.unsafe_reinterpret_as_bytes()
241+
match first_opcode {
242+
Some(Text) => {
243+
let text = @encoding/utf8.decode(message_data) catch {
244+
_ => {
245+
self.closed = Some(InvalidFramePayload)
246+
self.send_close(code=InvalidFramePayload) catch {
247+
_ => ()
248+
}
249+
self.out.put(Err(ConnectionClosed(InvalidFramePayload)))
250+
return
251+
}
252+
}
253+
let message = Message::Text(text)
254+
self.out.put(Ok(message))
255+
}
256+
Some(Binary) => {
257+
let message = Message::Binary(message_data)
258+
self.out.put(Ok(message))
259+
}
260+
_ => panic()
261+
}
262+
// Reset for next message
263+
frames.clear()
264+
first_opcode = None
265+
}
266+
}
267+
}
268+
}
98269
}
99270

100271
///|
@@ -115,6 +286,8 @@ pub async fn ServerConnection::send_text(
115286
if self.closed is Some(code) {
116287
raise ConnectionClosed(code)
117288
}
289+
self.semaphore.acquire()
290+
defer self.semaphore.release()
118291
let payload = @encoding/utf8.encode(text)
119292
write_frame(self.conn, true, OpCode::Text, payload, false)
120293
}
@@ -128,6 +301,8 @@ pub async fn ServerConnection::send_binary(
128301
if self.closed is Some(code) {
129302
raise ConnectionClosed(code)
130303
}
304+
self.semaphore.acquire()
305+
defer self.semaphore.release()
131306
write_frame(self.conn, true, OpCode::Binary, data, false)
132307
}
133308

@@ -143,6 +318,8 @@ async fn ServerConnection::_ping(
143318
if self.closed is Some(code) {
144319
raise ConnectionClosed(code)
145320
}
321+
self.semaphore.acquire()
322+
defer self.semaphore.release()
146323
write_frame(self.conn, true, OpCode::Ping, data, false)
147324
}
148325

@@ -157,6 +334,8 @@ async fn ServerConnection::pong(
157334
if self.closed is Some(code) {
158335
raise ConnectionClosed(code)
159336
}
337+
self.semaphore.acquire()
338+
defer self.semaphore.release()
160339
write_frame(self.conn, true, OpCode::Pong, data, false)
161340
}
162341

@@ -182,6 +361,8 @@ pub async fn ServerConnection::send_close(
182361
if reason != "" {
183362
payload.blit_from_bytesview(2, reason)
184363
}
364+
self.semaphore.acquire()
365+
defer self.semaphore.release()
185366
write_frame(
186367
self.conn,
187368
true,
@@ -199,78 +380,7 @@ pub async fn ServerConnection::receive(self : ServerConnection) -> Message {
199380
if self.closed is Some(code) {
200381
raise ConnectionClosed(code)
201382
}
202-
let frames : Array[Frame] = []
203-
let mut first_opcode : OpCode? = None
204-
for {
205-
let frame = read_frame(self.conn)
206-
207-
// Handle control frames immediately
208-
match frame.opcode {
209-
OpCode::Close => {
210-
// Parse close code and reason
211-
let mut close_code = Normal
212-
if frame.payload.length() >= 2 {
213-
let payload_arr = frame.payload.to_fixedarray()
214-
let code_int = (payload_arr[0].to_int() << 8) |
215-
payload_arr[1].to_int()
216-
close_code = CloseCode::from_int(code_int).unwrap_or(Normal)
217-
if frame.payload.length() > 2 {
218-
// As per spec https://datatracker.ietf.org/doc/html/rfc6455#autoid-27
219-
// The data is not guaranteed to be human readable
220-
// So we do not decode it here
221-
// And we are not using it further
222-
let _reason_bytes = payload_arr.unsafe_reinterpret_as_bytes()[2:]
223-
224-
}
225-
}
226-
// If we didn't send close first, respond with close
227-
if self.closed is None {
228-
// Echo the close frame back and close
229-
self.send_close(code=close_code)
230-
self.closed = Some(close_code)
231-
}
232-
raise ConnectionClosed(close_code)
233-
}
234-
OpCode::Ping => {
235-
// Auto-respond to ping with pong
236-
self.pong(data=frame.payload)
237-
continue
238-
}
239-
OpCode::Pong =>
240-
// Ignore pong frames
241-
continue
242-
_ => ()
243-
}
244-
245-
// Track the first opcode for message type
246-
if first_opcode is None {
247-
first_opcode = Some(frame.opcode)
248-
}
249-
frames.push(frame)
250-
251-
// If this is the final frame, assemble the message
252-
if frame.fin {
253-
break
254-
}
255-
}
256-
257-
// Assemble message from frames
258-
let total_size = frames.fold(init=0, fn(acc, f) { acc + f.payload.length() })
259-
let data = FixedArray::make(total_size, b'\x00')
260-
let mut offset = 0
261-
for frame in frames {
262-
let payload_arr = frame.payload.to_fixedarray()
263-
for i = 0; i < payload_arr.length(); i = i + 1 {
264-
data[offset + i] = payload_arr[i]
265-
}
266-
offset += payload_arr.length()
267-
}
268-
let message_data = data.unsafe_reinterpret_as_bytes()
269-
match first_opcode {
270-
Some(OpCode::Text) => Text(@encoding/utf8.decode_lossy(message_data))
271-
Some(OpCode::Binary) => Binary(message_data)
272-
_ => Binary(message_data)
273-
} // Default to binary
383+
self.out.get().unwrap_or_error()
274384
}
275385

276386
///|
@@ -311,7 +421,10 @@ pub async fn run_server(
311421
raise e
312422
}
313423
}
314-
f(ws_conn, client_addr)
424+
@async.with_task_group(taskgroup => {
425+
taskgroup.spawn_bg(() => f(ws_conn, client_addr))
426+
taskgroup.spawn_bg(() => ws_conn.serve_read())
427+
})
315428
},
316429
allow_failure?,
317430
max_connections?,

0 commit comments

Comments
 (0)