Compare commits

..

12 Commits

Author SHA1 Message Date
b440c38e73 feat: export getRPCContext from RPCContext 2025-12-02 15:46:07 +08:00
d936af9793 feat: add RPCContext 2025-12-02 15:03:33 +08:00
2fa7b851ff feat: export AbstractRPCPlugin and related types from RPCPlugin 2025-11-28 14:13:02 +08:00
4eac61144f Merge branch 'feature/rpc-plugin' into dev
* feature/rpc-plugin:
  feat: add session management and authentication tests for RPC plugin
  feat: add error reason to RPCError in callRequest method
  feat: enhance onCallRequest method with request handling and hook execution
  feat: add session and request handling to CallIncomingBeforeCtx and CallIncomingCtx interfaces
  feat: enhance callRequest method with options validation and hook execution
  feat: comment out abstract onInit and onDestroy methods in AbstractRPCPlugin class
  feat: update CallOutgoingBeforeCtx and CallOutgoingCtx interfaces to use unknown type for options and result
  feat: add unit tests for RPCPlugin hook execution and error handling
  feat: remove unused HookChainInterruptedError class from RPCPlugin
  feat: add RPCPlugin interface and abstract class with hook context definitions
  feat: add plugin management methods to RPCHandler
  feat: add createDeferrablePromise utility function
2025-11-28 14:12:53 +08:00
12afcb7a82 feat: add session management and authentication tests for RPC plugin 2025-11-28 14:12:37 +08:00
40d0e79358 feat: add error reason to RPCError in callRequest method 2025-11-28 13:35:03 +08:00
9a8733a77d feat: enhance onCallRequest method with request handling and hook execution 2025-11-28 11:08:19 +08:00
1b5281e0c1 feat: add session and request handling to CallIncomingBeforeCtx and CallIncomingCtx interfaces 2025-11-28 11:08:03 +08:00
3a4a54d37c test: enhance error handling in RPC disconnected test 2025-11-27 22:45:48 +08:00
ebb9ed21ad feat: enhance callRequest method with options validation and hook execution 2025-11-27 22:31:06 +08:00
720a79ca7a feat: comment out abstract onInit and onDestroy methods in AbstractRPCPlugin class 2025-11-27 22:30:29 +08:00
500e7c8fa6 feat: update CallOutgoingBeforeCtx and CallOutgoingCtx interfaces to use unknown type for options and result 2025-11-27 22:30:21 +08:00
8 changed files with 376 additions and 41 deletions

View File

@@ -0,0 +1,41 @@
import { getRPCContext, RPCContextKey } from "@/core/RPCContext";
import { RPCHandler, RPCSession } from "@/index"
import { getRandomAvailablePort } from "@/utils/utils";
describe('rpc-context.test', () => {
test('context.session', async () => {
const accessed = new WeakMap<RPCSession, string>();
const provider = {
login(name: string) {
const context = getRPCContext(this);
if (!context) {
throw new Error('context is null');
}
accessed.set(context.session, name);
return name;
},
me() {
const context = getRPCContext(this);
if (!context) {
throw new Error('context is null');
}
return accessed.get(context.session) || 'unknown user';
},
}
const server = new RPCHandler();
const client = new RPCHandler();
server.setProvider(provider);
const port = await getRandomAvailablePort();
await server.listen({ port });
const session = await client.connect({ url: `http://localhost:${port}` });
const api = session.getAPI<typeof provider>();
const clientname = 'clientname';
await expect(api.login(clientname)).resolves.toBe(clientname);
await expect(api.me()).resolves.toBe(clientname);
})
})

View File

@@ -33,8 +33,11 @@ describe('Rpc disconnected test', () => {
errorCode: RPCErrorCode.CONNECTION_DISCONNECTED
})
);
await expect(callPromise).rejects.toBeInstanceOf(RPCError);
await expect(callPromise).rejects
.toHaveProperty('errorCode', RPCErrorCode.CONNECTION_DISCONNECTED);
await expect(callPromise).rejects.toMatchObject(
expect.objectContaining({
constructor: RPCError,
errorCode: RPCErrorCode.CONNECTION_DISCONNECTED
})
);
})
})

View File

