feat: add tests for RPC protected methods and enhance RPCSession method protection

This commit is contained in:
tone
2025-11-15 18:08:27 +08:00
parent eb0574221e
commit e17a9591e1
2 changed files with 100 additions and 18 deletions

View File

@@ -1,4 +1,4 @@
import { ToDeepPromise } from "@/utils/utils";
import { isPublicMethod, ToDeepPromise } from "@/utils/utils";
import { RPCConnection } from "./RPCConnection";
import { RPCHandler } from "./RPCHandler";
import { RPCProvider } from "./RPCProvider";
@@ -9,7 +9,8 @@ import { makeCallPacket, makeCallResponsePacket, parseCallPacket, parseCallRespo
import { RPCPacket } from "./RPCPacket";
import { EventEmitter } from "@/utils/EventEmitter";
function getProviderFunction(provider: RPCProvider, fnPath: string) {
function getProviderFunction(provider: RPCProvider, fnPath: string):
[(...args: any[]) => Promise<any>, object] | null {
const paths = fnPath.split(':');
let fnThis: any = provider;
let fn: any = provider;
@@ -22,7 +23,7 @@ function getProviderFunction(provider: RPCProvider, fnPath: string) {
}
}
if (typeof fn === 'function') {
return fn.bind(fnThis);
return [fn, fnThis];
}
throw new Error();
@@ -64,7 +65,13 @@ export class RPCSession {
this.callResponseEmitter.removeAllListeners();
});
this.connection.on('call', this.onCallRequest.bind(this));
this.connection.on('call', (packet) => {
this.onCallRequest(packet).then(res => {
this.connection.send(res);
}).catch((e) => {
console.warn(`${e}`);
})
});
}
getAPI<T extends RPCProvider>(): ToDeepPromise<T> {
@@ -162,45 +169,57 @@ export class RPCSession {
});
}
private async onCallRequest(packet: RPCPacket) {
private async onCallRequest(packet: RPCPacket): Promise<RPCPacket> {
const request = parseCallPacket(packet);
if (request === null) {
return this.connection.send(makeCallResponsePacket({
return makeCallResponsePacket({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.CALL_PROTOCOL_ERROR,
})).catch(() => { })
});
}
// call the function
const provider = this.rpcHandler.getProvider();
if (!provider) {
return this.connection.send(makeCallResponsePacket({
return makeCallResponsePacket({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.PROVIDER_NOT_AVAILABLE,
}))
});
}
const { fnPath, args } = request;
const fn = getProviderFunction(provider, fnPath);
if (!fn) {
return this.connection.send(makeCallResponsePacket({
const fnRes = getProviderFunction(provider, fnPath);
if (!fnRes) {
return makeCallResponsePacket({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.METHOD_NOT_FOUND,
}))
})
}
const [fn, fnThis] = fnRes;
const { enableMethodProtection } = this.rpcHandler.getConfig();
if (enableMethodProtection) {
if (!isPublicMethod(fn)) {
return makeCallResponsePacket({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.METHOD_PROTECTED,
})
}
}
try {
const result = await fn(...args);
this.connection.send(makeCallResponsePacket({
const result = await fn.bind(fnThis)(...args);
return makeCallResponsePacket({
status: 'success',
requestPacket: packet,
data: result,
}))
})
} catch (error) {
this.connection.send(makeCallResponsePacket({
return makeCallResponsePacket({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.SERVER_ERROR,
@@ -208,7 +227,7 @@ export class RPCSession {
errorCode: error.errorCode,
reason: error.reason,
} : {})
}))
})
}
}
}