/** * @zh RPC 服务端模块 * @en RPC Server Module */ import { WebSocketServer, WebSocket } from 'ws' import type { ProtocolDef, ApiNames, MsgNames, ApiInput, ApiOutput, MsgData, Packet, PacketType, Connection, } from '../types' import { RpcError, ErrorCode } from '../types' import { json } from '../codec/json' import type { Codec } from '../codec/types' import { ServerConnection } from './connection' // ============ Types ============ /** * @zh API 处理函数 * @en API handler function */ type ApiHandler = ( input: TInput, conn: Connection ) => TOutput | Promise /** * @zh 消息处理函数 * @en Message handler function */ type MsgHandler = ( data: TData, conn: Connection ) => void | Promise /** * @zh API 处理器映射 * @en API handlers map */ type ApiHandlers

= { [K in ApiNames

]: ApiHandler< ApiInput, ApiOutput, TConnData > } /** * @zh 消息处理器映射 * @en Message handlers map */ type MsgHandlers

= { [K in MsgNames

]?: MsgHandler, TConnData> } /** * @zh 服务器配置 * @en Server options */ export interface ServeOptions

{ /** * @zh 监听端口 * @en Listen port */ port: number /** * @zh API 处理器 * @en API handlers */ api: ApiHandlers /** * @zh 消息处理器 * @en Message handlers */ msg?: MsgHandlers /** * @zh 编解码器 * @en Codec * @defaultValue json() */ codec?: Codec /** * @zh 连接初始数据工厂 * @en Connection initial data factory */ createConnData?: () => TConnData /** * @zh 连接建立回调 * @en Connection established callback */ onConnect?: (conn: Connection) => void | Promise /** * @zh 连接断开回调 * @en Connection closed callback */ onDisconnect?: (conn: Connection, reason?: string) => void | Promise /** * @zh 错误回调 * @en Error callback */ onError?: (error: Error, conn?: Connection) => void /** * @zh 服务器启动回调 * @en Server started callback */ onStart?: (port: number) => void } /** * @zh RPC 服务器实例 * @en RPC Server instance */ export interface RpcServer

{ /** * @zh 启动服务器 * @en Start server */ start(): Promise /** * @zh 停止服务器 * @en Stop server */ stop(): Promise /** * @zh 获取所有连接 * @en Get all connections */ readonly connections: ReadonlyArray> /** * @zh 向单个连接发送消息 * @en Send message to a single connection */ send>( conn: Connection, name: K, data: MsgData ): void /** * @zh 广播消息给所有连接 * @en Broadcast message to all connections */ broadcast>( name: K, data: MsgData, options?: { exclude?: Connection | Connection[] } ): void } // ============ Implementation ============ const PT = { ApiRequest: 0, ApiResponse: 1, ApiError: 2, Message: 3, Heartbeat: 9, } as const /** * @zh 创建 RPC 服务器 * @en Create RPC server * * @example * ```typescript * const server = serve(protocol, { * port: 3000, * api: { * join: async (input, conn) => { * return { id: conn.id } * }, * }, * }) * await server.start() * ``` */ export function serve

( _protocol: P, options: ServeOptions ): RpcServer { const codec = options.codec ?? json() const connections: ServerConnection[] = [] let wss: WebSocketServer | null = null let connIdCounter = 0 const getClientIp = (ws: WebSocket, req: any): string => { return req?.headers?.['x-forwarded-for']?.split(',')[0]?.trim() || req?.socket?.remoteAddress || 'unknown' } const handleMessage = async ( conn: ServerConnection, data: string | Buffer ): Promise => { try { const packet = codec.decode( typeof data === 'string' ? data : new Uint8Array(data) ) const type = packet[0] if (type === PT.ApiRequest) { const [, id, path, input] = packet as [number, number, string, unknown] await handleApiRequest(conn, id, path, input) } else if (type === PT.Message) { const [, path, msgData] = packet as [number, string, unknown] await handleMsg(conn, path, msgData) } else if (type === PT.Heartbeat) { conn.send(codec.encode([PT.Heartbeat])) } } catch (err) { options.onError?.(err as Error, conn) } } const handleApiRequest = async ( conn: ServerConnection, id: number, path: string, input: unknown ): Promise => { const handler = (options.api as any)[path] if (!handler) { const errPacket: Packet = [PT.ApiError, id, ErrorCode.NOT_FOUND, `API not found: ${path}`] conn.send(codec.encode(errPacket)) return } try { const result = await handler(input, conn) const resPacket: Packet = [PT.ApiResponse, id, result] conn.send(codec.encode(resPacket)) } catch (err) { if (err instanceof RpcError) { const errPacket: Packet = [PT.ApiError, id, err.code, err.message] conn.send(codec.encode(errPacket)) } else { const errPacket: Packet = [PT.ApiError, id, ErrorCode.INTERNAL_ERROR, 'Internal server error'] conn.send(codec.encode(errPacket)) options.onError?.(err as Error, conn) } } } const handleMsg = async ( conn: ServerConnection, path: string, data: unknown ): Promise => { const handler = options.msg?.[path as MsgNames

] if (handler) { await (handler as any)(data, conn) } } const server: RpcServer = { get connections() { return connections as ReadonlyArray> }, async start() { return new Promise((resolve) => { wss = new WebSocketServer({ port: options.port }) wss.on('connection', async (ws, req) => { const id = String(++connIdCounter) const ip = getClientIp(ws, req) const initialData = options.createConnData?.() ?? ({} as TConnData) const conn = new ServerConnection({ id, ip, socket: ws, initialData, onClose: () => { const idx = connections.indexOf(conn) if (idx !== -1) connections.splice(idx, 1) }, }) connections.push(conn) ws.on('message', (data) => { handleMessage(conn, data as string | Buffer) }) ws.on('close', async (code, reason) => { conn._markClosed() const idx = connections.indexOf(conn) if (idx !== -1) connections.splice(idx, 1) await options.onDisconnect?.(conn, reason?.toString()) }) ws.on('error', (err) => { options.onError?.(err, conn) }) await options.onConnect?.(conn) }) wss.on('listening', () => { options.onStart?.(options.port) resolve() }) }) }, async stop() { return new Promise((resolve, reject) => { if (!wss) { resolve() return } for (const conn of connections) { conn.close('Server shutting down') } wss.close((err) => { if (err) reject(err) else resolve() }) }) }, send(conn, name, data) { const packet: Packet = [PT.Message, name as string, data] ;(conn as ServerConnection).send(codec.encode(packet)) }, broadcast(name, data, opts) { const packet: Packet = [PT.Message, name as string, data] const encoded = codec.encode(packet) const excludeSet = new Set( Array.isArray(opts?.exclude) ? opts.exclude : opts?.exclude ? [opts.exclude] : [] ) for (const conn of connections) { if (!excludeSet.has(conn)) { conn.send(encoded) } } }, } return server }