feat: add tests for RPC protected methods and enhance RPCSession method protection

This commit is contained in:
tone
2025-11-15 18:08:27 +08:00
parent eb0574221e
commit e17a9591e1
2 changed files with 100 additions and 18 deletions

View File

@@ -0,0 +1,63 @@
import { isPublicMethod, markAsPublicMethod, publicMethod, RPCErrorCode, RPCHandler } from "@/index"
describe('Rpc protected method test', () => {
test('disabled protection', async () => {
class Methods {
@publicMethod
allow() {
return 0;
}
disabled() { }
}
const classMethods = new Methods();
const provider = {
classMethods,
normal: markAsPublicMethod(() => 0),
normal2: markAsPublicMethod(function () { return 0 }),
normalProtected: function () { },
shallowObj: markAsPublicMethod({
fn1: () => 0,
l1: {
fn1: () => 0,
}
}),
deepObj: markAsPublicMethod({
fn1: () => 0,
l1: {
fn1: () => 0,
}
}, { deep: true }),
}
const server = new RPCHandler({
enableMethodProtection: true,
});
server.setProvider(provider);
await server.listen({
port: 5210
});
const client = new RPCHandler();
const session = await client.connect({
url: 'http://localhost:5210'
});
const api = session.getAPI<typeof provider>();
await expect(api.classMethods.allow()).resolves.toBe(0);
await expect(api.classMethods.disabled()).rejects
.toHaveProperty('errorCode', RPCErrorCode.METHOD_PROTECTED);
await expect(api.normal()).resolves.toBe(0);
await expect(api.normal2()).resolves.toBe(0);
await expect(api.normalProtected()).rejects
.toHaveProperty('errorCode', RPCErrorCode.METHOD_PROTECTED);
await expect(api.shallowObj.fn1()).resolves.toBe(0);
await expect(api.shallowObj.l1.fn1()).rejects
.toHaveProperty('errorCode', RPCErrorCode.METHOD_PROTECTED);
await expect(api.deepObj.fn1()).resolves.toBe(0);
await expect(api.deepObj.l1.fn1()).resolves.toBe(0);
})
})

View File

@@ -1,4 +1,4 @@
import { ToDeepPromise } from "@/utils/utils";
import { isPublicMethod, ToDeepPromise } from "@/utils/utils";
import { RPCConnection } from "./RPCConnection";
import { RPCHandler } from "./RPCHandler";
import { RPCProvider } from "./RPCProvider";
@@ -9,7 +9,8 @@ import { makeCallPacket, makeCallResponsePacket, parseCallPacket, parseCallRespo
import { RPCPacket } from "./RPCPacket";
import { EventEmitter } from "@/utils/EventEmitter";
function getProviderFunction(provider: RPCProvider, fnPath: string) {
function getProviderFunction(provider: RPCProvider, fnPath: string):
[(...args: any[]) => Promise<any>, object] | null {
const paths = fnPath.split(':');
let fnThis: any = provider;
let fn: any = provider;
@@ -22,7 +23,7 @@ function getProviderFunction(provider: RPCProvider, fnPath: string) {
}
}
if (typeof fn === 'function') {
return fn.bind(fnThis);
return [fn, fnThis];
}
throw new Error();
@@ -64,7 +65,13 @@ export class RPCSession {
this.callResponseEmitter.removeAllListeners();
});
this.connection.on('call', this.onCallRequest.bind(this));
this.connection.on('call', (packet) => {
this.onCallRequest(packet).then(res => {
this.connection.send(res);
}).catch((e) => {
console.warn(`${e}`);
})
});
}
getAPI<T extends RPCProvider>(): ToDeepPromise<T> {
@@ -162,45 +169,57 @@ export class RPCSession {
});
}
private async onCallRequest(packet: RPCPacket) {
private async onCallRequest(packet: RPCPacket): Promise<RPCPacket> {
const request = parseCallPacket(packet);
if (request === null) {
return this.connection.send(makeCallResponsePacket({
return makeCallResponsePacket({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.CALL_PROTOCOL_ERROR,
})).catch(() => { })
});
}
// call the function
const provider = this.rpcHandler.getProvider();
if (!provider) {
return this.connection.send(makeCallResponsePacket({
return makeCallResponsePacket({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.PROVIDER_NOT_AVAILABLE,
}))
});
}
const { fnPath, args } = request;
const fn = getProviderFunction(provider, fnPath);
if (!fn) {
return this.connection.send(makeCallResponsePacket({
const fnRes = getProviderFunction(provider, fnPath);
if (!fnRes) {
return makeCallResponsePacket({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.METHOD_NOT_FOUND,
}))
})
}
const [fn, fnThis] = fnRes;
const { enableMethodProtection } = this.rpcHandler.getConfig();
if (enableMethodProtection) {
if (!isPublicMethod(fn)) {
return makeCallResponsePacket({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.METHOD_PROTECTED,
})
}
}
try {
const result = await fn(...args);
this.connection.send(makeCallResponsePacket({
const result = await fn.bind(fnThis)(...args);
return makeCallResponsePacket({
status: 'success',
requestPacket: packet,
data: result,
}))
})
} catch (error) {
this.connection.send(makeCallResponsePacket({
return makeCallResponsePacket({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.SERVER_ERROR,
@@ -208,7 +227,7 @@ export class RPCSession {
errorCode: error.errorCode,
reason: error.reason,
} : {})
}))
})
}
}
}