Compare commits

..

27 Commits

Author SHA1 Message Date
b440c38e73 feat: export getRPCContext from RPCContext 2025-12-02 15:46:07 +08:00
d936af9793 feat: add RPCContext 2025-12-02 15:03:33 +08:00
2fa7b851ff feat: export AbstractRPCPlugin and related types from RPCPlugin 2025-11-28 14:13:02 +08:00
4eac61144f Merge branch 'feature/rpc-plugin' into dev
* feature/rpc-plugin:
  feat: add session management and authentication tests for RPC plugin
  feat: add error reason to RPCError in callRequest method
  feat: enhance onCallRequest method with request handling and hook execution
  feat: add session and request handling to CallIncomingBeforeCtx and CallIncomingCtx interfaces
  feat: enhance callRequest method with options validation and hook execution
  feat: comment out abstract onInit and onDestroy methods in AbstractRPCPlugin class
  feat: update CallOutgoingBeforeCtx and CallOutgoingCtx interfaces to use unknown type for options and result
  feat: add unit tests for RPCPlugin hook execution and error handling
  feat: remove unused HookChainInterruptedError class from RPCPlugin
  feat: add RPCPlugin interface and abstract class with hook context definitions
  feat: add plugin management methods to RPCHandler
  feat: add createDeferrablePromise utility function
2025-11-28 14:12:53 +08:00
12afcb7a82 feat: add session management and authentication tests for RPC plugin 2025-11-28 14:12:37 +08:00
40d0e79358 feat: add error reason to RPCError in callRequest method 2025-11-28 13:35:03 +08:00
9a8733a77d feat: enhance onCallRequest method with request handling and hook execution 2025-11-28 11:08:19 +08:00
1b5281e0c1 feat: add session and request handling to CallIncomingBeforeCtx and CallIncomingCtx interfaces 2025-11-28 11:08:03 +08:00
3a4a54d37c test: enhance error handling in RPC disconnected test 2025-11-27 22:45:48 +08:00
ebb9ed21ad feat: enhance callRequest method with options validation and hook execution 2025-11-27 22:31:06 +08:00
720a79ca7a feat: comment out abstract onInit and onDestroy methods in AbstractRPCPlugin class 2025-11-27 22:30:29 +08:00
500e7c8fa6 feat: update CallOutgoingBeforeCtx and CallOutgoingCtx interfaces to use unknown type for options and result 2025-11-27 22:30:21 +08:00
809a3759d9 feat: add unit tests for RPCPlugin hook execution and error handling 2025-11-27 21:36:41 +08:00
bc1445d3a5 feat: remove unused HookChainInterruptedError class from RPCPlugin 2025-11-27 21:31:49 +08:00
24a14a8e1c feat: add RPCPlugin interface and abstract class with hook context definitions 2025-11-27 21:25:27 +08:00
99c673a8dc feat: add plugin management methods to RPCHandler 2025-11-27 20:43:07 +08:00
b0bcd64b41 feat: add createDeferrablePromise utility function 2025-11-27 20:35:14 +08:00
tone
e17a9591e1 feat: add tests for RPC protected methods and enhance RPCSession method protection 2025-11-15 18:08:27 +08:00
tone
eb0574221e feat: add RPCConfig interface and integrate it into RPCHandler constructor 2025-11-15 18:07:57 +08:00
tone
ab289186f1 feat: add public method utilities and corresponding tests 2025-11-15 18:07:21 +08:00
tone
36e0c17ad7 refactor: move call handlers to rpcSession 2025-11-15 13:03:09 +08:00
tone
821379f44f feat: pass provider instance to RPCSession constructor in RPCClient and RPCServer 2025-11-15 12:15:50 +08:00
tone
bb85abd77e refactor: remove commented-out example usage in RPCHandler 2025-11-15 12:04:11 +08:00
tone
b89187e4ac 1.0.2 2025-10-15 17:24:03 +08:00
tone
2666388df7 test: add RPC disconnection handling test case 2025-10-15 17:21:04 +08:00
tone
3b9cc6dffb fix: attach finally cleanup to returned promise to avoid unhandled rejections 2025-10-15 17:12:57 +08:00
tone
387cbc0b25 fix: update ToDeepPromise type to use Awaited for promise resolution 2025-10-15 16:32:09 +08:00
18 changed files with 899 additions and 205 deletions

View File

@@ -0,0 +1,41 @@
import { getRPCContext, RPCContextKey } from "@/core/RPCContext";
import { RPCHandler, RPCSession } from "@/index"
import { getRandomAvailablePort } from "@/utils/utils";
describe('rpc-context.test', () => {
test('context.session', async () => {
const accessed = new WeakMap<RPCSession, string>();
const provider = {
login(name: string) {
const context = getRPCContext(this);
if (!context) {
throw new Error('context is null');
}
accessed.set(context.session, name);
return name;
},
me() {
const context = getRPCContext(this);
if (!context) {
throw new Error('context is null');
}
return accessed.get(context.session) || 'unknown user';
},
}
const server = new RPCHandler();
const client = new RPCHandler();
server.setProvider(provider);
const port = await getRandomAvailablePort();
await server.listen({ port });
const session = await client.connect({ url: `http://localhost:${port}` });
const api = session.getAPI<typeof provider>();
const clientname = 'clientname';
await expect(api.login(clientname)).resolves.toBe(clientname);
await expect(api.me()).resolves.toBe(clientname);
})
})

