diff --git a/packages/socket.io/lib/adapter.ts b/packages/socket.io/lib/adapter.ts index 921b97e..48ec108 100644 --- a/packages/socket.io/lib/adapter.ts +++ b/packages/socket.io/lib/adapter.ts @@ -4,11 +4,17 @@ import { type Namespace } from "./namespace.ts"; import { type Packet } from "../../socket.io-parser/mod.ts"; import { generateId } from "../../engine.io/mod.ts"; import { getLogger } from "../../../deps.ts"; +import { yeast } from "./contrib/yeast.ts"; const DEFAULT_TIMEOUT_MS = 5000; export type SocketId = string; export type Room = string | number; +/** + * A private ID, sent by the server at the beginning of the Socket.IO session and used for connection state recovery + * upon reconnection + */ +export type PrivateSessionId = string | undefined; export interface BroadcastOptions { rooms: Set; @@ -23,6 +29,16 @@ export interface BroadcastFlags { timeout?: number; } + +interface SessionToPersist { + sid: SocketId; + pid: PrivateSessionId; + rooms: Room[]; + data: unknown; +} + +export type Session = SessionToPersist & { missedPackets: unknown[][] }; + interface AdapterEvents { "create-room": (room: Room) => void; "delete-room": (room: Room) => void; @@ -31,7 +47,7 @@ interface AdapterEvents { "error": (err: Error) => void; } -export class InMemoryAdapter extends EventEmitter< +export abstract class InMemoryAdapter extends EventEmitter< Record, Record, AdapterEvents @@ -375,7 +391,7 @@ function serializeSocket(socket: Socket) { }; } -export abstract class Adapter extends InMemoryAdapter { +export class Adapter extends InMemoryAdapter { protected readonly uid: string; #pendingRequests = new Map< @@ -388,7 +404,7 @@ export abstract class Adapter extends InMemoryAdapter { AckRequest >(); - protected constructor(nsp: Namespace) { + constructor(nsp: Namespace) { super(nsp); this.uid = generateId(); } @@ -399,12 +415,12 @@ export abstract class Adapter extends InMemoryAdapter { * @param request * @protected */ - protected abstract publishRequest(request: ClusterRequest): void; + protected publishRequest(_request: ClusterRequest): void { } - protected abstract publishResponse( - requesterUid: string, - response: ClusterResponse, - ): void; + protected publishResponse( + _requesterUid: string, + _response: ClusterResponse, + ): void { } override addSockets(opts: BroadcastOptions, rooms: Room[]) { super.addSockets(opts, rooms); @@ -867,4 +883,133 @@ export abstract class Adapter extends InMemoryAdapter { break; } } + + /** + * Save the client session in order to restore it upon reconnection. + */ + public persistSession(_session: SessionToPersist) { } + + /** + * Restore the session and find the packets that were missed by the client. + * @param pid + * @param offset + */ + public restoreSession( + _pid: PrivateSessionId, + _offset: string + ): Promise { + return Promise.resolve(null); + } +} + +interface PersistedPacket { + id: string; + emittedAt: number; + data: unknown[]; + opts: BroadcastOptions; } + +type SessionWithTimestamp = SessionToPersist & { disconnectedAt: number }; + +export class SessionAwareAdapter extends Adapter { + private readonly maxDisconnectionDuration: number; + + private sessions: Map = new Map(); + private packets: PersistedPacket[] = []; + + constructor(readonly nsp: Namespace) { + super(nsp); + // FIXME: Add conditional typing for server options + this.maxDisconnectionDuration = nsp._server.opts.connectionStateRecovery?.maxDisconnectionDuration || 2 * 60 * 1000; + getLogger("socket.io").debug(`[adapter] Create a session persist adapter`); + const timerId = setInterval(() => { + const threshold = Date.now() - this.maxDisconnectionDuration; + this.sessions.forEach((session, sessionId) => { + const hasExpired = session.disconnectedAt < threshold; + if (hasExpired) { + this.sessions.delete(sessionId); + } + }); + for (let i = this.packets.length - 1; i >= 0; i--) { + const hasExpired = this.packets[i].emittedAt < threshold; + if (hasExpired) { + this.packets.splice(0, i + 1); + break; + } + } + }, 60 * 1000); + // prevents the timer from keeping the process alive + clearTimeout(timerId) + } + + override persistSession(session: SessionToPersist) { + (session as SessionWithTimestamp).disconnectedAt = Date.now(); + this.sessions.set(session.pid, session as SessionWithTimestamp); + } + + override restoreSession( + pid: PrivateSessionId, + offset: string + ): Promise { + const session = this.sessions.get(pid); + if (!session) { + // the session may have expired + return Promise.resolve(null); + } + const hasExpired = + session.disconnectedAt + this.maxDisconnectionDuration < Date.now(); + if (hasExpired) { + // the session has expired + this.sessions.delete(pid); + return Promise.resolve(null); + } + const index = this.packets.findIndex((packet) => packet.id === offset); + if (index === -1) { + // the offset may be too old + return Promise.resolve(null); + } + const missedPackets = []; + for (let i = index + 1; i < this.packets.length; i++) { + const packet = this.packets[i]; + if (shouldIncludePacket(session.rooms, packet.opts)) { + missedPackets.push(packet.data); + } + } + return Promise.resolve({ + ...session, + missedPackets, + }); + } + + override broadcast(packet: any, opts: BroadcastOptions) { + const isEventPacket = packet.type === 2; + // packets with acknowledgement are not stored because the acknowledgement function cannot be serialized and + // restored on another server upon reconnection + const withoutAcknowledgement = packet.id === undefined; + const notVolatile = opts.flags?.volatile === undefined; + if (isEventPacket && withoutAcknowledgement && notVolatile) { + const id = yeast(); + // the offset is stored at the end of the data array, so the client knows the ID of the last packet it has + // processed (and the format is backward-compatible) + packet.data.push(id); + this.packets.push({ + id, + opts, + data: packet.data, + emittedAt: Date.now(), + }); + } + super.broadcast(packet, opts); + } +} + + +function shouldIncludePacket( + sessionRooms: Room[], + opts: BroadcastOptions +): boolean { + const included = + opts.rooms.size === 0 || sessionRooms.some((room) => opts.rooms.has(room)); + const notExcluded = sessionRooms.every((room) => !opts.except.has(room)); + return included && notExcluded; +} \ No newline at end of file diff --git a/packages/socket.io/lib/broadcast-operator.ts b/packages/socket.io/lib/broadcast-operator.ts index 9801147..6ffb2f6 100644 --- a/packages/socket.io/lib/broadcast-operator.ts +++ b/packages/socket.io/lib/broadcast-operator.ts @@ -287,7 +287,7 @@ export class BroadcastOperator * socket.disconnect(); * } */ - public fetchSockets(): Promise< + public fetchSockets>(): Promise< RemoteSocket[] > { return this.adapter diff --git a/packages/socket.io/lib/client.ts b/packages/socket.io/lib/client.ts index 15cdebb..8d437d0 100644 --- a/packages/socket.io/lib/client.ts +++ b/packages/socket.io/lib/client.ts @@ -18,7 +18,7 @@ export class Client< ListenEvents extends EventsMap, EmitEvents extends EventsMap, ServerSideEvents extends EventsMap, - SocketData = unknown, + SocketData = Record, > { public readonly conn: RawSocket; diff --git a/packages/socket.io/lib/contrib/yeast.ts b/packages/socket.io/lib/contrib/yeast.ts new file mode 100644 index 0000000..6cc5817 --- /dev/null +++ b/packages/socket.io/lib/contrib/yeast.ts @@ -0,0 +1,65 @@ +// imported from https://github.com/unshiftio/yeast +"use strict"; + +const alphabet = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-_".split( + "" + ), + length = 64, + map: Record = {}; +let seed = 0, + i = 0, + prev: string; + +/** + * Return a string representing the specified number. + * + * @param {Number} num The number to convert. + * @returns {String} The string representation of the number. + * @api public + */ +export function encode(num: number) { + let encoded = ""; + + do { + encoded = alphabet[num % length] + encoded; + num = Math.floor(num / length); + } while (num > 0); + + return encoded; +} + +/** + * Return the integer value specified by the given string. + * + * @param {String} str The string to convert. + * @returns {Number} The integer value represented by the string. + * @api public + */ +export function decode(str: string) { + let decoded = 0; + + for (i = 0; i < str.length; i++) { + decoded = decoded * length + map[str.charAt(i)]; + } + + return decoded; +} + +/** + * Yeast: A tiny growing id generator. + * + * @returns {String} A unique id. + * @api public + */ +export function yeast() { + const now = encode(+new Date()); + + if (now !== prev) return (seed = 0), (prev = now); + return now + "." + encode(seed++); +} + +// +// Map each character to its index. +// +for (; i < length; i++) map[alphabet[i]] = i; diff --git a/packages/socket.io/lib/namespace.ts b/packages/socket.io/lib/namespace.ts index 6005c3c..8fa425f 100644 --- a/packages/socket.io/lib/namespace.ts +++ b/packages/socket.io/lib/namespace.ts @@ -17,15 +17,18 @@ export interface NamespaceReservedEvents< EmitEvents extends EventsMap, ServerSideEvents extends EventsMap, SocketData, -> { + > { + connect: ( + socket: Socket + ) => void; connection: ( - socket: Socket, + socket: Socket ) => void; } export const RESERVED_EVENTS: ReadonlySet = new Set< keyof ServerReservedEvents ->(["connection", "new_namespace"] as const); + >(["connect", "connection", "new_namespace"] as const); /** * A Namespace is a communication channel that allows you to split the logic of your application over a single shared @@ -84,7 +87,7 @@ export class Namespace< ListenEvents extends EventsMap = DefaultEventsMap, EmitEvents extends EventsMap = DefaultEventsMap, ServerSideEvents extends EventsMap = DefaultEventsMap, - SocketData = unknown, + SocketData = Record, > extends EventEmitter< ServerSideEvents, EmitEvents, @@ -248,12 +251,15 @@ export class Namespace< getLogger("socket.io").debug( `[namespace] adding socket to nsp ${this.name}`, ); - const socket = new Socket< - ListenEvents, - EmitEvents, - ServerSideEvents, - SocketData - >(this, client, handshake); + const socket = await this._createSocket(client, handshake); + + if ( + this._server.opts.connectionStateRecovery?.skipMiddlewares && + socket.recovered && + client.conn.readyState === "open" + ) { + return this._doConnect(socket, callback); + } try { await this.run(socket); @@ -288,9 +294,68 @@ export class Namespace< callback(socket); // fire user-set events + this.emitReserved("connect", socket); this.emitReserved("connection", socket); } + private async _createSocket ( + client: Client, + handshake: Handshake + ): Promise> { + const sessionId = handshake.auth.pid; + const offset = handshake.auth.offset; + let session; + if ( + this._server.opts.connectionStateRecovery && + typeof sessionId === "string" && + typeof offset === "string" + ) { + try { + session = await this.adapter.restoreSession(sessionId, offset); + } catch (e) { + getLogger("socket.io").debug("error while restoring session: %s", e); + } + } + if (session) { + getLogger("socket.io").debug("connection state recovered for sid %s", session.sid); + return new Socket(this, client, handshake, session); + } else return new Socket(this, client, handshake); + } + + private async _doConnect( + socket: Socket, + fn: ( + socket: Socket + ) => void + ) { + try { + await this.run(socket); + } catch (err) { + getLogger("socket.io").debug( + "[namespace] middleware error, sending CONNECT_ERROR packet to the client", + ); + socket._cleanup(); + return socket._error({ + message: err.message || err, + data: err.data, + }); + } + + // track socket + this.sockets.set(socket.id, socket); + + // it's paramount that the internal `onconnect` logic + // fires before user-set events to prevent state order + // violations (such as a disconnection before the connection + // logic is complete) + socket._onconnect(); + if (fn) fn(socket); + + // fire user-set events + this.emitReserved("connect", socket); + this.emitReserved("connection", socket); + } + /** * Removes a client. Called by each `Socket`. * diff --git a/packages/socket.io/lib/parent-namespace.ts b/packages/socket.io/lib/parent-namespace.ts index 7517dfc..2efe50e 100644 --- a/packages/socket.io/lib/parent-namespace.ts +++ b/packages/socket.io/lib/parent-namespace.ts @@ -11,7 +11,7 @@ export class ParentNamespace< ListenEvents extends EventsMap = DefaultEventsMap, EmitEvents extends EventsMap = DefaultEventsMap, ServerSideEvents extends EventsMap = DefaultEventsMap, - SocketData = unknown, + SocketData = Record, > extends Namespace { private static count = 0; diff --git a/packages/socket.io/lib/server.ts b/packages/socket.io/lib/server.ts index b4da843..5d6f91a 100644 --- a/packages/socket.io/lib/server.ts +++ b/packages/socket.io/lib/server.ts @@ -15,7 +15,7 @@ import { Decoder, Encoder } from "../../socket.io-parser/mod.ts"; import { Namespace, NamespaceReservedEvents } from "./namespace.ts"; import { ParentNamespace } from "./parent-namespace.ts"; import { Socket } from "./socket.ts"; -import { Adapter, InMemoryAdapter, Room } from "./adapter.ts"; +import { Adapter, SessionAwareAdapter, Room } from "./adapter.ts"; import { BroadcastOperator, RemoteSocket } from "./broadcast-operator.ts"; export interface ServerOptions { @@ -29,6 +29,25 @@ export interface ServerOptions { * @default 45000 */ connectTimeout: number; + /** + * Whether to enable the recovery of connection state when a client temporarily disconnects. + * + * The connection state includes the missed packets, the rooms the socket was in and the `data` attribute. + */ + connectionStateRecovery?: { + /** + * The backup duration of the sessions and the packets. + * + * @default 120000 (2 minutes) + */ + maxDisconnectionDuration?: number; + /** + * Whether to skip middlewares upon successful connection state recovery. + * + * @default true + */ + skipMiddlewares?: boolean; + }; /** * The parser to use to encode and decode packets */ @@ -45,9 +64,9 @@ export interface ServerOptions { } export interface ServerReservedEvents< - ListenEvents, - EmitEvents, - ServerSideEvents, + ListenEvents extends EventsMap, + EmitEvents extends EventsMap, + ServerSideEvents extends EventsMap, SocketData, > extends NamespaceReservedEvents< @@ -104,7 +123,7 @@ export class Server< ListenEvents extends EventsMap = DefaultEventsMap, EmitEvents extends EventsMap = ListenEvents, ServerSideEvents extends EventsMap = DefaultEventsMap, - SocketData = unknown, + SocketData = Record, > extends EventEmitter< ListenEvents, EmitEvents, @@ -139,7 +158,7 @@ export class Server< constructor(opts: Partial = {}) { super(); - this.opts = Object.assign({ + opts = Object.assign({ path: "/socket.io/", connectTimeout: 45_000, parser: { @@ -149,12 +168,28 @@ export class Server< createDecoder() { return new Decoder(); }, - }, - adapter: ( - nsp: Namespace, - ) => new InMemoryAdapter(nsp), + } }, opts); + if (opts.connectionStateRecovery != null) { + opts.connectionStateRecovery = Object.assign( + { + maxDisconnectionDuration: 2 * 60 * 1000, + skipMiddlewares: true, + }, + opts.connectionStateRecovery + ) + opts.adapter = ( + nsp: Namespace, + ) => new SessionAwareAdapter(nsp) + } else { + opts.adapter = ( + nsp: Namespace, + ) => new Adapter(nsp) + } + + this.opts = opts as ServerOptions + this.engine = new Engine(this.opts); this.engine.on("connection", (conn, req, connInfo) => { diff --git a/packages/socket.io/lib/socket.ts b/packages/socket.io/lib/socket.ts index 0dd4fb6..b665939 100644 --- a/packages/socket.io/lib/socket.ts +++ b/packages/socket.io/lib/socket.ts @@ -7,7 +7,12 @@ import { EventParams, EventsMap, } from "../../event-emitter/mod.ts"; -import { Adapter, BroadcastFlags, Room, SocketId } from "./adapter.ts"; +import { Adapter, + BroadcastFlags, + Room, + SocketId, + Session, + PrivateSessionId } from "./adapter.ts"; import { generateId } from "../../engine.io/mod.ts"; import { Namespace } from "./namespace.ts"; import { Client } from "./client.ts"; @@ -23,9 +28,20 @@ type DisconnectReason = | "ping timeout" | "parse error" // Socket.IO disconnect reasons + | "server shutting down" + | "forced server close" | "client namespace disconnect" | "server namespace disconnect"; +const RECOVERABLE_DISCONNECT_REASONS: ReadonlySet = new Set([ + "transport error", + "transport close", + "forced close", + "ping timeout", + "server shutting down", + "forced server close", +]); + export interface SocketReservedEvents { disconnect: (reason: DisconnectReason) => void; disconnecting: (reason: DisconnectReason) => void; @@ -143,12 +159,13 @@ export type Event = [string, ...unknown[]]; * }); * }); */ + export class Socket< ListenEvents extends EventsMap = DefaultEventsMap, - EmitEvents extends EventsMap = DefaultEventsMap, + EmitEvents extends EventsMap = ListenEvents, ServerSideEvents extends EventsMap = DefaultEventsMap, - SocketData = unknown, -> extends EventEmitter< + SocketData = Record + > extends EventEmitter< ListenEvents, EmitEvents, SocketReservedEvents @@ -157,6 +174,11 @@ export class Socket< * An unique identifier for the session. */ public readonly id: SocketId; + /** + * Whether the connection state was recovered after a temporary disconnection. In that case, any missed packets will + * be transmitted to the client, the data attribute and the rooms will be restored. + */ + public recovered = false; /** * The handshake details. */ @@ -166,7 +188,6 @@ export class Socket< * {@link Server.fetchSockets()} method. */ public data: Partial = {}; - /** * Whether the socket is currently connected or not. * @@ -180,8 +201,24 @@ export class Socket< * }); */ public connected = false; - - private readonly nsp: Namespace< + /** + * The session ID, which must not be shared (unlike {@link id}). + * + * @private + */ + private readonly pid: PrivateSessionId; + /** + * Namespace identifier . + * + * @example + * const dynamicNsp = io.of(/^\/dynamic-\d+$/).on("connection", (socket) => { + * const newNamespace = socket.nsp; // newNamespace.name === "/dynamic-101" + * + * // broadcast to all clients in the given sub-namespace + * newNamespace.emit("hello"); + * }); + */ + readonly nsp: Namespace< ListenEvents, EmitEvents, ServerSideEvents, @@ -207,12 +244,30 @@ export class Socket< nsp: Namespace, client: Client, handshake: Handshake, + previousSession?: Session ) { super(); this.nsp = nsp; - this.id = generateId(); this.client = client; this.adapter = nsp.adapter; + if (previousSession) { + this.id = previousSession.sid; + this.pid = previousSession.pid; + previousSession.rooms.forEach((room) => this.join(room)); + this.data = previousSession.data as Partial; + previousSession.missedPackets.forEach((packet) => { + this.packet({ + type: PacketType.EVENT, + data: packet, + }); + }); + this.recovered = true; + } else { + this.id = generateId(); // don't reuse the Engine.IO id because it's sensitive information + if (this.nsp._server.opts.connectionStateRecovery) { + this.pid = generateId(); + } + } this.handshake = handshake; } @@ -261,12 +316,20 @@ export class Socket< const flags = Object.assign({}, this.flags); this.flags = {}; - - if (this.connected) { - this._notifyOutgoingListeners(packet.data); - this.packet(packet, flags); + if (this.nsp._server.opts.connectionStateRecovery) { + // this ensures the packet is stored and can be transmitted upon reconnection + this.adapter.broadcast(packet, { + rooms: new Set([this.id]), + except: new Set(), + flags, + }); } else { - this.#preConnectBuffer.push(packet); + if (this.connected) { + this._notifyOutgoingListeners(packet.data); + this.packet(packet, flags); + } else { + this.#preConnectBuffer.push(packet); + } } return true; @@ -475,6 +538,19 @@ export class Socket< if (!this.connected) return this; getLogger("socket.io").debug(`[socket] closing socket - reason ${reason}`); this.emitReserved("disconnecting", reason); + + if (RECOVERABLE_DISCONNECT_REASONS.has(reason)) { + getLogger("socket.io").debug(`connection state recovery is enabled for sid ${this.id}`); + getLogger("socket.io").debug(`SID: ${this.id}`); + getLogger("socket.io").debug(`PID: ${this.pid}`); + this.adapter.persistSession({ + sid: this.id, + pid: this.pid, + rooms: [...this.rooms], + data: this.data, + }); + } + this._cleanup(); this.nsp._remove(this); this.client._remove(this); @@ -611,7 +687,7 @@ export class Socket< getLogger("socket.io").debug("[socket] socket connected - writing packet"); this.connected = true; this.join(this.id); - this.packet({ type: PacketType.CONNECT, data: { sid: this.id } }); + this.packet({ type: PacketType.CONNECT, data: { sid: this.id, pid: this.pid } }); this.#preConnectBuffer.forEach((packet) => { this._notifyOutgoingListeners(packet.data); this.packet(packet);