@@ -0,0 +1,49 @@
import { RPCError } from "@/core/RPCError";
import { CallIncomingBeforeCtx, NormalMethodReturn } from "@/core/RPCPlugin";
import { AbstractRPCPlugin, RPCHandler } from "@/index";
import { getRandomAvailablePort, isObject } from "@/utils/utils";
describe('rpc-plugin.ctx.session.test', () => {
const userInfo = 'userinfo';
const users = new WeakMap();
class RPCTestPlugin implements AbstractRPCPlugin {
onCallIncomingBefore(ctx: CallIncomingBeforeCtx): NormalMethodReturn {
if (users.has(ctx.session)) {
return;
}
if (isObject(ctx.request)) {
if (ctx.request.fnPath === 'login') {
users.set(ctx.session, userInfo);
return;
}
}
throw new RPCError({
reason: 'not login'
});
}
}
test('session', async () => {
const server = new RPCHandler();
const loginRes = 'login';
const authRes = 'auth';
const provider = {
login() { return loginRes },
auth() { return authRes }
}
server.setProvider(provider);
server.loadPlugin(new RPCTestPlugin());
const client = new RPCHandler();
const port = await getRandomAvailablePort();
await server.listen({ port });
const session = await client.connect({ url: `http://localhost:${port}` });
const api = session.getAPI<typeof provider>();
await expect(api.auth()).rejects.toMatchObject({ reason: 'not login' });
await expect(api.login()).resolves.toBe(loginRes);
await expect(api.auth()).resolves.toBe(authRes);
});
})

View File

@@ -0,0 +1,42 @@
import { CallIncomingBeforeCtx, CallIncomingCtx, CallOutgoingBeforeCtx, CallOutgoingCtx, NormalMethodReturn } from "@/core/RPCPlugin";
import { AbstractRPCPlugin, RPCHandler } from "@/index";
import { getRandomAvailablePort } from "@/utils/utils";
describe('rpc-plugin.hook.test', () => {
let count = 0;
class RPCTestPlugin implements AbstractRPCPlugin {
onCallIncomingBefore(ctx: CallIncomingBeforeCtx): NormalMethodReturn {
count += 1;
}
onCallIncoming(ctx: CallIncomingCtx): NormalMethodReturn {
count += 2;
}
onCallOutgoingBefore(ctx: CallOutgoingBeforeCtx): NormalMethodReturn {
count += 4;
}
onCallOutgoing(ctx: CallOutgoingCtx): NormalMethodReturn {
count += 8;
}
}
const CountShouldBe = 15;
test('count', async () => {
const server = new RPCHandler();
const provider = {
add(a: number, b: number) { return a + b }
}
server.setProvider(provider);
server.loadPlugin(new RPCTestPlugin());
const client = new RPCHandler();
client.loadPlugin(new RPCTestPlugin());
const port = await getRandomAvailablePort();
await server.listen({ port });
const session = await client.connect({ url: `http://localhost:${port}` });
const api = session.getAPI<typeof provider>();
await expect(api.add(1, 2)).resolves.toBe(3);
expect(count).toBe(CountShouldBe);
});
})

27
src/core/RPCContext.ts Normal file
View File

@@ -0,0 +1,27 @@
import { isObject } from "@/utils/utils";
import { RPCSession } from "./RPCSession";
export interface RPCContext {
session: RPCSession;
}
export const RPCContextFlag = Symbol();
export const RPCContextKey = '__rpc_context';
export class RPCContextConstractor {
public [RPCContextKey] = RPCContextFlag;
constructor(public session: RPCSession) { }
}
export const getRPCContext = (self: unknown): RPCContext | null => {
if (!isObject(self)) {
return null;
}
const rpcContext = self[RPCContextKey];
if (!isObject(rpcContext) || rpcContext[RPCContextKey] !== RPCContextFlag) {
return null;
}
return rpcContext as unknown as RPCContext;
}

View File

