diff --git a/src/core/RPCSession.ts b/src/core/RPCSession.ts index 5dced55..a099c80 100644 --- a/src/core/RPCSession.ts +++ b/src/core/RPCSession.ts @@ -257,9 +257,81 @@ export class RPCSession { } private async onCallRequest(packet: RPCPacket): Promise { - const request = parseCallPacket(packet); - if (request === null) { + let request = parseCallPacket(packet); + const instance = this; + + function setRequest(req: unknown) { + if (!isObject(req)) { + return; + } + const { fnPath, args } = req; + if (typeof fnPath === 'string' && isArray(args)) { + request = { + ...request, + fnPath, + args, + }; + } + } + + const hookRunnerContext = { + session: this, + request, + setRequest, + } + + + + async function makeResponse(o: Parameters[0]) { + const hookRunner = createHookRunner(instance.rpcHandler.getPlugins(), 'onCallIncoming'); + try { + await hookRunner({ + ...hookRunnerContext, + response: o, + setResponse: (res: unknown) => { + /** TODO: Implement stricter validation of the response object in the future */ + o = res as Parameters[0]; + }, + }) + } catch (error) { + return makeCallResponsePacket({ + ...o, + ...(error instanceof RPCError ? { + errorCode: error.errorCode, + reason: error.reason, + } : { + errorCode: RPCErrorCode.SERVER_ERROR, + reason: `${error}` + }) + }) + } + return makeCallResponsePacket({ + ...o, + }) + } + + + + const hookRunner = createHookRunner(this.rpcHandler.getPlugins(), 'onCallIncomingBefore'); + try { + await hookRunner(hookRunnerContext); + } catch (error) { + return makeResponse({ + status: 'error', + requestPacket: packet, + ...(error instanceof RPCError ? { + errorCode: error.errorCode, + reason: error.reason + } : { + errorCode: RPCErrorCode.SERVER_ERROR, + reason: `${error}`, + }) + }) + } + + if (request === null) { + return makeResponse({ status: 'error', requestPacket: packet, errorCode: RPCErrorCode.CALL_PROTOCOL_ERROR, @@ -269,7 +341,7 @@ export class RPCSession { // call the function const provider = this.rpcHandler.getProvider(); if (!provider) { - return makeCallResponsePacket({ + return makeResponse({ status: 'error', requestPacket: packet, errorCode: RPCErrorCode.PROVIDER_NOT_AVAILABLE, @@ -279,7 +351,7 @@ export class RPCSession { const { fnPath, args } = request; const fnRes = getProviderFunction(provider, fnPath); if (!fnRes) { - return makeCallResponsePacket({ + return makeResponse({ status: 'error', requestPacket: packet, errorCode: RPCErrorCode.METHOD_NOT_FOUND, @@ -290,7 +362,7 @@ export class RPCSession { const { enableMethodProtection } = this.rpcHandler.getConfig(); if (enableMethodProtection) { if (!isPublicMethod(fn)) { - return makeCallResponsePacket({ + return makeResponse({ status: 'error', requestPacket: packet, errorCode: RPCErrorCode.METHOD_PROTECTED, @@ -300,13 +372,13 @@ export class RPCSession { try { const result = await fn.bind(fnThis)(...args); - return makeCallResponsePacket({ + return makeResponse({ status: 'success', requestPacket: packet, data: result, }) } catch (error) { - return makeCallResponsePacket({ + return makeResponse({ status: 'error', requestPacket: packet, errorCode: RPCErrorCode.SERVER_ERROR,