diff --git a/.changeset/blue-kids-cough.md b/.changeset/blue-kids-cough.md new file mode 100644 index 000000000..8c6d4d01a --- /dev/null +++ b/.changeset/blue-kids-cough.md @@ -0,0 +1,5 @@ +--- +"@livekit/agents": patch +--- + +Require session management grant for remote sessions diff --git a/agents/src/voice/agent_session.ts b/agents/src/voice/agent_session.ts index 46901b4c0..4cdb22578 100644 --- a/agents/src/voice/agent_session.ts +++ b/agents/src/voice/agent_session.ts @@ -448,7 +448,7 @@ export class AgentSession< this._roomIO.start(); - const transport = new RoomSessionTransport(room, this._roomIO); + const transport = new RoomSessionTransport(room); this.sessionHost = new SessionHost(transport); this.sessionHost.registerSession(this); if (inputOptions?.textEnabled !== false) { diff --git a/agents/src/voice/index.ts b/agents/src/voice/index.ts index 808ac88c1..a8f826385 100644 --- a/agents/src/voice/index.ts +++ b/agents/src/voice/index.ts @@ -12,6 +12,7 @@ export { export * from './avatar/index.js'; export * from './background_audio.js'; export { + type IncomingMessage, type TextInputCallback, type TextInputEvent, RemoteSession, diff --git a/agents/src/voice/remote_session.ts b/agents/src/voice/remote_session.ts index b8cb359a8..40f5e5379 100644 --- a/agents/src/voice/remote_session.ts +++ b/agents/src/voice/remote_session.ts @@ -83,27 +83,29 @@ export type RemoteSessionCallbacks = { export abstract class SessionTransport { async start(): Promise {} - abstract sendMessage(msg: pb.AgentSessionMessage): Promise; + abstract sendMessage( + msg: pb.AgentSessionMessage, + opts?: { destinationIdentity?: string }, + ): Promise; abstract close(): Promise; - abstract [Symbol.asyncIterator](): AsyncIterator; + abstract [Symbol.asyncIterator](): AsyncIterator; +} + +export interface IncomingMessage { + message: pb.AgentSessionMessage; + senderIdentity?: string; } export class RoomSessionTransport extends SessionTransport { private readonly room: Room; private handlerRegistered = false; private closed = false; - private pendingMessages: pb.AgentSessionMessage[] = []; - private waitingResolve: ((value: IteratorResult) => void) | null = null; - private roomIO: RoomIO; + private pendingMessages: IncomingMessage[] = []; + private waitingResolve: ((value: IteratorResult) => void) | null = null; - constructor(room: Room, roomIO: RoomIO) { + constructor(room: Room, _roomIO?: RoomIO) { super(); this.room = room; - this.roomIO = roomIO; - } - - private getRemoteIdentity() { - return this.roomIO.linkedParticipant?.identity; } override async start(): Promise { @@ -113,15 +115,42 @@ export class RoomSessionTransport extends SessionTransport { } private onByteStream = (reader: ByteStreamReader, participantInfo: { identity: string }) => { - if (this.getRemoteIdentity() && participantInfo.identity !== this.getRemoteIdentity()) { + if (!this.shouldAcceptMessage(participantInfo.identity)) { return; } - this.readStream(reader).catch((e) => { + this.readStream(reader, participantInfo.identity).catch((e) => { log().warn({ error: e }, 'failed to read binary stream message'); }); }; - private async readStream(reader: ByteStreamReader): Promise { + protected shouldAcceptMessage(identity: string): boolean { + if (this.canManage(identity)) { + return true; + } + log().debug( + { participant: identity }, + 'ignoring session message from participant without canManageAgentSession grant', + ); + return false; + } + + private canManage(identity: string): boolean { + return ( + this.room.remoteParticipants.get(identity)?.info.permission?.canManageAgentSession === true + ); + } + + private authorizedIdentities(): string[] { + const identities: string[] = []; + for (const [identity, participant] of this.room.remoteParticipants.entries()) { + if (participant.info.permission?.canManageAgentSession === true) { + identities.push(identity); + } + } + return identities; + } + + private async readStream(reader: ByteStreamReader, senderIdentity: string): Promise { try { const chunks = await reader.readAll(); let totalLength = 0; @@ -135,7 +164,7 @@ export class RoomSessionTransport extends SessionTransport { offset += chunk.length; } const msg = pb.AgentSessionMessage.fromBinary(data); - this.enqueue(msg); + this.enqueue({ message: msg, senderIdentity }); } catch (e) { if (!this.closed) { log().warn({ error: e }, 'failed to parse binary stream message'); @@ -143,20 +172,25 @@ export class RoomSessionTransport extends SessionTransport { } } - override async sendMessage(msg: pb.AgentSessionMessage): Promise { + override async sendMessage( + msg: pb.AgentSessionMessage, + opts: { destinationIdentity?: string } = {}, + ): Promise { if (this.closed || !this.room.isConnected) return; + const destinationIdentities = this.getDestinationIdentities(opts.destinationIdentity); + if (destinationIdentities?.length === 0) return; + try { const data = msg.toBinary(); - const opts: Record = { + const streamOpts: Record = { topic: TOPIC_SESSION_MESSAGES, name: shortuuid('AS_'), }; - const remoteIdentity = this.getRemoteIdentity(); - if (remoteIdentity) { - opts.destinationIdentities = [remoteIdentity]; + if (destinationIdentities) { + streamOpts.destinationIdentities = destinationIdentities; } - const writer = await this.room.localParticipant!.streamBytes(opts); + const writer = await this.room.localParticipant!.streamBytes(streamOpts); await writer.write(new Uint8Array(data)); await writer.close(); } catch (e) { @@ -164,6 +198,13 @@ export class RoomSessionTransport extends SessionTransport { } } + protected getDestinationIdentities(destinationIdentity?: string): string[] | undefined { + if (destinationIdentity) { + return this.canManage(destinationIdentity) ? [destinationIdentity] : []; + } + return this.authorizedIdentities(); + } + override async close(): Promise { if (this.closed) return; this.closed = true; @@ -179,14 +220,14 @@ export class RoomSessionTransport extends SessionTransport { if (this.waitingResolve) { this.waitingResolve({ - value: undefined as unknown as pb.AgentSessionMessage, + value: undefined as unknown as IncomingMessage, done: true, }); this.waitingResolve = null; } } - private enqueue(msg: pb.AgentSessionMessage): void { + private enqueue(msg: IncomingMessage): void { if (this.closed) return; if (this.waitingResolve) { @@ -198,12 +239,12 @@ export class RoomSessionTransport extends SessionTransport { } } - override [Symbol.asyncIterator](): AsyncIterator { + override [Symbol.asyncIterator](): AsyncIterator { return { - next: (): Promise> => { + next: (): Promise> => { if (this.closed && this.pendingMessages.length === 0) { return ThrowsPromise.resolve({ - value: undefined as unknown as pb.AgentSessionMessage, + value: undefined as unknown as IncomingMessage, done: true, }); } @@ -213,14 +254,14 @@ export class RoomSessionTransport extends SessionTransport { return ThrowsPromise.resolve({ value: pending, done: false }); } - return new ThrowsPromise, never>((resolve) => { + return new ThrowsPromise, never>((resolve) => { this.waitingResolve = resolve; }); }, - return: (): Promise> => { + return: (): Promise> => { this.close(); return ThrowsPromise.resolve({ - value: undefined as unknown as pb.AgentSessionMessage, + value: undefined as unknown as IncomingMessage, done: true, }); }, @@ -228,6 +269,29 @@ export class RoomSessionTransport extends SessionTransport { } } +class LinkedParticipantSessionTransport extends RoomSessionTransport { + private readonly roomIO: RoomIO; + + constructor(room: Room, roomIO: RoomIO) { + super(room); + this.roomIO = roomIO; + } + + protected override shouldAcceptMessage(identity: string): boolean { + const remoteIdentity = this.getRemoteIdentity(); + return !remoteIdentity || identity === remoteIdentity; + } + + protected override getDestinationIdentities(destinationIdentity?: string): string[] | undefined { + const remoteIdentity = destinationIdentity ?? this.getRemoteIdentity(); + return remoteIdentity ? [remoteIdentity] : undefined; + } + + private getRemoteIdentity(): string | undefined { + return this.roomIO.linkedParticipant?.identity; + } +} + // =========================================================================== // Enum maps // =========================================================================== @@ -559,12 +623,11 @@ export class SessionHost { private async recvLoop(): Promise { try { - for await (const msg of this.transport) { + for await (const incoming of this.transport) { + const msg = incoming.message; if (msg.message.case === 'request') { if (this.session) { - this.trackTask( - Task.from(async () => this.handleRequestSafe(msg.message.value as pb.SessionRequest)), - ); + this.trackTask(Task.from(async () => this.handleRequestSafe(incoming))); } } } @@ -726,9 +789,10 @@ export class SessionHost { }); } - private async handleRequestSafe(req: pb.SessionRequest): Promise { + private async handleRequestSafe(incoming: IncomingMessage): Promise { + const req = incoming.message.message.value as pb.SessionRequest; try { - await this.handleRequest(req); + await this.handleRequest(req, incoming.senderIdentity); } catch (e) { log().warn({ error: e, requestId: req.requestId }, 'error handling session request'); try { @@ -741,75 +805,104 @@ export class SessionHost { }), }, }); - await this.transport.sendMessage(resp); + await this.transport.sendMessage(resp, { destinationIdentity: incoming.senderIdentity }); } catch (e) { log().debug({ error: e }, 'failed to send error response'); } } } - private async handleRequest(req: pb.SessionRequest): Promise { + private async handleRequest(req: pb.SessionRequest, destinationIdentity?: string): Promise { if (!this.session) return; switch (req.request.case) { case 'ping': - return this.sendResponse(req.requestId, { - case: 'pong', - value: new pb.SessionResponse_Pong(), - }); + return this.sendResponse( + req.requestId, + { + case: 'pong', + value: new pb.SessionResponse_Pong(), + }, + undefined, + destinationIdentity, + ); case 'getChatHistory': - return this.handleGetChatHistory(req.requestId); + return this.handleGetChatHistory(req.requestId, destinationIdentity); case 'getAgentInfo': - return this.handleGetAgentInfo(req.requestId); + return this.handleGetAgentInfo(req.requestId, destinationIdentity); case 'runInput': - return this.handleRunInput(req.requestId, req.request.value); + return this.handleRunInput(req.requestId, req.request.value, destinationIdentity); case 'getSessionState': - return this.handleGetSessionState(req.requestId); + return this.handleGetSessionState(req.requestId, destinationIdentity); case 'getRtcStats': - return this.sendResponse(req.requestId, { - case: 'getRtcStats', - value: new pb.SessionResponse_GetRTCStatsResponse({ - publisherStats: [], - subscriberStats: [], - }), - }); + return this.sendResponse( + req.requestId, + { + case: 'getRtcStats', + value: new pb.SessionResponse_GetRTCStatsResponse({ + publisherStats: [], + subscriberStats: [], + }), + }, + undefined, + destinationIdentity, + ); case 'getSessionUsage': - return this.handleGetSessionUsage(req.requestId); + return this.handleGetSessionUsage(req.requestId, destinationIdentity); case 'getFrameworkInfo': - return this.sendResponse(req.requestId, { - case: 'getFrameworkInfo', - value: new pb.SessionResponse_GetFrameworkInfoResponse({ - sdk: 'js', - sdkVersion: version, - }), - }); + return this.sendResponse( + req.requestId, + { + case: 'getFrameworkInfo', + value: new pb.SessionResponse_GetFrameworkInfoResponse({ + sdk: 'js', + sdkVersion: version, + }), + }, + undefined, + destinationIdentity, + ); } } - private async handleGetChatHistory(requestId: string): Promise { + private async handleGetChatHistory( + requestId: string, + destinationIdentity?: string, + ): Promise { const items = chatItemsToProto(this.session!.history.items); - return this.sendResponse(requestId, { - case: 'getChatHistory', - value: new pb.SessionResponse_GetChatHistoryResponse({ items }), - }); + return this.sendResponse( + requestId, + { + case: 'getChatHistory', + value: new pb.SessionResponse_GetChatHistoryResponse({ items }), + }, + undefined, + destinationIdentity, + ); } - private async handleGetAgentInfo(requestId: string): Promise { + private async handleGetAgentInfo(requestId: string, destinationIdentity?: string): Promise { const agent = this.session!.currentAgent; - return this.sendResponse(requestId, { - case: 'getAgentInfo', - value: new pb.SessionResponse_GetAgentInfoResponse({ - id: agent.id, - instructions: agent.instructions, - tools: toolNames(agent.toolCtx), - chatCtx: chatItemsToProto(agent.chatCtx.items), - }), - }); + return this.sendResponse( + requestId, + { + case: 'getAgentInfo', + value: new pb.SessionResponse_GetAgentInfoResponse({ + id: agent.id, + instructions: agent.instructions, + tools: toolNames(agent.toolCtx), + chatCtx: chatItemsToProto(agent.chatCtx.items), + }), + }, + undefined, + destinationIdentity, + ); } private async handleRunInput( requestId: string, input: pb.SessionRequest_RunInput, + destinationIdentity?: string, ): Promise { const text = input.text; let items: pb.ChatContext_ChatItem[] = []; @@ -845,43 +938,61 @@ export class SessionHost { value: new pb.SessionResponse_RunInputResponse({ items }), }, error, + destinationIdentity, ); } - private async handleGetSessionState(requestId: string): Promise { + private async handleGetSessionState( + requestId: string, + destinationIdentity?: string, + ): Promise { const agent = this.session!.currentAgent; const startedAt = this.session!._startedAt ?? Date.now(); - return this.sendResponse(requestId, { - case: 'getSessionState', - value: new pb.SessionResponse_GetSessionStateResponse({ - agentState: AGENT_STATE_MAP[this.session!.agentState], - userState: USER_STATE_MAP[this.session!.userState], - agentId: agent.id, - options: protoSerializeOptions({ - turnHandling: this.session!.sessionOptions.turnHandling, - maxToolSteps: this.session!.sessionOptions.maxToolSteps, - userAwayTimeout: this.session!.sessionOptions.userAwayTimeout, - useTtsAlignedTranscript: this.session!.sessionOptions.useTtsAlignedTranscript, + return this.sendResponse( + requestId, + { + case: 'getSessionState', + value: new pb.SessionResponse_GetSessionStateResponse({ + agentState: AGENT_STATE_MAP[this.session!.agentState], + userState: USER_STATE_MAP[this.session!.userState], + agentId: agent.id, + options: protoSerializeOptions({ + turnHandling: this.session!.sessionOptions.turnHandling, + maxToolSteps: this.session!.sessionOptions.maxToolSteps, + userAwayTimeout: this.session!.sessionOptions.userAwayTimeout, + useTtsAlignedTranscript: this.session!.sessionOptions.useTtsAlignedTranscript, + }), + createdAt: msToTimestamp(startedAt), }), - createdAt: msToTimestamp(startedAt), - }), - }); + }, + undefined, + destinationIdentity, + ); } - private async handleGetSessionUsage(requestId: string): Promise { - return this.sendResponse(requestId, { - case: 'getSessionUsage', - value: new pb.SessionResponse_GetSessionUsageResponse({ - usage: sessionUsageToProto(this.session!.usage), - createdAt: nowTimestamp(), - }), - }); + private async handleGetSessionUsage( + requestId: string, + destinationIdentity?: string, + ): Promise { + return this.sendResponse( + requestId, + { + case: 'getSessionUsage', + value: new pb.SessionResponse_GetSessionUsageResponse({ + usage: sessionUsageToProto(this.session!.usage), + createdAt: nowTimestamp(), + }), + }, + undefined, + destinationIdentity, + ); } private async sendResponse( requestId: string, response: pb.SessionResponse['response'], error?: string, + destinationIdentity?: string, ): Promise { await this.transport.sendMessage( new pb.AgentSessionMessage({ @@ -890,6 +1001,7 @@ export class SessionHost { value: new pb.SessionResponse({ requestId, response, error }), }, }), + { destinationIdentity }, ); } @@ -921,7 +1033,7 @@ export class RemoteSession extends (EventEmitter as new () => TypedEventEmitter< } static fromRoom(room: Room, roomIO: RoomIO): RemoteSession { - const transport = new RoomSessionTransport(room, roomIO); + const transport = new LinkedParticipantSessionTransport(room, roomIO); return new RemoteSession(transport); } @@ -955,7 +1067,8 @@ export class RemoteSession extends (EventEmitter as new () => TypedEventEmitter< private async recvLoop(): Promise { try { - for await (const msg of this.transport) { + for await (const incoming of this.transport) { + const msg = incoming.message; switch (msg.message.case) { case 'event': this.dispatchEvent(msg.message.value);