@@ -11,23 +11,24 @@ export interface BaseHookCtx {
export interface CallOutgoingBeforeCtx extends BaseHookCtx {
session: RPCSession;
options: {
fnPath: string;
args: any[];
};
options: unknown;
setOptions: (opt: unknown) => void;
}
export interface CallOutgoingCtx extends CallOutgoingBeforeCtx {
result: any;
setResult: (data: any) => void;
result: unknown;
setResult: (res: unknown) => void;
}
export interface CallIncomingBeforeCtx extends BaseHookCtx {
session: RPCSession;
request: unknown;
setRequest: (req: unknown) => void;
}
export interface CallIncomingCtx extends CallIncomingBeforeCtx {
response: unknown;
setResponse: (res: unknown) => void;
}
export type NormalMethodReturn = Promise<void> | void;
@@ -51,8 +52,8 @@ export interface RPCPlugin extends RPCPluginHooks {
}
export abstract class AbstractRPCPlugin implements RPCPlugin {
abstract onInit?(): void;
abstract onDestroy?(): void;
// abstract onInit?(): void;
// abstract onDestroy?(): void;
abstract onCallOutgoingBefore?(ctx: CallOutgoingBeforeCtx): NormalMethodReturn;
abstract onCallOutgoing?(ctx: CallOutgoingCtx): NormalMethodReturn;
abstract onCallIncomingBefore?(ctx: CallIncomingBeforeCtx): NormalMethodReturn;

View File

@@ -1,4 +1,4 @@
import { isPublicMethod, ToDeepPromise } from "@/utils/utils";
import { isArray, isObject, isPublicMethod, ToDeepPromise } from "@/utils/utils";
import { RPCConnection } from "./RPCConnection";
import { RPCHandler } from "./RPCHandler";
import { RPCProvider } from "./RPCProvider";
@@ -8,6 +8,8 @@ import { RPCError, RPCErrorCode } from "./RPCError";
import { makeCallPacket, makeCallResponsePacket, parseCallPacket, parseCallResponsePacket } from "./RPCCommon";
import { RPCPacket } from "./RPCPacket";
import { EventEmitter } from "@/utils/EventEmitter";
import { createHookRunner } from "./RPCPlugin";
import { RPCContextConstractor, RPCContextKey } from "./RPCContext";
function getProviderFunction(provider: RPCProvider, fnPath: string):
[(...args: any[]) => Promise<any>, object] | null {
@@ -111,6 +113,45 @@ export class RPCSession {
});
}
function setOptions(opt: unknown) {
if (!isObject(opt)) {
return;
}
if ('fnPath' in opt && 'args' in opt && 'timeout' in opt) {
const { fnPath, args, timeout } = opt;
if (typeof fnPath !== 'string') {
return;
}
if (!isArray(args)) {
return;
}
if (typeof timeout !== 'number') {
return;
}
options = {
...options,
fnPath,
args,
timeout,
}
}
}
const hookRunner = createHookRunner(this.rpcHandler.getPlugins(), 'onCallOutgoingBefore');
await hookRunner({
session: this,
options: { ...options },
setOptions,
});
/** due to `await hookRunner` */
if (this.connection.closed) {
throw new RPCError({
errorCode: RPCErrorCode.CONNECTION_DISCONNECTED,
});
}
const { fnPath, args } = options;
const packet = makeCallPacket({
fnPath,
@@ -135,28 +176,76 @@ export class RPCSession {
})();
const handleCallResponsePacket = (packet: RPCPacket) => {
const result = parseCallResponsePacket(packet);
if (result === null) {
return reject(new RPCError({
let result = parseCallResponsePacket(packet);
function setResult(res: unknown) {
if (!isObject(res)) {
return;
}
const { success, error } = res;
if (typeof success === 'object' && typeof error === 'object') {
if (success && !error) {
if ('data' in success) {
result = {
success: { data: success.data },
error: null,
}
}
} else if (!success && error) {
const { errorCode, reason } = error;
if (typeof errorCode === 'number' && typeof reason === 'string') {
result = {
success: null,
error: {
errorCode,
reason,
},
}
}
}
}
}
const hookRunner = createHookRunner(this.rpcHandler.getPlugins(), 'onCallOutgoing');
hookRunner({
session: this,
options: { ...options },
setOptions,
result,
setResult,
}).then(() => {
if (result === null) {
return reject(new RPCError({
errorCode: RPCErrorCode.UNKNOWN_ERROR,
}));;
}
const { success, error } = result;
if (success) {
return resolve(success.data);
}
if (error) {
return reject(new RPCError({
errorCode: error.errorCode,
reason: error.reason
}));
}
reject(new RPCError({
errorCode: RPCErrorCode.UNKNOWN_ERROR,
}));;
}
}).catch((e) => {
if (e instanceof RPCError) {
return reject(e);
}
const { success, error } = result;
if (success) {
return resolve(success.data);
}
if (error) {
return reject(new RPCError({
errorCode: error.errorCode,
reason: error.reason
}));
}
return reject(new RPCError({
errorCode: RPCErrorCode.UNKNOWN_ERROR,
}));;
reject(new RPCError({
errorCode: RPCErrorCode.UNKNOWN_ERROR,
reason: e instanceof Error ? e.message : `${e}`
}))
})
}
this.callResponseEmitter.once(packet.id, handleCallResponsePacket);
@@ -170,9 +259,81 @@ export class RPCSession {
}
private async onCallRequest(packet: RPCPacket): Promise<RPCPacket> {
const request = parseCallPacket(packet);
if (request === null) {
let request = parseCallPacket(packet);
const instance = this;
function setRequest(req: unknown) {
if (!isObject(req)) {
return;
}
const { fnPath, args } = req;
if (typeof fnPath === 'string' && isArray(args)) {
request = {
...request,
fnPath,
args,
};
}
}
const hookRunnerContext = {
session: this,
request,
setRequest,
}
async function makeResponse(o: Parameters<typeof makeCallResponsePacket>[0]) {
const hookRunner = createHookRunner(instance.rpcHandler.getPlugins(), 'onCallIncoming');
try {
await hookRunner({
...hookRunnerContext,
response: o,
setResponse: (res: unknown) => {
/** TODO: Implement stricter validation of the response object in the future */
o = res as Parameters<typeof makeCallResponsePacket>[0];
},
})
} catch (error) {
return makeCallResponsePacket({
...o,
...(error instanceof RPCError ? {
errorCode: error.errorCode,
reason: error.reason,
} : {
errorCode: RPCErrorCode.SERVER_ERROR,
reason: `${error}`
})
})
}
return makeCallResponsePacket({
...o,
})
}
const hookRunner = createHookRunner(this.rpcHandler.getPlugins(), 'onCallIncomingBefore');
try {
await hookRunner(hookRunnerContext);
} catch (error) {
return makeResponse({
status: 'error',
requestPacket: packet,
...(error instanceof RPCError ? {
errorCode: error.errorCode,
reason: error.reason
} : {
errorCode: RPCErrorCode.SERVER_ERROR,
reason: `${error}`,
})
})
}
if (request === null) {
return makeResponse({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.CALL_PROTOCOL_ERROR,
@@ -182,7 +343,7 @@ export class RPCSession {
// call the function
const provider = this.rpcHandler.getProvider();
if (!provider) {
return makeCallResponsePacket({
return makeResponse({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.PROVIDER_NOT_AVAILABLE,
@@ -192,7 +353,7 @@ export class RPCSession {
const { fnPath, args } = request;
const fnRes = getProviderFunction(provider, fnPath);
if (!fnRes) {
return makeCallResponsePacket({
return makeResponse({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.METHOD_NOT_FOUND,
@@ -203,7 +364,7 @@ export class RPCSession {
const { enableMethodProtection } = this.rpcHandler.getConfig();
if (enableMethodProtection) {
if (!isPublicMethod(fn)) {
return makeCallResponsePacket({
return makeResponse({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.METHOD_PROTECTED,
@@ -212,14 +373,21 @@ export class RPCSession {
}
try {
const result = await fn.bind(fnThis)(...args);
return makeCallResponsePacket({
const result = await fn.bind(new Proxy(fnThis, {
get(target, p, receiver) {
if (p === RPCContextKey) {
return new RPCContextConstractor(instance);
}
return Reflect.get(target, p, receiver);
},
}))(...args);
return makeResponse({
status: 'success',
requestPacket: packet,
data: result,
})
} catch (error) {
return makeCallResponsePacket({
return makeResponse({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.SERVER_ERROR,

View File

@@ -12,6 +12,10 @@ export { SocketConnection } from "./core/SocketConnection";
export { SocketServer } from "./core/SocketServer";
export { EventEmitter } from "./utils/EventEmitter";
export { AbstractRPCPlugin, createHookRunner } from './core/RPCPlugin';
export type { RPCPlugin, RPCPluginHooks, RPCPluginHooksCtx } from './core/RPCPlugin';
export { getRPCContext } from './core/RPCContext';
export { injectSocketClient } from "./core/SocketClient";
export { injectSocketServer } from "./core/SocketServer";
import { injectSocketIOImplements } from "./implements/socket.io";