Skip to content

Commit b9a28f1

Browse files
authored
mcp sampling (#1722)
* Add support for MCP model provider and enhance debugging capabilities * Add MCP Client Sampling configuration to language model providers * Add model specification to emojifier script * Refactor MCP server initialization and enhance client sampling registration * Fix resource handling in MCP server and update resource manager methods * Implement MCP sampling language model and refactor MCP server client registration * Add parent language model support to MCP server and worker * Refactor MCP server and worker to enhance message handling and support sampling language model * Enhance debug logging for chatCompletion messages in MCP server and worker * Refactor message handling in createWorkerLanguageModel for improved clarity and maintainability
1 parent 0c45bdc commit b9a28f1

16 files changed

Lines changed: 316 additions & 63 deletions

File tree

.vscode/settings.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
"dbgi",
7272
"dbgp",
7373
"dbgr",
74+
"dbgs",
7475
"dbgt",
7576
"ddir",
7677
"debugify",
@@ -235,6 +236,7 @@
235236
"oldsrc",
236237
"ollama",
237238
"olmo",
239+
"oninitialized",
238240
"onnx",
239241
"onvsc",
240242
"openai",

packages/api/src/api.ts

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,19 @@ export async function run(
5555
/**
5656
* Handles messages
5757
*/
58-
onMessage?: (data: { type: "resourceChange" } & Resource) => Awaitable<void>;
58+
onMessage?: (
59+
data: { type: "resourceChange" } & Resource,
60+
postMessage: (data: any) => void,
61+
) => Awaitable<void>;
62+
/**
63+
* Enable client language model as parent.
64+
*/
65+
parentLanguageModel?: boolean;
5966
},
6067
): Promise<GenerationResult> {
6168
if (!scriptId) throw new Error("scriptId is required");
6269
dbg(`run ${scriptId}`);
70+
// eslint-disable-next-line no-param-reassign
6371
if (typeof files === "string") files = [files];
6472

6573
const { signal, onMessage, ...rest } = options || {};
@@ -77,7 +85,7 @@ export async function run(
7785
dbg(`start ${workerJs}`);
7886
const worker = new Worker(workerJs, { workerData, name: options?.label });
7987
return new Promise((resolve, reject) => {
80-
const abort = () => {
88+
const abort = (): void => {
8189
if (worker) {
8290
dbg(`abort`);
8391
reject(new Error("aborted")); // fail early
@@ -92,12 +100,15 @@ export async function run(
92100
signal?.removeEventListener("abort", abort);
93101
resolve(res.result);
94102
} else if (onMessage) {
95-
await onMessage(res);
103+
await onMessage(res, (data) => {
104+
dbg(`postMessage %O`, data);
105+
worker.postMessage(data);
106+
});
96107
} else {
97108
dbg(`unknown message type ${type}`);
98109
}
99110
});
100-
worker.on("error", (reason) => {
111+
worker.on("error", (reason: string) => {
101112
dbg(`error ${reason}`);
102113
signal?.removeEventListener("abort", abort);
103114
reject(reason);

packages/api/src/worker.ts

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@ import { delay } from "es-toolkit";
66
import { NodeHost } from "@genaiscript/runtime";
77
import {
88
RESOURCE_CHANGE,
9+
genaiscriptDebug,
910
installGlobals,
1011
overrideStdoutWithStdErr,
1112
runtimeHost,
13+
createWorkerLanguageModel,
1214
} from "@genaiscript/core";
1315
import type { Resource } from "@genaiscript/core";
1416
import { runScriptInternal } from "./run.js";
17+
const dbg = genaiscriptDebug("worker");
1518

1619
/**
1720
* Handles worker thread execution based on the provided data type.
@@ -28,12 +31,13 @@ import { runScriptInternal } from "./run.js";
2831
* - Handles resource change events and communicates them to the parent thread.
2932
* - Ensures compatibility with Windows by setting the SystemRoot environment variable.
3033
*/
31-
export async function worker() {
34+
export async function worker(): Promise<void> {
3235
overrideStdoutWithStdErr();
3336
installGlobals();
3437
const { type, ...data } = workerData as {
3538
type: string;
3639
};
40+
dbg(`worker data: %O`, data);
3741
await NodeHost.install(undefined, undefined); // Install NodeHost with environment options
3842
if (process.platform === "win32") {
3943
// https://github.com/Azure/azure-sdk-for-js/issues/32374
@@ -55,9 +59,13 @@ export async function worker() {
5559
const { scriptId, files, options } = data as {
5660
scriptId: string;
5761
files: string[];
58-
options: object;
62+
options: { parentLanguageModel?: boolean };
5963
};
60-
const { result } = await runScriptInternal(scriptId, files, options);
64+
if (options.parentLanguageModel) {
65+
dbg(`using parent language model`);
66+
runtimeHost.clientLanguageModel = createWorkerLanguageModel();
67+
}
68+
const { result } = await runScriptInternal(scriptId, files, options as any);
6169
await delay(0); // flush streams
6270
parentPort.postMessage({ type: "run", result });
6371
break;

packages/cli/src/mcpserver.ts

Lines changed: 69 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@ import {
99
deleteUndefinedValues,
1010
ensureDotGenaiscriptPath,
1111
errorMessage,
12+
genaiscriptDebug,
1213
logVerbose,
1314
logWarn,
1415
runtimeHost,
1516
setConsoleColors,
1617
splitMarkdownTextImageParts,
1718
toStrictJSONSchema,
19+
mcpRequestSample,
1820
} from "@genaiscript/core";
1921
import type {
2022
GenerationResult,
@@ -39,8 +41,8 @@ import type {
3941
import { applyRemoteOptions } from "./remote.js";
4042
import type { RemoteOptions } from "./remote.js";
4143
import { startProjectWatcher } from "./watch.js";
42-
import debug from "debug";
43-
const dbg = debug("genaiscript:mcp:server");
44+
import { workerData } from "worker_threads";
45+
const dbg = genaiscriptDebug("mcp:server");
4446

4547
/**
4648
* Starts the MCP server.
@@ -57,13 +59,14 @@ export async function startMcpServer(
5759
RemoteOptions & {
5860
startup?: string;
5961
},
60-
) {
62+
): Promise<void> {
6163
setConsoleColors(false);
6264
logVerbose(`mcp server: starting...`);
6365

6466
await ensureDotGenaiscriptPath();
6567
await applyRemoteOptions(options);
6668
const { startup } = options || {};
69+
let samplingSupported = false;
6770

6871
const watcher = await startProjectWatcher(options);
6972
logVerbose(`mcp server: watching ${watcher.cwd}`);
@@ -89,10 +92,30 @@ export async function startMcpServer(
8992
},
9093
},
9194
);
92-
watcher.addEventListener("change", async () => {
93-
logVerbose(`mcp server: tools changed`);
94-
await server.sendToolListChanged();
95-
});
95+
watcher.addEventListener(
96+
"change",
97+
async () => {
98+
logVerbose(`mcp server: tools changed`);
99+
await server.sendToolListChanged();
100+
},
101+
false,
102+
);
103+
const onMessage = async (data: any, postMessage: (data: any) => void) => {
104+
if (data.type === RESOURCE_CHANGE) {
105+
await runtimeHost.resources.upsertResource(data.reference, data.content);
106+
} else if (data.type === "chatCompletion") {
107+
if (!samplingSupported) throw new Error("Sampling not supported by client");
108+
// Handle chat completion messages if needed
109+
dbg(`chatCompletion message received: %O`, data);
110+
const { request, ...rest } = data;
111+
const response = await mcpRequestSample(server, data.request);
112+
const msg = { ...rest, response };
113+
dbg(`chatCompletion response: %O`, msg);
114+
postMessage(msg);
115+
} else {
116+
dbg(`unknown message type: ${data.type}`);
117+
}
118+
};
96119
server.setRequestHandler(ListToolsRequestSchema, async () => {
97120
dbg(`fetching scripts from watcher`);
98121
const scripts = await watcher.scripts();
@@ -112,14 +135,15 @@ export async function startMcpServer(
112135
properties: {},
113136
};
114137
const outputSchema = responseSchema ? toStrictJSONSchema(responseSchema) : undefined;
115-
if (accept !== "none")
138+
if (accept !== "none") {
116139
scriptSchema.properties.files = {
117140
type: "array",
118141
items: {
119142
type: "string",
120143
description: `Filename or globs relative to the workspace used by the script.${accept ? ` Accepts: ${accept}` : ""}`,
121144
},
122145
};
146+
}
123147
if (!description) logWarn(`script ${id} has no description`);
124148
return deleteUndefinedValues({
125149
name: id,
@@ -146,14 +170,16 @@ export async function startMcpServer(
146170
vars: vars as Record<string, string | number | boolean | object>,
147171
runTrace: false,
148172
outputTrace: false,
173+
parentLanguageModel: samplingSupported,
174+
onMessage,
149175
})) || { status: "error", error: { message: "run failed" } };
150176
dbg(`res: %s`, res.status);
151177
if (res.error) dbg(`error: %O`, res.error);
152178
const isError = res.status !== "success" || !!res.error;
153179
const text = res?.error?.message || (res.json ? JSON.stringify(res.json) : res.text) || "";
154180
dbg(`inlining images`);
155181
const parts = await splitMarkdownTextImageParts(text, {
156-
dir: res.env.runDir,
182+
dir: res.env?.runDir,
157183
convertToDataUri: true,
158184
});
159185
dbg(`parts: %O`, parts);
@@ -193,31 +219,42 @@ export async function startMcpServer(
193219
if (!resource) dbg(`resource not found: ${uri}`);
194220
return resource as ReadResourceResult;
195221
});
196-
runtimeHost.resources.addEventListener(CHANGE, async () => {
197-
await server.sendResourceListChanged();
198-
});
199-
runtimeHost.resources.addEventListener(RESOURCE_CHANGE, async (e) => {
200-
const ev = e as CustomEvent<Resource>;
201-
await server.sendResourceUpdated({
202-
uri: ev.detail.reference.uri,
203-
});
204-
});
222+
runtimeHost.resources.addEventListener(
223+
CHANGE,
224+
async () => {
225+
await server.sendResourceListChanged();
226+
},
227+
false,
228+
);
229+
runtimeHost.resources.addEventListener(
230+
RESOURCE_CHANGE,
231+
async (e) => {
232+
const ev = e as CustomEvent<Resource>;
233+
await server.sendResourceUpdated({
234+
uri: ev.detail.reference.uri,
235+
});
236+
},
237+
false,
238+
);
239+
240+
server.oninitialized = async () => {
241+
dbg(`server/client connection initialized`);
242+
// Check if client supports sampling
243+
const clientCapabilities = server.getClientCapabilities();
244+
dbg(`client capabilities: %O`, clientCapabilities);
245+
samplingSupported = !!clientCapabilities?.sampling;
246+
247+
if (startup) {
248+
logVerbose(`startup script: ${startup}`);
249+
await run(startup, [], {
250+
vars: {},
251+
parentLanguageModel: samplingSupported,
252+
onMessage,
253+
});
254+
}
255+
};
205256

206257
const transport = new StdioServerTransport();
207258
dbg(`connecting server with transport`);
208259
await server.connect(transport);
209-
210-
if (startup) {
211-
logVerbose(`startup script: ${startup}`);
212-
await run(startup, [], {
213-
vars: {},
214-
onMessage: async (data) => {
215-
if (data.type === RESOURCE_CHANGE) {
216-
await runtimeHost.resources.upsetResource(data.reference, data.content);
217-
} else {
218-
dbg(`unknown message type: ${data.type}`);
219-
}
220-
},
221-
});
222-
}
223260
}

packages/cli/src/watch.ts

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// Copyright (c) Microsoft Corporation.
22
// Licensed under the MIT License.
33

4-
import { FSWatcher, watch } from "chokidar";
4+
import { watch } from "chokidar";
5+
import type { FSWatcher } from "chokidar";
56
import { basename, resolve } from "node:path";
67
import {
78
CHANGE,
@@ -40,11 +41,11 @@ export class ProjectWatcher extends EventTarget {
4041
signal?.addEventListener("abort", this.close.bind(this));
4142
}
4243

43-
get cwd() {
44+
get cwd(): string {
4445
return this.options.cwd;
4546
}
4647

47-
async open() {
48+
async open(): Promise<void> {
4849
if (this._watcher) return;
4950

5051
dbg(`starting`);
@@ -77,7 +78,7 @@ export class ProjectWatcher extends EventTarget {
7778
depth: 30,
7879
cwd,
7980
});
80-
const changed = () => {
81+
const changed = (): void => {
8182
dbg(`changed`);
8283
this.dispatchEvent(new Event(CHANGE));
8384
};
@@ -90,27 +91,27 @@ export class ProjectWatcher extends EventTarget {
9091
this.dispatchEvent(new Event(OPEN));
9192
}
9293

93-
private async refresh() {
94+
private async refresh(): Promise<void> {
9495
this._project = undefined;
9596
}
9697

97-
async project() {
98+
async project(): Promise<Project> {
9899
if (!this._project) {
99100
dbg(`building project`);
100101
this._project = await buildProject();
101102
}
102103
return this._project;
103104
}
104105

105-
async scripts() {
106+
async scripts(): Promise<PromptScript[]> {
106107
if (!this._scripts) {
107108
const project = await this.project();
108109
this._scripts = filterScripts(project.scripts, this.options);
109110
}
110111
return this._scripts?.slice(0);
111112
}
112113

113-
async close() {
114+
async close(): Promise<void> {
114115
dbg(`closing`);
115116
await this._watcher?.close();
116117
this._watcher = undefined;
@@ -133,7 +134,7 @@ export async function startProjectWatcher(
133134
paths?: ElementOrArray<string>;
134135
cwd?: string;
135136
} & CancellationOptions,
136-
) {
137+
): Promise<ProjectWatcher> {
137138
const { paths = ".", cwd = resolve("."), ...rest } = options || {};
138139
const watcher = new ProjectWatcher({ paths, cwd, ...rest });
139140
await watcher.open();

packages/core/src/ast.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ export function collectFolders(
8585
* @returns The script with the matching ID, or undefined if no match is found.
8686
*/
8787
export function resolveScript(prj: Project, system: SystemPromptInstance) {
88-
return prj?.scripts?.find((t) => t.id == system.id); // Find and return the template with the matching ID
88+
return prj?.scripts?.find((t) => t.id === system.id); // Find and return the template with the matching ID
8989
}
9090

9191
export interface ScriptFilterOptions {

packages/core/src/constants.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ export const MODEL_PROVIDER_WINDOWS_AI = "windows";
205205
export const MODEL_PROVIDER_DOCKER_MODEL_RUNNER = "docker";
206206
export const MODEL_PROVIDER_ECHO = "echo";
207207
export const MODEL_PROVIDER_NONE = "none";
208+
export const MODEL_PROVIDER_MCP = "mcp";
208209

209210
export const MODEL_GITHUB_COPILOT_CHAT_CURRENT = MODEL_PROVIDER_GITHUB_COPILOT_CHAT + ":current";
210211

0 commit comments

Comments
 (0)