Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/fork-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async function streamText(
): Promise<string> {
let text = '';
for await (const msg of session.stream(prompt)) {
if (msg.type === DroidMessageType.AssistantTextDelta) {
if (msg.type === DroidMessageType.Assistant) {
text += msg.text;
}
}
Expand Down
2 changes: 1 addition & 1 deletion examples/interrupt-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ try {
for await (const msg of session.stream(
'Write a long history of computing.'
)) {
if (msg.type !== DroidMessageType.AssistantTextDelta) {
if (msg.type !== DroidMessageType.Assistant) {
continue;
}

Expand Down
47 changes: 47 additions & 0 deletions examples/message-id-run.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/**
* Manual smoke test for `run(prompt, { messageId })`.
*
* Usage:
* npx tsx examples/message-id-run.ts
* npx tsx examples/message-id-run.ts "Reply with exactly: RUN_OK"
*/

import { randomUUID } from 'node:crypto';

import { DroidMessageType, run } from '../src/index.js';

async function main(): Promise<void> {
const prompt =
process.argv.slice(2).join(' ') || 'Reply with exactly: RUN_OK';
const messageId = `sdk-run-${randomUUID()}`;

const result = await run(prompt, {
cwd: process.cwd(),
messageId,
});
const userMessage = result.messages.find(
(msg) => msg.type === DroidMessageType.User
);
const observedUserMessageId = userMessage?.message.id;

if (observedUserMessageId !== messageId) {
throw new Error(
`Expected user messageId ${messageId}, got ${observedUserMessageId ?? 'none'}`
);
}

console.log(
JSON.stringify({
api: 'run',
sessionId: result.sessionId,
messageId,
observedUserMessageId,
text: result.text,
})
);
}

main().catch((err: unknown) => {
console.error('Error:', err);
process.exit(1);
});
56 changes: 56 additions & 0 deletions examples/message-id-stream.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/**
* Manual smoke test for `session.stream(prompt, { messageId })`.
*
* Usage:
* npx tsx examples/message-id-stream.ts
* npx tsx examples/message-id-stream.ts "Reply with exactly: STREAM_OK"
*/

import { randomUUID } from 'node:crypto';

import { DroidMessageType, createSession } from '../src/index.js';

async function main(): Promise<void> {
const prompt =
process.argv.slice(2).join(' ') || 'Reply with exactly: STREAM_OK';
const messageId = `sdk-stream-${randomUUID()}`;
const session = await createSession({ cwd: process.cwd() });

try {
let observedUserMessageId: string | undefined;
let text = '';

for await (const msg of session.stream(prompt, { messageId })) {
if (msg.type === DroidMessageType.User) {
observedUserMessageId = msg.message.id;
} else if (msg.type === DroidMessageType.Assistant) {
text += msg.text;
} else if (msg.type === DroidMessageType.Result && text.length === 0) {
text = msg.text;
}
}

if (observedUserMessageId !== messageId) {
throw new Error(
`Expected user messageId ${messageId}, got ${observedUserMessageId ?? 'none'}`
);
}

console.log(
JSON.stringify({
api: 'session.stream',
sessionId: session.sessionId,
messageId,
observedUserMessageId,
text,
})
);
} finally {
await session.close();
}
}

main().catch((err: unknown) => {
console.error('Error:', err);
process.exit(1);
});
2 changes: 1 addition & 1 deletion examples/multi-turn-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async function streamText(
): Promise<string> {
let text = '';
for await (const msg of session.stream(prompt)) {
if (msg.type === DroidMessageType.AssistantTextDelta) {
if (msg.type === DroidMessageType.Assistant) {
text += msg.text;
}
}
Expand Down
2 changes: 1 addition & 1 deletion examples/sdk-mcp-tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ try {
for await (const msg of session.stream(
'Use the favorite_number tool for Ada and tell me the answer.'
)) {
if (msg.type === DroidMessageType.AssistantTextDelta) {
if (msg.type === DroidMessageType.Assistant) {
process.stdout.write(msg.text);
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ import {
ListToolsResultSchema,
ListSkillsResultSchema,
LoadSessionResultSchema,
MessageIdSchema,
RemoveMcpServerResultSchema,
SubmitBugReportResultSchema,
SubmitMcpAuthCodeResultSchema,
Expand Down Expand Up @@ -230,6 +231,10 @@ export class DroidClient {
>
>
): Promise<AddUserMessageResult> {
if (params.messageId !== undefined) {
MessageIdSchema.parse(params.messageId);
}

return this._sessionRpc(
DroidServerMethod.ADD_USER_MESSAGE,
params,
Expand Down
13 changes: 10 additions & 3 deletions src/schemas/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,17 @@ export const OutputFormatSchema = z

export type OutputFormat = z.infer<typeof OutputFormatSchema>;

export const MessageIdSchema = z
.string()
.max(512, 'messageId must be at most 512 characters')
.refine((value) => value.trim().length > 0, {
message: 'messageId must be a non-empty string',
});

/** Parameters for droid.add_user_message request. */
export const AddUserMessageRequestParamsSchema = z
.object({
messageId: z.string().optional(),
messageId: MessageIdSchema.optional(),
text: z.string(),
images: z.array(Base64ImageSourceSchema).optional(),
files: z.array(DocumentSourceSchema).optional(),
Expand Down Expand Up @@ -557,7 +564,7 @@ export type RewindEvictedFile = z.infer<typeof RewindEvictedFileSchema>;
/** Parameters for droid.get_rewind_info request. */
export const GetRewindInfoRequestParamsSchema = z
.object({
messageId: z.string(),
messageId: MessageIdSchema,
})
.passthrough();

Expand All @@ -568,7 +575,7 @@ export type GetRewindInfoRequestParams = z.infer<
/** Parameters for droid.execute_rewind request. */
export const ExecuteRewindRequestParamsSchema = z
.object({
messageId: z.string(),
messageId: MessageIdSchema,
filesToRestore: z.array(RewindFileSnapshotSchema),
filesToDelete: z.array(RewindFileCreationSchema),
forkTitle: z.string(),
Expand Down
4 changes: 4 additions & 0 deletions src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ export interface ResumeSessionOptions extends Pick<
}

export interface MessageOptions {
messageId?: string;
images?: Base64ImageSource[];
files?: DocumentSource[];
outputFormat?: OutputFormat;
Expand Down Expand Up @@ -209,6 +210,9 @@ export class DroidSession {
try {
await Promise.race([
this._client.addUserMessage({
...(options?.messageId !== undefined && {
messageId: options.messageId,
}),
text: prompt,
images: options?.images,
files: options?.files,
Expand Down
19 changes: 19 additions & 0 deletions tests/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,25 @@ export class InMemoryTransport implements DroidClientTransport {
}
}

export function findSentRequestParams(
transport: InMemoryTransport,
method: string
): Record<string, unknown> {
const message = transport.sentMessages.find((sentMessage) => {
return sentMessage['method'] === method;
});
if (!message) {
throw new Error(`Expected request for method ${method}`);
}

const params = message['params'];
if (!params || typeof params !== 'object' || Array.isArray(params)) {
throw new Error(`Expected params for method ${method}`);
}

return params as Record<string, unknown>;
}

export function makeSuccessResponse(
id: string,
result: JsonRpcTestMessage = {}
Expand Down
13 changes: 7 additions & 6 deletions tests/run.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
} from '../src/schemas/index.js';
import {
InMemoryTransport,
findSentRequestParams,
makeErrorResponse,
makeSessionNotification,
makeSuccessResponse,
Expand Down Expand Up @@ -78,6 +79,7 @@ describe('run()', () => {
machineId: 'machine-1',
modelId: 'model-1',
reasoningEffort: ReasoningEffort.High,
messageId: 'run-message-id',
images: [{ type: 'base64', data: 'image-data', mediaType: 'image/png' }],
files: [
{
Expand All @@ -100,12 +102,11 @@ describe('run()', () => {
expect(initParams['modelId']).toBe('model-1');
expect(initParams['reasoningEffort']).toBe(ReasoningEffort.High);

const addMsg = transport.sentMessages.find(
(message) =>
(message as Record<string, unknown>)['method'] ===
DroidServerMethod.ADD_USER_MESSAGE
) as Record<string, unknown>;
const addParams = addMsg['params'] as Record<string, unknown>;
const addParams = findSentRequestParams(
transport,
DroidServerMethod.ADD_USER_MESSAGE
);
expect(addParams['messageId']).toBe('run-message-id');
expect(addParams['text']).toBe('Describe these inputs');
expect(addParams['images']).toEqual([
{ type: 'base64', data: 'image-data', mediaType: 'image/png' },
Expand Down
81 changes: 81 additions & 0 deletions tests/session.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import {
InMemoryTransport,
collectStreamText,
findLastResult,
findSentRequestParams,
makeErrorResponse,
makeSessionNotification,
makeSuccessResponse,
Expand Down Expand Up @@ -362,6 +363,86 @@ describe('DroidSession', () => {
await session.close();
});

it('passes custom messageId in addUserMessage RPC params', async () => {
const transport = new InMemoryTransport();
await transport.connect();

setupFullResponder(transport, 'sess-stream-message-id');

const session = await createSession({ transport });

for await (const _msg of session.stream('Hello', {
messageId: 'caller-message-id',
})) {
void _msg;
}

const addParams = findSentRequestParams(
transport,
DroidServerMethod.ADD_USER_MESSAGE
);
expect(addParams['messageId']).toBe('caller-message-id');
expect(addParams['text']).toBe('Hello');

await session.close();
});

it('omits messageId from addUserMessage RPC params by default', async () => {
const transport = new InMemoryTransport();
await transport.connect();

setupFullResponder(transport, 'sess-stream-default-message-id');

const session = await createSession({ transport });

for await (const _msg of session.stream('Hello')) {
void _msg;
}

const addParams = findSentRequestParams(
transport,
DroidServerMethod.ADD_USER_MESSAGE
);
expect(addParams).not.toHaveProperty('messageId');
expect(addParams['text']).toBe('Hello');

await session.close();
});

it.each([
['empty string', ''],
['whitespace-only string', ' '],
['too-long string', 'x'.repeat(513)],
['non-string value', 123],
])('rejects invalid messageId: %s', async (_label, messageId) => {
const transport = new InMemoryTransport();
await transport.connect();

setupFullResponder(transport, 'sess-stream-invalid-message-id');

const session = await createSession({ transport });

await expect(
(async () => {
for await (const _msg of session.stream('Hello', {
messageId: messageId as string,
})) {
void _msg;
}
})()
).rejects.toThrow();

expect(
transport.sentMessages.some(
(message) =>
(message as Record<string, unknown>)['method'] ===
DroidServerMethod.ADD_USER_MESSAGE
)
).toBe(false);

await session.close();
});

it('defaults to message-level events and opts into partial events', async () => {
const createStreamingSession = async (
sessionId: string
Expand Down
Loading