View File

@@ -0,0 +1,43 @@
import { RPCError, RPCErrorCode } from "@/core/RPCError";
import { RPCHandler } from "@/index"
import { getRandomAvailablePort } from "@/utils/utils";
type serverProvider = { test: () => Promise<string> };
describe('Rpc disconnected test', () => {
test('main', async () => {
const port = await getRandomAvailablePort();
const server = new RPCHandler();
server.setProvider<serverProvider>({
test() {
return new Promise<string>((resolve) => setTimeout(() => resolve('ok'), 1000))
},
})
await server.listen({
port,
});
const client = new RPCHandler();
const session = await client.connect({
url: `http://localhost:${port}`,
});
const api = session.getAPI<serverProvider>();
const callPromise = api.test()
await session.connection.close();
await expect(api.test()).rejects.toMatchObject(
expect.objectContaining({
constructor: RPCError,
errorCode: RPCErrorCode.CONNECTION_DISCONNECTED
})
);
await expect(callPromise).rejects.toMatchObject(
expect.objectContaining({
constructor: RPCError,
errorCode: RPCErrorCode.CONNECTION_DISCONNECTED
})
);
})
})

View File

@@ -0,0 +1,49 @@
import { RPCError } from "@/core/RPCError";
import { CallIncomingBeforeCtx, NormalMethodReturn } from "@/core/RPCPlugin";
import { AbstractRPCPlugin, RPCHandler } from "@/index";
import { getRandomAvailablePort, isObject } from "@/utils/utils";
describe('rpc-plugin.ctx.session.test', () => {
const userInfo = 'userinfo';
const users = new WeakMap();
class RPCTestPlugin implements AbstractRPCPlugin {
onCallIncomingBefore(ctx: CallIncomingBeforeCtx): NormalMethodReturn {
if (users.has(ctx.session)) {
return;
}
if (isObject(ctx.request)) {
if (ctx.request.fnPath === 'login') {
users.set(ctx.session, userInfo);
return;
}
}
throw new RPCError({
reason: 'not login'
});
}
}
test('session', async () => {
const server = new RPCHandler();
const loginRes = 'login';
const authRes = 'auth';
const provider = {
login() { return loginRes },
auth() { return authRes }
}
server.setProvider(provider);
server.loadPlugin(new RPCTestPlugin());
const client = new RPCHandler();
const port = await getRandomAvailablePort();
await server.listen({ port });
const session = await client.connect({ url: `http://localhost:${port}` });
const api = session.getAPI<typeof provider>();
await expect(api.auth()).rejects.toMatchObject({ reason: 'not login' });
await expect(api.login()).resolves.toBe(loginRes);
await expect(api.auth()).resolves.toBe(authRes);
});
})

View File

@@ -0,0 +1,42 @@
import { CallIncomingBeforeCtx, CallIncomingCtx, CallOutgoingBeforeCtx, CallOutgoingCtx, NormalMethodReturn } from "@/core/RPCPlugin";
import { AbstractRPCPlugin, RPCHandler } from "@/index";
import { getRandomAvailablePort } from "@/utils/utils";
describe('rpc-plugin.hook.test', () => {
let count = 0;
class RPCTestPlugin implements AbstractRPCPlugin {
onCallIncomingBefore(ctx: CallIncomingBeforeCtx): NormalMethodReturn {
count += 1;
}
onCallIncoming(ctx: CallIncomingCtx): NormalMethodReturn {
count += 2;
}
onCallOutgoingBefore(ctx: CallOutgoingBeforeCtx): NormalMethodReturn {
count += 4;
}
onCallOutgoing(ctx: CallOutgoingCtx): NormalMethodReturn {
count += 8;
}
}
const CountShouldBe = 15;
test('count', async () => {
const server = new RPCHandler();
const provider = {
add(a: number, b: number) { return a + b }
}
server.setProvider(provider);
server.loadPlugin(new RPCTestPlugin());
const client = new RPCHandler();
client.loadPlugin(new RPCTestPlugin());
const port = await getRandomAvailablePort();
await server.listen({ port });
const session = await client.connect({ url: `http://localhost:${port}` });
const api = session.getAPI<typeof provider>();
await expect(api.add(1, 2)).resolves.toBe(3);
expect(count).toBe(CountShouldBe);
});
})

View File

