Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions graphile/graphile-llm/src/config-cache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,24 @@ export interface BillingConfig {
publicSchema: string;
}

/**
* Inference log table metadata resolved from the inference_log_module.
*/
export interface InferenceLogConfig {
/** Schema containing the usage_log_inference table */
schema: string;
/** Name of the inference log table */
tableName: string;
}

/**
* Per-database cached configuration for the LLM billing integration.
*/
export interface LlmBillingCacheEntry {
/** Billing function references (null if billing_module not provisioned) */
billing: BillingConfig | null;
/** Inference log table references (null if inference_log_module not provisioned) */
inferenceLog: InferenceLogConfig | null;
}

// ─── SQL Queries ────────────────────────────────────────────────────────────
Expand All @@ -71,6 +83,18 @@ const BILLING_MODULE_SQL = `
LIMIT 1
`;

/**
* Resolve the inference log module's schema and table name.
*/
const INFERENCE_LOG_MODULE_SQL = `
SELECT
s.schema_name AS schema,
ilm.inference_log_table_name AS table_name
FROM metaschema_modules_public.inference_log_module ilm
JOIN metaschema_public.schema s ON ilm.schema_id = s.id
WHERE ilm.database_id = $1
LIMIT 1
`;
// ─── Cache ──────────────────────────────────────────────────────────────────

