diff --git a/__tests__/e2e/rpc-context.test.ts b/__tests__/e2e/rpc-context.test.ts new file mode 100644 index 0000000..98f2a99 --- /dev/null +++ b/__tests__/e2e/rpc-context.test.ts @@ -0,0 +1,41 @@ +import { getRPCContext, RPCContextKey } from "@/core/RPCContext"; +import { RPCHandler, RPCSession } from "@/index" +import { getRandomAvailablePort } from "@/utils/utils"; + +describe('rpc-context.test', () => { + test('context.session', async () => { + const accessed = new WeakMap(); + const provider = { + login(name: string) { + const context = getRPCContext(this); + if (!context) { + throw new Error('context is null'); + } + + accessed.set(context.session, name); + return name; + }, + me() { + const context = getRPCContext(this); + if (!context) { + throw new Error('context is null'); + } + + return accessed.get(context.session) || 'unknown user'; + }, + } + + const server = new RPCHandler(); + const client = new RPCHandler(); + + server.setProvider(provider); + const port = await getRandomAvailablePort(); + await server.listen({ port }); + + const session = await client.connect({ url: `http://localhost:${port}` }); + const api = session.getAPI(); + const clientname = 'clientname'; + await expect(api.login(clientname)).resolves.toBe(clientname); + await expect(api.me()).resolves.toBe(clientname); + }) +}) \ No newline at end of file diff --git a/__tests__/e2e/rpc-disconnected.test.ts b/__tests__/e2e/rpc-disconnected.test.ts index e94ea73..e20fc5f 100644 --- a/__tests__/e2e/rpc-disconnected.test.ts +++ b/__tests__/e2e/rpc-disconnected.test.ts @@ -33,8 +33,11 @@ describe('Rpc disconnected test', () => { errorCode: RPCErrorCode.CONNECTION_DISCONNECTED }) ); - await expect(callPromise).rejects.toBeInstanceOf(RPCError); - await expect(callPromise).rejects - .toHaveProperty('errorCode', RPCErrorCode.CONNECTION_DISCONNECTED); + await expect(callPromise).rejects.toMatchObject( + expect.objectContaining({ + constructor: RPCError, + errorCode: RPCErrorCode.CONNECTION_DISCONNECTED + }) + ); }) }) diff --git a/__tests__/e2e/rpc-plugin/rpc-plugin.ctx.session.test.ts b/__tests__/e2e/rpc-plugin/rpc-plugin.ctx.session.test.ts new file mode 100644 index 0000000..5ea4736 --- /dev/null +++ b/__tests__/e2e/rpc-plugin/rpc-plugin.ctx.session.test.ts @@ -0,0 +1,49 @@ +import { RPCError } from "@/core/RPCError"; +import { CallIncomingBeforeCtx, NormalMethodReturn } from "@/core/RPCPlugin"; +import { AbstractRPCPlugin, RPCHandler } from "@/index"; +import { getRandomAvailablePort, isObject } from "@/utils/utils"; + +describe('rpc-plugin.ctx.session.test', () => { + const userInfo = 'userinfo'; + const users = new WeakMap(); + class RPCTestPlugin implements AbstractRPCPlugin { + onCallIncomingBefore(ctx: CallIncomingBeforeCtx): NormalMethodReturn { + if (users.has(ctx.session)) { + return; + } + + if (isObject(ctx.request)) { + if (ctx.request.fnPath === 'login') { + users.set(ctx.session, userInfo); + return; + } + } + + throw new RPCError({ + reason: 'not login' + }); + } + } + + test('session', async () => { + const server = new RPCHandler(); + const loginRes = 'login'; + const authRes = 'auth'; + const provider = { + login() { return loginRes }, + auth() { return authRes } + } + server.setProvider(provider); + server.loadPlugin(new RPCTestPlugin()); + const client = new RPCHandler(); + + const port = await getRandomAvailablePort(); + await server.listen({ port }); + + const session = await client.connect({ url: `http://localhost:${port}` }); + const api = session.getAPI(); + await expect(api.auth()).rejects.toMatchObject({ reason: 'not login' }); + await expect(api.login()).resolves.toBe(loginRes); + await expect(api.auth()).resolves.toBe(authRes); + }); +}) \ No newline at end of file diff --git a/__tests__/e2e/rpc-plugin/rpc-plugin.hook.test.ts b/__tests__/e2e/rpc-plugin/rpc-plugin.hook.test.ts new file mode 100644 index 0000000..b54ee49 --- /dev/null +++ b/__tests__/e2e/rpc-plugin/rpc-plugin.hook.test.ts @@ -0,0 +1,42 @@ +import { CallIncomingBeforeCtx, CallIncomingCtx, CallOutgoingBeforeCtx, CallOutgoingCtx, NormalMethodReturn } from "@/core/RPCPlugin"; +import { AbstractRPCPlugin, RPCHandler } from "@/index"; +import { getRandomAvailablePort } from "@/utils/utils"; + +describe('rpc-plugin.hook.test', () => { + let count = 0; + class RPCTestPlugin implements AbstractRPCPlugin { + onCallIncomingBefore(ctx: CallIncomingBeforeCtx): NormalMethodReturn { + count += 1; + } + onCallIncoming(ctx: CallIncomingCtx): NormalMethodReturn { + count += 2; + } + onCallOutgoingBefore(ctx: CallOutgoingBeforeCtx): NormalMethodReturn { + count += 4; + } + onCallOutgoing(ctx: CallOutgoingCtx): NormalMethodReturn { + count += 8; + } + } + const CountShouldBe = 15; + + test('count', async () => { + const server = new RPCHandler(); + const provider = { + add(a: number, b: number) { return a + b } + } + server.setProvider(provider); + server.loadPlugin(new RPCTestPlugin()); + const client = new RPCHandler(); + client.loadPlugin(new RPCTestPlugin()); + + const port = await getRandomAvailablePort(); + await server.listen({ port }); + + + const session = await client.connect({ url: `http://localhost:${port}` }); + const api = session.getAPI(); + await expect(api.add(1, 2)).resolves.toBe(3); + expect(count).toBe(CountShouldBe); + }); +}) \ No newline at end of file diff --git a/__tests__/unit/core/RPCPlugin.test.ts b/__tests__/unit/core/RPCPlugin.test.ts new file mode 100644 index 0000000..5fb797d --- /dev/null +++ b/__tests__/unit/core/RPCPlugin.test.ts @@ -0,0 +1,43 @@ +import { createHookRunner, RPCPlugin } from "@/core/RPCPlugin" +import type { CallIncomingCtx } from "@/core/RPCPlugin" + +describe('RPCPlugin.test', () => { + const plugin3 = { + onCallIncoming(ctx: CallIncomingCtx) { throw new Error() } + } as RPCPlugin; + const plugin4 = { + async onCallIncoming(ctx: CallIncomingCtx) { throw new Error() } + } as RPCPlugin; + + test('should be resolved', async () => { + const plugin1 = { + onCallIncoming: jest.fn(), + } as RPCPlugin; + const plugin2 = { + async onCallIncoming(ctx: CallIncomingCtx) { } + } as RPCPlugin; + + const plugins = [plugin1, plugin2]; + const hookRunner = createHookRunner(plugins, 'onCallIncoming'); + await hookRunner({} as any); + expect(plugin1.onCallIncoming).toHaveBeenCalled() + }) + test('should be resolved2', async () => { + const plugins = [] as RPCPlugin[]; + const hookRunner = createHookRunner(plugins, 'onCallIncoming'); + await hookRunner({} as any); + expect.assertions(0); + }) + + test('should be rejected1', async () => { + const plugins = [plugin3]; + const hookRunner = createHookRunner(plugins, 'onCallIncoming'); + await expect(hookRunner({} as any)).rejects.toThrow(Error) + }) + + test('should be rejected2', async () => { + const plugins = [plugin4]; + const hookRunner = createHookRunner(plugins, 'onCallIncoming'); + await expect(hookRunner({} as any)).rejects.toThrow(Error) + }) +}) \ No newline at end of file diff --git a/src/core/RPCContext.ts b/src/core/RPCContext.ts new file mode 100644 index 0000000..89ede25 --- /dev/null +++ b/src/core/RPCContext.ts @@ -0,0 +1,27 @@ +import { isObject } from "@/utils/utils"; +import { RPCSession } from "./RPCSession"; + +export interface RPCContext { + session: RPCSession; +} + +export const RPCContextFlag = Symbol(); +export const RPCContextKey = '__rpc_context'; + +export class RPCContextConstractor { + public [RPCContextKey] = RPCContextFlag; + constructor(public session: RPCSession) { } +} + +export const getRPCContext = (self: unknown): RPCContext | null => { + if (!isObject(self)) { + return null; + } + + const rpcContext = self[RPCContextKey]; + if (!isObject(rpcContext) || rpcContext[RPCContextKey] !== RPCContextFlag) { + return null; + } + + return rpcContext as unknown as RPCContext; +} \ No newline at end of file diff --git a/src/core/RPCHandler.ts b/src/core/RPCHandler.ts index 1dba807..03959a0 100644 --- a/src/core/RPCHandler.ts +++ b/src/core/RPCHandler.ts @@ -3,6 +3,7 @@ import { RPCClient } from "./RPCClient"; import { RPCServer } from "./RPCServer"; import { RPCProvider } from "./RPCProvider"; import { RPCSession } from "./RPCSession"; +import { RPCPlugin } from "./RPCPlugin"; const DefaultListenOptions = { port: 5201, @@ -34,6 +35,7 @@ export class RPCHandler extends EventEmitter { private provider?: RPCProvider; private accessKey?: string; private config: RPCConfig; + private plugins: RPCPlugin[] = []; constructor( args?: { @@ -88,6 +90,29 @@ export class RPCHandler extends EventEmitter { } } + loadPlugin(plugin: RPCPlugin): boolean { + const plugins = this.plugins; + if (plugins.includes(plugin)) { + return false; + } + plugins.push(plugin); + return true; + } + + unloadPlugin(plugin: RPCPlugin): boolean { + const plugins = this.plugins; + const idx = plugins.indexOf(plugin); + if (idx === -1) { + return false; + } + plugins.splice(idx, 1); + return true; + } + + getPlugins(): RPCPlugin[] { + return [...this.plugins]; + } + async connect(options: { url?: string; accessKey?: string; diff --git a/src/core/RPCPlugin.ts b/src/core/RPCPlugin.ts new file mode 100644 index 0000000..96207d9 --- /dev/null +++ b/src/core/RPCPlugin.ts @@ -0,0 +1,75 @@ +import { RPCSession } from "./RPCSession"; + +// interface BaseHookRuntimeCtx { +// nexts: RPCPlugin[]; +// setNextPlugins: (plugins: RPCPlugin[]) => void +// } + +export interface BaseHookCtx { + +} + +export interface CallOutgoingBeforeCtx extends BaseHookCtx { + session: RPCSession; + options: unknown; + setOptions: (opt: unknown) => void; +} + +export interface CallOutgoingCtx extends CallOutgoingBeforeCtx { + result: unknown; + setResult: (res: unknown) => void; +} + +export interface CallIncomingBeforeCtx extends BaseHookCtx { + session: RPCSession; + request: unknown; + setRequest: (req: unknown) => void; +} + +export interface CallIncomingCtx extends CallIncomingBeforeCtx { + response: unknown; + setResponse: (res: unknown) => void; +} + +export type NormalMethodReturn = Promise | void; +export type HookFn = (ctx: Ctx) => NormalMethodReturn; + +export interface RPCPluginHooksCtx { + onCallOutgoingBefore: CallOutgoingBeforeCtx; + onCallOutgoing: CallOutgoingCtx; + onCallIncomingBefore: CallIncomingBeforeCtx; + onCallIncoming: CallIncomingCtx; +} + +export type RPCPluginHooks = { + [K in keyof RPCPluginHooksCtx]?: HookFn; +}; + + +export interface RPCPlugin extends RPCPluginHooks { + onInit?(): void; + onDestroy?(): void; +} + +export abstract class AbstractRPCPlugin implements RPCPlugin { + // abstract onInit?(): void; + // abstract onDestroy?(): void; + abstract onCallOutgoingBefore?(ctx: CallOutgoingBeforeCtx): NormalMethodReturn; + abstract onCallOutgoing?(ctx: CallOutgoingCtx): NormalMethodReturn; + abstract onCallIncomingBefore?(ctx: CallIncomingBeforeCtx): NormalMethodReturn; + abstract onCallIncoming?(ctx: CallIncomingCtx): NormalMethodReturn; +} + +type HookName = keyof RPCPluginHooksCtx; + +type HookRunner = (ctx: Ctx) => Promise; +export function createHookRunner( + plugins: RPCPlugin[], + hookName: K, +): HookRunner { + return async (ctx: RPCPluginHooksCtx[K]) => { + for (const plugin of plugins) { + await (plugin[hookName] as HookFn | undefined)?.(ctx); + } + }; +} \ No newline at end of file diff --git a/src/core/RPCSession.ts b/src/core/RPCSession.ts index d76e799..9285fb0 100644 --- a/src/core/RPCSession.ts +++ b/src/core/RPCSession.ts @@ -1,4 +1,4 @@ -import { isPublicMethod, ToDeepPromise } from "@/utils/utils"; +import { isArray, isObject, isPublicMethod, ToDeepPromise } from "@/utils/utils"; import { RPCConnection } from "./RPCConnection"; import { RPCHandler } from "./RPCHandler"; import { RPCProvider } from "./RPCProvider"; @@ -8,6 +8,8 @@ import { RPCError, RPCErrorCode } from "./RPCError"; import { makeCallPacket, makeCallResponsePacket, parseCallPacket, parseCallResponsePacket } from "./RPCCommon"; import { RPCPacket } from "./RPCPacket"; import { EventEmitter } from "@/utils/EventEmitter"; +import { createHookRunner } from "./RPCPlugin"; +import { RPCContextConstractor, RPCContextKey } from "./RPCContext"; function getProviderFunction(provider: RPCProvider, fnPath: string): [(...args: any[]) => Promise, object] | null { @@ -111,6 +113,45 @@ export class RPCSession { }); } + function setOptions(opt: unknown) { + if (!isObject(opt)) { + return; + } + + if ('fnPath' in opt && 'args' in opt && 'timeout' in opt) { + const { fnPath, args, timeout } = opt; + if (typeof fnPath !== 'string') { + return; + } + if (!isArray(args)) { + return; + } + if (typeof timeout !== 'number') { + return; + } + options = { + ...options, + fnPath, + args, + timeout, + } + } + } + + const hookRunner = createHookRunner(this.rpcHandler.getPlugins(), 'onCallOutgoingBefore'); + await hookRunner({ + session: this, + options: { ...options }, + setOptions, + }); + + /** due to `await hookRunner` */ + if (this.connection.closed) { + throw new RPCError({ + errorCode: RPCErrorCode.CONNECTION_DISCONNECTED, + }); + } + const { fnPath, args } = options; const packet = makeCallPacket({ fnPath, @@ -135,28 +176,76 @@ export class RPCSession { })(); const handleCallResponsePacket = (packet: RPCPacket) => { - const result = parseCallResponsePacket(packet); - if (result === null) { - return reject(new RPCError({ + let result = parseCallResponsePacket(packet); + function setResult(res: unknown) { + if (!isObject(res)) { + return; + } + + const { success, error } = res; + + if (typeof success === 'object' && typeof error === 'object') { + if (success && !error) { + if ('data' in success) { + result = { + success: { data: success.data }, + error: null, + } + } + } else if (!success && error) { + const { errorCode, reason } = error; + if (typeof errorCode === 'number' && typeof reason === 'string') { + result = { + success: null, + error: { + errorCode, + reason, + }, + } + } + } + } + } + + const hookRunner = createHookRunner(this.rpcHandler.getPlugins(), 'onCallOutgoing'); + hookRunner({ + session: this, + options: { ...options }, + setOptions, + result, + setResult, + }).then(() => { + if (result === null) { + return reject(new RPCError({ + errorCode: RPCErrorCode.UNKNOWN_ERROR, + }));; + } + + const { success, error } = result; + if (success) { + return resolve(success.data); + } + + if (error) { + return reject(new RPCError({ + errorCode: error.errorCode, + reason: error.reason + })); + } + + reject(new RPCError({ errorCode: RPCErrorCode.UNKNOWN_ERROR, }));; - } + }).catch((e) => { + if (e instanceof RPCError) { + return reject(e); + } - const { success, error } = result; - if (success) { - return resolve(success.data); - } - - if (error) { - return reject(new RPCError({ - errorCode: error.errorCode, - reason: error.reason - })); - } - - return reject(new RPCError({ - errorCode: RPCErrorCode.UNKNOWN_ERROR, - }));; + reject(new RPCError({ + errorCode: RPCErrorCode.UNKNOWN_ERROR, + reason: e instanceof Error ? e.message : `${e}` + })) + }) } this.callResponseEmitter.once(packet.id, handleCallResponsePacket); @@ -170,9 +259,81 @@ export class RPCSession { } private async onCallRequest(packet: RPCPacket): Promise { - const request = parseCallPacket(packet); - if (request === null) { + let request = parseCallPacket(packet); + const instance = this; + + function setRequest(req: unknown) { + if (!isObject(req)) { + return; + } + const { fnPath, args } = req; + if (typeof fnPath === 'string' && isArray(args)) { + request = { + ...request, + fnPath, + args, + }; + } + } + + const hookRunnerContext = { + session: this, + request, + setRequest, + } + + + + async function makeResponse(o: Parameters[0]) { + const hookRunner = createHookRunner(instance.rpcHandler.getPlugins(), 'onCallIncoming'); + try { + await hookRunner({ + ...hookRunnerContext, + response: o, + setResponse: (res: unknown) => { + /** TODO: Implement stricter validation of the response object in the future */ + o = res as Parameters[0]; + }, + }) + } catch (error) { + return makeCallResponsePacket({ + ...o, + ...(error instanceof RPCError ? { + errorCode: error.errorCode, + reason: error.reason, + } : { + errorCode: RPCErrorCode.SERVER_ERROR, + reason: `${error}` + }) + }) + } + return makeCallResponsePacket({ + ...o, + }) + } + + + + const hookRunner = createHookRunner(this.rpcHandler.getPlugins(), 'onCallIncomingBefore'); + try { + await hookRunner(hookRunnerContext); + } catch (error) { + return makeResponse({ + status: 'error', + requestPacket: packet, + ...(error instanceof RPCError ? { + errorCode: error.errorCode, + reason: error.reason + } : { + errorCode: RPCErrorCode.SERVER_ERROR, + reason: `${error}`, + }) + }) + } + + if (request === null) { + return makeResponse({ status: 'error', requestPacket: packet, errorCode: RPCErrorCode.CALL_PROTOCOL_ERROR, @@ -182,7 +343,7 @@ export class RPCSession { // call the function const provider = this.rpcHandler.getProvider(); if (!provider) { - return makeCallResponsePacket({ + return makeResponse({ status: 'error', requestPacket: packet, errorCode: RPCErrorCode.PROVIDER_NOT_AVAILABLE, @@ -192,7 +353,7 @@ export class RPCSession { const { fnPath, args } = request; const fnRes = getProviderFunction(provider, fnPath); if (!fnRes) { - return makeCallResponsePacket({ + return makeResponse({ status: 'error', requestPacket: packet, errorCode: RPCErrorCode.METHOD_NOT_FOUND, @@ -203,7 +364,7 @@ export class RPCSession { const { enableMethodProtection } = this.rpcHandler.getConfig(); if (enableMethodProtection) { if (!isPublicMethod(fn)) { - return makeCallResponsePacket({ + return makeResponse({ status: 'error', requestPacket: packet, errorCode: RPCErrorCode.METHOD_PROTECTED, @@ -212,14 +373,21 @@ export class RPCSession { } try { - const result = await fn.bind(fnThis)(...args); - return makeCallResponsePacket({ + const result = await fn.bind(new Proxy(fnThis, { + get(target, p, receiver) { + if (p === RPCContextKey) { + return new RPCContextConstractor(instance); + } + return Reflect.get(target, p, receiver); + }, + }))(...args); + return makeResponse({ status: 'success', requestPacket: packet, data: result, }) } catch (error) { - return makeCallResponsePacket({ + return makeResponse({ status: 'error', requestPacket: packet, errorCode: RPCErrorCode.SERVER_ERROR, diff --git a/src/index.ts b/src/index.ts index 5781f33..e461501 100644 --- a/src/index.ts +++ b/src/index.ts @@ -12,6 +12,10 @@ export { SocketConnection } from "./core/SocketConnection"; export { SocketServer } from "./core/SocketServer"; export { EventEmitter } from "./utils/EventEmitter"; +export { AbstractRPCPlugin, createHookRunner } from './core/RPCPlugin'; +export type { RPCPlugin, RPCPluginHooks, RPCPluginHooksCtx } from './core/RPCPlugin'; +export { getRPCContext } from './core/RPCContext'; + export { injectSocketClient } from "./core/SocketClient"; export { injectSocketServer } from "./core/SocketServer"; import { injectSocketIOImplements } from "./implements/socket.io"; diff --git a/src/utils/utils.ts b/src/utils/utils.ts index b164b99..2fbda51 100644 --- a/src/utils/utils.ts +++ b/src/utils/utils.ts @@ -81,4 +81,20 @@ export function markAsPublicMethod | u markAs(obj); return obj; +} + +export function createDeferrablePromise() { + let resolve!: (value: T | PromiseLike) => void; + let reject!: (reason?: unknown) => void; + + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + + return { + promise, + resolve, + reject + }; } \ No newline at end of file