@@ -0,0 +1,63 @@
import { isPublicMethod, markAsPublicMethod, publicMethod, RPCErrorCode, RPCHandler } from "@/index"
describe('Rpc protected method test', () => {
test('disabled protection', async () => {
class Methods {
@publicMethod
allow() {
return 0;
}
disabled() { }
}
const classMethods = new Methods();
const provider = {
classMethods,
normal: markAsPublicMethod(() => 0),
normal2: markAsPublicMethod(function () { return 0 }),
normalProtected: function () { },
shallowObj: markAsPublicMethod({
fn1: () => 0,
l1: {
fn1: () => 0,
}
}),
deepObj: markAsPublicMethod({
fn1: () => 0,
l1: {
fn1: () => 0,
}
}, { deep: true }),
}
const server = new RPCHandler({
enableMethodProtection: true,
});
server.setProvider(provider);
await server.listen({
port: 5210
});
const client = new RPCHandler();
const session = await client.connect({
url: 'http://localhost:5210'
});
const api = session.getAPI<typeof provider>();
await expect(api.classMethods.allow()).resolves.toBe(0);
await expect(api.classMethods.disabled()).rejects
.toHaveProperty('errorCode', RPCErrorCode.METHOD_PROTECTED);
await expect(api.normal()).resolves.toBe(0);
await expect(api.normal2()).resolves.toBe(0);
await expect(api.normalProtected()).rejects
.toHaveProperty('errorCode', RPCErrorCode.METHOD_PROTECTED);
await expect(api.shallowObj.fn1()).resolves.toBe(0);
await expect(api.shallowObj.l1.fn1()).rejects
.toHaveProperty('errorCode', RPCErrorCode.METHOD_PROTECTED);
await expect(api.deepObj.fn1()).resolves.toBe(0);
await expect(api.deepObj.l1.fn1()).resolves.toBe(0);
})
})

View File

@@ -0,0 +1,43 @@
import { createHookRunner, RPCPlugin } from "@/core/RPCPlugin"
import type { CallIncomingCtx } from "@/core/RPCPlugin"
describe('RPCPlugin.test', () => {
const plugin3 = {
onCallIncoming(ctx: CallIncomingCtx) { throw new Error() }
} as RPCPlugin;
const plugin4 = {
async onCallIncoming(ctx: CallIncomingCtx) { throw new Error() }
} as RPCPlugin;
test('should be resolved', async () => {
const plugin1 = {
onCallIncoming: jest.fn(),
} as RPCPlugin;
const plugin2 = {
async onCallIncoming(ctx: CallIncomingCtx) { }
} as RPCPlugin;
const plugins = [plugin1, plugin2];
const hookRunner = createHookRunner(plugins, 'onCallIncoming');
await hookRunner({} as any);
expect(plugin1.onCallIncoming).toHaveBeenCalled()
})
test('should be resolved2', async () => {
const plugins = [] as RPCPlugin[];
const hookRunner = createHookRunner(plugins, 'onCallIncoming');
await hookRunner({} as any);
expect.assertions(0);
})
test('should be rejected1', async () => {
const plugins = [plugin3];
const hookRunner = createHookRunner(plugins, 'onCallIncoming');
await expect(hookRunner({} as any)).rejects.toThrow(Error)
})
test('should be rejected2', async () => {
const plugins = [plugin4];
const hookRunner = createHookRunner(plugins, 'onCallIncoming');
await expect(hookRunner({} as any)).rejects.toThrow(Error)
})
})

View File