const billingCache = new ModuleConfigCache<LlmBillingCacheEntry>({
Expand All @@ -89,6 +113,27 @@ const SCHEMA_EXISTS_SQL = `
SELECT 1 FROM information_schema.schemata WHERE schema_name = $1 LIMIT 1
`;

async function resolveInferenceLogConfig(
pgClient: PgClient,
databaseId: string,
): Promise<InferenceLogConfig | null> {
try {
const schemaCheck = await pgClient.query(SCHEMA_EXISTS_SQL, ['metaschema_modules_public']);
if (schemaCheck.rows.length === 0) return null;

const result = await pgClient.query(INFERENCE_LOG_MODULE_SQL, [databaseId]);
const row = result.rows[0];
if (!row?.schema || !row?.table_name) return null;

return {
schema: row.schema as string,
tableName: row.table_name as string,
};
} catch {
return null;
}
}

async function resolveBillingConfig(
pgClient: PgClient,
databaseId: string,
Expand Down Expand Up @@ -133,9 +178,12 @@ export async function getLlmBillingConfig(
const cached = billingCache.get(databaseId);
if (cached) return cached;

const billing = await resolveBillingConfig(pgClient, databaseId);
const [billing, inferenceLog] = await Promise.all([
resolveBillingConfig(pgClient, databaseId),
resolveInferenceLogConfig(pgClient, databaseId),
]);

const entry: LlmBillingCacheEntry = { billing };
const entry: LlmBillingCacheEntry = { billing, inferenceLog };
billingCache.set(databaseId, entry);
return entry;
}
Expand Down
6 changes: 3 additions & 3 deletions graphile/graphile-llm/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,16 @@ export {
} from './chat';

// Metering utilities (for custom integration)
export { meteredEmbed, meteredChat, QuotaExceededError } from './metering';
export type { MeteringContext, MeteringOptions, MeterResult, WithPgClient } from './metering';
export { meteredEmbed, meteredChat, logInferenceUsage, QuotaExceededError } from './metering';
export type { MeteringContext, MeteringOptions, MeterResult, WithPgClient, InferenceLogEntry } from './metering';

// Config cache (for custom integration)
export {
getLlmBillingConfig,
invalidateLlmBillingConfig,
getLlmBillingCacheStats,
} from './config-cache';
export type { BillingConfig, LlmBillingCacheEntry, PgClient } from './config-cache';
export type { BillingConfig, LlmBillingCacheEntry, InferenceLogConfig, PgClient } from './config-cache';

// Types
export type {
Expand Down
171 changes: 168 additions & 3 deletions graphile/graphile-llm/src/metering.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
* resolved from `billing_module` metaschema and cached by `config-cache.ts`.
*/

import type { PgClient, BillingConfig } from './config-cache';
import type { PgClient, BillingConfig, InferenceLogConfig } from './config-cache';
import type { EmbedderFunction, ChatFunction, ChatMessage, ChatOptions } from './types';

// ─── Types ──────────────────────────────────────────────────────────────────
Expand All @@ -43,6 +43,12 @@ export interface MeteringContext {
entityId: string;
/** Per-request correlation ID (from request.id pgSetting) */
requestId: string | null;
/** Database UUID from JWT claims */
databaseId: string;
/** Actor (user) ID from JWT claims */
actorId: string | null;
/** Inference log table config (null if inference_log_module not provisioned) */
inferenceLog: InferenceLogConfig | null;
}

export interface MeteringOptions {
Expand All @@ -52,6 +58,12 @@ export interface MeteringOptions {
chatMeterSlug?: string;
/** Whether to skip metering entirely (e.g. for local dev). Default: false */
skipMetering?: boolean;
/** Embedding model name (for inference log) */
embeddingModel?: string;
/** Chat model name (for inference log) */
chatModel?: string;
/** Provider name (for inference log) */
provider?: string;
}

export interface MeterResult<T> {
Expand Down Expand Up @@ -113,6 +125,73 @@ async function recordUsage(
}
}

// ─── Inference Usage Log ────────────────────────────────────────────────────

export interface InferenceLogEntry {
databaseId: string;
entityId: string;
actorId: string | null;
model: string;
provider: string | null;
requestType: 'embedding' | 'chat' | 'rag';
inputTokens: number;
outputTokens: number;
totalTokens: number;
latencyMs: number;
ragEnabled: boolean;
chunksRetrieved: number | null;
embeddingModel: string | null;
embeddingLatencyMs: number | null;
status: 'success' | 'quota_exceeded' | 'provider_error' | 'timeout';
errorType: string | null;
}

/**
* Write a row to the usage_log_inference table.
* Gracefully skips if the inference_log_module is not provisioned.
*
* TODO: Also write to child (generated) database when dual-write is needed.
*/
export async function logInferenceUsage(
ctx: MeteringContext,
entry: InferenceLogEntry,
): Promise<void> {
if (!ctx.inferenceLog) return;

const { schema, tableName } = ctx.inferenceLog;
const sql = `INSERT INTO "${schema}"."${tableName}" (
database_id, entity_id, actor_id,
model, provider, request_type,
input_tokens, output_tokens, total_tokens,
latency_ms, rag_enabled, chunks_retrieved,
embedding_model, embedding_latency_ms,
status, error_type
) VALUES (
$1, $2, $3,
$4, $5, $6,
$7, $8, $9,
$10, $11, $12,
$13, $14,
$15, $16
)`;

try {
await ctx.withPgClient(ctx.pgSettings, async (pgClient) => {
await pgClient.query(sql, [
entry.databaseId, entry.entityId, entry.actorId,
entry.model, entry.provider, entry.requestType,
entry.inputTokens, entry.outputTokens, entry.totalTokens,
entry.latencyMs, entry.ragEnabled, entry.chunksRetrieved,
entry.embeddingModel, entry.embeddingLatencyMs,
entry.status, entry.errorType,
]);
});
} catch (e: unknown) {
const message = e instanceof Error ? e.message : String(e);
console.warn(`[graphile-llm] inference log INSERT failed (non-fatal): ${message}`);
}
}

// ─── Metered Embedder ───────────────────────────────────────────────────────

/**
Expand Down Expand Up @@ -172,6 +251,27 @@ export async function meteredEmbed(
}

if (!allowed) {
// Placeholder: replace with actual provider token counts once generateWithUsage() is approved
const placeholderAmountTokens = Math.ceil(text.length / 4);
logInferenceUsage(ctx, {
databaseId: ctx.databaseId,
entityId: ctx.entityId,
actorId: ctx.actorId,
model: options.embeddingModel ?? meterSlug,
provider: options.provider ?? null,
requestType: 'embedding',
inputTokens: placeholderAmountTokens,
outputTokens: 0,
totalTokens: placeholderAmountTokens,
latencyMs: Date.now() - startTime,
ragEnabled: false,
chunksRetrieved: null,
embeddingModel: options.embeddingModel ?? null,
embeddingLatencyMs: null,
status: 'quota_exceeded',
errorType: null,
}).catch(() => {});

return {
result: null,
metered: true,
Expand All @@ -184,7 +284,8 @@ export async function meteredEmbed(
const result = await embedder(text);
const latencyMs = Date.now() - startTime;

// Record actual usage (input_chars as the metered amount)
// Placeholder: replace with actual provider token counts once generateWithUsage() is approved
const placeholderAmountTokens = Math.ceil(text.length / 4);
ctx.withPgClient(ctx.pgSettings, async (pgClient) => {
await recordUsage(pgClient, ctx.billing, ctx.entityId, meterSlug, text.length, {
request_id: ctx.requestId,
Expand All @@ -194,6 +295,26 @@ export async function meteredEmbed(
});
}).catch(() => {});

// Log to inference usage table
logInferenceUsage(ctx, {
databaseId: ctx.databaseId,
entityId: ctx.entityId,
actorId: ctx.actorId,
model: options.embeddingModel ?? meterSlug,
provider: options.provider ?? null,
requestType: 'embedding',
inputTokens: placeholderAmountTokens,
outputTokens: 0,
totalTokens: placeholderAmountTokens,
latencyMs,
ragEnabled: false,
chunksRetrieved: null,
embeddingModel: options.embeddingModel ?? null,
embeddingLatencyMs: latencyMs,
status: 'success',
errorType: null,
}).catch(() => {});

return {
result,
metered: true,
Expand Down Expand Up @@ -258,6 +379,27 @@ export async function meteredChat(
}

if (!allowed) {
// Placeholder: replace with actual provider token counts once generateWithUsage() is approved
const placeholderInputTokens = Math.ceil(messages.reduce((sum, m) => sum + m.content.length, 0) / 4);
logInferenceUsage(ctx, {
databaseId: ctx.databaseId,
entityId: ctx.entityId,
actorId: ctx.actorId,
model: meteringOptions.chatModel ?? meterSlug,
provider: meteringOptions.provider ?? null,
requestType: 'chat',
inputTokens: placeholderInputTokens,
outputTokens: 0,
totalTokens: placeholderInputTokens,
latencyMs: Date.now() - startTime,
ragEnabled: false,
chunksRetrieved: null,
embeddingModel: null,
embeddingLatencyMs: null,
status: 'quota_exceeded',
errorType: null,
}).catch(() => {});

return {
result: null,
metered: true,
Expand All @@ -270,8 +412,11 @@ export async function meteredChat(
const result = await chat(messages, chatOptions);
const latencyMs = Date.now() - startTime;

// Record actual usage (input + output chars as the metered amount)
// Placeholder: replace with actual provider token counts once generateWithUsage() is approved
const inputChars = messages.reduce((sum, m) => sum + m.content.length, 0);
const placeholderInputTokens = Math.ceil(inputChars / 4);
const placeholderOutputTokens = Math.ceil(result.length / 4);
const placeholderTotalTokens = placeholderInputTokens + placeholderOutputTokens;
ctx.withPgClient(ctx.pgSettings, async (pgClient) => {
await recordUsage(pgClient, ctx.billing, ctx.entityId, meterSlug, inputChars + result.length, {
request_id: ctx.requestId,
Expand All @@ -282,6 +427,26 @@ export async function meteredChat(
});
}).catch(() => {});

// Log to inference usage table
logInferenceUsage(ctx, {
databaseId: ctx.databaseId,
entityId: ctx.entityId,
actorId: ctx.actorId,
model: meteringOptions.chatModel ?? meterSlug,
provider: meteringOptions.provider ?? null,
requestType: 'chat',
inputTokens: placeholderInputTokens,
outputTokens: placeholderOutputTokens,
totalTokens: placeholderTotalTokens,
latencyMs,
ragEnabled: false,
chunksRetrieved: null,
embeddingModel: null,
embeddingLatencyMs: null,
status: 'success',
errorType: null,
}).catch(() => {});

return {
result,
metered: true,
Expand Down
8 changes: 8 additions & 0 deletions graphile/graphile-llm/src/plugins/metering-plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,19 @@ async function buildMeteringContext(
const entityId = resolveEntityId(pgSettings);
const databaseId = pgSettings['jwt.claims.database_id'] ?? null;
const requestId = pgSettings['request.id'] ?? null;
const actorId = pgSettings['jwt.claims.user_id'] ?? null;
if (!entityId || !databaseId) return null;

const withPgClient: WithPgClient | undefined = graphqlContext?.withPgClient;
if (!withPgClient) return null;

let billingConfig = null;
let inferenceLogConfig = null;
try {
await withPgClient(pgSettings, async (pgClient: PgClient) => {
const entry = await getLlmBillingConfig(pgClient, databaseId);
billingConfig = entry.billing;
inferenceLogConfig = entry.inferenceLog;
});
} catch {
return null;
Expand All @@ -90,6 +93,9 @@ async function buildMeteringContext(
billing: billingConfig,
entityId,
requestId,
databaseId,
actorId,
inferenceLog: inferenceLogConfig,
};
}

Expand Down Expand Up @@ -173,6 +179,8 @@ export function createLlmMeteringPlugin(
embeddingMeterSlug: embeddingSlug,
chatMeterSlug: chatSlug,
skipMetering,
embeddingModel: embeddingModel ?? undefined,
chatModel: chatModel ?? undefined,
};

// Replace the embedder with a metered version.
Expand Down
Loading
Loading