diff --git a/__tests__/e2e/rpc-protected-method.test.ts b/__tests__/e2e/rpc-protected-method.test.ts new file mode 100644 index 0000000..a0315ee --- /dev/null +++ b/__tests__/e2e/rpc-protected-method.test.ts @@ -0,0 +1,63 @@ +import { isPublicMethod, markAsPublicMethod, publicMethod, RPCErrorCode, RPCHandler } from "@/index" + +describe('Rpc protected method test', () => { + test('disabled protection', async () => { + class Methods { + @publicMethod + allow() { + return 0; + } + disabled() { } + } + + const classMethods = new Methods(); + const provider = { + classMethods, + normal: markAsPublicMethod(() => 0), + normal2: markAsPublicMethod(function () { return 0 }), + normalProtected: function () { }, + shallowObj: markAsPublicMethod({ + fn1: () => 0, + l1: { + fn1: () => 0, + } + }), + deepObj: markAsPublicMethod({ + fn1: () => 0, + l1: { + fn1: () => 0, + } + }, { deep: true }), + } + + const server = new RPCHandler({ + enableMethodProtection: true, + }); + server.setProvider(provider); + await server.listen({ + port: 5210 + }); + + + const client = new RPCHandler(); + const session = await client.connect({ + url: 'http://localhost:5210' + }); + const api = session.getAPI(); + await expect(api.classMethods.allow()).resolves.toBe(0); + await expect(api.classMethods.disabled()).rejects + .toHaveProperty('errorCode', RPCErrorCode.METHOD_PROTECTED); + await expect(api.normal()).resolves.toBe(0); + await expect(api.normal2()).resolves.toBe(0); + await expect(api.normalProtected()).rejects + .toHaveProperty('errorCode', RPCErrorCode.METHOD_PROTECTED); + + await expect(api.shallowObj.fn1()).resolves.toBe(0); + await expect(api.shallowObj.l1.fn1()).rejects + .toHaveProperty('errorCode', RPCErrorCode.METHOD_PROTECTED); + + await expect(api.deepObj.fn1()).resolves.toBe(0); + await expect(api.deepObj.l1.fn1()).resolves.toBe(0); + }) + +}) diff --git a/src/core/RPCSession.ts b/src/core/RPCSession.ts index c231a93..d76e799 100644 --- a/src/core/RPCSession.ts +++ b/src/core/RPCSession.ts @@ -1,4 +1,4 @@ -import { ToDeepPromise } from "@/utils/utils"; +import { isPublicMethod, ToDeepPromise } from "@/utils/utils"; import { RPCConnection } from "./RPCConnection"; import { RPCHandler } from "./RPCHandler"; import { RPCProvider } from "./RPCProvider"; @@ -9,7 +9,8 @@ import { makeCallPacket, makeCallResponsePacket, parseCallPacket, parseCallRespo import { RPCPacket } from "./RPCPacket"; import { EventEmitter } from "@/utils/EventEmitter"; -function getProviderFunction(provider: RPCProvider, fnPath: string) { +function getProviderFunction(provider: RPCProvider, fnPath: string): + [(...args: any[]) => Promise, object] | null { const paths = fnPath.split(':'); let fnThis: any = provider; let fn: any = provider; @@ -22,7 +23,7 @@ function getProviderFunction(provider: RPCProvider, fnPath: string) { } } if (typeof fn === 'function') { - return fn.bind(fnThis); + return [fn, fnThis]; } throw new Error(); @@ -64,7 +65,13 @@ export class RPCSession { this.callResponseEmitter.removeAllListeners(); }); - this.connection.on('call', this.onCallRequest.bind(this)); + this.connection.on('call', (packet) => { + this.onCallRequest(packet).then(res => { + this.connection.send(res); + }).catch((e) => { + console.warn(`${e}`); + }) + }); } getAPI(): ToDeepPromise { @@ -162,45 +169,57 @@ export class RPCSession { }); } - private async onCallRequest(packet: RPCPacket) { + private async onCallRequest(packet: RPCPacket): Promise { const request = parseCallPacket(packet); if (request === null) { - return this.connection.send(makeCallResponsePacket({ + return makeCallResponsePacket({ status: 'error', requestPacket: packet, errorCode: RPCErrorCode.CALL_PROTOCOL_ERROR, - })).catch(() => { }) + }); } // call the function const provider = this.rpcHandler.getProvider(); if (!provider) { - return this.connection.send(makeCallResponsePacket({ + return makeCallResponsePacket({ status: 'error', requestPacket: packet, errorCode: RPCErrorCode.PROVIDER_NOT_AVAILABLE, - })) + }); } const { fnPath, args } = request; - const fn = getProviderFunction(provider, fnPath); - if (!fn) { - return this.connection.send(makeCallResponsePacket({ + const fnRes = getProviderFunction(provider, fnPath); + if (!fnRes) { + return makeCallResponsePacket({ status: 'error', requestPacket: packet, errorCode: RPCErrorCode.METHOD_NOT_FOUND, - })) + }) + } + const [fn, fnThis] = fnRes; + + const { enableMethodProtection } = this.rpcHandler.getConfig(); + if (enableMethodProtection) { + if (!isPublicMethod(fn)) { + return makeCallResponsePacket({ + status: 'error', + requestPacket: packet, + errorCode: RPCErrorCode.METHOD_PROTECTED, + }) + } } try { - const result = await fn(...args); - this.connection.send(makeCallResponsePacket({ + const result = await fn.bind(fnThis)(...args); + return makeCallResponsePacket({ status: 'success', requestPacket: packet, data: result, - })) + }) } catch (error) { - this.connection.send(makeCallResponsePacket({ + return makeCallResponsePacket({ status: 'error', requestPacket: packet, errorCode: RPCErrorCode.SERVER_ERROR, @@ -208,7 +227,7 @@ export class RPCSession { errorCode: error.errorCode, reason: error.reason, } : {}) - })) + }) } } } \ No newline at end of file