@@ -1,4 +1,4 @@
import { getRandomAvailablePort, isObject, isString, makeId } from "@/utils/utils"
import { getRandomAvailablePort, isObject, isPublicMethod, isString, makeId, markAsPublicMethod } from "@/utils/utils"
test('makeId', () => {
const id = makeId();
@@ -26,4 +26,29 @@ test('getRandomAvailablePort', async () => {
const port = await getRandomAvailablePort();
expect(port).toBeGreaterThanOrEqual(1);
expect(port).toBeLessThanOrEqual(65535);
})
test('markAsPublick', () => {
const shallowObj = {
fn1() { },
l1: {
fn1() { },
}
};
const deepObj = {
fn1() { },
l1: {
fn1() { },
}
};
markAsPublicMethod(shallowObj);
markAsPublicMethod(deepObj, { deep: true });
expect(isPublicMethod(shallowObj.fn1)).toBeTruthy();
expect(isPublicMethod(shallowObj.l1.fn1)).toBeUndefined();
expect(isPublicMethod(deepObj.fn1)).toBeTruthy();
expect(isPublicMethod(deepObj.l1.fn1)).toBeTruthy();
})

View File

@@ -1,6 +1,6 @@
{
"name": "@tonecn/typesrpc",
"version": "1.0.1",
"version": "1.0.2",
"description": "A lightweight, type-safe RPC framework for TypeScript with deep nested API support",
"main": "dist/index.js",
"types": "dist/index.d.ts",

View File

@@ -85,6 +85,7 @@ export class RPCClient {
resolve(new RPCSession(
new RPCConnection(connection!),
this.rpcHandler,
this,
));
} else {
reject(new Error('Server rejected handshake request'));

View File

@@ -1,9 +1,6 @@
import { EventEmitter } from "@/utils/EventEmitter";
import { SocketConnection } from "./SocketConnection";
import { RPCPacket } from "./RPCPacket";
import { makeCallPacket, makeCallResponsePacket, parseCallPacket, parseCallResponsePacket } from "./RPCCommon";
import { RPCProvider } from "./RPCProvider";
import { RPCError, RPCErrorCode } from "./RPCError";
interface RPCConnectionEvents {
call: RPCPacket;
@@ -14,33 +11,15 @@ interface RPCConnectionEvents {
closed: void;
}
class CallResponseEmitter extends EventEmitter<{
[id: string]: RPCPacket;
}> {
emitAll(packet: RPCPacket) {
this.events.forEach(subscribers => {
subscribers.forEach(fn => fn(packet));
})
}
}
export class RPCConnection extends EventEmitter<RPCConnectionEvents> {
closed: boolean = false;
private callResponseEmitter = new CallResponseEmitter();
constructor(public socket: SocketConnection) {
super();
socket.on('closed', () => {
this.emit('closed');
this.callResponseEmitter.emitAll(makeCallResponsePacket({
status: 'error',
requestPacketId: 'connection error',
errorCode: RPCErrorCode.CONNECTION_DISCONNECTED,
}));
this.callResponseEmitter.removeAllListeners();
this.closed = true;
this.emit('closed');
});
socket.on('msg', (msg) => {
@@ -68,156 +47,17 @@ export class RPCConnection extends EventEmitter<RPCConnectionEvents> {
this.emit('unknownPacket', packet);
});
/** route by packet.id */
this.on('callResponse', (packet) => {
this.callResponseEmitter.emit(packet.id, packet);
})
}
/** @throws */
public async callRequest(options: {
fnPath: string;
args: any[];
timeout: number;
}): Promise<any> {
public async close() {
return this.socket.close();
}
public async send(data: RPCPacket) {
if (this.closed) {
throw new RPCError({
errorCode: RPCErrorCode.CONNECTION_DISCONNECTED,
});
return;
}
const { fnPath, args } = options;
const packet = makeCallPacket({
fnPath,
args
});
let resolve: (data: any) => void;
let reject: (data: any) => void;
const promise = new Promise((res, rej) => {
resolve = res;
reject = rej;
});
const cancelTimeoutTimer = (() => {
const t = setTimeout(() => {
reject(new RPCError({
errorCode: RPCErrorCode.TIMEOUT_ERROR,
}))
}, options.timeout);
return () => clearTimeout(t);
})();
promise.finally(() => {
this.callResponseEmitter.removeAllListeners(packet.id);
cancelTimeoutTimer();
})
const handleCallResponsePacket = (packet: RPCPacket) => {
const result = parseCallResponsePacket(packet);
if (result === null) {
return reject(new RPCError({
errorCode: RPCErrorCode.UNKNOWN_ERROR,
}));;
}
const { success, error } = result;
if (success) {
return resolve(success.data);
}
if (error) {
return reject(new RPCError({
errorCode: error.errorCode,
reason: error.reason
}));
}
return reject(new RPCError({
errorCode: RPCErrorCode.UNKNOWN_ERROR,
}));;
}
this.callResponseEmitter.on(packet.id, handleCallResponsePacket);
/** send call request */
this.socket.send(packet);
return promise;
}
public onCallRequest(getProvider: () => RPCProvider | undefined) {
this.on('call', async (packet) => {
const request = parseCallPacket(packet);
if (request === null) {
return this.socket.send(makeCallResponsePacket({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.CALL_PROTOCOL_ERROR,
})).catch(() => { })
}
// call the function
const provider = getProvider();
if (!provider) {
return this.socket.send(makeCallResponsePacket({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.PROVIDER_NOT_AVAILABLE,
}))
}
const { fnPath, args } = request;
const fn = this.getProviderFunction(provider, fnPath);
if (!fn) {
return this.socket.send(makeCallResponsePacket({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.METHOD_NOT_FOUND,
}))
}
try {
const result = await fn(...args);
this.socket.send(makeCallResponsePacket({
status: 'success',
requestPacket: packet,
data: result,
}))
} catch (error) {
this.socket.send(makeCallResponsePacket({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.SERVER_ERROR,
...(error instanceof RPCError ? {
errorCode: error.errorCode,
reason: error.reason,
} : {})
}))
}
})
}
private getProviderFunction(provider: RPCProvider, fnPath: string) {
const paths = fnPath.split(':');
let fnThis: any = provider;
let fn: any = provider;
try {
while (paths.length) {
const path = paths.shift()!;
fn = fn[path];
if (paths.length !== 0) {
fnThis = fn;
}
}
if (typeof fn === 'function') {
return fn.bind(fnThis);
}
throw new Error();
} catch (error) {
return null;
}
return this.socket.send(data);
}
}

27
src/core/RPCContext.ts Normal file
View File

@@ -0,0 +1,27 @@
import { isObject } from "@/utils/utils";
import { RPCSession } from "./RPCSession";
export interface RPCContext {
session: RPCSession;
}
export const RPCContextFlag = Symbol();
export const RPCContextKey = '__rpc_context';
export class RPCContextConstractor {
public [RPCContextKey] = RPCContextFlag;
constructor(public session: RPCSession) { }
}
export const getRPCContext = (self: unknown): RPCContext | null => {
if (!isObject(self)) {
return null;
}
const rpcContext = self[RPCContextKey];
if (!isObject(rpcContext) || rpcContext[RPCContextKey] !== RPCContextFlag) {
return null;
}
return rpcContext as unknown as RPCContext;
}

View File

@@ -3,6 +3,7 @@ import { RPCClient } from "./RPCClient";
import { RPCServer } from "./RPCServer";
import { RPCProvider } from "./RPCProvider";
import { RPCSession } from "./RPCSession";
import { RPCPlugin } from "./RPCPlugin";
const DefaultListenOptions = {
port: 5201,
@@ -19,28 +20,48 @@ interface RPCHandlerEvents {
connect: RPCSession;
}
interface RPCConfig {
enableMethodProtection: boolean;
}
let DefaultRPCConfig: RPCConfig = {
enableMethodProtection: false,
}
export class RPCHandler extends EventEmitter<RPCHandlerEvents> {
private rpcClient?: RPCClient;
private rpcServer?: RPCServer;
private provider?: RPCProvider;
private accessKey?: string;
private config: RPCConfig;
private plugins: RPCPlugin[] = [];
constructor(
args?: {
rpcClient?: RPCClient;
rpcServer?: RPCServer;
}
} & Partial<RPCConfig>
) {
super();
const { rpcClient, rpcServer, ...config } = args ?? {};
if (args?.rpcClient) {
this.setRPCProvider(args.rpcClient);
if (rpcClient) {
this.setRPCProvider(rpcClient);
}
if (args?.rpcServer) {
this.setRPCProvider(args.rpcServer);
if (rpcServer) {
this.setRPCProvider(rpcServer);
}
this.config = {
...DefaultRPCConfig,
...config,
}
}
getConfig() {
return this.config;
}
setProvider<T extends RPCProvider>(provider: T) {
@@ -69,6 +90,29 @@ export class RPCHandler extends EventEmitter<RPCHandlerEvents> {
}
}
loadPlugin(plugin: RPCPlugin): boolean {
const plugins = this.plugins;
if (plugins.includes(plugin)) {
return false;
}
plugins.push(plugin);
return true;
}
unloadPlugin(plugin: RPCPlugin): boolean {
const plugins = this.plugins;
const idx = plugins.indexOf(plugin);
if (idx === -1) {
return false;
}
plugins.splice(idx, 1);
return true;
}
getPlugins(): RPCPlugin[] {
return [...this.plugins];
}
async connect(options: {
url?: string;
accessKey?: string;
@@ -116,27 +160,4 @@ export class RPCHandler extends EventEmitter<RPCHandlerEvents> {
throw new Error();
}
}
}
// const h = new RPCHandler();
// h.setProvider<{
// plus: (a: number, b: number) => number;
// math: {
// minus: (a: number, b: number) => number;
// multiply: (a: number, b: number) => number;
// }
// }>({
// plus(a, b) {
// return a + b
// },
// math: {
// minus(a, b) {
// return a - b;
// },
// multiply(a, b) {
// return a * b;
// },
// }
// })
}

75
src/core/RPCPlugin.ts Normal file
View File

@@ -0,0 +1,75 @@
import { RPCSession } from "./RPCSession";
// interface BaseHookRuntimeCtx {
// nexts: RPCPlugin[];
// setNextPlugins: (plugins: RPCPlugin[]) => void
// }
export interface BaseHookCtx {
}
export interface CallOutgoingBeforeCtx extends BaseHookCtx {
session: RPCSession;
options: unknown;
setOptions: (opt: unknown) => void;
}
export interface CallOutgoingCtx extends CallOutgoingBeforeCtx {
result: unknown;
setResult: (res: unknown) => void;
}
export interface CallIncomingBeforeCtx extends BaseHookCtx {
session: RPCSession;
request: unknown;
setRequest: (req: unknown) => void;
}
export interface CallIncomingCtx extends CallIncomingBeforeCtx {
response: unknown;
setResponse: (res: unknown) => void;
}
export type NormalMethodReturn = Promise<void> | void;
export type HookFn<Ctx> = (ctx: Ctx) => NormalMethodReturn;
export interface RPCPluginHooksCtx {
onCallOutgoingBefore: CallOutgoingBeforeCtx;
onCallOutgoing: CallOutgoingCtx;
onCallIncomingBefore: CallIncomingBeforeCtx;
onCallIncoming: CallIncomingCtx;
}
export type RPCPluginHooks = {
[K in keyof RPCPluginHooksCtx]?: HookFn<RPCPluginHooksCtx[K]>;
};
export interface RPCPlugin extends RPCPluginHooks {
onInit?(): void;
onDestroy?(): void;
}
export abstract class AbstractRPCPlugin implements RPCPlugin {
// abstract onInit?(): void;
// abstract onDestroy?(): void;
abstract onCallOutgoingBefore?(ctx: CallOutgoingBeforeCtx): NormalMethodReturn;
abstract onCallOutgoing?(ctx: CallOutgoingCtx): NormalMethodReturn;
abstract onCallIncomingBefore?(ctx: CallIncomingBeforeCtx): NormalMethodReturn;
abstract onCallIncoming?(ctx: CallIncomingCtx): NormalMethodReturn;
}
type HookName = keyof RPCPluginHooksCtx;
type HookRunner<Ctx> = (ctx: Ctx) => Promise<void>;
export function createHookRunner<K extends HookName>(
plugins: RPCPlugin[],
hookName: K,
): HookRunner<RPCPluginHooksCtx[K]> {
return async (ctx: RPCPluginHooksCtx[K]) => {
for (const plugin of plugins) {
await (plugin[hookName] as HookFn<RPCPluginHooksCtx[K]> | undefined)?.(ctx);
}
};
}

View File

@@ -78,6 +78,7 @@ export class RPCServer extends EventEmitter<RPCServerEvents> {
this.emit('connect', new RPCSession(
new RPCConnection(socketConnection),
this.rpcHandler,
this,
));
}

View File

@@ -1,16 +1,79 @@
import { ToDeepPromise } from "@/utils/utils";
import { isArray, isObject, isPublicMethod, ToDeepPromise } from "@/utils/utils";
import { RPCConnection } from "./RPCConnection";
import { RPCHandler } from "./RPCHandler";
import { RPCProvider } from "./RPCProvider";
import { RPCClient } from "./RPCClient";
import { RPCServer } from "./RPCServer";
import { RPCError, RPCErrorCode } from "./RPCError";
import { makeCallPacket, makeCallResponsePacket, parseCallPacket, parseCallResponsePacket } from "./RPCCommon";
import { RPCPacket } from "./RPCPacket";
import { EventEmitter } from "@/utils/EventEmitter";
import { createHookRunner } from "./RPCPlugin";
import { RPCContextConstractor, RPCContextKey } from "./RPCContext";
function getProviderFunction(provider: RPCProvider, fnPath: string):
[(...args: any[]) => Promise<any>, object] | null {
const paths = fnPath.split(':');
let fnThis: any = provider;
let fn: any = provider;
try {
while (paths.length) {
const path = paths.shift()!;
fn = fn[path];
if (paths.length !== 0) {
fnThis = fn;
}
}
if (typeof fn === 'function') {
return [fn, fnThis];
}
throw new Error();
} catch (error) {
return null;
}
}
class CallResponseEmitter extends EventEmitter<{
[id: string]: RPCPacket;
}> {
emitAll(packet: RPCPacket) {
this.events.forEach(subscribers => {
subscribers.forEach(fn => fn(packet));
})
}
}
export class RPCSession {
public callResponseEmitter = new CallResponseEmitter();
constructor(
public readonly connection: RPCConnection,
public readonly rpcHandler: RPCHandler,
public readonly rpcProvider: RPCClient | RPCServer,
) {
connection.onCallRequest(rpcHandler.getProvider.bind(rpcHandler));
/** route by packet.id */
this.connection.on('callResponse', (packet) => {
this.callResponseEmitter.emit(packet.id, packet);
});
this.connection.on('closed', () => {
this.callResponseEmitter.emitAll(makeCallResponsePacket({
status: 'error',
requestPacketId: 'connection error',
errorCode: RPCErrorCode.CONNECTION_DISCONNECTED,
}));
this.callResponseEmitter.removeAllListeners();
});
this.connection.on('call', (packet) => {
this.onCallRequest(packet).then(res => {
this.connection.send(res);
}).catch((e) => {
console.warn(`${e}`);
})
});
}
getAPI<T extends RPCProvider>(): ToDeepPromise<T> {
@@ -23,7 +86,7 @@ export class RPCSession {
return createProxy(newPath);
},
apply: (target, thisArg, args) => {
return this.connection.callRequest({
return this.callRequest({
fnPath: path.join(':'),
args: args,
/** @todo accept from caller */
@@ -37,4 +100,302 @@ export class RPCSession {
return createProxy() as unknown as ToDeepPromise<T>;
}
/** @throws */
public async callRequest(options: {
fnPath: string;
args: any[];
timeout: number;
}): Promise<any> {
if (this.connection.closed) {
throw new RPCError({
errorCode: RPCErrorCode.CONNECTION_DISCONNECTED,
});
}
function setOptions(opt: unknown) {
if (!isObject(opt)) {
return;
}
if ('fnPath' in opt && 'args' in opt && 'timeout' in opt) {
const { fnPath, args, timeout } = opt;
if (typeof fnPath !== 'string') {
return;
}
if (!isArray(args)) {
return;
}
if (typeof timeout !== 'number') {
return;
}
options = {
...options,
fnPath,
args,
timeout,
}
}
}
const hookRunner = createHookRunner(this.rpcHandler.getPlugins(), 'onCallOutgoingBefore');
await hookRunner({
session: this,
options: { ...options },
setOptions,
});
/** due to `await hookRunner` */
if (this.connection.closed) {
throw new RPCError({
errorCode: RPCErrorCode.CONNECTION_DISCONNECTED,
});
}
const { fnPath, args } = options;
const packet = makeCallPacket({
fnPath,
args
});
let resolve: (data: any) => void;
let reject: (data: any) => void;
const promise = new Promise((res, rej) => {
resolve = res;
reject = rej;
});
const cancelTimeoutTimer = (() => {
const t = setTimeout(() => {
reject(new RPCError({
errorCode: RPCErrorCode.TIMEOUT_ERROR,
}))
}, options.timeout);
return () => clearTimeout(t);
})();
const handleCallResponsePacket = (packet: RPCPacket) => {
let result = parseCallResponsePacket(packet);
function setResult(res: unknown) {
if (!isObject(res)) {
return;
}
const { success, error } = res;
if (typeof success === 'object' && typeof error === 'object') {
if (success && !error) {
if ('data' in success) {
result = {
success: { data: success.data },
error: null,
}
}
} else if (!success && error) {
const { errorCode, reason } = error;
if (typeof errorCode === 'number' && typeof reason === 'string') {
result = {
success: null,
error: {
errorCode,
reason,
},
}
}
}
}
}
const hookRunner = createHookRunner(this.rpcHandler.getPlugins(), 'onCallOutgoing');
hookRunner({
session: this,
options: { ...options },
setOptions,
result,
setResult,
}).then(() => {
if (result === null) {
return reject(new RPCError({
errorCode: RPCErrorCode.UNKNOWN_ERROR,
}));;
}
const { success, error } = result;
if (success) {
return resolve(success.data);
}
if (error) {
return reject(new RPCError({
errorCode: error.errorCode,
reason: error.reason
}));
}
reject(new RPCError({
errorCode: RPCErrorCode.UNKNOWN_ERROR,
}));;
}).catch((e) => {
if (e instanceof RPCError) {
return reject(e);
}
reject(new RPCError({
errorCode: RPCErrorCode.UNKNOWN_ERROR,
reason: e instanceof Error ? e.message : `${e}`
}))
})
}
this.callResponseEmitter.once(packet.id, handleCallResponsePacket);
/** send call request */
this.connection.send(packet);
return promise.finally(() => {
this.callResponseEmitter.removeAllListeners(packet.id);
cancelTimeoutTimer();
});
}
private async onCallRequest(packet: RPCPacket): Promise<RPCPacket> {
let request = parseCallPacket(packet);
const instance = this;
function setRequest(req: unknown) {
if (!isObject(req)) {
return;
}
const { fnPath, args } = req;
if (typeof fnPath === 'string' && isArray(args)) {
request = {
...request,
fnPath,
args,
};
}
}
const hookRunnerContext = {
session: this,
request,
setRequest,
}
async function makeResponse(o: Parameters<typeof makeCallResponsePacket>[0]) {
const hookRunner = createHookRunner(instance.rpcHandler.getPlugins(), 'onCallIncoming');
try {
await hookRunner({
...hookRunnerContext,
response: o,
setResponse: (res: unknown) => {
/** TODO: Implement stricter validation of the response object in the future */
o = res as Parameters<typeof makeCallResponsePacket>[0];
},
})
} catch (error) {
return makeCallResponsePacket({
...o,
...(error instanceof RPCError ? {
errorCode: error.errorCode,
reason: error.reason,
} : {
errorCode: RPCErrorCode.SERVER_ERROR,
reason: `${error}`
})
})
}
return makeCallResponsePacket({
...o,
})
}
const hookRunner = createHookRunner(this.rpcHandler.getPlugins(), 'onCallIncomingBefore');
try {
await hookRunner(hookRunnerContext);
} catch (error) {
return makeResponse({
status: 'error',
requestPacket: packet,
...(error instanceof RPCError ? {
errorCode: error.errorCode,
reason: error.reason
} : {
errorCode: RPCErrorCode.SERVER_ERROR,
reason: `${error}`,
})
})
}
if (request === null) {
return makeResponse({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.CALL_PROTOCOL_ERROR,
});
}
// call the function
const provider = this.rpcHandler.getProvider();
if (!provider) {
return makeResponse({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.PROVIDER_NOT_AVAILABLE,
});
}
const { fnPath, args } = request;
const fnRes = getProviderFunction(provider, fnPath);
if (!fnRes) {
return makeResponse({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.METHOD_NOT_FOUND,
})
}
const [fn, fnThis] = fnRes;
const { enableMethodProtection } = this.rpcHandler.getConfig();
if (enableMethodProtection) {
if (!isPublicMethod(fn)) {
return makeResponse({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.METHOD_PROTECTED,
})
}
}
try {
const result = await fn.bind(new Proxy(fnThis, {
get(target, p, receiver) {
if (p === RPCContextKey) {
return new RPCContextConstractor(instance);
}
return Reflect.get(target, p, receiver);
},
}))(...args);
return makeResponse({
status: 'success',
requestPacket: packet,
data: result,
})
} catch (error) {
return makeResponse({
status: 'error',
requestPacket: packet,
errorCode: RPCErrorCode.SERVER_ERROR,
...(error instanceof RPCError ? {
errorCode: error.errorCode,
reason: error.reason,
} : {})
})
}
}
}

View File

@@ -12,10 +12,16 @@ export { SocketConnection } from "./core/SocketConnection";
export { SocketServer } from "./core/SocketServer";
export { EventEmitter } from "./utils/EventEmitter";
export { AbstractRPCPlugin, createHookRunner } from './core/RPCPlugin';
export type { RPCPlugin, RPCPluginHooks, RPCPluginHooksCtx } from './core/RPCPlugin';
export { getRPCContext } from './core/RPCContext';
export { injectSocketClient } from "./core/SocketClient";
export { injectSocketServer } from "./core/SocketServer";
import { injectSocketIOImplements } from "./implements/socket.io";
export { publicMethod, isPublicMethod, markAsPublicMethod } from './utils/utils';
injectSocketIOImplements();
export {

View File

@@ -4,13 +4,17 @@ export const makeId = () => md5(`${Date.now()}${Math.random()}`);
export const isObject = (v: unknown): v is Record<string, any> => typeof v === 'object' && v !== null;
export const isArray = (v: unknown): v is Array<unknown> => Array.isArray(v);
export const isString = (v: unknown): v is string => typeof v === 'string';
export const isFunction = (v: unknown): v is Function => typeof v === 'function';
export type ObjectType = Record<string, any>;
export type ToDeepPromise<T> = {
[K in keyof T]: T[K] extends (...args: infer P) => infer R
? (...args: P) => Promise<R>
? (...args: P) => Promise<Awaited<R>>
: T[K] extends object
? ToDeepPromise<T[K]>
: T[K]
@@ -41,4 +45,56 @@ export async function getRandomAvailablePort() {
server.listen(0);
})
}
const publicMethodMap = new WeakMap<Function, boolean>();
export function publicMethod(target: any, propertyKey: string, descriptor: PropertyDescriptor) {
publicMethodMap.set(descriptor.value, true);
};
export function isPublicMethod(target: Function) {
return publicMethodMap.get(target);
};
export function markAsPublicMethod<T extends Function | Record<any, unknown> | unknown>(obj: T, options?: {
deep?: boolean
}): T {
const accessed = new Set();
function markAs(obj: Function | Record<any, unknown> | unknown) {
if (accessed.has(obj)) {
return;
}
if (isFunction(obj)) {
publicMethodMap.set(obj, true);
} else if (isObject(obj)) {
accessed.add(obj);
Object.values(obj).forEach(subObj => {
if (isFunction(subObj)) {
publicMethodMap.set(subObj, true);
}
if (options?.deep && isObject(subObj)) {
markAs(subObj);
}
});
}
}
markAs(obj);
return obj;
}
export function createDeferrablePromise<T = unknown>() {
let resolve!: (value: T | PromiseLike<T>) => void;
let reject!: (reason?: unknown) => void;
const promise = new Promise<T>((res, rej) => {
resolve = res;
reject = rej;
});
return {
promise,
resolve,
reject
};
}

View File

@@ -12,8 +12,8 @@
"target": "es2016", /* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */
// "lib": [], /* Specify a set of bundled library declaration files that describe the target runtime environment. */
// "jsx": "preserve", /* Specify what JSX code is generated. */
// "experimentalDecorators": true, /* Enable experimental support for legacy experimental decorators. */
// "emitDecoratorMetadata": true, /* Emit design-type metadata for decorated declarations in source files. */
"experimentalDecorators": true, /* Enable experimental support for legacy experimental decorators. */
"emitDecoratorMetadata": true, /* Emit design-type metadata for decorated declarations in source files. */
// "jsxFactory": "", /* Specify the JSX factory function used when targeting React JSX emit, e.g. 'React.createElement' or 'h'. */
// "jsxFragmentFactory": "", /* Specify the JSX Fragment reference used for fragments when targeting React JSX emit e.g. 'React.Fragment' or 'Fragment'. */
// "jsxImportSource": "", /* Specify module specifier used to import the JSX factory functions when using 'jsx: react-jsx*'. */