From 91e887e42bd81d9d5c6854a6ebe9e79173a28088 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 13 May 2026 01:32:34 +0000 Subject: [PATCH 01/39] Add OHTTP-style anonymous inference endpoint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements RFC 9458 Oblivious HTTP encapsulation so clients can submit chat completions through an independent relay without exposing their IP to the enclave or their prompt to the relay. The HPKE X25519 keypair is generated alongside the existing RSA signing key and bound to the same nitriding registration digest, so the Nitro attestation document commits to both. - tee_gateway/ohttp.py: HPKE wrap/unwrap helpers (DHKEM(X25519)/HKDF-SHA256/ ChaCha20-Poly1305). Response keying derived per-context per RFC 9458 §4.2. - tee_gateway/tee_manager.py: HPKE keypair, key-config blob, attestation document now includes the HPKE public key. - tee_gateway/controllers/ohttp_controller.py: /v1/ohttp dispatches the decrypted request to the existing chat handler, scrubs identifying fields before forwarding upstream, refuses stream=true. - /v1/ohttp/config exposes the HPKE key config for client discovery. - Test coverage: round-trip, wrong-suite, truncated input, tampered ciphertext. Known limitation: payment gating is not yet wired for this endpoint; a blind-token layer will follow in a separate change. https://claude.ai/code/session_01WyddtSz2rtiP61LtVJbsJy --- pyproject.toml | 1 + tee_gateway/__main__.py | 18 +++ tee_gateway/controllers/ohttp_controller.py | 141 +++++++++++++++++ tee_gateway/ohttp.py | 166 ++++++++++++++++++++ tee_gateway/tee_manager.py | 53 ++++++- tee_gateway/test/test_ohttp.py | 99 ++++++++++++ 6 files changed, 476 insertions(+), 2 deletions(-) create mode 100644 tee_gateway/controllers/ohttp_controller.py create mode 100644 tee_gateway/ohttp.py create mode 100644 tee_gateway/test/test_ohttp.py diff --git a/pyproject.toml b/pyproject.toml index 80a9bc6..426ab33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "psutil>=7.2.1", "python-dotenv>=1.2.1", "eth-account>=0.13.0", + "pyhpke>=0.6.0", ] [dependency-groups] diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index 165d5d1..ad087af 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -23,6 +23,10 @@ ) from tee_gateway.llm_backend import get_provider_config, set_provider_config from tee_gateway.heartbeat import create_heartbeat_service +from tee_gateway.controllers.ohttp_controller import ( + create_anonymous_chat_completion, + get_hpke_config, +) from x402.http import FacilitatorConfig, HTTPFacilitatorClientSync, PaymentOption from x402.http.middleware.flask import payment_middleware @@ -437,6 +441,20 @@ def create_app(): "/heartbeat/status", "heartbeat-status", heartbeat_status, methods=["GET"] ) + # Anonymous inference (OHTTP-wrapped chat completions). Deliberately + # mounted via add_url_rule rather than the OpenAPI spec because the body + # is raw binary and connexion's request-validation pipeline would reject + # it as malformed JSON. + app.app.add_url_rule( + "/v1/ohttp", + "anonymous-chat", + create_anonymous_chat_completion, + methods=["POST"], + ) + app.app.add_url_rule( + "/v1/ohttp/config", "ohttp-config", get_hpke_config, methods=["GET"] + ) + # Initialize TEE here so it runs under both Gunicorn and direct execution. # This is the single TEEKeyManager instance — the same key both registers # with nitriding and signs all LLM responses. diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py new file mode 100644 index 0000000..688e862 --- /dev/null +++ b/tee_gateway/controllers/ohttp_controller.py @@ -0,0 +1,141 @@ +""" +Oblivious HTTP endpoint for anonymous inference. + +This handler is intentionally minimal: it does HPKE decapsulation, dispatches +the inner request to the existing chat-completions handler in-process (no +network hop), and HPKE-encapsulates the response. The inner JSON request is +identical to the standard /v1/chat/completions body. + +Threat model nuances: + * The relay in front of this endpoint sees the encapsulated ciphertext and + the client IP, but no request content. + * The enclave sees plaintext and the relay's IP, never the client's. + * If the client's payload contains identifiers (cookies, ``user`` field, + custom request IDs), unlinkability is broken at the application layer — + we strip the obvious ones below. + * Streaming is intentionally not supported on this endpoint. SSE would + create per-chunk side channels (timing, length) that defeat the point of + bundling everything into a single sealed response. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +import connexion +from flask import Response + +from tee_gateway import ohttp +from tee_gateway.tee_manager import get_tee_keys + +logger = logging.getLogger(__name__) + +OHTTP_MEDIA_TYPE = "message/ohttp-req" +OHTTP_RESPONSE_MEDIA_TYPE = "message/ohttp-res" + +# Fields that can re-identify a client and have no role in inference. We drop +# them before forwarding to the inner handler — keeping them inside the +# encrypted envelope would only protect them from the relay, not from us or +# the upstream LLM provider. +_IDENTIFYING_FIELDS = ("user", "metadata", "x-request-id", "request_id") + + +def _scrub(payload: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in payload.items() if k not in _IDENTIFYING_FIELDS} + + +def create_anonymous_chat_completion(): + """POST /v1/ohttp — decrypt, dispatch, re-encrypt. + + Body: raw bytes (OHTTP-encapsulated request). + Returns: raw bytes (OHTTP-encapsulated response) with Content-Type + ``message/ohttp-res``. + """ + req = connexion.request + # Tolerate both Connexion's Flask request and a bare Flask request. + raw_body: bytes = req.get_data(cache=False) + if not raw_body: + return _error(400, "empty body") + + tee = get_tee_keys() + if tee.hpke_private_key is None: + return _error(503, "anonymous inference not initialized") + + try: + decap = ohttp.decapsulate_request(tee.hpke_private_key, raw_body) + except Exception as exc: + # Don't leak which step failed — clients can retry with a fresh + # encapsulation, all observable failures look identical. + logger.warning("OHTTP decapsulation failed: %s", type(exc).__name__) + return _error(400, "malformed encapsulated request") + + try: + inner_body = json.loads(decap.plaintext.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError): + return _error(400, "inner payload is not valid JSON") + + if not isinstance(inner_body, dict): + return _error(400, "inner payload must be a JSON object") + + if inner_body.get("stream"): + # Streaming is rejected on principle (see module docstring). Clients + # who want low TTFT under anonymity should use a shorter max_tokens. + return _error(400, "stream=true is not supported over OHTTP") + + inner_body = _scrub(inner_body) + + # Late import to avoid a circular dependency at module load (the chat + # controller pulls in models that import this package). + from tee_gateway.controllers.chat_controller import ( + _create_non_streaming_response, + _parse_chat_request, + ) + + try: + chat_request = _parse_chat_request(inner_body) + inner_result = _create_non_streaming_response(chat_request) + except Exception as exc: + logger.error("inner inference failed under OHTTP: %s", exc, exc_info=True) + inner_result = ({"error": "inference failed"}, 500) + + # _create_non_streaming_response returns either a dict or (body, status) + if isinstance(inner_result, tuple): + body_obj, status = inner_result + else: + body_obj, status = inner_result, 200 + + inner_json = json.dumps( + {"status": status, "body": body_obj}, + separators=(",", ":"), + ).encode("utf-8") + + sealed = ohttp.encapsulate_response( + decap.response_key, decap.enc, inner_json + ) + return Response( + sealed, + status=200, + mimetype=OHTTP_RESPONSE_MEDIA_TYPE, + ) + + +def get_hpke_config(): + """GET /v1/ohttp/config — return the HPKE key configuration. + + Returns both an OHTTP-compliant binary key_config (base64) and the + individual fields for clients that prefer to parse JSON. The same data is + embedded inside the attestation document at /signing-key for clients that + want to verify the binding to the enclave's PCRs in one step. + """ + try: + tee = get_tee_keys() + return tee.get_hpke_config(), 200 + except Exception as exc: + logger.error("HPKE config error: %s", exc, exc_info=True) + return {"error": str(exc)}, 500 + + +def _error(status: int, message: str) -> tuple[dict, int]: + return {"error": message}, status diff --git a/tee_gateway/ohttp.py b/tee_gateway/ohttp.py new file mode 100644 index 0000000..d5369aa --- /dev/null +++ b/tee_gateway/ohttp.py @@ -0,0 +1,166 @@ +""" +Oblivious HTTP encapsulation for anonymous inference. + +Implements request/response encapsulation per RFC 9458 (Oblivious HTTP) +with a fixed HPKE ciphersuite: + - KEM: DHKEM(X25519, HKDF-SHA256) (0x0020) + - KDF: HKDF-SHA256 (0x0001) + - AEAD: ChaCha20-Poly1305 (0x0003) + +The inner payload is application/json — we do not BHTTP-wrap the inference +request, since the enclave is the terminal endpoint and not a generic HTTP +proxy. This is a documented divergence from strict RFC 9458; the cryptographic +construction (HPKE base + exported response keying) is identical. + +Trust model: the relay sees ciphertext + client IP; the enclave sees plaintext ++ relay IP. Unlinkability holds unless relay and enclave collude. +""" + +from __future__ import annotations + +import os +import struct +from dataclasses import dataclass + +from cryptography.hazmat.primitives import hashes, hmac +from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 +from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand +from pyhpke import AEADId, CipherSuite, KDFId, KEMId +from pyhpke.kem_key_interface import KEMKeyInterface + + +# RFC 9180 / 9458 algorithm identifiers +KEM_ID_X25519 = 0x0020 +KDF_ID_HKDF_SHA256 = 0x0001 +AEAD_ID_CHACHA20_POLY1305 = 0x0003 + +# Single, stable key configuration ID. Bump when the keypair or suite changes +# so clients can refuse stale configs. +KEY_CONFIG_ID = 0x01 + +# AEAD parameters for ChaCha20-Poly1305 +_NK = 32 # key length +_NN = 12 # nonce length + +# Per RFC 9458 §4.1/4.2 — "info" labels for the HPKE context. +_LABEL_REQUEST = b"message/bhttp request" +_LABEL_RESPONSE = b"message/bhttp response" + +_SUITE = CipherSuite.new( + KEMId.DHKEM_X25519_HKDF_SHA256, + KDFId.HKDF_SHA256, + AEADId.CHACHA20_POLY1305, +) + + +def _header_bytes(key_id: int = KEY_CONFIG_ID) -> bytes: + return bytes([key_id]) + struct.pack( + ">HHH", + KEM_ID_X25519, + KDF_ID_HKDF_SHA256, + AEAD_ID_CHACHA20_POLY1305, + ) + + +def key_config(public_key_raw: bytes, key_id: int = KEY_CONFIG_ID) -> bytes: + """Build an OHTTP key configuration blob (RFC 9458 §3). + + Format: + key_id(1) || kem_id(2) || public_key(Npk=32) || + symmetric_algorithms_length(2) || (kdf_id(2) || aead_id(2))+ + """ + if len(public_key_raw) != 32: + raise ValueError("X25519 public key must be 32 bytes") + symmetric = struct.pack(">HH", KDF_ID_HKDF_SHA256, AEAD_ID_CHACHA20_POLY1305) + return ( + bytes([key_id]) + + struct.pack(">H", KEM_ID_X25519) + + public_key_raw + + struct.pack(">H", len(symmetric)) + + symmetric + ) + + +@dataclass +class DecapsulatedRequest: + """Result of decapsulating an OHTTP-wrapped request.""" + + plaintext: bytes + response_key: bytes # 32 bytes exported from the HPKE context + enc: bytes # client's ephemeral public key, used as salt for the response + + +def decapsulate_request( + private_key: KEMKeyInterface, encapsulated_request: bytes +) -> DecapsulatedRequest: + """Decrypt an HPKE-wrapped request inside the enclave. + + Raises ValueError on malformed input or unsupported ciphersuite. We + never echo the underlying exception text to clients — it can leak + timing/oracle info. + """ + if len(encapsulated_request) < 7 + 32: + raise ValueError("encapsulated request too short") + + key_id = encapsulated_request[0] + kem_id, kdf_id, aead_id = struct.unpack(">HHH", encapsulated_request[1:7]) + if (key_id, kem_id, kdf_id, aead_id) != ( + KEY_CONFIG_ID, + KEM_ID_X25519, + KDF_ID_HKDF_SHA256, + AEAD_ID_CHACHA20_POLY1305, + ): + raise ValueError("unsupported HPKE configuration") + + enc = encapsulated_request[7 : 7 + 32] + aead_ct = encapsulated_request[7 + 32 :] + + info = _LABEL_REQUEST + b"\x00" + _header_bytes(key_id) + recipient = _SUITE.create_recipient_context(enc, private_key, info=info) + plaintext = recipient.open(aead_ct, aad=b"") + + # Export a fresh secret bound to this HPKE context, used to derive the + # response AEAD key. This is the OHTTP-defined separation between the + # request and response halves of the same exchange. + response_secret = recipient.export(_LABEL_RESPONSE, _NK) + return DecapsulatedRequest( + plaintext=plaintext, response_key=response_secret, enc=enc + ) + + +def encapsulate_response( + response_secret: bytes, enc: bytes, plaintext: bytes +) -> bytes: + """Seal a response under the per-request derived key (RFC 9458 §4.2). + + Wire format: response_nonce(max(Nn, Nk)=Nk=32) || AEAD ciphertext + """ + response_nonce = os.urandom(max(_NN, _NK)) + salt = enc + response_nonce + + h = hmac.HMAC(salt, hashes.SHA256()) + h.update(response_secret) + prk = h.finalize() + + aead_key = HKDFExpand( + algorithm=hashes.SHA256(), length=_NK, info=b"key" + ).derive(prk) + aead_nonce = HKDFExpand( + algorithm=hashes.SHA256(), length=_NN, info=b"nonce" + ).derive(prk) + + ct = ChaCha20Poly1305(aead_key).encrypt(aead_nonce, plaintext, b"") + return response_nonce + ct + + +def generate_keypair() -> tuple[KEMKeyInterface, bytes]: + """Generate an X25519 keypair for HPKE. Returns (private_key, public_key_raw). + + pyhpke 0.6 derives keys from random IKM via ``kem.derive_key_pair(ikm)``, + which returns a ``KEMKeyPair`` wrapper. We hold onto the private side for + decapsulation and serialize the public side to raw 32-byte form for the + key configuration blob. + """ + pair = _SUITE.kem.derive_key_pair(os.urandom(32)) + pk_raw = pair.public_key.to_public_bytes() + return pair.private_key, pk_raw diff --git a/tee_gateway/tee_manager.py b/tee_gateway/tee_manager.py index 3d98b12..402438e 100644 --- a/tee_gateway/tee_manager.py +++ b/tee_gateway/tee_manager.py @@ -17,6 +17,8 @@ from eth_account import Account from eth_hash.auto import keccak +from tee_gateway import ohttp + logger = logging.getLogger(__name__) NITRIDING_BASE_URL = "http://127.0.0.1:8080" @@ -36,6 +38,11 @@ def __init__(self, register=True): self.public_key_pem = None self.tee_id = None self.wallet_address = None + # HPKE keypair for OHTTP-style anonymous inference. Generated in the + # same enclave boot so the X25519 public key is covered by the same + # attestation that covers the RSA signing key. + self.hpke_private_key = None + self.hpke_public_key_raw: bytes | None = None self._generate_keys() if register: self.register_with_nitriding() @@ -70,19 +77,39 @@ def _generate_keys(self): wallet_account = Account.from_key(wallet_key_bytes) self.wallet_address = wallet_account.address + # HPKE X25519 keypair — never leaves the enclave; clients address it + # via the public-key fingerprint published with the attestation. + self.hpke_private_key, self.hpke_public_key_raw = ohttp.generate_keypair() + logger.info("TEE key pair generated successfully") logger.info(f"tee_id: 0x{self.tee_id}") logger.info(f"wallet_address: {self.wallet_address}") + logger.info( + f"hpke_public_key (X25519, raw, hex): {self.hpke_public_key_raw.hex()}" + ) def register_with_nitriding(self): - """Register public key hash with nitriding.""" + """Register public key hash with nitriding. + + The hash covers both the RSA signing key (DER-encoded SPKI) and the + raw X25519 HPKE public key. Including both in a single attested digest + means a verifier who validates the attestation document automatically + gets binding for the HPKE config used for anonymous inference — no + separate trust anchor required. + """ try: public_key_der = self.public_key.public_bytes( encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo, ) - key_hash = hashlib.sha256(public_key_der).digest() + # Domain-separated transcript so a future addition of more keys + # can't be confused with the existing layout. + transcript = ( + b"og-tee-keys|v2|rsa-spki=" + public_key_der + + b"|hpke-x25519=" + (self.hpke_public_key_raw or b"") + ) + key_hash = hashlib.sha256(transcript).digest() key_hash_b64 = base64.b64encode(key_hash).decode("utf-8") logger.info(f"Public key DER length: {len(public_key_der)} bytes") @@ -149,12 +176,34 @@ def get_wallet_address(self) -> str: """Return the TEE-generated Ethereum wallet address (checksum).""" return self.wallet_address + def get_hpke_config(self) -> dict: + """Return the HPKE key configuration for anonymous inference. + + ``key_config`` is the RFC 9458 §3 binary key-config blob, base64-encoded. + Clients should treat this as authoritative only when fetched alongside + the Nitro attestation document (which commits to the same key hash via + nitriding registration). + """ + if self.hpke_public_key_raw is None: + raise RuntimeError("HPKE keypair not initialized") + return { + "key_id": ohttp.KEY_CONFIG_ID, + "kem_id": ohttp.KEM_ID_X25519, + "kdf_id": ohttp.KDF_ID_HKDF_SHA256, + "aead_id": ohttp.AEAD_ID_CHACHA20_POLY1305, + "public_key": self.hpke_public_key_raw.hex(), + "key_config": base64.b64encode( + ohttp.key_config(self.hpke_public_key_raw) + ).decode("ascii"), + } + def get_attestation_document(self) -> dict: """Return TEE attestation document.""" return { "public_key": self.public_key_pem, "tee_id": f"0x{self.tee_id}", "wallet_address": self.wallet_address, + "hpke": self.get_hpke_config() if self.hpke_public_key_raw else None, "timestamp": datetime.now(UTC).isoformat(), "enclave_info": { "platform": "aws-nitro", diff --git a/tee_gateway/test/test_ohttp.py b/tee_gateway/test/test_ohttp.py new file mode 100644 index 0000000..44342ce --- /dev/null +++ b/tee_gateway/test/test_ohttp.py @@ -0,0 +1,99 @@ +"""Tests for the OHTTP encapsulation module.""" + +from __future__ import annotations + +import json + +import pytest + +from tee_gateway import ohttp + + +def test_round_trip_request_and_response(): + sk, pk_raw = ohttp.generate_keypair() + + plaintext = json.dumps({"model": "gpt-4.1", "n": 1}).encode() + # Encapsulate using the same code paths a client would, since pyhpke is + # symmetric — we wire the request manually. + config = ohttp.key_config(pk_raw) + assert config[0] == ohttp.KEY_CONFIG_ID + + # Build a wire payload exactly as the SDK does. + import struct + hdr = ( + bytes([ohttp.KEY_CONFIG_ID]) + + struct.pack( + ">HHH", + ohttp.KEM_ID_X25519, + ohttp.KDF_ID_HKDF_SHA256, + ohttp.AEAD_ID_CHACHA20_POLY1305, + ) + ) + info = b"message/bhttp request" + b"\x00" + hdr + pkr = ohttp._SUITE.kem.deserialize_public_key(pk_raw) + enc, sender = ohttp._SUITE.create_sender_context(pkr, info=info) + ct = sender.seal(plaintext, aad=b"") + wire = hdr + enc + ct + + decap = ohttp.decapsulate_request(sk, wire) + assert decap.plaintext == plaintext + + response_secret = sender.export(b"message/bhttp response", 32) + assert decap.response_key == response_secret + assert decap.enc == enc + + response = b'{"ok":true}' + sealed = ohttp.encapsulate_response(decap.response_key, decap.enc, response) + + # Round-trip the response on the "client" side using the same primitives. + import os + from cryptography.hazmat.primitives import hashes, hmac + from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 + from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand + + response_nonce = sealed[:32] + aead_ct = sealed[32:] + salt = enc + response_nonce + h = hmac.HMAC(salt, hashes.SHA256()) + h.update(response_secret) + prk = h.finalize() + key = HKDFExpand(algorithm=hashes.SHA256(), length=32, info=b"key").derive(prk) + nonce = HKDFExpand(algorithm=hashes.SHA256(), length=12, info=b"nonce").derive(prk) + assert ChaCha20Poly1305(key).decrypt(nonce, aead_ct, b"") == response + + +def test_rejects_wrong_suite(): + sk, pk_raw = ohttp.generate_keypair() + # Build a wire with the wrong AEAD ID + import struct + hdr = bytes([ohttp.KEY_CONFIG_ID]) + struct.pack( + ">HHH", ohttp.KEM_ID_X25519, ohttp.KDF_ID_HKDF_SHA256, 0x0001 # AES-128-GCM + ) + fake_wire = hdr + b"\x00" * 32 + b"\x00" * 16 + with pytest.raises(ValueError, match="unsupported"): + ohttp.decapsulate_request(sk, fake_wire) + + +def test_rejects_short_input(): + sk, _ = ohttp.generate_keypair() + with pytest.raises(ValueError, match="too short"): + ohttp.decapsulate_request(sk, b"\x01") + + +def test_rejects_tampered_ciphertext(): + sk, pk_raw = ohttp.generate_keypair() + import struct + hdr = bytes([ohttp.KEY_CONFIG_ID]) + struct.pack( + ">HHH", + ohttp.KEM_ID_X25519, + ohttp.KDF_ID_HKDF_SHA256, + ohttp.AEAD_ID_CHACHA20_POLY1305, + ) + info = b"message/bhttp request" + b"\x00" + hdr + pkr = ohttp._SUITE.kem.deserialize_public_key(pk_raw) + enc, sender = ohttp._SUITE.create_sender_context(pkr, info=info) + ct = sender.seal(b"hello", aad=b"") + wire = bytearray(hdr + enc + ct) + wire[-1] ^= 0xFF + with pytest.raises(Exception): + ohttp.decapsulate_request(sk, bytes(wire)) From 5dcbdc8d1643d262da689c16b42bfb3597eea3f8 Mon Sep 17 00:00:00 2001 From: kukac Date: Fri, 15 May 2026 17:59:57 -0400 Subject: [PATCH 02/39] Update test_ohttp.py --- tee_gateway/test/test_ohttp.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tee_gateway/test/test_ohttp.py b/tee_gateway/test/test_ohttp.py index 44342ce..db4fe62 100644 --- a/tee_gateway/test/test_ohttp.py +++ b/tee_gateway/test/test_ohttp.py @@ -45,8 +45,6 @@ def test_round_trip_request_and_response(): response = b'{"ok":true}' sealed = ohttp.encapsulate_response(decap.response_key, decap.enc, response) - # Round-trip the response on the "client" side using the same primitives. - import os from cryptography.hazmat.primitives import hashes, hmac from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand From a8c5c89ee8059f01573a2f834683da6cd86e2729 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Fri, 15 May 2026 18:01:49 -0400 Subject: [PATCH 03/39] lint --- tee_gateway/controllers/ohttp_controller.py | 4 +--- tee_gateway/ohttp.py | 10 ++++------ tee_gateway/tee_manager.py | 6 ++++-- tee_gateway/test/test_ohttp.py | 21 ++++++++++++--------- uv.lock | 14 ++++++++++++++ 5 files changed, 35 insertions(+), 20 deletions(-) diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py index 688e862..9afa914 100644 --- a/tee_gateway/controllers/ohttp_controller.py +++ b/tee_gateway/controllers/ohttp_controller.py @@ -111,9 +111,7 @@ def create_anonymous_chat_completion(): separators=(",", ":"), ).encode("utf-8") - sealed = ohttp.encapsulate_response( - decap.response_key, decap.enc, inner_json - ) + sealed = ohttp.encapsulate_response(decap.response_key, decap.enc, inner_json) return Response( sealed, status=200, diff --git a/tee_gateway/ohttp.py b/tee_gateway/ohttp.py index d5369aa..ba5c7c4 100644 --- a/tee_gateway/ohttp.py +++ b/tee_gateway/ohttp.py @@ -128,9 +128,7 @@ def decapsulate_request( ) -def encapsulate_response( - response_secret: bytes, enc: bytes, plaintext: bytes -) -> bytes: +def encapsulate_response(response_secret: bytes, enc: bytes, plaintext: bytes) -> bytes: """Seal a response under the per-request derived key (RFC 9458 §4.2). Wire format: response_nonce(max(Nn, Nk)=Nk=32) || AEAD ciphertext @@ -142,9 +140,9 @@ def encapsulate_response( h.update(response_secret) prk = h.finalize() - aead_key = HKDFExpand( - algorithm=hashes.SHA256(), length=_NK, info=b"key" - ).derive(prk) + aead_key = HKDFExpand(algorithm=hashes.SHA256(), length=_NK, info=b"key").derive( + prk + ) aead_nonce = HKDFExpand( algorithm=hashes.SHA256(), length=_NN, info=b"nonce" ).derive(prk) diff --git a/tee_gateway/tee_manager.py b/tee_gateway/tee_manager.py index 402438e..193e6b4 100644 --- a/tee_gateway/tee_manager.py +++ b/tee_gateway/tee_manager.py @@ -106,8 +106,10 @@ def register_with_nitriding(self): # Domain-separated transcript so a future addition of more keys # can't be confused with the existing layout. transcript = ( - b"og-tee-keys|v2|rsa-spki=" + public_key_der - + b"|hpke-x25519=" + (self.hpke_public_key_raw or b"") + b"og-tee-keys|v2|rsa-spki=" + + public_key_der + + b"|hpke-x25519=" + + (self.hpke_public_key_raw or b"") ) key_hash = hashlib.sha256(transcript).digest() key_hash_b64 = base64.b64encode(key_hash).decode("utf-8") diff --git a/tee_gateway/test/test_ohttp.py b/tee_gateway/test/test_ohttp.py index db4fe62..acb42ee 100644 --- a/tee_gateway/test/test_ohttp.py +++ b/tee_gateway/test/test_ohttp.py @@ -20,14 +20,12 @@ def test_round_trip_request_and_response(): # Build a wire payload exactly as the SDK does. import struct - hdr = ( - bytes([ohttp.KEY_CONFIG_ID]) - + struct.pack( - ">HHH", - ohttp.KEM_ID_X25519, - ohttp.KDF_ID_HKDF_SHA256, - ohttp.AEAD_ID_CHACHA20_POLY1305, - ) + + hdr = bytes([ohttp.KEY_CONFIG_ID]) + struct.pack( + ">HHH", + ohttp.KEM_ID_X25519, + ohttp.KDF_ID_HKDF_SHA256, + ohttp.AEAD_ID_CHACHA20_POLY1305, ) info = b"message/bhttp request" + b"\x00" + hdr pkr = ohttp._SUITE.kem.deserialize_public_key(pk_raw) @@ -64,8 +62,12 @@ def test_rejects_wrong_suite(): sk, pk_raw = ohttp.generate_keypair() # Build a wire with the wrong AEAD ID import struct + hdr = bytes([ohttp.KEY_CONFIG_ID]) + struct.pack( - ">HHH", ohttp.KEM_ID_X25519, ohttp.KDF_ID_HKDF_SHA256, 0x0001 # AES-128-GCM + ">HHH", + ohttp.KEM_ID_X25519, + ohttp.KDF_ID_HKDF_SHA256, + 0x0001, # AES-128-GCM ) fake_wire = hdr + b"\x00" * 32 + b"\x00" * 16 with pytest.raises(ValueError, match="unsupported"): @@ -81,6 +83,7 @@ def test_rejects_short_input(): def test_rejects_tampered_ciphertext(): sk, pk_raw = ohttp.generate_keypair() import struct + hdr = bytes([ohttp.KEY_CONFIG_ID]) + struct.pack( ">HHH", ohttp.KEM_ID_X25519, diff --git a/uv.lock b/uv.lock index 4d8b6f1..d51f8a1 100644 --- a/uv.lock +++ b/uv.lock @@ -1540,6 +1540,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyhpke" +version = "0.6.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/27/37/1acb2cee5afd3dcf45b425b0d984a9cba8917fd935106ef278b42062ecfa/pyhpke-0.6.4.tar.gz", hash = "sha256:1402c6c41a0605941d2d2a589774d346c0e7a0dc7f745e84c6f0a06c2fd335c9", size = 1638147, upload-time = "2025-12-21T10:38:07.556Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/f6/ff7df9e21b38ec1c827efd90c28b3bc76eddbfdf5a44aaf2fadb59a17cb9/pyhpke-0.6.4-py3-none-any.whl", hash = "sha256:abd0b2fec1424858399ffbed0d236fb7e9740dece9907f59ca40bd567d7fef78", size = 23792, upload-time = "2025-12-21T10:38:06.172Z" }, +] + [[package]] name = "pyparsing" version = "3.3.2" @@ -1849,6 +1861,7 @@ dependencies = [ { name = "openai" }, { name = "psutil" }, { name = "pydantic" }, + { name = "pyhpke" }, { name = "python-dateutil" }, { name = "python-dotenv" }, { name = "requests" }, @@ -1895,6 +1908,7 @@ requires-dist = [ { name = "openai", specifier = ">=2.15.0" }, { name = "psutil", specifier = ">=7.2.1" }, { name = "pydantic", specifier = ">=2.12.5" }, + { name = "pyhpke", specifier = ">=0.6.0" }, { name = "python-dateutil", specifier = ">=2.6.0" }, { name = "python-dotenv", specifier = ">=1.2.1" }, { name = "requests", specifier = ">=2.32.5" }, From 6425c31b86e03719e50fff5eb0572f75e3fdf471 Mon Sep 17 00:00:00 2001 From: kukac Date: Fri, 15 May 2026 18:36:50 -0400 Subject: [PATCH 04/39] Add OHTTP anonymous chat completions with x402 payment integration (#71) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * OHTTP: derive HPKE from TEE RSA key + gate /v1/ohttp behind x402 * Replace the random os.urandom() seed for the HPKE keypair with an HKDF derivation from the RSA TEE private key (PKCS8 DER) salted with the RSA public DER. The HPKE keypair is now a deterministic function of the attested RSA key — anything that attests the RSA signing key implicitly covers the X25519 OHTTP key, with no separate randomness source to attest. Domain-separated info "og-tee-hpke-x25519-v1" pins the derivation to this use. * ohttp.generate_keypair() -> ohttp.derive_keypair(seed), with explicit >=32-byte seed validation. Tests cover deterministic output for the same seed and rejection of short seeds. * Add /v1/ohttp to the x402 payment middleware routes with the same CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND cap and upto scheme used by /v1/chat/completions. Anonymous inference is now metered identically to the public chat endpoint. * Bridge the encrypted request/response back to the token-based cost calculator via a thread-local set in the OHTTP controller. The calculator detects path=/v1/ohttp and uses the stashed plaintext inner request/response instead of the (unparseable) ciphertext bytes the middleware would otherwise see. * Fix the response-export length to max(Nn, Nk) per RFC 9458 §4.5; the prior _NK was equal here for ChaCha20-Poly1305 but would silently break under a different AEAD. * Refactor /v1/ohttp as a thin WSGI wrapper around /v1/chat/completions Replace the parallel routing/pricing logic with an in-process WSGI sub- request: the OHTTP handler decrypts, dispatches the inner request as a POST /v1/chat/completions through the app's own wsgi_app, captures the status/headers/body, then encrypts and returns. Everything that already existed for the public chat endpoint — x402 payment verification, the pre-inference pricing gate, LangChain routing, post-inference cost settlement, TEE response signing — runs unchanged for OHTTP requests. * /v1/ohttp is no longer in the x402 RouteConfig table. Gating happens naturally when the sub-request hits /v1/chat/completions; the payment header travels inside the sealed envelope as `x-payment` so the relay never sees it. * The thread-local side channel and the OHTTP-specific branch in _session_cost_calculator are removed — there is now only one cost calculator path for the whole gateway. * Inner request envelope: `{"x-payment": "...", "body": {...}}`. Inner response envelope: `{"status": int, "headers": {...}, "body": ...}`, forwarding only x402/TEE settlement headers back to the client. * Pre-decap errors stay plaintext; post-decap errors are sealed so the relay can't distinguish failure modes by response shape. * Revert HPKE key derivation; keep random HPKE keypair independent of RSA Reverts deriving the OHTTP X25519 keypair from the RSA TEE private key. The HPKE keypair is now freshly random per enclave boot (os.urandom(32) fed to pyhpke's DeriveKeyPair). The attestation binding still works because nitriding's transcript covers both public keys, but the two private keys no longer share a derivation surface: a compromise of one cannot be used to recover the other. * ohttp.generate_keypair() restored; ohttp.derive_keypair() removed. * tee_manager.TEEKeyManager no longer pulls HKDF; HPKE keypair is generated independently right after the RSA keypair. * Test for deterministic derivation replaced with an independence test that asserts two generate_keypair() calls return different pubkeys. --------- Co-authored-by: Claude --- tee_gateway/controllers/ohttp_controller.py | 228 +++++++++++++++----- tee_gateway/ohttp.py | 26 ++- tee_gateway/tee_manager.py | 7 +- tee_gateway/test/test_ohttp.py | 8 + 4 files changed, 199 insertions(+), 70 deletions(-) diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py index 9afa914..73c507e 100644 --- a/tee_gateway/controllers/ohttp_controller.py +++ b/tee_gateway/controllers/ohttp_controller.py @@ -1,31 +1,49 @@ """ Oblivious HTTP endpoint for anonymous inference. -This handler is intentionally minimal: it does HPKE decapsulation, dispatches -the inner request to the existing chat-completions handler in-process (no -network hop), and HPKE-encapsulates the response. The inner JSON request is -identical to the standard /v1/chat/completions body. +This handler is intentionally a thin shell: it does HPKE decapsulation and +then re-issues the inner request as a real WSGI sub-request against the +enclave's own ``/v1/chat/completions``. That means x402 payment handling, +the pre-inference pricing gate, LangChain routing, the post-inference cost +calculator and TEE response signing all execute via the same code paths as +the public chat endpoint — no duplicate routing tables, no thread-local +side channels, no parallel pricing logic. ``/v1/ohttp`` itself is NOT +gated by x402: payment travels inside the sealed envelope as an +``x-payment`` header on the inner request, and the gating happens +naturally when the sub-request hits the chat endpoint. + +Wire format of the (HPKE-decrypted) inner payload — a JSON object: + { + "x-payment": "", + "body": { ... standard /v1/chat/completions JSON body ... } + } + +Wire format of the (pre-HPKE) inner response: + { + "status": , + "headers": { "x-payment-response": "...", "x-upto-session": "..." }, + "body": + } Threat model nuances: - * The relay in front of this endpoint sees the encapsulated ciphertext and - the client IP, but no request content. + * The relay in front of this endpoint sees the encapsulated ciphertext + and the client IP, but no request content or payment header. * The enclave sees plaintext and the relay's IP, never the client's. - * If the client's payload contains identifiers (cookies, ``user`` field, - custom request IDs), unlinkability is broken at the application layer — - we strip the obvious ones below. - * Streaming is intentionally not supported on this endpoint. SSE would - create per-chunk side channels (timing, length) that defeat the point of - bundling everything into a single sealed response. + * If the inner JSON body contains identifiers (``user``, cookies, + custom request IDs), unlinkability is broken at the application + layer — we strip the obvious ones below. + * Streaming is intentionally rejected; the inner sub-request must + return a single sealed response. """ from __future__ import annotations +import io import json import logging from typing import Any -import connexion -from flask import Response +from flask import Response, current_app, request as flask_request from tee_gateway import ohttp from tee_gateway.tee_manager import get_tee_keys @@ -41,21 +59,27 @@ # the upstream LLM provider. _IDENTIFYING_FIELDS = ("user", "metadata", "x-request-id", "request_id") +# Response headers we propagate from the inner /v1/chat/completions response +# back to the client (encrypted). Includes x402 settlement metadata and +# anything the standard chat endpoint exposes via the TEE-signed response. +_FORWARDED_HEADER_PREFIXES = ("x-payment", "x-upto", "x-settlement", "x-tee") +_FORWARDED_HEADER_NAMES = ("www-authenticate",) + def _scrub(payload: dict[str, Any]) -> dict[str, Any]: return {k: v for k, v in payload.items() if k not in _IDENTIFYING_FIELDS} -def create_anonymous_chat_completion(): - """POST /v1/ohttp — decrypt, dispatch, re-encrypt. +def _should_forward_header(name: str) -> bool: + lower = name.lower() + return lower in _FORWARDED_HEADER_NAMES or any( + lower.startswith(p) for p in _FORWARDED_HEADER_PREFIXES + ) - Body: raw bytes (OHTTP-encapsulated request). - Returns: raw bytes (OHTTP-encapsulated response) with Content-Type - ``message/ohttp-res``. - """ - req = connexion.request - # Tolerate both Connexion's Flask request and a bare Flask request. - raw_body: bytes = req.get_data(cache=False) + +def create_anonymous_chat_completion(): + """POST /v1/ohttp — decrypt, sub-dispatch to /v1/chat/completions, re-encrypt.""" + raw_body: bytes = flask_request.get_data(cache=False) if not raw_body: return _error(400, "empty body") @@ -72,51 +96,138 @@ def create_anonymous_chat_completion(): return _error(400, "malformed encapsulated request") try: - inner_body = json.loads(decap.plaintext.decode("utf-8")) + envelope = json.loads(decap.plaintext.decode("utf-8")) except (UnicodeDecodeError, json.JSONDecodeError): - return _error(400, "inner payload is not valid JSON") + return _seal_inner(decap, 400, {}, {"error": "inner payload is not valid JSON"}) + + if not isinstance(envelope, dict): + return _seal_inner( + decap, 400, {}, {"error": "inner payload must be a JSON object"} + ) + + body_obj = envelope.get("body") + if not isinstance(body_obj, dict): + return _seal_inner( + decap, 400, {}, {"error": "inner 'body' must be a JSON object"} + ) + + if body_obj.get("stream"): + # Streaming is rejected on principle: SSE re-introduces per-chunk + # timing/length side channels that defeat the point of sealing + # everything into one response. + return _seal_inner( + decap, 400, {}, {"error": "stream=true is not supported over OHTTP"} + ) + + body_obj = _scrub(body_obj) + body_bytes = json.dumps(body_obj, separators=(",", ":")).encode("utf-8") + + payment_header = envelope.get("x-payment") + if payment_header is not None and not isinstance(payment_header, str): + return _seal_inner( + decap, 400, {}, {"error": "'x-payment' must be a string if present"} + ) + + status_code, response_headers, response_body = _wsgi_subrequest( + path="/v1/chat/completions", + body_bytes=body_bytes, + payment_header=payment_header, + ) - if not isinstance(inner_body, dict): - return _error(400, "inner payload must be a JSON object") + return _seal_inner(decap, status_code, response_headers, response_body) - if inner_body.get("stream"): - # Streaming is rejected on principle (see module docstring). Clients - # who want low TTFT under anonymity should use a shorter max_tokens. - return _error(400, "stream=true is not supported over OHTTP") - inner_body = _scrub(inner_body) +def _wsgi_subrequest( + path: str, + body_bytes: bytes, + payment_header: str | None, +) -> tuple[int, dict[str, str], Any]: + """Issue an in-process WSGI request through the app's full middleware stack. - # Late import to avoid a circular dependency at module load (the chat - # controller pulls in models that import this package). - from tee_gateway.controllers.chat_controller import ( - _create_non_streaming_response, - _parse_chat_request, + Returns ``(status_code, forwarded_headers, parsed_body_or_text)``. The + parsed body is the decoded JSON object on JSON responses, otherwise the + raw response text. We invoke ``current_app.wsgi_app`` directly so the + x402 payment middleware (which wraps ``wsgi_app`` at injection time) + runs the same way it would for an external HTTP request to the same + path — including the pre-inference pricing gate, payment verification, + cost settlement and TEE response signing. + """ + outer_env = flask_request.environ + sub_env: dict[str, Any] = { + k: v + for k, v in outer_env.items() + if k.startswith("wsgi.") + or k in ("SERVER_NAME", "SERVER_PORT", "SERVER_PROTOCOL", "HTTP_HOST") + } + sub_env.update( + { + "REQUEST_METHOD": "POST", + "PATH_INFO": path, + "RAW_URI": path, + "REQUEST_URI": path, + "SCRIPT_NAME": "", + "QUERY_STRING": "", + "CONTENT_TYPE": "application/json", + "CONTENT_LENGTH": str(len(body_bytes)), + "wsgi.input": io.BytesIO(body_bytes), + } ) + if payment_header: + sub_env["HTTP_X_PAYMENT"] = payment_header - try: - chat_request = _parse_chat_request(inner_body) - inner_result = _create_non_streaming_response(chat_request) - except Exception as exc: - logger.error("inner inference failed under OHTTP: %s", exc, exc_info=True) - inner_result = ({"error": "inference failed"}, 500) + captured: dict[str, Any] = {"status": "500 Internal Server Error", "headers": []} - # _create_non_streaming_response returns either a dict or (body, status) - if isinstance(inner_result, tuple): - body_obj, status = inner_result - else: - body_obj, status = inner_result, 200 + def _start_response(status: str, headers: list, exc_info: Any = None): + captured["status"] = status + captured["headers"] = headers + return lambda _chunk: None - inner_json = json.dumps( - {"status": status, "body": body_obj}, + iterator = current_app.wsgi_app(sub_env, _start_response) + body_chunks: list[bytes] = [] + try: + for chunk in iterator: + if chunk: + body_chunks.append(chunk) + finally: + close = getattr(iterator, "close", None) + if callable(close): + # Triggers x402's post-response settlement (StreamingSessionResponse.close). + close() + + status_code = int(captured["status"].split(" ", 1)[0]) + forwarded_headers = { + name: value + for name, value in captured["headers"] + if _should_forward_header(name) + } + + raw_body = b"".join(body_chunks) + parsed_body: Any + if not raw_body: + parsed_body = "" + else: + try: + parsed_body = json.loads(raw_body.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError): + parsed_body = raw_body.decode("utf-8", errors="replace") + + return status_code, forwarded_headers, parsed_body + + +def _seal_inner( + decap: ohttp.DecapsulatedRequest, + status_code: int, + headers: dict[str, str], + body: Any, +) -> Response: + """Encapsulate a ``{status, headers, body}`` triple as an OHTTP response.""" + plaintext = json.dumps( + {"status": status_code, "headers": headers, "body": body}, separators=(",", ":"), + ensure_ascii=False, ).encode("utf-8") - - sealed = ohttp.encapsulate_response(decap.response_key, decap.enc, inner_json) - return Response( - sealed, - status=200, - mimetype=OHTTP_RESPONSE_MEDIA_TYPE, - ) + sealed = ohttp.encapsulate_response(decap.response_key, decap.enc, plaintext) + return Response(sealed, status=200, mimetype=OHTTP_RESPONSE_MEDIA_TYPE) def get_hpke_config(): @@ -136,4 +247,7 @@ def get_hpke_config(): def _error(status: int, message: str) -> tuple[dict, int]: + """Plaintext error response (not sealed) — only used before HPKE decap + succeeds. Once we have a recipient context we always seal errors so the + relay can't distinguish them from real failures.""" return {"error": message}, status diff --git a/tee_gateway/ohttp.py b/tee_gateway/ohttp.py index ba5c7c4..ff47c47 100644 --- a/tee_gateway/ohttp.py +++ b/tee_gateway/ohttp.py @@ -120,9 +120,8 @@ def decapsulate_request( plaintext = recipient.open(aead_ct, aad=b"") # Export a fresh secret bound to this HPKE context, used to derive the - # response AEAD key. This is the OHTTP-defined separation between the - # request and response halves of the same exchange. - response_secret = recipient.export(_LABEL_RESPONSE, _NK) + # response AEAD key. RFC 9458 §4.5 specifies export length max(Nn, Nk). + response_secret = recipient.export(_LABEL_RESPONSE, max(_NN, _NK)) return DecapsulatedRequest( plaintext=plaintext, response_key=response_secret, enc=enc ) @@ -152,13 +151,18 @@ def encapsulate_response(response_secret: bytes, enc: bytes, plaintext: bytes) - def generate_keypair() -> tuple[KEMKeyInterface, bytes]: - """Generate an X25519 keypair for HPKE. Returns (private_key, public_key_raw). - - pyhpke 0.6 derives keys from random IKM via ``kem.derive_key_pair(ikm)``, - which returns a ``KEMKeyPair`` wrapper. We hold onto the private side for - decapsulation and serialize the public side to raw 32-byte form for the - key configuration blob. + """Generate a fresh X25519 keypair for HPKE. + + The HPKE keypair is intentionally independent of the RSA TEE signing + key: deriving one from the other would create a single point of + compromise (a leak of the RSA private key would also leak the OHTTP + private key, and vice versa). Both public keys are still covered by + the same nitriding attestation transcript, so verifiers get binding + without sharing key material. + + pyhpke 0.6 derives the keypair from random IKM via + ``kem.derive_key_pair(ikm)``; we feed it ``os.urandom(32)`` so each + enclave boot produces an independent keypair. """ pair = _SUITE.kem.derive_key_pair(os.urandom(32)) - pk_raw = pair.public_key.to_public_bytes() - return pair.private_key, pk_raw + return pair.private_key, pair.public_key.to_public_bytes() diff --git a/tee_gateway/tee_manager.py b/tee_gateway/tee_manager.py index 193e6b4..16ac97c 100644 --- a/tee_gateway/tee_manager.py +++ b/tee_gateway/tee_manager.py @@ -77,8 +77,11 @@ def _generate_keys(self): wallet_account = Account.from_key(wallet_key_bytes) self.wallet_address = wallet_account.address - # HPKE X25519 keypair — never leaves the enclave; clients address it - # via the public-key fingerprint published with the attestation. + # HPKE X25519 keypair — independent random material from the RSA + # signing key. Both public keys are bound to the enclave by the + # nitriding attestation transcript (register_with_nitriding below), + # so verifiers still get a single attested fingerprint covering both, + # without sharing private-key material between them. self.hpke_private_key, self.hpke_public_key_raw = ohttp.generate_keypair() logger.info("TEE key pair generated successfully") diff --git a/tee_gateway/test/test_ohttp.py b/tee_gateway/test/test_ohttp.py index acb42ee..faa732c 100644 --- a/tee_gateway/test/test_ohttp.py +++ b/tee_gateway/test/test_ohttp.py @@ -80,6 +80,14 @@ def test_rejects_short_input(): ohttp.decapsulate_request(sk, b"\x01") +def test_generate_keypair_is_independent(): + """Each invocation must produce an independent keypair — the HPKE key + is intentionally not derived from any shared seed.""" + _, pk_a = ohttp.generate_keypair() + _, pk_b = ohttp.generate_keypair() + assert pk_a != pk_b + + def test_rejects_tampered_ciphertext(): sk, pk_raw = ohttp.generate_keypair() import struct From e1a720490388d614ffca694bdb25bc5679083687 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 15 May 2026 23:09:39 +0000 Subject: [PATCH 05/39] Relay-pays OHTTP: x-payment from outer header, surface usage to relay MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Switches /v1/ohttp to the relay-pays model. The client encrypts only a chat-completion request — no payment material — and a relay between the client and the enclave supplies the x402 payment as a standard outer-request header. The enclave reads x-payment from the outer request, attaches it to the in-process sub-request to /v1/chat/completions, and lets the existing x402 middleware verify and settle exactly as it would for a public call. * Inner plaintext is now bare chat-completion JSON; the {x-payment, body} envelope is gone since payment travels outside the seal. * On 2xx the response body is still HPKE-sealed (it contains user prompts/completions), but the outer response surfaces token usage as headers so the relay can bill: X-Usage-Prompt-Tokens, X-Usage-Completion-Tokens, X-Usage-Total-Tokens, X-Usage-Model. x402 settlement and TEE signature headers are also forwarded. * On non-2xx (402 payment required, validation errors) the body is forwarded as plaintext so the relay can read x402 payment requirements, retry with a larger payment, or surface errors. These bodies never contain user prompts/completions. * Privacy: relay sees ciphertext + usage + settlement + relay-side wallet; never sees prompts, completions, or the client's IP. Unlinkability holds unless relay and enclave collude. --- tee_gateway/controllers/ohttp_controller.py | 255 +++++++++++--------- 1 file changed, 146 insertions(+), 109 deletions(-) diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py index 73c507e..e541991 100644 --- a/tee_gateway/controllers/ohttp_controller.py +++ b/tee_gateway/controllers/ohttp_controller.py @@ -1,39 +1,54 @@ """ -Oblivious HTTP endpoint for anonymous inference. - -This handler is intentionally a thin shell: it does HPKE decapsulation and -then re-issues the inner request as a real WSGI sub-request against the -enclave's own ``/v1/chat/completions``. That means x402 payment handling, -the pre-inference pricing gate, LangChain routing, the post-inference cost -calculator and TEE response signing all execute via the same code paths as -the public chat endpoint — no duplicate routing tables, no thread-local -side channels, no parallel pricing logic. ``/v1/ohttp`` itself is NOT -gated by x402: payment travels inside the sealed envelope as an -``x-payment`` header on the inner request, and the gating happens -naturally when the sub-request hits the chat endpoint. - -Wire format of the (HPKE-decrypted) inner payload — a JSON object: - { - "x-payment": "", - "body": { ... standard /v1/chat/completions JSON body ... } - } - -Wire format of the (pre-HPKE) inner response: - { - "status": , - "headers": { "x-payment-response": "...", "x-upto-session": "..." }, - "body": - } - -Threat model nuances: - * The relay in front of this endpoint sees the encapsulated ciphertext - and the client IP, but no request content or payment header. - * The enclave sees plaintext and the relay's IP, never the client's. - * If the inner JSON body contains identifiers (``user``, cookies, - custom request IDs), unlinkability is broken at the application - layer — we strip the obvious ones below. - * Streaming is intentionally rejected; the inner sub-request must - return a single sealed response. +Oblivious HTTP endpoint for anonymous inference (relay-pays model). + +This handler is a thin shell: it HPKE-decapsulates the inner request, re-issues +it as an in-process WSGI sub-request against the enclave's own +``/v1/chat/completions``, then encapsulates the response. All x402 payment, +LangChain routing, cost settlement and TEE response signing reuse the public +chat code paths — there is no duplicated routing or pricing logic here. + +Trust / payment model: + * The CLIENT encrypts only an LLM chat-completion request. It does not see, + sign, or carry x402 payment material. + * A RELAY sits between the client and the enclave. The relay holds the + x402 wallet and forwards the (still-encrypted) inner request to the + enclave, attaching its own ``x-payment`` header on the OUTER HTTP + request. + * The ENCLAVE decrypts the inner payload, forwards the request to its own + chat endpoint with the relay's ``x-payment`` header, and returns. The + relay sees status, settlement headers, and token-usage headers, but + never sees the inner prompt or completion. + +Wire format — outer HTTP request to /v1/ohttp: + Headers: + x-payment: + content-type: message/ohttp-req + Body: HPKE-encapsulated chat-completion JSON (just the inner body, + no JSON envelope — the inner payload IS the chat request). + +Wire format — outer HTTP response from /v1/ohttp: + On 2xx (inference succeeded): + Headers: + content-type: message/ohttp-res + x-payment-response, x-upto-session, ... (forwarded from x402) + x-usage-prompt-tokens, x-usage-completion-tokens, + x-usage-total-tokens, x-usage-model (so the relay can bill) + Body: HPKE-encapsulated chat-completion response JSON. The relay + cannot decrypt; only the client (who has the HPKE response + key from its sender context) can read prompts/completions. + On non-2xx (402 payment required, validation errors, etc.): + Body forwarded as plaintext so the relay can act on it (read x402 + payment requirements, retry with a larger payment, surface errors + to the user). These bodies never contain user prompts/completions. + +Privacy properties: + * Relay sees ciphertext + relay-side wallet + token usage + relay's IP. + Never sees prompts, completions, or the client's IP. + * Enclave sees plaintext prompts/completions + relay's IP. Never sees + the client's IP. + * Unlinkability holds unless the relay and the enclave collude. + * Streaming is rejected; SSE re-introduces per-chunk timing/length side + channels that would defeat sealing. """ from __future__ import annotations @@ -60,8 +75,7 @@ _IDENTIFYING_FIELDS = ("user", "metadata", "x-request-id", "request_id") # Response headers we propagate from the inner /v1/chat/completions response -# back to the client (encrypted). Includes x402 settlement metadata and -# anything the standard chat endpoint exposes via the TEE-signed response. +# back through the relay to the client. _FORWARDED_HEADER_PREFIXES = ("x-payment", "x-upto", "x-settlement", "x-tee") _FORWARDED_HEADER_NAMES = ("www-authenticate",) @@ -96,61 +110,115 @@ def create_anonymous_chat_completion(): return _error(400, "malformed encapsulated request") try: - envelope = json.loads(decap.plaintext.decode("utf-8")) + chat_body = json.loads(decap.plaintext.decode("utf-8")) except (UnicodeDecodeError, json.JSONDecodeError): - return _seal_inner(decap, 400, {}, {"error": "inner payload is not valid JSON"}) + return _error(400, "inner payload is not valid JSON") - if not isinstance(envelope, dict): - return _seal_inner( - decap, 400, {}, {"error": "inner payload must be a JSON object"} - ) + if not isinstance(chat_body, dict): + return _error(400, "inner payload must be a JSON object") - body_obj = envelope.get("body") - if not isinstance(body_obj, dict): - return _seal_inner( - decap, 400, {}, {"error": "inner 'body' must be a JSON object"} - ) - - if body_obj.get("stream"): - # Streaming is rejected on principle: SSE re-introduces per-chunk - # timing/length side channels that defeat the point of sealing - # everything into one response. - return _seal_inner( - decap, 400, {}, {"error": "stream=true is not supported over OHTTP"} - ) + if chat_body.get("stream"): + return _error(400, "stream=true is not supported over OHTTP") - body_obj = _scrub(body_obj) - body_bytes = json.dumps(body_obj, separators=(",", ":")).encode("utf-8") + chat_body = _scrub(chat_body) + body_bytes = json.dumps(chat_body, separators=(",", ":")).encode("utf-8") - payment_header = envelope.get("x-payment") - if payment_header is not None and not isinstance(payment_header, str): - return _seal_inner( - decap, 400, {}, {"error": "'x-payment' must be a string if present"} - ) + # The relay pays — x-payment is a standard outer-request header, not + # inside the encrypted envelope. Pass it through to the inner endpoint + # so x402 verifies and settles exactly as it does for a normal call. + payment_header = flask_request.headers.get("X-Payment") - status_code, response_headers, response_body = _wsgi_subrequest( + sub_status, sub_headers, sub_body = _wsgi_subrequest( path="/v1/chat/completions", body_bytes=body_bytes, payment_header=payment_header, ) - return _seal_inner(decap, status_code, response_headers, response_body) + return _build_outer_response(decap, sub_status, sub_headers, sub_body) + + +def _build_outer_response( + decap: ohttp.DecapsulatedRequest, + status: int, + headers: list[tuple[str, str]], + body_bytes: bytes, +) -> Response: + """Translate the inner sub-response into the outer OHTTP response. + + On 2xx we seal the body (which contains user prompts/completions) and + surface token-usage headers so the relay can bill. On non-2xx we pass + through the inner body verbatim — those responses carry x402 payment + requirements or error messages that the relay needs to read, and never + contain user content. + """ + forwarded = {name: value for name, value in headers if _should_forward_header(name)} + inner_content_type = next( + (v for k, v in headers if k.lower() == "content-type"), + "application/json", + ) + + if not (200 <= status < 300): + return Response( + body_bytes, + status=status, + headers=forwarded, + content_type=inner_content_type, + ) + + forwarded.update(_extract_usage_headers(body_bytes)) + sealed = ohttp.encapsulate_response(decap.response_key, decap.enc, body_bytes) + return Response( + sealed, + status=status, + headers=forwarded, + mimetype=OHTTP_RESPONSE_MEDIA_TYPE, + ) + + +def _extract_usage_headers(body_bytes: bytes) -> dict[str, str]: + """Pull token-usage + model name out of a chat-completion response and + project them onto outer HTTP headers for the relay's billing pipeline. + These are the ONLY pieces of metadata the relay needs to charge; the + prompt and completion themselves stay sealed.""" + try: + body = json.loads(body_bytes.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError): + return {} + if not isinstance(body, dict): + return {} + + headers: dict[str, str] = {} + usage = body.get("usage") + if isinstance(usage, dict): + prompt_tokens = usage.get("prompt_tokens", usage.get("input_tokens")) + completion_tokens = usage.get("completion_tokens", usage.get("output_tokens")) + total_tokens = usage.get("total_tokens") + if prompt_tokens is not None: + headers["X-Usage-Prompt-Tokens"] = str(prompt_tokens) + if completion_tokens is not None: + headers["X-Usage-Completion-Tokens"] = str(completion_tokens) + if total_tokens is not None: + headers["X-Usage-Total-Tokens"] = str(total_tokens) + + model = body.get("model") + if isinstance(model, str): + headers["X-Usage-Model"] = model + return headers def _wsgi_subrequest( path: str, body_bytes: bytes, payment_header: str | None, -) -> tuple[int, dict[str, str], Any]: +) -> tuple[int, list[tuple[str, str]], bytes]: """Issue an in-process WSGI request through the app's full middleware stack. - Returns ``(status_code, forwarded_headers, parsed_body_or_text)``. The - parsed body is the decoded JSON object on JSON responses, otherwise the - raw response text. We invoke ``current_app.wsgi_app`` directly so the - x402 payment middleware (which wraps ``wsgi_app`` at injection time) - runs the same way it would for an external HTTP request to the same - path — including the pre-inference pricing gate, payment verification, - cost settlement and TEE response signing. + Returns ``(status_code, headers, body_bytes)``. We invoke + ``current_app.wsgi_app`` directly so the x402 payment middleware (which + wraps ``wsgi_app`` at injection time) runs the same way it would for an + external HTTP request to the same path — including the pre-inference + pricing gate, payment verification, cost settlement and TEE response + signing. """ outer_env = flask_request.environ sub_env: dict[str, Any] = { @@ -195,39 +263,7 @@ def _start_response(status: str, headers: list, exc_info: Any = None): close() status_code = int(captured["status"].split(" ", 1)[0]) - forwarded_headers = { - name: value - for name, value in captured["headers"] - if _should_forward_header(name) - } - - raw_body = b"".join(body_chunks) - parsed_body: Any - if not raw_body: - parsed_body = "" - else: - try: - parsed_body = json.loads(raw_body.decode("utf-8")) - except (UnicodeDecodeError, json.JSONDecodeError): - parsed_body = raw_body.decode("utf-8", errors="replace") - - return status_code, forwarded_headers, parsed_body - - -def _seal_inner( - decap: ohttp.DecapsulatedRequest, - status_code: int, - headers: dict[str, str], - body: Any, -) -> Response: - """Encapsulate a ``{status, headers, body}`` triple as an OHTTP response.""" - plaintext = json.dumps( - {"status": status_code, "headers": headers, "body": body}, - separators=(",", ":"), - ensure_ascii=False, - ).encode("utf-8") - sealed = ohttp.encapsulate_response(decap.response_key, decap.enc, plaintext) - return Response(sealed, status=200, mimetype=OHTTP_RESPONSE_MEDIA_TYPE) + return status_code, captured["headers"], b"".join(body_chunks) def get_hpke_config(): @@ -247,7 +283,8 @@ def get_hpke_config(): def _error(status: int, message: str) -> tuple[dict, int]: - """Plaintext error response (not sealed) — only used before HPKE decap - succeeds. Once we have a recipient context we always seal errors so the - relay can't distinguish them from real failures.""" + """Plaintext error for cases where we never have a recipient context + (empty body, malformed encapsulation). Post-decap input errors are + also returned plaintext so the relay can surface them to the client — + they never contain user prompts.""" return {"error": message}, status From 9d55a8e79536c2d7884dd21c7763391ae1471acb Mon Sep 17 00:00:00 2001 From: kukac Date: Fri, 15 May 2026 19:16:37 -0400 Subject: [PATCH 06/39] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- tee_gateway/controllers/ohttp_controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py index e541991..1c09e5e 100644 --- a/tee_gateway/controllers/ohttp_controller.py +++ b/tee_gateway/controllers/ohttp_controller.py @@ -279,7 +279,7 @@ def get_hpke_config(): return tee.get_hpke_config(), 200 except Exception as exc: logger.error("HPKE config error: %s", exc, exc_info=True) - return {"error": str(exc)}, 500 + return {"error": "Failed to retrieve HPKE config"}, 500 def _error(status: int, message: str) -> tuple[dict, int]: From 38419fd08d9d977c5eb6f2c793096ac577f80a09 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 16 May 2026 13:04:00 +0000 Subject: [PATCH 07/39] Chunked OHTTP: stream SSE inference responses end-to-end MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds streaming support per draft-ietf-ohai-chunked-ohttp-08. When the inner chat-completion request has stream=true, /v1/ohttp pipes the sub-request's SSE events through a chunked OHTTP encrypter and yields them as they arrive, instead of buffering. Non-streaming requests continue to use the existing single-shot RFC 9458 §4.5 path. ohttp.py: * QUIC varint encode/decode helpers (RFC 9000 §16). * New _LABEL_CHUNKED_RESPONSE = "message/bhttp chunked response" and a second secret export at decap time; DecapsulatedRequest now carries response_key + response_key_chunked so the controller can decide which mode to use AFTER inspecting the decrypted body. * ChunkedResponseEncrypter: response_nonce header, varint(len)||ct per chunk (AAD=""), zero-prefix final chunk (AAD=b"final") so truncation is detectable, per-chunk nonce = aead_nonce XOR encode_be(counter). * Extracted _derive_response_keys() shared between single-shot and chunked paths (HKDF-Extract on enc||response_nonce, then Expand twice for "key" and "nonce"). ohttp_controller.py: * Drop the stream=true rejection. Pass stream through to the inner sub-request and detect text/event-stream in the captured headers. * _wsgi_subrequest now returns the raw iterator instead of draining, so the streaming path can pipe chunks through Flask without buffering. close() still invoked downstream to trigger x402 settlement. * _build_streaming_response: look-ahead-by-one over the inner SSE iterator so the last event is sealed with AAD=b"final"; content-type message/ohttp-chunked-res; x402/TEE settlement headers forwarded. Usage stats stay inside the encrypted stream (final SSE event); the relay bills via X-Upto-Session as usual. Tests: varint round-trip across all 4 length classes, chunked response round-trip with a hand-rolled client-side decrypter that walks the varint frames and verifies AAD=b"final", double-finalize rejection. 96 unit tests total now passing. --- tee_gateway/controllers/ohttp_controller.py | 188 +++++++++++++------- tee_gateway/ohttp.py | 150 ++++++++++++++-- tee_gateway/test/test_ohttp.py | 108 +++++++++++ 3 files changed, 363 insertions(+), 83 deletions(-) diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py index 1c09e5e..97eb296 100644 --- a/tee_gateway/controllers/ohttp_controller.py +++ b/tee_gateway/controllers/ohttp_controller.py @@ -7,6 +7,17 @@ LangChain routing, cost settlement and TEE response signing reuse the public chat code paths — there is no duplicated routing or pricing logic here. +Two response modes are supported, dispatched by the inner ``stream`` flag: + * stream=false → single-shot OHTTP response (RFC 9458 §4.5), + content-type ``message/ohttp-res``. Usage stats surface in outer headers. + * stream=true → chunked OHTTP response (draft-ietf-ohai-chunked-ohttp-08), + content-type ``message/ohttp-chunked-res``. Each SSE event from the + inner /v1/chat/completions stream becomes one sealed OHTTP chunk; the + final chunk uses AAD=b"final" so truncation is detectable. Usage stats + can't appear in outer headers (sent before body) — clients read them + from the final SSE event inside the decrypted stream; the relay relies + on x402 settlement metadata (X-Upto-Session) for billing. + Trust / payment model: * The CLIENT encrypts only an LLM chat-completion request. It does not see, sign, or carry x402 payment material. @@ -16,39 +27,19 @@ request. * The ENCLAVE decrypts the inner payload, forwards the request to its own chat endpoint with the relay's ``x-payment`` header, and returns. The - relay sees status, settlement headers, and token-usage headers, but - never sees the inner prompt or completion. - -Wire format — outer HTTP request to /v1/ohttp: - Headers: - x-payment: - content-type: message/ohttp-req - Body: HPKE-encapsulated chat-completion JSON (just the inner body, - no JSON envelope — the inner payload IS the chat request). - -Wire format — outer HTTP response from /v1/ohttp: - On 2xx (inference succeeded): - Headers: - content-type: message/ohttp-res - x-payment-response, x-upto-session, ... (forwarded from x402) - x-usage-prompt-tokens, x-usage-completion-tokens, - x-usage-total-tokens, x-usage-model (so the relay can bill) - Body: HPKE-encapsulated chat-completion response JSON. The relay - cannot decrypt; only the client (who has the HPKE response - key from its sender context) can read prompts/completions. - On non-2xx (402 payment required, validation errors, etc.): - Body forwarded as plaintext so the relay can act on it (read x402 - payment requirements, retry with a larger payment, surface errors - to the user). These bodies never contain user prompts/completions. + relay sees status, settlement headers, and (for non-stream) token-usage + headers, but never sees the inner prompt or completion. Privacy properties: - * Relay sees ciphertext + relay-side wallet + token usage + relay's IP. - Never sees prompts, completions, or the client's IP. + * Relay sees ciphertext + relay-side wallet + (non-stream) token usage + + relay's IP. Never sees prompts, completions, or the client's IP. * Enclave sees plaintext prompts/completions + relay's IP. Never sees the client's IP. * Unlinkability holds unless the relay and the enclave collude. - * Streaming is rejected; SSE re-introduces per-chunk timing/length side - channels that would defeat sealing. + * Streaming leaks per-chunk timing and length (the relay sees the + cadence of varint-framed sealed chunks). This is an inherent cost of + server-sent events — clients who can't accept that signal should + use stream=false. """ from __future__ import annotations @@ -56,7 +47,7 @@ import io import json import logging -from typing import Any +from typing import Any, Iterator from flask import Response, current_app, request as flask_request @@ -67,6 +58,8 @@ OHTTP_MEDIA_TYPE = "message/ohttp-req" OHTTP_RESPONSE_MEDIA_TYPE = "message/ohttp-res" +OHTTP_CHUNKED_RESPONSE_MEDIA_TYPE = "message/ohttp-chunked-res" +_SSE_CONTENT_TYPE = "text/event-stream" # Fields that can re-identify a client and have no role in inference. We drop # them before forwarding to the inner handler — keeping them inside the @@ -117,9 +110,6 @@ def create_anonymous_chat_completion(): if not isinstance(chat_body, dict): return _error(400, "inner payload must be a JSON object") - if chat_body.get("stream"): - return _error(400, "stream=true is not supported over OHTTP") - chat_body = _scrub(chat_body) body_bytes = json.dumps(chat_body, separators=(",", ":")).encode("utf-8") @@ -128,13 +118,30 @@ def create_anonymous_chat_completion(): # so x402 verifies and settles exactly as it does for a normal call. payment_header = flask_request.headers.get("X-Payment") - sub_status, sub_headers, sub_body = _wsgi_subrequest( + sub_status, sub_headers, sub_iter = _wsgi_subrequest( path="/v1/chat/completions", body_bytes=body_bytes, payment_header=payment_header, ) - return _build_outer_response(decap, sub_status, sub_headers, sub_body) + inner_content_type = next( + (v for k, v in sub_headers if k.lower() == "content-type"), + "application/json", + ) + is_streaming = ( + 200 <= sub_status < 300 + and inner_content_type.split(";", 1)[0].strip().lower() == _SSE_CONTENT_TYPE + ) + + if is_streaming: + return _build_streaming_response(decap, sub_status, sub_headers, sub_iter) + + # Non-streaming: drain into bytes (this also triggers x402's + # post-response settlement via the WSGI iterator's close()). + body_bytes_out = _drain(sub_iter) + return _build_outer_response( + decap, sub_status, sub_headers, body_bytes_out, inner_content_type + ) def _build_outer_response( @@ -142,20 +149,13 @@ def _build_outer_response( status: int, headers: list[tuple[str, str]], body_bytes: bytes, + inner_content_type: str, ) -> Response: - """Translate the inner sub-response into the outer OHTTP response. - - On 2xx we seal the body (which contains user prompts/completions) and - surface token-usage headers so the relay can bill. On non-2xx we pass - through the inner body verbatim — those responses carry x402 payment - requirements or error messages that the relay needs to read, and never - contain user content. - """ + """Single-shot OHTTP response. Seals the body on 2xx (contains user + prompts/completions) and surfaces token usage as outer headers so the + relay can bill. Non-2xx bodies (x402 payment requirements, validation + errors) are forwarded as plaintext so the relay can act on them.""" forwarded = {name: value for name, value in headers if _should_forward_header(name)} - inner_content_type = next( - (v for k, v in headers if k.lower() == "content-type"), - "application/json", - ) if not (200 <= status < 300): return Response( @@ -175,6 +175,72 @@ def _build_outer_response( ) +def _build_streaming_response( + decap: ohttp.DecapsulatedRequest, + status: int, + headers: list[tuple[str, str]], + sub_iter: Iterator[bytes], +) -> Response: + """Chunked OHTTP response (draft-ietf-ohai-chunked-ohttp-08). + + Each SSE event from the inner /v1/chat/completions stream is sealed as + one OHTTP chunk and yielded immediately. The final chunk uses + AAD=b"final" with a zero-length varint prefix; emitting it requires + look-ahead by one chunk so we know which one is last, hence the + ``pending`` buffer below. + + Usage stats can't be exposed as outer headers (those are already sent + before the body); the relay bills via x402 settlement metadata + (X-Upto-Session header, set up-front). The client reads usage from the + final SSE event inside the decrypted stream. + """ + forwarded = {name: value for name, value in headers if _should_forward_header(name)} + + def _stream() -> Iterator[bytes]: + encrypter = ohttp.ChunkedResponseEncrypter( + decap.response_key_chunked, decap.enc + ) + yield encrypter.header() + + pending: bytes | None = None + try: + for chunk in sub_iter: + if not chunk: + continue + if pending is not None: + yield encrypter.encrypt_chunk(pending, is_final=False) + pending = chunk + # Always emit exactly one final chunk so the AAD=b"final" + # marker is present — that's what protects clients from + # undetected truncation. + yield encrypter.encrypt_chunk(pending or b"", is_final=True) + finally: + close = getattr(sub_iter, "close", None) + if callable(close): + # Triggers x402's streaming-session settlement. + close() + + return Response( + _stream(), + status=status, + headers=forwarded, + mimetype=OHTTP_CHUNKED_RESPONSE_MEDIA_TYPE, + ) + + +def _drain(sub_iter: Iterator[bytes]) -> bytes: + chunks: list[bytes] = [] + try: + for chunk in sub_iter: + if chunk: + chunks.append(chunk) + finally: + close = getattr(sub_iter, "close", None) + if callable(close): + close() + return b"".join(chunks) + + def _extract_usage_headers(body_bytes: bytes) -> dict[str, str]: """Pull token-usage + model name out of a chat-completion response and project them onto outer HTTP headers for the relay's billing pipeline. @@ -210,15 +276,16 @@ def _wsgi_subrequest( path: str, body_bytes: bytes, payment_header: str | None, -) -> tuple[int, list[tuple[str, str]], bytes]: +) -> tuple[int, list[tuple[str, str]], Iterator[bytes]]: """Issue an in-process WSGI request through the app's full middleware stack. - Returns ``(status_code, headers, body_bytes)``. We invoke - ``current_app.wsgi_app`` directly so the x402 payment middleware (which - wraps ``wsgi_app`` at injection time) runs the same way it would for an - external HTTP request to the same path — including the pre-inference - pricing gate, payment verification, cost settlement and TEE response - signing. + Returns ``(status_code, headers, body_iterator)``. The caller is + responsible for draining and closing the iterator (close() triggers + x402's post-response settlement). We invoke ``current_app.wsgi_app`` + directly so the x402 payment middleware (which wraps ``wsgi_app`` at + injection time) runs the same way it would for an external HTTP + request to the same path — including the pre-inference pricing gate, + payment verification, cost settlement and TEE response signing. """ outer_env = flask_request.environ sub_env: dict[str, Any] = { @@ -251,19 +318,10 @@ def _start_response(status: str, headers: list, exc_info: Any = None): return lambda _chunk: None iterator = current_app.wsgi_app(sub_env, _start_response) - body_chunks: list[bytes] = [] - try: - for chunk in iterator: - if chunk: - body_chunks.append(chunk) - finally: - close = getattr(iterator, "close", None) - if callable(close): - # Triggers x402's post-response settlement (StreamingSessionResponse.close). - close() - status_code = int(captured["status"].split(" ", 1)[0]) - return status_code, captured["headers"], b"".join(body_chunks) + # Don't wrap in iter() — that would strip the iterable's close() method, + # which the caller relies on to trigger x402's post-response settlement. + return status_code, captured["headers"], iterator # type: ignore[return-value] def get_hpke_config(): diff --git a/tee_gateway/ohttp.py b/tee_gateway/ohttp.py index ff47c47..93f993d 100644 --- a/tee_gateway/ohttp.py +++ b/tee_gateway/ohttp.py @@ -7,10 +7,18 @@ - KDF: HKDF-SHA256 (0x0001) - AEAD: ChaCha20-Poly1305 (0x0003) -The inner payload is application/json — we do not BHTTP-wrap the inference -request, since the enclave is the terminal endpoint and not a generic HTTP -proxy. This is a documented divergence from strict RFC 9458; the cryptographic -construction (HPKE base + exported response keying) is identical. +Also implements the chunked-response extension from +draft-ietf-ohai-chunked-ohttp-08 for streaming responses (SSE inference): + - Same HPKE context, separate export label "message/bhttp chunked response". + - Wire: response_nonce || (varint(sealed_len) || sealed_ct)+ + || varint(0) || sealed_final_ct + - AEAD AAD: "" for non-final chunks, "final" for the last chunk. + - Per-chunk nonce: aead_nonce XOR encode_be(counter), counter from 0. + +The inner payload is application/json (or text/event-stream for chunked) — +we do not BHTTP-wrap the inference request, since the enclave is the terminal +endpoint and not a generic HTTP proxy. This is a documented divergence from +strict RFC 9458; the cryptographic construction is identical. Trust model: the relay sees ciphertext + client IP; the enclave sees plaintext + relay IP. Unlinkability holds unless relay and enclave collude. @@ -42,9 +50,10 @@ _NK = 32 # key length _NN = 12 # nonce length -# Per RFC 9458 §4.1/4.2 — "info" labels for the HPKE context. +# Per RFC 9458 §4.1/4.5 and draft-ietf-ohai-chunked-ohttp §3.1 — "info" labels. _LABEL_REQUEST = b"message/bhttp request" _LABEL_RESPONSE = b"message/bhttp response" +_LABEL_CHUNKED_RESPONSE = b"message/bhttp chunked response" _SUITE = CipherSuite.new( KEMId.DHKEM_X25519_HKDF_SHA256, @@ -53,6 +62,50 @@ ) +def encode_varint(value: int) -> bytes: + """QUIC variable-length integer encoding (RFC 9000 §16). + + The top two bits of the first byte encode the length (00=1B, 01=2B, + 10=4B, 11=8B); the remaining bits hold the big-endian value. Used by + draft-ietf-ohai-chunked-ohttp to frame each response chunk on the wire. + """ + if value < 0: + raise ValueError("varint must be non-negative") + if value < (1 << 6): + return bytes([value]) + if value < (1 << 14): + return bytes([0x40 | (value >> 8), value & 0xFF]) + if value < (1 << 30): + return struct.pack(">I", 0x80000000 | value) + if value < (1 << 62): + return struct.pack(">Q", 0xC000000000000000 | value) + raise ValueError("varint value exceeds 2^62-1") + + +def decode_varint(buf: bytes, offset: int = 0) -> tuple[int, int]: + """Parse one QUIC varint from ``buf`` starting at ``offset``. Returns + ``(value, new_offset)``. Used by clients and tests.""" + if offset >= len(buf): + raise ValueError("varint truncated") + first = buf[offset] + length_bits = first >> 6 + if length_bits == 0: + return first, offset + 1 + if length_bits == 1: + if offset + 2 > len(buf): + raise ValueError("varint truncated") + return ((first & 0x3F) << 8) | buf[offset + 1], offset + 2 + if length_bits == 2: + if offset + 4 > len(buf): + raise ValueError("varint truncated") + head = bytes([first & 0x3F]) + buf[offset + 1 : offset + 4] + return struct.unpack(">I", head)[0], offset + 4 + if offset + 8 > len(buf): + raise ValueError("varint truncated") + head = bytes([first & 0x3F]) + buf[offset + 1 : offset + 8] + return struct.unpack(">Q", head)[0], offset + 8 + + def _header_bytes(key_id: int = KEY_CONFIG_ID) -> bytes: return bytes([key_id]) + struct.pack( ">HHH", @@ -83,11 +136,18 @@ def key_config(public_key_raw: bytes, key_id: int = KEY_CONFIG_ID) -> bytes: @dataclass class DecapsulatedRequest: - """Result of decapsulating an OHTTP-wrapped request.""" + """Result of decapsulating an OHTTP-wrapped request. + + Two response secrets are exported up-front so the caller can switch + between single-shot (``response_key``) and chunked + (``response_key_chunked``) encapsulation after inspecting the inner + body — the recipient context can only be created once per request. + """ plaintext: bytes - response_key: bytes # 32 bytes exported from the HPKE context - enc: bytes # client's ephemeral public key, used as salt for the response + response_key: bytes # exported with label "message/bhttp response" + response_key_chunked: bytes # exported with label "message/bhttp chunked response" + enc: bytes # client's ephemeral public key, used as salt for response keying def decapsulate_request( @@ -119,35 +179,89 @@ def decapsulate_request( recipient = _SUITE.create_recipient_context(enc, private_key, info=info) plaintext = recipient.open(aead_ct, aad=b"") - # Export a fresh secret bound to this HPKE context, used to derive the - # response AEAD key. RFC 9458 §4.5 specifies export length max(Nn, Nk). - response_secret = recipient.export(_LABEL_RESPONSE, max(_NN, _NK)) + # Two exports, one per response mode. RFC 9458 §4.5 and the chunked + # draft §3.1 specify max(Nn, Nk) as the export length. + export_len = max(_NN, _NK) + response_secret = recipient.export(_LABEL_RESPONSE, export_len) + response_secret_chunked = recipient.export(_LABEL_CHUNKED_RESPONSE, export_len) return DecapsulatedRequest( - plaintext=plaintext, response_key=response_secret, enc=enc + plaintext=plaintext, + response_key=response_secret, + response_key_chunked=response_secret_chunked, + enc=enc, ) def encapsulate_response(response_secret: bytes, enc: bytes, plaintext: bytes) -> bytes: - """Seal a response under the per-request derived key (RFC 9458 §4.2). + """Seal a response under the per-request derived key (RFC 9458 §4.5). Wire format: response_nonce(max(Nn, Nk)=Nk=32) || AEAD ciphertext """ response_nonce = os.urandom(max(_NN, _NK)) - salt = enc + response_nonce + aead_key, aead_nonce = _derive_response_keys(response_secret, enc, response_nonce) + ct = ChaCha20Poly1305(aead_key).encrypt(aead_nonce, plaintext, b"") + return response_nonce + ct + - h = hmac.HMAC(salt, hashes.SHA256()) +def _derive_response_keys( + response_secret: bytes, enc: bytes, response_nonce: bytes +) -> tuple[bytes, bytes]: + """HKDF-Extract(enc || response_nonce, response_secret) then Expand for + ``aead_key`` (Nk bytes, info=b"key") and ``aead_nonce`` (Nn bytes, + info=b"nonce"). Shared by single-shot and chunked response paths.""" + h = hmac.HMAC(enc + response_nonce, hashes.SHA256()) h.update(response_secret) prk = h.finalize() - aead_key = HKDFExpand(algorithm=hashes.SHA256(), length=_NK, info=b"key").derive( prk ) aead_nonce = HKDFExpand( algorithm=hashes.SHA256(), length=_NN, info=b"nonce" ).derive(prk) + return aead_key, aead_nonce - ct = ChaCha20Poly1305(aead_key).encrypt(aead_nonce, plaintext, b"") - return response_nonce + ct + +class ChunkedResponseEncrypter: + """Stream a chunked OHTTP response per draft-ietf-ohai-chunked-ohttp-08. + + Usage: + enc = ChunkedResponseEncrypter(decap.response_key_chunked, decap.enc) + yield enc.header() # response_nonce + for plaintext in non_final_chunks: + yield enc.encrypt_chunk(plaintext, is_final=False) + yield enc.encrypt_chunk(last_plaintext, is_final=True) + + The final chunk uses AAD=b"final" with a zero-length varint prefix — + that pair is what prevents an attacker from truncating the stream + undetectably, so callers MUST always emit exactly one is_final=True + chunk to terminate the response (even if its plaintext is empty). + """ + + def __init__(self, response_secret: bytes, enc: bytes): + self._response_nonce = os.urandom(max(_NN, _NK)) + self._aead_key, self._aead_nonce = _derive_response_keys( + response_secret, enc, self._response_nonce + ) + self._aead = ChaCha20Poly1305(self._aead_key) + self._counter = 0 + self._finalized = False + + def header(self) -> bytes: + """Wire bytes that prefix the chunk stream.""" + return self._response_nonce + + def encrypt_chunk(self, plaintext: bytes, is_final: bool) -> bytes: + if self._finalized: + raise RuntimeError("ChunkedResponseEncrypter already finalized") + ctr_bytes = self._counter.to_bytes(_NN, "big") + chunk_nonce = bytes(a ^ b for a, b in zip(self._aead_nonce, ctr_bytes)) + aad = b"final" if is_final else b"" + sealed = self._aead.encrypt(chunk_nonce, plaintext, aad) + self._counter += 1 + length_prefix = encode_varint(0) if is_final else encode_varint(len(sealed)) + if is_final: + self._finalized = True + return length_prefix + sealed def generate_keypair() -> tuple[KEMKeyInterface, bytes]: diff --git a/tee_gateway/test/test_ohttp.py b/tee_gateway/test/test_ohttp.py index faa732c..48c256c 100644 --- a/tee_gateway/test/test_ohttp.py +++ b/tee_gateway/test/test_ohttp.py @@ -88,6 +88,114 @@ def test_generate_keypair_is_independent(): assert pk_a != pk_b +def test_varint_round_trip(): + """QUIC varint encoder/decoder matches across all 4 length classes.""" + for value in (0, 63, 64, 16383, 16384, 1073741823, 1073741824, (1 << 62) - 1): + encoded = ohttp.encode_varint(value) + decoded, off = ohttp.decode_varint(encoded) + assert decoded == value, value + assert off == len(encoded) + + with pytest.raises(ValueError): + ohttp.encode_varint(-1) + with pytest.raises(ValueError): + ohttp.encode_varint(1 << 62) + + +def test_chunked_response_round_trip(): + """Client encrypts a chunked response stream; client-side decrypter + must recover every chunk and detect the AAD=b'final' terminator.""" + from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 + + sk, pk_raw = ohttp.generate_keypair() + import struct as _struct + + hdr = bytes([ohttp.KEY_CONFIG_ID]) + _struct.pack( + ">HHH", + ohttp.KEM_ID_X25519, + ohttp.KDF_ID_HKDF_SHA256, + ohttp.AEAD_ID_CHACHA20_POLY1305, + ) + info = b"message/bhttp request" + b"\x00" + hdr + pkr = ohttp._SUITE.kem.deserialize_public_key(pk_raw) + enc, sender = ohttp._SUITE.create_sender_context(pkr, info=info) + wire = hdr + enc + sender.seal(b'{"stream": true}', aad=b"") + + decap = ohttp.decapsulate_request(sk, wire) + assert decap.response_key_chunked == sender.export( + b"message/bhttp chunked response", 32 + ) + + # Server side: stream three chunks plus an empty final marker. + plaintexts = [b"data: {chunk1}\n\n", b"data: {chunk2}\n\n", b"data: [DONE]\n\n"] + encrypter = ohttp.ChunkedResponseEncrypter(decap.response_key_chunked, decap.enc) + wire_chunks = [encrypter.header()] + for pt in plaintexts[:-1]: + wire_chunks.append(encrypter.encrypt_chunk(pt, is_final=False)) + wire_chunks.append(encrypter.encrypt_chunk(plaintexts[-1], is_final=True)) + stream = b"".join(wire_chunks) + + # Client side: re-derive keys, walk the varint-framed chunks. + response_nonce = stream[:32] + off = 32 + response_secret = sender.export(b"message/bhttp chunked response", 32) + aead_key, aead_nonce = ohttp._derive_response_keys( + response_secret, enc, response_nonce + ) + aead = ChaCha20Poly1305(aead_key) + + recovered: list[bytes] = [] + counter = 0 + while off < len(stream): + length, off = ohttp.decode_varint(stream, off) + is_final = length == 0 + # On the final chunk the prefix is zero; the actual sealed length is + # plaintext_len + 16 (Poly1305 tag). The chunk consumes the rest. + seg_len = len(stream) - off if is_final else length + ct = stream[off : off + seg_len] + off += seg_len + chunk_nonce = bytes( + a ^ b for a, b in zip(aead_nonce, counter.to_bytes(12, "big")) + ) + aad = b"final" if is_final else b"" + recovered.append(aead.decrypt(chunk_nonce, ct, aad)) + counter += 1 + if is_final: + break + + assert recovered == plaintexts + # The decrypter MUST reject the same stream with the final-AAD swapped + # — protects against undetected truncation at the boundary. + with pytest.raises(Exception): + aead.decrypt( + bytes(a ^ b for a, b in zip(aead_nonce, (counter - 1).to_bytes(12, "big"))), + stream[-len(ct) :], + b"", + ) + + +def test_chunked_encrypter_rejects_double_finalize(): + sk, pk_raw = ohttp.generate_keypair() + import struct as _struct + + hdr = bytes([ohttp.KEY_CONFIG_ID]) + _struct.pack( + ">HHH", + ohttp.KEM_ID_X25519, + ohttp.KDF_ID_HKDF_SHA256, + ohttp.AEAD_ID_CHACHA20_POLY1305, + ) + info = b"message/bhttp request" + b"\x00" + hdr + pkr = ohttp._SUITE.kem.deserialize_public_key(pk_raw) + enc, sender = ohttp._SUITE.create_sender_context(pkr, info=info) + decap = ohttp.decapsulate_request(sk, hdr + enc + sender.seal(b"hi", aad=b"")) + + encrypter = ohttp.ChunkedResponseEncrypter(decap.response_key_chunked, decap.enc) + encrypter.header() + encrypter.encrypt_chunk(b"only", is_final=True) + with pytest.raises(RuntimeError, match="already finalized"): + encrypter.encrypt_chunk(b"extra", is_final=False) + + def test_rejects_tampered_ciphertext(): sk, pk_raw = ohttp.generate_keypair() import struct From b34baa7257d9b0546aeb1bf7405f931473999062 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 16 May 2026 13:18:34 +0000 Subject: [PATCH 08/39] README: document /v1/ohttp anonymous inference + chunked streaming Adds the two OHTTP endpoints to the API table and a concise section covering the relay-pays flow, the single-shot vs chunked response modes, billing channel for each mode, and the relay/enclave/client trust split. Refs RFC 9458 and draft-ietf-ohai-chunked-ohttp-08. --- README.md | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/README.md b/README.md index 95148be..a00fd4a 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,8 @@ The `measurements.txt` checked into this repository reflects the OpenGradient-op | `/signing-key` | GET | TEE public key (PEM format) and tee_id | | `/v1/completions` | POST | Text completion (signed) | | `/v1/chat/completions` | POST | Chat completion (signed) | +| `/v1/ohttp` | POST | Anonymous chat completion (OHTTP-encapsulated, relay-paid) | +| `/v1/ohttp/config` | GET | HPKE key configuration (RFC 9458) for OHTTP clients | ### Request Format @@ -170,6 +172,34 @@ The `tee_*` fields provide cryptographic proof of the response: - **`tee_timestamp`** — Unix timestamp when the response was signed (proves freshness) - **`tee_id`** — keccak256 of the enclave's DER-encoded public key (stable identifier for this enclave instance) +## Anonymous Inference (Oblivious HTTP) + +`/v1/ohttp` is a thin wrapper around `/v1/chat/completions` that adds **client unlinkability** via [RFC 9458 OHTTP](https://www.rfc-editor.org/rfc/rfc9458) + [draft-ietf-ohai-chunked-ohttp-08](https://datatracker.ietf.org/doc/draft-ietf-ohai-chunked-ohttp/). HPKE ciphersuite is fixed: DHKEM(X25519,HKDF-SHA256) / HKDF-SHA256 / ChaCha20-Poly1305. + +**Flow:** + +1. Client fetches `/v1/ohttp/config` (HPKE pubkey, key_id, suite IDs) and verifies it against the Nitro attestation. +2. Client HPKE-encapsulates a normal chat-completion JSON body and POSTs the ciphertext to a **relay**. The client carries no payment material. +3. Relay forwards the ciphertext to `/v1/ohttp` and attaches its own `X-Payment: ` header. +4. Enclave decrypts → re-issues the request internally to `/v1/chat/completions` with the relay's `X-Payment` header → x402 verifies and settles → response is sealed back to the client. + +**Two response modes** (chosen by the inner `stream` flag): + +| Mode | Outer content-type | Body | Usage to relay | +|------|---|---|---| +| `stream=false` | `message/ohttp-res` | Single-shot sealed body (RFC 9458 §4.5) | `X-Usage-{Prompt,Completion,Total}-Tokens`, `X-Usage-Model` headers | +| `stream=true` | `message/ohttp-chunked-res` | `response_nonce \|\| (varint(len) \|\| sealed_ct)+ \|\| varint(0) \|\| sealed_final_ct` — one OHTTP chunk per SSE event, AAD=`b"final"` on the last chunk (chunked-ohttp draft §3) | Inside the encrypted stream's final SSE event; relay bills via `X-Upto-Session` | + +On non-2xx (e.g. 402 payment required) the body is forwarded plaintext so the relay can read x402 payment requirements and retry — those bodies never contain prompts or completions. + +**Trust split:** + +- **Relay** sees ciphertext + token counts + settlement metadata + its own wallet. Never sees prompts, completions, or the client's IP. +- **Enclave** sees plaintext + relay's IP. Never sees the client's IP. +- **Client** decrypts and verifies the TEE signature embedded in the response body against the attested public key. + +Unlinkability holds unless relay and enclave collude. Streaming leaks per-chunk timing and length — clients who can't accept that signal should use `stream=false`. + ## Verification ### 1. Verify Attestation From 9fc7134c49034fc74461cf8046d5f3eb74f8543a Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 16 May 2026 13:41:23 +0000 Subject: [PATCH 09/39] Add scripts/test_ohttp.py local smoke-test client Mirrors scripts/test_bytedance.py but exercises /v1/ohttp end-to-end: fetches /v1/ohttp/config, cross-checks the HPKE pubkey against the /signing-key attestation document, HPKE-encapsulates a chat request, POSTs to /v1/ohttp, and decrypts the response. Supports both single- shot and chunked OHTTP (--stream); the chunked path decrypts the varint-framed sealed stream incrementally so you can see SSE events arrive in real time. Includes a hand-rolled QUIC varint reader so the script stays usable as a standalone client SDK reference. Usage examples in the module docstring. --- scripts/test_ohttp.py | 335 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 335 insertions(+) create mode 100644 scripts/test_ohttp.py diff --git a/scripts/test_ohttp.py b/scripts/test_ohttp.py new file mode 100644 index 0000000..a4896d2 --- /dev/null +++ b/scripts/test_ohttp.py @@ -0,0 +1,335 @@ +""" +Smoke test: exercise the OHTTP anonymous-inference endpoints end-to-end against +a running gateway. Mirrors test_bytedance.py — point it at a local server and +verify a request round-trips through HPKE encap/decap + the chat backend. + +Usage: + # Non-streaming round-trip (default): + uv run python scripts/test_ohttp.py + uv run python scripts/test_ohttp.py --model gpt-4.1 --prompt "what model are you?" + + # Streaming via chunked OHTTP: + uv run python scripts/test_ohttp.py --stream --model claude-haiku-4-5 + + # Against a remote gateway (anything reachable via HTTP/S): + uv run python scripts/test_ohttp.py --url https://my-enclave.example + +The gateway must already have provider keys injected via POST /v1/keys for the +inner chat request to succeed. Run scripts/test_bytedance.py first if unsure — +it shares the same backend. +""" + +from __future__ import annotations + +import argparse +import json +import os +import struct +import sys +from pathlib import Path +from typing import IO, Iterator + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import requests # noqa: E402 +from cryptography.hazmat.primitives import hashes, hmac # noqa: E402 +from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 # noqa: E402 +from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand # noqa: E402 +from pyhpke import AEADId, CipherSuite, KDFId, KEMId # noqa: E402 + +# Fixed ciphersuite — the gateway only accepts this combination. +_SUITE = CipherSuite.new( + KEMId.DHKEM_X25519_HKDF_SHA256, + KDFId.HKDF_SHA256, + AEADId.CHACHA20_POLY1305, +) +_KEM_ID = 0x0020 +_KDF_ID = 0x0001 +_AEAD_ID = 0x0003 +_NK = 32 +_NN = 12 +_LABEL_REQ = b"message/bhttp request" +_LABEL_RESP = b"message/bhttp response" +_LABEL_RESP_CHUNKED = b"message/bhttp chunked response" + + +# --------------------------------------------------------------------------- +# HPKE encapsulation (client side) +# --------------------------------------------------------------------------- + + +def encapsulate_request( + public_key_raw: bytes, key_id: int, plaintext: bytes +) -> tuple[bytes, object, bytes]: + """Build an OHTTP-encapsulated request. Returns ``(wire, sender, enc)``. + + The sender context is kept by the caller so it can later export the + response secret (single-shot or chunked) and decrypt the reply.""" + hdr = bytes([key_id]) + struct.pack(">HHH", _KEM_ID, _KDF_ID, _AEAD_ID) + info = _LABEL_REQ + b"\x00" + hdr + pkr = _SUITE.kem.deserialize_public_key(public_key_raw) + enc, sender = _SUITE.create_sender_context(pkr, info=info) + ct = sender.seal(plaintext, aad=b"") + return hdr + enc + ct, sender, enc + + +# --------------------------------------------------------------------------- +# Response decryption (single-shot + chunked) +# --------------------------------------------------------------------------- + + +def _derive_response_keys( + response_secret: bytes, enc: bytes, response_nonce: bytes +) -> tuple[bytes, bytes]: + h = hmac.HMAC(enc + response_nonce, hashes.SHA256()) + h.update(response_secret) + prk = h.finalize() + aead_key = HKDFExpand(algorithm=hashes.SHA256(), length=_NK, info=b"key").derive( + prk + ) + aead_nonce = HKDFExpand( + algorithm=hashes.SHA256(), length=_NN, info=b"nonce" + ).derive(prk) + return aead_key, aead_nonce + + +def decrypt_single_shot(sealed: bytes, sender, enc: bytes) -> bytes: + response_secret = sender.export(_LABEL_RESP, max(_NN, _NK)) + nonce_len = max(_NN, _NK) + response_nonce = sealed[:nonce_len] + aead_ct = sealed[nonce_len:] + aead_key, aead_nonce = _derive_response_keys(response_secret, enc, response_nonce) + return ChaCha20Poly1305(aead_key).decrypt(aead_nonce, aead_ct, b"") + + +def _read_varint(stream: IO[bytes]) -> int | None: + """Read one QUIC varint from a byte stream. Returns None at clean EOF.""" + first = stream.read(1) + if not first: + return None + b = first[0] + nbytes = 1 << (b >> 6) # 1, 2, 4, or 8 + rest = stream.read(nbytes - 1) if nbytes > 1 else b"" + if len(rest) != nbytes - 1: + raise ValueError("truncated varint") + head = bytes([b & 0x3F]) + rest + if nbytes == 1: + return b & 0x3F + if nbytes == 2: + return (head[0] << 8) | head[1] + if nbytes == 4: + return struct.unpack(">I", head)[0] + return struct.unpack(">Q", head)[0] + + +def decrypt_chunked_stream( + raw_stream: IO[bytes], sender, enc: bytes +) -> Iterator[bytes]: + """Decrypt a chunked OHTTP response (draft-ietf-ohai-chunked-ohttp-08) + incrementally, yielding plaintext as each sealed chunk arrives. + + Wire: response_nonce || (varint(len) || ct)+ || varint(0) || final_ct + Per-chunk nonce = aead_nonce XOR encode_be(counter), AAD=""/b"final". + """ + response_secret = sender.export(_LABEL_RESP_CHUNKED, max(_NN, _NK)) + response_nonce = raw_stream.read(max(_NN, _NK)) + if len(response_nonce) != max(_NN, _NK): + raise ValueError("truncated response_nonce") + aead_key, aead_nonce = _derive_response_keys(response_secret, enc, response_nonce) + aead = ChaCha20Poly1305(aead_key) + counter = 0 + + while True: + length = _read_varint(raw_stream) + if length is None: + raise ValueError("stream ended without AAD=final chunk") + + chunk_nonce = bytes( + a ^ b for a, b in zip(aead_nonce, counter.to_bytes(_NN, "big")) + ) + counter += 1 + + if length == 0: + final_ct = raw_stream.read() # rest of stream + yield aead.decrypt(chunk_nonce, final_ct, b"final") + return + + ct = raw_stream.read(length) + if len(ct) != length: + raise ValueError("truncated chunk ciphertext") + yield aead.decrypt(chunk_nonce, ct, b"") + + +# --------------------------------------------------------------------------- +# Driver +# --------------------------------------------------------------------------- + + +def fetch_config(base_url: str) -> tuple[bytes, int]: + """GET /v1/ohttp/config; return (public_key_raw_bytes, key_id).""" + r = requests.get(f"{base_url}/v1/ohttp/config", timeout=10) + r.raise_for_status() + cfg = r.json() + print( + f"HPKE config: key_id={cfg['key_id']} suite={cfg['kem_id']}/{cfg['kdf_id']}/{cfg['aead_id']}" + ) + print(f" public_key = {cfg['public_key']}") + print(f" key_config = {cfg['key_config']}") + pk_raw = bytes.fromhex(cfg["public_key"]) + if len(pk_raw) != 32: + raise ValueError(f"expected 32-byte X25519 pubkey, got {len(pk_raw)}") + return pk_raw, cfg["key_id"] + + +def verify_attestation_binding(base_url: str, hpke_pubkey_hex: str) -> None: + """Cross-check that the HPKE pubkey we just got matches what the + attestation document at /signing-key reports. A network attacker + between us and the enclave could otherwise swap in their own + pubkey and decrypt our prompt.""" + r = requests.get(f"{base_url}/signing-key", timeout=10) + r.raise_for_status() + doc = r.json() + attested = (doc.get("hpke") or {}).get("public_key") + if attested is None: + print( + "WARN: /signing-key did not include hpke.public_key; skipping binding check" + ) + return + if attested.lower() != hpke_pubkey_hex.lower(): + raise ValueError( + f"HPKE pubkey mismatch! config={hpke_pubkey_hex} attestation={attested}" + ) + print("HPKE pubkey matches the attestation document") + + +def run_non_streaming( + base_url: str, model: str, prompt: str, payment: str | None +) -> int: + pk_raw, key_id = fetch_config(base_url) + verify_attestation_binding(base_url, pk_raw.hex()) + + inner = { + "model": model, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 200, + } + inner_bytes = json.dumps(inner, separators=(",", ":")).encode("utf-8") + wire, sender, enc = encapsulate_request(pk_raw, key_id, inner_bytes) + + headers = {"Content-Type": "message/ohttp-req"} + if payment: + headers["X-Payment"] = payment + + print(f"\nPOST {base_url}/v1/ohttp ({len(wire)} encapsulated bytes)") + r = requests.post(f"{base_url}/v1/ohttp", data=wire, headers=headers, timeout=120) + print(f"HTTP {r.status_code} content-type={r.headers.get('content-type')}") + + if r.status_code >= 400 or "ohttp-res" not in (r.headers.get("content-type") or ""): + # Plaintext error pass-through (e.g. 402 with x402 payment requirements). + print("---- plaintext error body ----") + try: + print(json.dumps(r.json(), indent=2)) + except ValueError: + print(r.text) + return 1 + + print("---- forwarded headers ----") + for k, v in r.headers.items(): + if k.lower().startswith( + ("x-payment", "x-upto", "x-settlement", "x-tee", "x-usage") + ): + print(f" {k}: {v}") + + plaintext = decrypt_single_shot(r.content, sender, enc) + print("---- decrypted response ----") + try: + parsed = json.loads(plaintext) + choices = parsed.get("choices") or [] + if choices: + print(choices[0].get("message", {}).get("content", "")) + print("---- usage ----") + print(json.dumps(parsed.get("usage"), indent=2)) + for field in ("tee_signature", "tee_request_hash", "tee_output_hash"): + if field in parsed: + print(f" {field}: {parsed[field][:40]}...") + except json.JSONDecodeError: + print(plaintext.decode("utf-8", errors="replace")) + return 0 + + +def run_streaming(base_url: str, model: str, prompt: str, payment: str | None) -> int: + pk_raw, key_id = fetch_config(base_url) + verify_attestation_binding(base_url, pk_raw.hex()) + + inner = { + "model": model, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 200, + "stream": True, + } + inner_bytes = json.dumps(inner, separators=(",", ":")).encode("utf-8") + wire, sender, enc = encapsulate_request(pk_raw, key_id, inner_bytes) + + headers = {"Content-Type": "message/ohttp-req"} + if payment: + headers["X-Payment"] = payment + + print(f"\nPOST {base_url}/v1/ohttp (stream=true, {len(wire)} encapsulated bytes)") + r = requests.post( + f"{base_url}/v1/ohttp", + data=wire, + headers=headers, + timeout=120, + stream=True, + ) + print(f"HTTP {r.status_code} content-type={r.headers.get('content-type')}") + + ct = r.headers.get("content-type", "") + if r.status_code >= 400 or "chunked-res" not in ct: + print("---- non-streaming response body ----") + print(r.text) + return 1 + + print("---- decrypted SSE events ----") + # urllib3's HTTPResponse — supports .read(n) without re-decoding chunked transfer. + for plaintext in decrypt_chunked_stream(r.raw, sender, enc): + text = plaintext.decode("utf-8", errors="replace") + sys.stdout.write(text) + sys.stdout.flush() + print("\n---- end of stream ----") + return 0 + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--url", default=os.environ.get("OHTTP_URL", "http://127.0.0.1:8000") + ) + parser.add_argument("--model", default="gpt-4.1") + parser.add_argument( + "--prompt", default="What model are you? Reply in one short sentence." + ) + parser.add_argument( + "--stream", action="store_true", help="Use chunked OHTTP streaming" + ) + parser.add_argument( + "--payment", + default=os.environ.get("X_PAYMENT"), + help="Optional x402 payment payload (base64). Send via the outer X-Payment header.", + ) + args = parser.parse_args() + + try: + if args.stream: + return run_streaming(args.url, args.model, args.prompt, args.payment) + return run_non_streaming(args.url, args.model, args.prompt, args.payment) + except requests.RequestException as exc: + print(f"\nERROR: HTTP failure — {exc}", file=sys.stderr) + return 2 + except Exception as exc: + print(f"\nERROR: {type(exc).__name__}: {exc}", file=sys.stderr) + return 2 + + +if __name__ == "__main__": + raise SystemExit(main()) From fea4d013f84ef85d231799eb1ef6c79aa12e585c Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 16 May 2026 13:46:28 +0000 Subject: [PATCH 10/39] Forward Authorization header on OHTTP sub-request The OpenAPI spec declares a global ApiKeyAuth requirement; connexion enforces it on /v1/chat/completions before any handler runs and returns 401 "No authorization token provided" when missing. Our WSGI sub-request from /v1/ohttp arrived without an Authorization header, so OHTTP requests bounced with 401 before reaching the chat backend. security_controller.info_from_ApiKeyAuth is an intentional passthrough (x402 is the real access control) so any token value satisfies the schema check. Forward the outer Authorization header to the sub-request when the relay supplied one, else inject a placeholder bearer token. --- tee_gateway/controllers/ohttp_controller.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py index 97eb296..f812fd3 100644 --- a/tee_gateway/controllers/ohttp_controller.py +++ b/tee_gateway/controllers/ohttp_controller.py @@ -310,6 +310,15 @@ def _wsgi_subrequest( if payment_header: sub_env["HTTP_X_PAYMENT"] = payment_header + # The OpenAPI spec declares a global ApiKeyAuth requirement and connexion + # enforces it before our handler runs (returns 401 "No authorization + # token provided"). The security function (security_controller.py) is an + # intentional passthrough — x402 is the real access control — so any + # value satisfies the schema check. Forward the outer header if the + # relay supplied one, otherwise inject a placeholder. + outer_auth = flask_request.headers.get("Authorization") + sub_env["HTTP_AUTHORIZATION"] = outer_auth or "Bearer ohttp-relay" + captured: dict[str, Any] = {"status": "500 Internal Server Error", "headers": []} def _start_response(status: str, headers: list, exc_info: Any = None): From b12e70f9bee5716e7f1526f75cd74e61388e0391 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 16 May 2026 13:47:31 +0000 Subject: [PATCH 11/39] OHTTP: use a fixed dummy Authorization on the inner sub-request MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Don't forward the outer Authorization header to the chat sub-request — anything the relay attached there (API keys, JWT subjects, bearer tokens, ...) could re-identify the client and defeat unlinkability. A constant "Bearer ohttp" placeholder satisfies connexion's ApiKeyAuth schema check (security_controller is a passthrough; x402 is the real access control) and keeps every OHTTP request indistinguishable at this layer. --- tee_gateway/controllers/ohttp_controller.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py index f812fd3..101d34b 100644 --- a/tee_gateway/controllers/ohttp_controller.py +++ b/tee_gateway/controllers/ohttp_controller.py @@ -314,10 +314,12 @@ def _wsgi_subrequest( # enforces it before our handler runs (returns 401 "No authorization # token provided"). The security function (security_controller.py) is an # intentional passthrough — x402 is the real access control — so any - # value satisfies the schema check. Forward the outer header if the - # relay supplied one, otherwise inject a placeholder. - outer_auth = flask_request.headers.get("Authorization") - sub_env["HTTP_AUTHORIZATION"] = outer_auth or "Bearer ohttp-relay" + # value satisfies the schema check. We deliberately do NOT forward the + # outer Authorization header: anything the relay attached there could + # re-identify the client (API keys, JWT subjects, bearer tokens) and + # defeat the whole point of OHTTP. A fixed constant keeps every OHTTP + # request indistinguishable to the chat backend at this layer. + sub_env["HTTP_AUTHORIZATION"] = "Bearer ohttp" captured: dict[str, Any] = {"status": "500 Internal Server Error", "headers": []} From 58908aaf7219a82eab51a19886a2afcbf8269602 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 16 May 2026 13:50:40 +0000 Subject: [PATCH 12/39] Add TEE_GATEWAY_DEV_SKIP_X402 dev escape hatch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Set the env var to "1" before /v1/keys is POSTed and the gateway will skip attaching the x402 payment middleware. Lets developers smoke-test /v1/chat/completions and /v1/ohttp locally without a reachable facilitator URL — without it, the middleware's first-request initialize() blows up on facilitator DNS lookups. Logs a WARNING when active and is explicitly NOT for production use. --- tee_gateway/__main__.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index ad087af..df0a772 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -369,7 +369,18 @@ def _set(val: str | None) -> str: except Exception as e: logger.warning(f"Heartbeat initialization failed: {e}") - _init_payment_middleware(facilitator_url) + # Local-dev escape hatch: set TEE_GATEWAY_DEV_SKIP_X402=1 to run the + # chat backend without attaching the x402 payment middleware. Lets + # you smoke-test /v1/chat/completions and /v1/ohttp without a + # reachable facilitator. NEVER set in production — this disables + # all payment-based access control. + if os.getenv("TEE_GATEWAY_DEV_SKIP_X402") == "1": + logger.warning( + "TEE_GATEWAY_DEV_SKIP_X402=1 — skipping x402 payment middleware. " + "Do NOT use this in production." + ) + else: + _init_payment_middleware(facilitator_url) _active_facilitator_url = facilitator_url _keys_initialized = True From d18be753c438e718158fa8b94c33a09fd82cf37a Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 16 May 2026 14:08:00 +0000 Subject: [PATCH 13/39] test_ohttp.py: dump the outgoing OHTTP request so you can eyeball it MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Prints the request line, headers, the inner plaintext (clearly labeled as never-on-the-wire), then a breakdown of the encapsulated body: the 7-byte OHTTP header, the 32-byte ephemeral X25519 enc, and an xxd-style hex dump of the AEAD ciphertext. Makes it visually obvious that the relay only sees opaque sealed bytes — no prompt content, no model name, no API key, nothing. --- scripts/test_ohttp.py | 65 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 63 insertions(+), 2 deletions(-) diff --git a/scripts/test_ohttp.py b/scripts/test_ohttp.py index a4896d2..d5fe97b 100644 --- a/scripts/test_ohttp.py +++ b/scripts/test_ohttp.py @@ -160,6 +160,67 @@ def decrypt_chunked_stream( yield aead.decrypt(chunk_nonce, ct, b"") +# --------------------------------------------------------------------------- +# Wire-payload dump (for visual confirmation that nothing plaintext leaves) +# --------------------------------------------------------------------------- + + +def _hexdump(data: bytes, max_bytes: int = 96) -> str: + """xxd-style hex dump, capped at ``max_bytes`` so output stays readable.""" + out: list[str] = [] + for i in range(0, min(len(data), max_bytes), 16): + row = data[i : i + 16] + hex_part = " ".join(f"{b:02x}" for b in row) + ascii_part = "".join(chr(b) if 32 <= b < 127 else "." for b in row) + out.append(f" {i:04x} {hex_part:<48} |{ascii_part}|") + if len(data) > max_bytes: + out.append(f" ... ({len(data) - max_bytes} more bytes of ciphertext)") + return "\n".join(out) + + +def dump_outgoing( + url: str, + headers: dict[str, str], + inner_plaintext: bytes, + wire: bytes, + enc: bytes, + key_id: int, +) -> None: + """Print everything that's about to go on the wire so you can eyeball + that the relay only sees opaque ciphertext.""" + print("\n================ OUTGOING REQUEST ================") + print(f"POST {url}") + for name, value in headers.items(): + print(f" {name}: {value}") + print( + f"\n inner plaintext ({len(inner_plaintext)} bytes, " + f"NEVER goes on the wire — sealed under HPKE):" + ) + try: + print( + " " + + json.dumps(json.loads(inner_plaintext), indent=4).replace("\n", "\n ") + ) + except json.JSONDecodeError: + print(f" {inner_plaintext!r}") + + # The wire body decomposes into: + # header (7 bytes): key_id || kem_id || kdf_id || aead_id + # enc (32 bytes): client's ephemeral X25519 public key + # ct (rest): AEAD ciphertext + 16-byte tag — opaque to the relay + print(f"\n encapsulated body ({len(wire)} bytes total):") + print( + f" OHTTP header = {wire[:7].hex()} " + f"(key_id=0x{key_id:02x}, suite=0x0020/0x0001/0x0003)" + ) + print(f" enc (ephemeral)= {enc.hex()}") + print( + f" ciphertext = {len(wire) - 7 - 32} bytes (HPKE-sealed, no plaintext leaks):" + ) + print(_hexdump(wire[7 + 32 :])) + print("==================================================\n") + + # --------------------------------------------------------------------------- # Driver # --------------------------------------------------------------------------- @@ -220,7 +281,7 @@ def run_non_streaming( if payment: headers["X-Payment"] = payment - print(f"\nPOST {base_url}/v1/ohttp ({len(wire)} encapsulated bytes)") + dump_outgoing(f"{base_url}/v1/ohttp", headers, inner_bytes, wire, enc, key_id) r = requests.post(f"{base_url}/v1/ohttp", data=wire, headers=headers, timeout=120) print(f"HTTP {r.status_code} content-type={r.headers.get('content-type')}") @@ -274,7 +335,7 @@ def run_streaming(base_url: str, model: str, prompt: str, payment: str | None) - if payment: headers["X-Payment"] = payment - print(f"\nPOST {base_url}/v1/ohttp (stream=true, {len(wire)} encapsulated bytes)") + dump_outgoing(f"{base_url}/v1/ohttp", headers, inner_bytes, wire, enc, key_id) r = requests.post( f"{base_url}/v1/ohttp", data=wire, From 366d2f51557cf608b5d6acf85ef3cd579f22ca84 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 16 May 2026 14:09:33 +0000 Subject: [PATCH 14/39] Revert "Add TEE_GATEWAY_DEV_SKIP_X402 dev escape hatch" This reverts commit 58908aaf7219a82eab51a19886a2afcbf8269602. --- tee_gateway/__main__.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index df0a772..ad087af 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -369,18 +369,7 @@ def _set(val: str | None) -> str: except Exception as e: logger.warning(f"Heartbeat initialization failed: {e}") - # Local-dev escape hatch: set TEE_GATEWAY_DEV_SKIP_X402=1 to run the - # chat backend without attaching the x402 payment middleware. Lets - # you smoke-test /v1/chat/completions and /v1/ohttp without a - # reachable facilitator. NEVER set in production — this disables - # all payment-based access control. - if os.getenv("TEE_GATEWAY_DEV_SKIP_X402") == "1": - logger.warning( - "TEE_GATEWAY_DEV_SKIP_X402=1 — skipping x402 payment middleware. " - "Do NOT use this in production." - ) - else: - _init_payment_middleware(facilitator_url) + _init_payment_middleware(facilitator_url) _active_facilitator_url = facilitator_url _keys_initialized = True From 87eec3ceccbb78339f3b0980df167021a136d6e2 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 16 May 2026 14:36:34 +0000 Subject: [PATCH 15/39] tee_manager: refuse to register if HPKE pubkey is missing The v2 attestation transcript labels both the RSA SPKI and the X25519 HPKE pubkey, but the previous (self.hpke_public_key_raw or b"") fallback would silently produce a "v2"-labeled digest that actually only covers RSA whenever hpke_public_key_raw was None or empty. A verifier trusting the label would then accept an enclave whose HPKE key was never bound to attestation. Add an explicit length check (must be exactly 32 bytes) outside the broad try/except, so a real misconfiguration raises clearly instead of being masked as the "Could not register with nitriding (may not be in TEE)" warning. Today _generate_keys() always sets both keys so this is a defense-in-depth guard against future partial-init regressions. --- tee_gateway/tee_manager.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tee_gateway/tee_manager.py b/tee_gateway/tee_manager.py index 16ac97c..98deaa6 100644 --- a/tee_gateway/tee_manager.py +++ b/tee_gateway/tee_manager.py @@ -100,6 +100,20 @@ def register_with_nitriding(self): gets binding for the HPKE config used for anonymous inference — no separate trust anchor required. """ + # Defensive check: the v2 transcript labels both keys, so refusing to + # register is the only safe behavior when one is missing. Falling back + # to b"" would produce a digest that's nominally v2 but only covers + # RSA, and a verifier trusting the label would accept an enclave whose + # HPKE key was never attested. Raise outside the broad try/except + # below so a real misconfiguration isn't masked as a non-TEE + # environment. + if not self.hpke_public_key_raw or len(self.hpke_public_key_raw) != 32: + raise RuntimeError( + "Refusing to register with nitriding: HPKE X25519 public key " + "is missing or wrong length; the v2 attestation transcript " + "requires both RSA and HPKE keys." + ) + try: public_key_der = self.public_key.public_bytes( encoding=serialization.Encoding.DER, @@ -112,7 +126,7 @@ def register_with_nitriding(self): b"og-tee-keys|v2|rsa-spki=" + public_key_der + b"|hpke-x25519=" - + (self.hpke_public_key_raw or b"") + + self.hpke_public_key_raw ) key_hash = hashlib.sha256(transcript).digest() key_hash_b64 = base64.b64encode(key_hash).decode("utf-8") From 4dbd35eff94fc97679983a28e0117291c29dd9b1 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 16 May 2026 14:38:47 +0000 Subject: [PATCH 16/39] ohttp: normalize decap failures to ValueError with a generic message MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit decapsulate_request's docstring promised ValueError on malformed input, but recipient.open() raises pyhpke / cryptography exception types on AEAD tag failure, bad ephemeral keys, etc., so the contract was a lie. The error strings from those libraries can encode oracle information about which specific check failed (tag verification vs. length vs. KDF), which would turn the function into a padding-oracle-style side channel if any caller logged with exc_info=True. * Wrap the crypto path (create_recipient_context + open) and re-raise as ValueError("HPKE decapsulation failed") with `from None` so the underlying exception chain is suppressed entirely. Don't wrap the HKDF exports — those are deterministic and can't fail on valid input. * Bump the minimum input length to 7 + 32 + 16 so truncated inputs hit our own "too short" ValueError instead of whatever pyhpke would raise. * Tighten test_rejects_tampered_ciphertext from pytest.raises(Exception) to pytest.raises(ValueError, match="HPKE decapsulation failed") so the contract is enforced by tests, not just documented. --- tee_gateway/ohttp.py | 30 +++++++++++++++++++++++------- tee_gateway/test/test_ohttp.py | 5 ++++- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/tee_gateway/ohttp.py b/tee_gateway/ohttp.py index 93f993d..0f9fa7f 100644 --- a/tee_gateway/ohttp.py +++ b/tee_gateway/ohttp.py @@ -155,11 +155,19 @@ def decapsulate_request( ) -> DecapsulatedRequest: """Decrypt an HPKE-wrapped request inside the enclave. - Raises ValueError on malformed input or unsupported ciphersuite. We - never echo the underlying exception text to clients — it can leak - timing/oracle info. + Raises ``ValueError`` on any malformed or unauthenticated input — + structural errors (wrong length, unsupported ciphersuite) and + cryptographic errors (AEAD tag failure, bad ephemeral key, etc.) are + all surfaced as ``ValueError`` with a generic message. The original + pyhpke/cryptography exception is deliberately NOT chained, because + those error strings can encode oracle information about which check + failed (tag verification vs. KDF vs. length). """ - if len(encapsulated_request) < 7 + 32: + # Header(7) + ephemeral pubkey(32) + AEAD tag(16) is the absolute + # minimum. Reject shorter inputs up front so every short-input failure + # mode is a ValueError rather than whatever pyhpke would raise on a + # truncated open(). + if len(encapsulated_request) < 7 + 32 + 16: raise ValueError("encapsulated request too short") key_id = encapsulated_request[0] @@ -176,11 +184,19 @@ def decapsulate_request( aead_ct = encapsulated_request[7 + 32 :] info = _LABEL_REQUEST + b"\x00" + _header_bytes(key_id) - recipient = _SUITE.create_recipient_context(enc, private_key, info=info) - plaintext = recipient.open(aead_ct, aad=b"") + try: + recipient = _SUITE.create_recipient_context(enc, private_key, info=info) + plaintext = recipient.open(aead_ct, aad=b"") + except Exception: + # Map AEAD / HPKE failures (invalid tag, bad ephemeral pubkey, etc.) + # to ValueError to match the documented contract. ``from None`` + # suppresses the chain so pyhpke / cryptography error strings — which + # can distinguish which check failed — never reach a caller's log. + raise ValueError("HPKE decapsulation failed") from None # Two exports, one per response mode. RFC 9458 §4.5 and the chunked - # draft §3.1 specify max(Nn, Nk) as the export length. + # draft §3.1 specify max(Nn, Nk) as the export length. HKDF-Expand is + # deterministic and can't fail on valid inputs, so we don't wrap it. export_len = max(_NN, _NK) response_secret = recipient.export(_LABEL_RESPONSE, export_len) response_secret_chunked = recipient.export(_LABEL_CHUNKED_RESPONSE, export_len) diff --git a/tee_gateway/test/test_ohttp.py b/tee_gateway/test/test_ohttp.py index 48c256c..fa7c26f 100644 --- a/tee_gateway/test/test_ohttp.py +++ b/tee_gateway/test/test_ohttp.py @@ -212,5 +212,8 @@ def test_rejects_tampered_ciphertext(): ct = sender.seal(b"hello", aad=b"") wire = bytearray(hdr + enc + ct) wire[-1] ^= 0xFF - with pytest.raises(Exception): + # decapsulate_request normalises every crypto failure to ValueError with + # a generic message — pyhpke / cryptography exception types must NOT leak + # through, since their strings encode oracle info about which check failed. + with pytest.raises(ValueError, match="HPKE decapsulation failed"): ohttp.decapsulate_request(sk, bytes(wire)) From 30e873790fb9850fbf76acc649a730d6629af3d7 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 16 May 2026 14:40:12 +0000 Subject: [PATCH 17/39] ohttp_controller: fix inaccurate privacy claim in docstring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous wording said the relay "never sees the client's IP", which is wrong — in the relay-pays model the client connects directly to the relay, so the relay necessarily sees the client's IP at the network layer. The actual privacy property is that the ENCLAVE never sees the client's IP (it only sees the relay's), and the relay sees only the encapsulated ciphertext (plus billing metadata it needs), not the prompt or completion. Reword to spell out network position vs. compute position for each party and the precise unlinkability claim (and the collusion caveat). --- tee_gateway/controllers/ohttp_controller.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py index 101d34b..ec763c5 100644 --- a/tee_gateway/controllers/ohttp_controller.py +++ b/tee_gateway/controllers/ohttp_controller.py @@ -31,11 +31,21 @@ headers, but never sees the inner prompt or completion. Privacy properties: - * Relay sees ciphertext + relay-side wallet + (non-stream) token usage + - relay's IP. Never sees prompts, completions, or the client's IP. - * Enclave sees plaintext prompts/completions + relay's IP. Never sees - the client's IP. - * Unlinkability holds unless the relay and the enclave collude. + * Relay (network position): terminates the client's TCP/TLS connection, + so the relay DOES see the client's IP at the network layer — that's + unavoidable. What the relay does NOT see is the request/response + content: it observes only the OHTTP-encapsulated ciphertext, its own + wallet's x-payment material, and (single-shot only) the token-usage + outer headers it needs to bill. + * Enclave (compute position): sees the plaintext prompt and completion + (they are decrypted inside the enclave to run the LLM call), but at + the network layer it only sees the RELAY's IP — never the client's. + That's the unlinkability property: the enclave cannot tie a request's + plaintext back to a specific end user. + * Unlinkability between a specific client identity and a specific + plaintext request holds unless the relay and the enclave collude + (the relay would have to disclose its client-IP log alongside the + enclave's per-request plaintext log). * Streaming leaks per-chunk timing and length (the relay sees the cadence of varint-framed sealed chunks). This is an inherent cost of server-sent events — clients who can't accept that signal should From 2eae9e93ecff7e1863b8239da6cba3c5edc527ac Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 16 May 2026 14:40:50 +0000 Subject: [PATCH 18/39] README: fix the Anonymous Inference trust-split wording Mirror the docstring correction in ohttp_controller.py: the relay does see the client's IP at the network layer (it terminates the TCP/TLS connection). What it doesn't see is request/response content. The unlinkability claim is that the ENCLAVE never sees the client's IP and therefore can't tie a plaintext request to a specific end user. --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a00fd4a..07a4258 100644 --- a/README.md +++ b/README.md @@ -194,11 +194,11 @@ On non-2xx (e.g. 402 payment required) the body is forwarded plaintext so the re **Trust split:** -- **Relay** sees ciphertext + token counts + settlement metadata + its own wallet. Never sees prompts, completions, or the client's IP. -- **Enclave** sees plaintext + relay's IP. Never sees the client's IP. +- **Relay** terminates the client's TCP/TLS connection, so it does see the client's IP — that's unavoidable. What it doesn't see is content: only OHTTP ciphertext + its own wallet's `x-payment` material + (single-shot only) the token-usage outer headers it needs to bill. +- **Enclave** sees plaintext prompts/completions (it has to run the LLM call) but at the network layer only sees the relay's IP, never the client's. This is the unlinkability claim — the enclave can't tie a plaintext request to a specific end user. - **Client** decrypts and verifies the TEE signature embedded in the response body against the attested public key. -Unlinkability holds unless relay and enclave collude. Streaming leaks per-chunk timing and length — clients who can't accept that signal should use `stream=false`. +Unlinkability between a client identity and a plaintext request holds unless relay and enclave collude (the relay would have to share its client-IP log alongside the enclave's plaintext log). Streaming additionally leaks per-chunk timing and length — clients who can't accept that signal should use `stream=false`. ## Verification From deae24336dbbfc3f786bcb2f962d769e5c1d2794 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 16 May 2026 14:55:41 +0000 Subject: [PATCH 19/39] docs: clarify how usage stats reach the relay vs. how x402 settles MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previous wording on streaming was wrong / muddled: it implied the relay could extract usage "inside the encrypted stream's final SSE event", which is nonsense — the relay can't decrypt. Rewrite both the controller docstring and the README Anonymous Inference section to state the actual behavior: * Source of truth for billing is x402 settlement against the relay's x-payment under the `upto` scheme. The gateway settles the real cost in both modes. * stream=false: outer response ALSO exposes X-Usage-* headers so the relay can do its own per-token bookkeeping. * stream=true: NO per-token detail in outer headers — they ship before any body chunk, so token counts aren't known yet, and the sealed body is opaque. The relay learns the settled amount by querying the facilitator with X-Upto-Session or via the next X-Payment-Response. Only the client sees per-token detail (from the final SSE event inside the decrypted stream). Drop the "Usage to relay" column from the response-modes table since the billing channel is now described in its own paragraph. --- README.md | 13 +++++++--- tee_gateway/controllers/ohttp_controller.py | 27 +++++++++++++++------ 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 07a4258..741b7e4 100644 --- a/README.md +++ b/README.md @@ -185,10 +185,15 @@ The `tee_*` fields provide cryptographic proof of the response: **Two response modes** (chosen by the inner `stream` flag): -| Mode | Outer content-type | Body | Usage to relay | -|------|---|---|---| -| `stream=false` | `message/ohttp-res` | Single-shot sealed body (RFC 9458 §4.5) | `X-Usage-{Prompt,Completion,Total}-Tokens`, `X-Usage-Model` headers | -| `stream=true` | `message/ohttp-chunked-res` | `response_nonce \|\| (varint(len) \|\| sealed_ct)+ \|\| varint(0) \|\| sealed_final_ct` — one OHTTP chunk per SSE event, AAD=`b"final"` on the last chunk (chunked-ohttp draft §3) | Inside the encrypted stream's final SSE event; relay bills via `X-Upto-Session` | +| Mode | Outer content-type | Body | +|------|---|---| +| `stream=false` | `message/ohttp-res` | Single-shot sealed body (RFC 9458 §4.5) | +| `stream=true` | `message/ohttp-chunked-res` | `response_nonce \|\| (varint(len) \|\| sealed_ct)+ \|\| varint(0) \|\| sealed_final_ct` — one OHTTP chunk per SSE event, AAD=`b"final"` on the last chunk (chunked-ohttp draft §3) | + +**Billing channel for the relay.** Both modes settle the actual cost via x402 against the relay's `X-Payment` (`upto` scheme); the gateway is the source of truth for the amount. + +- `stream=false`: outer response *also* exposes per-token detail — `X-Usage-Prompt-Tokens`, `X-Usage-Completion-Tokens`, `X-Usage-Total-Tokens`, `X-Usage-Model` — for the relay's own bookkeeping. The sealed body carries the same `usage` block for the client. +- `stream=true`: **no** per-token detail in outer headers (they're flushed before any body chunk, so we can't know token counts at header-write time) and the sealed chunks are opaque to the relay. The relay reads the actual settled amount from x402 — either by querying the facilitator with its `X-Upto-Session`, or via `X-Payment-Response` on its next call. The client still sees per-token detail in the final SSE event inside the decrypted stream. On non-2xx (e.g. 402 payment required) the body is forwarded plaintext so the relay can read x402 payment requirements and retry — those bodies never contain prompts or completions. diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py index ec763c5..5bdf84b 100644 --- a/tee_gateway/controllers/ohttp_controller.py +++ b/tee_gateway/controllers/ohttp_controller.py @@ -9,14 +9,26 @@ Two response modes are supported, dispatched by the inner ``stream`` flag: * stream=false → single-shot OHTTP response (RFC 9458 §4.5), - content-type ``message/ohttp-res``. Usage stats surface in outer headers. + content-type ``message/ohttp-res``. * stream=true → chunked OHTTP response (draft-ietf-ohai-chunked-ohttp-08), content-type ``message/ohttp-chunked-res``. Each SSE event from the inner /v1/chat/completions stream becomes one sealed OHTTP chunk; the - final chunk uses AAD=b"final" so truncation is detectable. Usage stats - can't appear in outer headers (sent before body) — clients read them - from the final SSE event inside the decrypted stream; the relay relies - on x402 settlement metadata (X-Upto-Session) for billing. + final chunk uses AAD=b"final" so truncation is detectable. + +Billing channel for the relay (both modes settle the real cost via x402 +against the relay's x-payment; the gateway is the source of truth for the +amount): + * stream=false: outer headers ALSO expose per-token detail — + X-Usage-Prompt-Tokens, X-Usage-Completion-Tokens, X-Usage-Total-Tokens, + X-Usage-Model — for the relay's own bookkeeping. The sealed body + contains the same usage info for the client. + * stream=true: NO per-token detail in outer headers. Outer response + headers are flushed before any body chunk is yielded, so we cannot + know token counts at header-write time; the sealed chunks are opaque + to the relay. The relay learns the actual settled amount from x402: + by querying the facilitator with its X-Upto-Session, or via + X-Payment-Response on its next request. The client still gets + per-token detail from the final SSE event inside the decrypted stream. Trust / payment model: * The CLIENT encrypts only an LLM chat-completion request. It does not see, @@ -27,8 +39,9 @@ request. * The ENCLAVE decrypts the inner payload, forwards the request to its own chat endpoint with the relay's ``x-payment`` header, and returns. The - relay sees status, settlement headers, and (for non-stream) token-usage - headers, but never sees the inner prompt or completion. + relay sees status and settlement headers (and, for non-stream, the + per-token usage headers above), but never sees the inner prompt or + completion. Privacy properties: * Relay (network position): terminates the client's TCP/TLS connection, From de38f89410dead3ec28a7ac1b2febb472b25d24c Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 16 May 2026 15:42:04 +0000 Subject: [PATCH 20/39] ci: add tee_gateway/test/test_ohttp.py to the CI test list MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit .github/workflows/test.yml runs an explicit list of test files; the new OHTTP test module wasn't in it, so the HPKE/varint/chunked-OHTTP cryptography code was passing locally but not exercised in CI on PRs or pushes to main. Add it to the list. Followups to flag separately (out of scope for this PR): two other local-passing unit-test files are similarly excluded from CI — test_chat_controller.py and test_completions_controller.py — so the chat and completions controllers don't get continuous coverage either. --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 273d4c9..1540f00 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v5 - name: Run unit tests - run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py tee_gateway/test/test_price_feed.py tests/test_pricing.py -v --import-mode=importlib + run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py tee_gateway/test/test_price_feed.py tee_gateway/test/test_ohttp.py tests/test_pricing.py -v --import-mode=importlib # To also run integration tests (real CoinGecko network calls), add: # env: # RUN_INTEGRATION_TESTS: "1" From 879e953fce60627fcc3c62a3eafd9d335993b219 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 16 May 2026 15:43:59 +0000 Subject: [PATCH 21/39] ci: discover tee_gateway/test/ as a directory MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Switch the unit-tests step from an explicit per-file list to directory discovery so newly added test modules aren't silently excluded the way test_ohttp (this PR) was — and the way test_chat_controller and test_completions_controller still are on main. test_price_feed_integration.py raises unittest.SkipTest at module import when RUN_INTEGRATION_TESTS isn't set, so directory discovery picks it up but skips cleanly: 188 passed / 3 skipped, same as the previous explicit invocation plus the two never-covered controller test files. tests/test_pricing.py lives outside the package test dir so it stays listed explicitly. --- .github/workflows/test.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1540f00..1bef00c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,12 @@ jobs: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v5 - name: Run unit tests - run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py tee_gateway/test/test_price_feed.py tee_gateway/test/test_ohttp.py tests/test_pricing.py -v --import-mode=importlib + # Discover the whole tee_gateway/test/ directory so new test files + # aren't silently excluded from CI — the previous explicit list had + # left test_chat_controller, test_completions_controller, and + # test_ohttp out. test_price_feed_integration.py self-gates on + # RUN_INTEGRATION_TESTS and skips cleanly without it. + run: uv run --group test pytest tee_gateway/test/ tests/test_pricing.py -v --import-mode=importlib # To also run integration tests (real CoinGecko network calls), add: # env: # RUN_INTEGRATION_TESTS: "1" From 2a0c4b5322139f719452311886d0da4ee0a6b2f9 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 16 May 2026 15:48:36 +0000 Subject: [PATCH 22/39] ohttp_controller: preserve duplicate forwarded headers (no dict collapse) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit WSGI passes response headers as a list of (name, value) tuples precisely because HTTP allows multi-valued headers. Two cases relevant to us: RFC 7230 §3.2.2 (duplicates merge by comma but order matters) and RFC 7235 §4.1 (WWW-Authenticate may legally repeat — one entry per challenge scheme). The previous dict comprehension flattened the list and would silently drop a duplicate if x402 ever emitted multiple payment challenges or any other multi-valued header in our allowlist. Keep `forwarded` as a list of tuples in both _build_outer_response and _build_streaming_response; convert the dict-by-construction _extract_usage_headers values to tuples at merge time. Werkzeug's Response(headers=...) accepts either form and preserves duplicates when given a list. --- tee_gateway/controllers/ohttp_controller.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py index 5bdf84b..2ab441e 100644 --- a/tee_gateway/controllers/ohttp_controller.py +++ b/tee_gateway/controllers/ohttp_controller.py @@ -178,7 +178,14 @@ def _build_outer_response( prompts/completions) and surfaces token usage as outer headers so the relay can bill. Non-2xx bodies (x402 payment requirements, validation errors) are forwarded as plaintext so the relay can act on them.""" - forwarded = {name: value for name, value in headers if _should_forward_header(name)} + # Keep headers as a list of (name, value) tuples — WSGI gives us a list + # specifically because HTTP allows multi-valued headers (RFC 7230 §3.2.2; + # WWW-Authenticate in particular per RFC 7235 §4.1 can repeat, one per + # challenge scheme). Collapsing to a dict would drop duplicates silently + # and could lose an x402 challenge if future versions emit more than one. + forwarded: list[tuple[str, str]] = [ + (name, value) for name, value in headers if _should_forward_header(name) + ] if not (200 <= status < 300): return Response( @@ -188,7 +195,9 @@ def _build_outer_response( content_type=inner_content_type, ) - forwarded.update(_extract_usage_headers(body_bytes)) + # Usage headers come from a JSON parse — single-valued by construction — + # so a dict is fine internally; just project to tuples at merge time. + forwarded.extend(_extract_usage_headers(body_bytes).items()) sealed = ohttp.encapsulate_response(decap.response_key, decap.enc, body_bytes) return Response( sealed, @@ -217,7 +226,11 @@ def _build_streaming_response( (X-Upto-Session header, set up-front). The client reads usage from the final SSE event inside the decrypted stream. """ - forwarded = {name: value for name, value in headers if _should_forward_header(name)} + # See _build_outer_response: keep as a list so duplicate HTTP header + # values (e.g. multiple WWW-Authenticate challenges) survive forwarding. + forwarded: list[tuple[str, str]] = [ + (name, value) for name, value in headers if _should_forward_header(name) + ] def _stream() -> Iterator[bytes]: encrypter = ohttp.ChunkedResponseEncrypter( From ac51a3571c09c818e61ef1981e580ca00e26c912 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 16 May 2026 16:03:57 +0000 Subject: [PATCH 23/39] ohttp_controller: drop unused OHTTP_MEDIA_TYPE constant The request media type was defined for symmetry with the response constants but never read. Decapsulation itself is the security gate; the unauthenticated Content-Type header gives us nothing to enforce. The response constants (OHTTP_RESPONSE_MEDIA_TYPE, OHTTP_CHUNKED_RESPONSE_MEDIA_TYPE) are still in use. --- tee_gateway/controllers/ohttp_controller.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py index 2ab441e..81ea582 100644 --- a/tee_gateway/controllers/ohttp_controller.py +++ b/tee_gateway/controllers/ohttp_controller.py @@ -79,7 +79,6 @@ logger = logging.getLogger(__name__) -OHTTP_MEDIA_TYPE = "message/ohttp-req" OHTTP_RESPONSE_MEDIA_TYPE = "message/ohttp-res" OHTTP_CHUNKED_RESPONSE_MEDIA_TYPE = "message/ohttp-chunked-res" _SSE_CONTENT_TYPE = "text/event-stream" From fe780d01f55c71d78934904e9af671c7cca76e81 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Sat, 16 May 2026 12:50:19 -0400 Subject: [PATCH 24/39] pricing --- tee_gateway/__main__.py | 36 +-- tee_gateway/controllers/chat_controller.py | 10 + .../controllers/completions_controller.py | 10 +- tee_gateway/controllers/ohttp_controller.py | 81 +++--- tee_gateway/definitions.py | 6 +- tee_gateway/encoder.py | 5 + tee_gateway/price_feed/__init__.py | 21 ++ tee_gateway/price_feed/feed.py | 2 +- tee_gateway/pricing.py | 256 ++++++++++++++++++ tee_gateway/test/test_price_feed.py | 141 ++++++---- tee_gateway/util.py | 171 ------------ tests/test_pricing.py | 27 +- 12 files changed, 476 insertions(+), 290 deletions(-) create mode 100644 tee_gateway/pricing.py diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index ad087af..3edc0cf 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -41,9 +41,8 @@ from x402.session import SessionStore import x402.http.middleware.flask as x402_flask -from .util import calculate_session_cost from .model_registry import get_model_config -from .price_feed import OPGPriceFeed +from .price_feed import OPGPriceFeed, set_price_feed from .definitions import ( EVM_PAYMENT_ADDRESS, BASE_MAINNET_NETWORK, @@ -121,6 +120,7 @@ def _shutdown_heartbeat(): # --------------------------------------------------------------------------- _price_feed = OPGPriceFeed() _price_feed.start() +set_price_feed(_price_feed) _started_at = time.time() @@ -161,22 +161,22 @@ def _patched_read_body_bytes(environ): def _session_cost_calculator(ctx: dict) -> int: - # Post-inference cost calculation — response already sent to client. - # Predictable failures (unknown price, unknown model) are blocked by the - # pre-inference gate; any exception here indicates a provider-side error - # (e.g. missing usage field in the LLM response). The x402 middleware - # swallows the exception in close(), so the client is not charged. - # Log CRITICAL so provider errors are never silently missed. - try: - return calculate_session_cost(ctx, _price_feed.get_price) - except Exception as exc: - logger.critical( - "Post-inference cost calculation failed (provider error) — " - "client was NOT charged: %s", - exc, - exc_info=True, - ) - raise + # The chat/completions controllers compute cost in-band and embed it on + # the response as a SessionCost model BEFORE returning. We parse it back + # here — single source of truth for cost lives in the controller, not + # split between the controller (which serves clients) and x402's callback + # (which charges them). + # + # If the controller couldn't compute cost (e.g. missing usage from the + # provider), the block is absent — SessionCost.model_validate raises and + # x402's close() swallows it so the client is not charged. The controller + # has already logged CRITICAL in that case. + from .pricing import SessionCost + + response_json = ctx.get("response_json") + if not isinstance(response_json, dict): + raise ValueError("response_json missing or not a dict") + return SessionCost.model_validate(response_json.get("opengradient")).cost_opg # --------------------------------------------------------------------------- diff --git a/tee_gateway/controllers/chat_controller.py b/tee_gateway/controllers/chat_controller.py index e6d9337..2fee235 100644 --- a/tee_gateway/controllers/chat_controller.py +++ b/tee_gateway/controllers/chat_controller.py @@ -30,6 +30,7 @@ convert_messages, extract_usage, ) +from tee_gateway.pricing import compute_session_cost logger = logging.getLogger(__name__) @@ -265,6 +266,9 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest): if usage: openai_response["usage"] = usage + cost = compute_session_cost(request_dict, openai_response) + if cost is not None: + openai_response["opengradient"] = cost # Validate schema (the extra tee_* fields are preserved by returning dict directly) CreateChatCompletionResponse.from_dict(openai_response) @@ -590,6 +594,12 @@ def generate(): "completion_tokens": final_usage.get("output_tokens", 0), "total_tokens": final_usage.get("total_tokens", 0), } + cost = compute_session_cost(request_dict, final_data) + if cost is not None: + # final_data is hand-serialized to SSE via json.dumps below, + # which doesn't go through Flask's JSONEncoder — so do the + # serialization ourselves here. + final_data["opengradient"] = cost.model_dump(mode="json") logger.info( f"Stream completed — usage: {final_data['usage']}, " f"finish: {finish_reason}, " diff --git a/tee_gateway/controllers/completions_controller.py b/tee_gateway/controllers/completions_controller.py index 78b9314..68b7b58 100644 --- a/tee_gateway/controllers/completions_controller.py +++ b/tee_gateway/controllers/completions_controller.py @@ -4,12 +4,15 @@ import logging import connexion +from typing import Any + from langchain_core.messages import HumanMessage from tee_gateway.models.create_completion_request import CreateCompletionRequest from tee_gateway.tee_manager import get_tee_keys, compute_tee_msg_hash from tee_gateway.llm_backend import get_chat_model_cached, extract_usage +from tee_gateway.pricing import compute_session_cost logger = logging.getLogger(__name__) @@ -57,7 +60,7 @@ def create_completion(body): tee_keys = get_tee_keys() signature = tee_keys.sign_data(msg_hash) - return { + completion_response: dict[str, Any] = { "id": f"cmpl-{uuid.uuid4()}", "object": "text_completion", "created": timestamp, @@ -76,6 +79,11 @@ def create_completion(body): "tee_timestamp": timestamp, "tee_id": f"0x{tee_keys.get_tee_id()}", } + if usage: + cost = compute_session_cost(request_dict, completion_response) + if cost is not None: + completion_response["opengradient"] = cost + return completion_response except Exception as e: logger.error(f"Completion error: {str(e)}", exc_info=True) diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py index 81ea582..1f7fbf8 100644 --- a/tee_gateway/controllers/ohttp_controller.py +++ b/tee_gateway/controllers/ohttp_controller.py @@ -18,17 +18,17 @@ Billing channel for the relay (both modes settle the real cost via x402 against the relay's x-payment; the gateway is the source of truth for the amount): - * stream=false: outer headers ALSO expose per-token detail — - X-Usage-Prompt-Tokens, X-Usage-Completion-Tokens, X-Usage-Total-Tokens, - X-Usage-Model — for the relay's own bookkeeping. The sealed body - contains the same usage info for the client. - * stream=true: NO per-token detail in outer headers. Outer response - headers are flushed before any body chunk is yielded, so we cannot - know token counts at header-write time; the sealed chunks are opaque - to the relay. The relay learns the actual settled amount from x402: - by querying the facilitator with its X-Upto-Session, or via - X-Payment-Response on its next request. The client still gets - per-token detail from the final SSE event inside the decrypted stream. + * stream=false: outer headers expose ONLY the settled cost — + X-Inference-Cost-OPG (smallest units, the integer x402 actually charged) + and X-Inference-Cost-USD (the equivalent USD at the price used). No + model name, no token counts: those would be a fingerprint of the + inner request and have no role in billing. + * stream=true: outer response headers are flushed before any body chunk + is yielded, so cost isn't known at header-write time. The relay reads + the settled amount from x402: either by querying the facilitator with + its X-Upto-Session, or via X-Payment-Response on its next request. The + client still sees cost in the final SSE event inside the encrypted + stream (the ``opengradient`` block written by the chat controller). Trust / payment model: * The CLIENT encrypts only an LLM chat-completion request. It does not see, @@ -48,8 +48,10 @@ so the relay DOES see the client's IP at the network layer — that's unavoidable. What the relay does NOT see is the request/response content: it observes only the OHTTP-encapsulated ciphertext, its own - wallet's x-payment material, and (single-shot only) the token-usage - outer headers it needs to bill. + wallet's x-payment material, and (single-shot only) the settled-cost + outer headers it needs to bill its own customer. Model name and token + counts are deliberately NOT surfaced — they'd fingerprint the inner + request without adding anything billing-relevant. * Enclave (compute position): sees the plaintext prompt and completion (they are decrypted inside the enclave to run the LLM call), but at the network layer it only sees the RELAY's IP — never the client's. @@ -76,6 +78,7 @@ from tee_gateway import ohttp from tee_gateway.tee_manager import get_tee_keys +from tee_gateway.pricing import SessionCost logger = logging.getLogger(__name__) @@ -194,9 +197,9 @@ def _build_outer_response( content_type=inner_content_type, ) - # Usage headers come from a JSON parse — single-valued by construction — + # Cost headers come from a JSON parse — single-valued by construction — # so a dict is fine internally; just project to tuples at merge time. - forwarded.extend(_extract_usage_headers(body_bytes).items()) + forwarded.extend(_extract_cost_headers(body_bytes).items()) sealed = ohttp.encapsulate_response(decap.response_key, decap.enc, body_bytes) return Response( sealed, @@ -220,10 +223,11 @@ def _build_streaming_response( look-ahead by one chunk so we know which one is last, hence the ``pending`` buffer below. - Usage stats can't be exposed as outer headers (those are already sent + Cost can't be exposed as outer headers (those are already flushed before the body); the relay bills via x402 settlement metadata - (X-Upto-Session header, set up-front). The client reads usage from the - final SSE event inside the decrypted stream. + (X-Upto-Session header, set up-front). The client reads cost from the + ``opengradient`` block on the final SSE event inside the decrypted + stream. """ # See _build_outer_response: keep as a list so duplicate HTTP header # values (e.g. multiple WWW-Authenticate challenges) survive forwarding. @@ -276,35 +280,30 @@ def _drain(sub_iter: Iterator[bytes]) -> bytes: return b"".join(chunks) -def _extract_usage_headers(body_bytes: bytes) -> dict[str, str]: - """Pull token-usage + model name out of a chat-completion response and - project them onto outer HTTP headers for the relay's billing pipeline. - These are the ONLY pieces of metadata the relay needs to charge; the - prompt and completion themselves stay sealed.""" +def _extract_cost_headers(body_bytes: bytes) -> dict[str, str]: + """Project the response's ``opengradient`` cost block onto outer headers + for the relay's billing pipeline. Cost is the only metadata the relay + needs — model name and token counts stay sealed (they'd fingerprint the + inner request without being billing-relevant). The price is surfaced too + so the relay (and downstream auditors) can verify + ``cost_usd == cost_opg / 10^decimals * opg_price_usd`` without trusting us. + """ try: body = json.loads(body_bytes.decode("utf-8")) except (UnicodeDecodeError, json.JSONDecodeError): return {} if not isinstance(body, dict): return {} - - headers: dict[str, str] = {} - usage = body.get("usage") - if isinstance(usage, dict): - prompt_tokens = usage.get("prompt_tokens", usage.get("input_tokens")) - completion_tokens = usage.get("completion_tokens", usage.get("output_tokens")) - total_tokens = usage.get("total_tokens") - if prompt_tokens is not None: - headers["X-Usage-Prompt-Tokens"] = str(prompt_tokens) - if completion_tokens is not None: - headers["X-Usage-Completion-Tokens"] = str(completion_tokens) - if total_tokens is not None: - headers["X-Usage-Total-Tokens"] = str(total_tokens) - - model = body.get("model") - if isinstance(model, str): - headers["X-Usage-Model"] = model - return headers + try: + cost = SessionCost.model_validate(body.get("opengradient")) + except Exception: + return {} + wire = cost.model_dump(mode="json") + return { + "X-Inference-Cost-OPG": wire["cost_opg"], + "X-Inference-Cost-USD": wire["cost_usd"], + "X-Inference-Price-OPG-USD": wire["opg_price_usd"], + } def _wsgi_subrequest( diff --git a/tee_gateway/definitions.py b/tee_gateway/definitions.py index 932f9eb..ce33541 100644 --- a/tee_gateway/definitions.py +++ b/tee_gateway/definitions.py @@ -63,19 +63,19 @@ # # These are the *maximum* amounts shown during the x402 payment pre-check. # Actual per-request costs are calculated dynamically from real token usage -# by dynamic_session_cost_calculator() in util.py. +# by calculate_session_cost() in pricing.py. # --------------------------------------------------------------------------- # /v1/chat/completions — maximum OPG spend per session (18 decimals: 100000000000000000 = 0.1 OPG). # This is the upper-bound amount presented to the client during the x402 pre-check handshake. # The x402 "upto" scheme allows the actual charge to be any value up to this cap; -# the real per-request cost is settled dynamically by dynamic_session_cost_calculator() in util.py +# the real per-request cost is settled dynamically by calculate_session_cost() in pricing.py # based on actual token usage, so clients are never overcharged beyond what they consumed. CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND: str = "100000000000000000" # /v1/completions — maximum OPG spend per session (18 decimals: 100000000000000000 = 0.1 OPG). # This is the upper-bound amount presented to the client during the x402 pre-check handshake. # The x402 "upto" scheme allows the actual charge to be any value up to this cap; -# the real per-request cost is settled dynamically by dynamic_session_cost_calculator() in util.py +# the real per-request cost is settled dynamically by calculate_session_cost() in pricing.py # based on actual token usage, so clients are never overcharged beyond what they consumed. COMPLETIONS_OPG_SESSION_MAX_SPEND: str = "100000000000000000" diff --git a/tee_gateway/encoder.py b/tee_gateway/encoder.py index 05a073e..0963b50 100644 --- a/tee_gateway/encoder.py +++ b/tee_gateway/encoder.py @@ -1,4 +1,5 @@ from connexion.apps.flask_app import FlaskJSONEncoder +from pydantic import BaseModel from tee_gateway.models.base_model import Model @@ -7,6 +8,10 @@ class JSONEncoder(FlaskJSONEncoder): include_nulls = False def default(self, o): + if isinstance(o, BaseModel): + # mode="json" runs the model's field_serializers — e.g. SessionCost + # emits int/Decimal as JS-safe strings declared on the model. + return o.model_dump(mode="json") if isinstance(o, Model): dikt = {} for attr in o.openapi_types: diff --git a/tee_gateway/price_feed/__init__.py b/tee_gateway/price_feed/__init__.py index 1349825..aba9902 100644 --- a/tee_gateway/price_feed/__init__.py +++ b/tee_gateway/price_feed/__init__.py @@ -4,4 +4,25 @@ __all__ = [ "OPGPriceFeed", "PriceFeedConfig", + "get_price_feed", + "set_price_feed", ] + + +_price_feed: OPGPriceFeed | None = None + + +def set_price_feed(feed: OPGPriceFeed) -> None: + """Register the process-wide OPG price feed. Called once from app startup.""" + global _price_feed + _price_feed = feed + + +def get_price_feed() -> OPGPriceFeed: + """Return the registered process-wide price feed. Raises if not initialized.""" + if _price_feed is None: + raise RuntimeError( + "OPG price feed not initialized — set_price_feed() must be called " + "during app startup before pricing-dependent code runs." + ) + return _price_feed diff --git a/tee_gateway/price_feed/feed.py b/tee_gateway/price_feed/feed.py index bbfb86b..e32f94a 100644 --- a/tee_gateway/price_feed/feed.py +++ b/tee_gateway/price_feed/feed.py @@ -9,7 +9,7 @@ ----- Create an ``OPGPriceFeed`` instance in the application entry point, call ``start()``, then pass it explicitly to wherever the price is needed (e.g. -``calculate_session_cost(...)`` in ``util.py``). +``calculate_session_cost(...)`` in ``pricing.py``). """ import logging diff --git a/tee_gateway/pricing.py b/tee_gateway/pricing.py new file mode 100644 index 0000000..e332677 --- /dev/null +++ b/tee_gateway/pricing.py @@ -0,0 +1,256 @@ +"""Per-request session-cost calculation for x402 settlement. + +Converts realized token usage from an LLM response into an OPG smallest-units +integer (what x402 actually charges) plus the equivalent USD figure and the +OPG/USD price used for the conversion. All three are bundled in +:class:`SessionCost`, embedded on the response under the ``opengradient`` key, +and read back by both the x402 settlement calculator and the OHTTP outer-header +extractor. This module is the single source of truth for that math. +""" + +import logging +from decimal import Decimal, InvalidOperation, ROUND_CEILING +from typing import Any, Callable + +from pydantic import BaseModel, ConfigDict, field_serializer + +from tee_gateway.definitions import ( + ASSET_DECIMALS_BY_ADDRESS, + BASE_MAINNET_OPG_ADDRESS, +) +from tee_gateway.model_registry import get_model_config + +logger = logging.getLogger("llm_server.dynamic_pricing") + +_OPG_DECIMALS = ASSET_DECIMALS_BY_ADDRESS[BASE_MAINNET_OPG_ADDRESS.lower()] + + +def _as_dict(value: Any) -> dict[str, Any] | None: + if value is None: + return None + if isinstance(value, dict): + return value + if hasattr(value, "model_dump"): + try: + dumped = value.model_dump(by_alias=True, exclude_none=True) + if isinstance(dumped, dict): + return dumped + except Exception: + pass + if hasattr(value, "to_dict"): + try: + dumped = value.to_dict() + if isinstance(dumped, dict): + return dumped + except Exception: + pass + return None + + +def _to_decimal(value: Any) -> Decimal | None: + if value is None: + return None + try: + return Decimal(str(value)) + except (InvalidOperation, ValueError, TypeError): + return None + + +def _normalize_model_name(model: str | None) -> str | None: + if not model: + return None + return str(model).strip().lower() + + +def _extract_usage_tokens( + response_json: dict[str, Any] | None, +) -> tuple[int, int]: + """Extract (input_tokens, output_tokens) from response JSON. + + Raises ValueError if usage data is missing or malformed — no silent fallback. + """ + if not isinstance(response_json, dict): + raise ValueError("response_json is not a dict; cannot extract usage tokens") + usage = response_json.get("usage") + if not isinstance(usage, dict): + raise ValueError( + "response_json has no 'usage' dict; cannot extract usage tokens" + ) + + prompt_tokens = usage.get("prompt_tokens", usage.get("input_tokens")) + completion_tokens = usage.get("completion_tokens", usage.get("output_tokens")) + if prompt_tokens is None or completion_tokens is None: + raise ValueError(f"usage dict is missing token counts: {usage!r}") + + try: + return max(0, int(prompt_tokens)), max(0, int(completion_tokens)) + except (TypeError, ValueError) as exc: + raise ValueError(f"Could not parse token counts from usage: {usage!r}") from exc + + +def _extract_model_from_context( + request_json: dict[str, Any] | None, + response_json: dict[str, Any] | None, +) -> str: + """Extract and normalize model name from request JSON. + + Uses only the request model name — the response model field is ignored + because providers may return a versioned alias that differs from the + user-facing name. Raises ValueError if the model name is absent. + """ + if not isinstance(request_json, dict): + raise ValueError("request_json is not a dict; cannot extract model name") + req_model = request_json.get("model") + if not req_model: + raise ValueError("request_json has no 'model' field") + normalized = _normalize_model_name(req_model) + if not normalized: + raise ValueError(f"model name normalizes to empty string: {req_model!r}") + return normalized + + +def _extract_asset_decimals_from_requirements(payment_requirements: Any) -> int: + req = _as_dict(payment_requirements) or {} + + asset = req.get("asset") + if not asset and isinstance(req.get("price"), dict): + asset = req["price"].get("asset") + + if not isinstance(asset, str) or not asset: + raise ValueError( + f"payment_requirements has no recognizable asset address; " + f"cannot determine token decimals: {req!r}" + ) + + asset_lower = asset.lower() + if asset_lower not in ASSET_DECIMALS_BY_ADDRESS: + raise ValueError( + f"Unknown asset address {asset!r}; not in ASSET_DECIMALS_BY_ADDRESS. " + f"Add it to definitions.py before accepting payments with this token." + ) + return ASSET_DECIMALS_BY_ADDRESS[asset_lower] + + +class SessionCost(BaseModel): + """Settled per-request cost. ``cost_opg`` is what x402 actually charges; + ``cost_usd`` and ``opg_price_usd`` are reported so clients and relays can + audit the conversion without re-fetching the price feed. + + All three fields serialize to JSON strings: the OPG integer can exceed JS + safe-int (2^53) for any non-trivial cost at 18 decimals, and the Decimals + would lose precision through a float round-trip. + """ + + model_config = ConfigDict(frozen=True) + + cost_opg: int + cost_usd: Decimal + opg_price_usd: Decimal + + @field_serializer("cost_opg") + def _serialize_opg(self, value: int) -> str: + return str(value) + + @field_serializer("cost_usd", "opg_price_usd") + def _serialize_decimal(self, value: Decimal) -> str: + return format(value, "f") + + +def calculate_session_cost( + request_json: dict[str, Any], + response_json: dict[str, Any], + asset_decimals: int, + get_price: Callable[[], Decimal], +) -> SessionCost: + """Compute the settled cost for a completed inference request. + + ``get_price`` is called on every invocation to fetch the current OPG/USD + price — pass ``price_feed.get_price`` so the latest cached value is used. + Raises ``ValueError`` on any missing/invalid data. Predictable failures + (unavailable price, unknown model) are blocked before inference by the + pre-inference gate in ``__main__.py``; post-inference failures are logged + as CRITICAL by the caller and the client is not charged. + + Returns both the OPG integer (what x402 charges) and the equivalent USD — + derived from the SAME rounded OPG value, not the raw USD math — so the two + numbers always reconcile via ``cost_usd = cost_opg / 10^decimals * price``. + """ + if not isinstance(request_json, dict) or not isinstance(response_json, dict): + raise ValueError( + "calculate_session_cost requires both request_json and response_json" + ) + + model = _extract_model_from_context(request_json, response_json) + cfg = get_model_config(model) + input_tokens, output_tokens = _extract_usage_tokens(response_json) + + raw_usd = (Decimal(input_tokens) * cfg.input_price_usd) + ( + Decimal(output_tokens) * cfg.output_price_usd + ) + token_price_usd = get_price() + if token_price_usd <= 0: + raise ValueError(f"Token price is non-positive: {token_price_usd}") + + scale = Decimal(10) ** asset_decimals + cost_smallest_units = max( + 0, + int( + ((raw_usd / token_price_usd) * scale).to_integral_value( + rounding=ROUND_CEILING + ) + ), + ) + # Reconcile USD from the rounded OPG integer so the two surfaced figures + # are exactly consistent (clients verify: usd == opg / 10^decimals * price). + settled_usd = (Decimal(cost_smallest_units) / scale) * token_price_usd + + logger.info( + "DYNAMIC_SESSION_COST model=%s input_tokens=%d output_tokens=%d " + "raw_usd=%s settled_usd=%s token_price_usd=%s decimals=%d cost=%d", + model, + input_tokens, + output_tokens, + str(raw_usd), + str(settled_usd), + str(token_price_usd), + asset_decimals, + cost_smallest_units, + ) + return SessionCost( + cost_opg=cost_smallest_units, + cost_usd=settled_usd, + opg_price_usd=token_price_usd, + ) + + +def compute_session_cost( + request_json: dict[str, Any], response_with_usage: dict[str, Any] +) -> SessionCost | None: + """Wrap calculate_session_cost for controllers: returns the SessionCost + pydantic model, or ``None`` on failure. Predictable failures (unknown + price/model) are blocked by the pre-inference gate, so anything reaching + here is a provider-side error (e.g. missing usage). Logging as CRITICAL + matches __main__._session_cost_calculator's contract: when this returns + None, x402's downstream reader will skip settlement and the client is not + charged. + """ + # Imported lazily to keep this module free of process-level state at import + # time (price_feed singleton is registered during app startup, after this + # module has already been imported transitively via other paths). + from tee_gateway.price_feed import get_price_feed + + try: + return calculate_session_cost( + request_json=request_json, + response_json=response_with_usage, + asset_decimals=_OPG_DECIMALS, + get_price=get_price_feed().get_price, + ) + except Exception as exc: + logger.critical( + "Post-inference cost calculation failed (provider error) — " + "client will NOT be charged: %s", + exc, + exc_info=True, + ) + return None diff --git a/tee_gateway/test/test_price_feed.py b/tee_gateway/test/test_price_feed.py index fe9bb8a..970aac8 100644 --- a/tee_gateway/test/test_price_feed.py +++ b/tee_gateway/test/test_price_feed.py @@ -1,5 +1,5 @@ """ -Unit tests for tee_gateway.price_feed and tee_gateway.util.calculate_session_cost. +Unit tests for tee_gateway.price_feed and tee_gateway.pricing.calculate_session_cost. All external HTTP calls are mocked — no network access required. @@ -9,7 +9,7 @@ TestOPGPriceFeedRefresh — OPGPriceFeed._refresh_price() (retry, rate-limit, stats) TestOPGPriceFeedGetPrice — OPGPriceFeed.get_price() (stale warning, ValueError before fetch) TestOPGPriceFeedStatus — OPGPriceFeed.get_status() snapshots -TestCalculateSessionCost — calculate_session_cost(context, get_price) in util.py +TestCalculateSessionCost — calculate_session_cost(context, get_price) in pricing.py """ import time @@ -23,7 +23,7 @@ from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS from tee_gateway.price_feed import OPGPriceFeed from tee_gateway.price_feed.feed import fetch_opg_price -from tee_gateway.util import calculate_session_cost +from tee_gateway.pricing import SessionCost, calculate_session_cost # --------------------------------------------------------------------------- # Helpers @@ -411,14 +411,21 @@ def _make_get_price(price_usd: Decimal = Decimal("0.10")) -> MagicMock: return mock -class TestCalculateSessionCost(unittest.TestCase): - """Tests for calculate_session_cost(context, get_price).""" +def _call( + ctx: dict, + get_price, + asset_decimals: int = _ASSET_DECIMALS, +) -> SessionCost: + return calculate_session_cost( + request_json=ctx["request_json"], + response_json=ctx["response_json"], + asset_decimals=asset_decimals, + get_price=get_price, + ) - def _patch_definitions(self): - return patch( - "tee_gateway.util.ASSET_DECIMALS_BY_ADDRESS", - {_ASSET_ADDR_LOWER: _ASSET_DECIMALS}, - ) + +class TestCalculateSessionCost(unittest.TestCase): + """Tests for calculate_session_cost(request_json, response_json, asset_decimals, get_price).""" def _patch_model( self, input_price: str = "0.000001", output_price: str = "0.000002" @@ -426,90 +433,116 @@ def _patch_model( cfg = MagicMock() cfg.input_price_usd = Decimal(input_price) cfg.output_price_usd = Decimal(output_price) - return patch("tee_gateway.util.get_model_config", return_value=cfg) + return patch("tee_gateway.pricing.get_model_config", return_value=cfg) def test_calls_get_price(self): get_price = _make_get_price() - with self._patch_definitions(), self._patch_model(): - calculate_session_cost(_make_context(), get_price) + with self._patch_model(): + _call(_make_context(), get_price) get_price.assert_called_once() - def test_returns_positive_int(self): - with self._patch_definitions(), self._patch_model(): - result = calculate_session_cost(_make_context(), _make_get_price()) - self.assertIsInstance(result, int) - self.assertGreaterEqual(result, 0) + def test_returns_session_cost(self): + with self._patch_model(): + result = _call(_make_context(), _make_get_price()) + self.assertIsInstance(result, SessionCost) + self.assertIsInstance(result.cost_opg, int) + self.assertGreaterEqual(result.cost_opg, 0) + self.assertEqual(result.opg_price_usd, Decimal("0.10")) + + def test_reported_usd_reconciles_with_opg(self): + """cost_usd must equal cost_opg / 10^decimals * price exactly — otherwise + clients verifying the conversion will reject the response.""" + with self._patch_model(): + result = _call(_make_context(), _make_get_price(Decimal("0.10"))) + scale = Decimal(10) ** _ASSET_DECIMALS + expected = (Decimal(result.cost_opg) / scale) * Decimal("0.10") + self.assertEqual(result.cost_usd, expected) def test_zero_tokens_returns_zero(self): - with self._patch_definitions(), self._patch_model(): - result = calculate_session_cost( + with self._patch_model(): + result = _call( _make_context(input_tokens=0, output_tokens=0), _make_get_price() ) - self.assertEqual(result, 0) + self.assertEqual(result.cost_opg, 0) + self.assertEqual(result.cost_usd, Decimal(0)) def test_raises_when_get_price_raises(self): get_price = MagicMock(side_effect=ValueError("price not available")) - with self._patch_definitions(), self._patch_model(): + with self._patch_model(): with self.assertRaises(ValueError): - calculate_session_cost(_make_context(), get_price) + _call(_make_context(), get_price) def test_raises_when_non_positive_price(self): - with self._patch_definitions(), self._patch_model(): + with self._patch_model(): with self.assertRaises(ValueError): - calculate_session_cost(_make_context(), _make_get_price(Decimal("0"))) + _call(_make_context(), _make_get_price(Decimal("0"))) def test_raises_when_request_json_missing(self): ctx = _make_context() ctx["request_json"] = None - with self._patch_definitions(), self._patch_model(): + with self._patch_model(): with self.assertRaises(ValueError): - calculate_session_cost(ctx, _make_get_price()) + _call(ctx, _make_get_price()) def test_raises_when_usage_missing(self): ctx = _make_context() ctx["response_json"] = {"model": "gpt-4.1-mini"} - with self._patch_definitions(), self._patch_model(): + with self._patch_model(): with self.assertRaises(ValueError): - calculate_session_cost(ctx, _make_get_price()) - - def test_raises_when_asset_unknown(self): - ctx = _make_context(asset="0xunknown") - with ( - patch("tee_gateway.util.ASSET_DECIMALS_BY_ADDRESS", {}), - self._patch_model(), - ): - with self.assertRaises(ValueError): - calculate_session_cost(ctx, _make_get_price()) + _call(ctx, _make_get_price()) def test_cost_scales_with_token_count(self): - with self._patch_definitions(), self._patch_model(): - cost_small = calculate_session_cost( + with self._patch_model(): + cost_small = _call( _make_context(input_tokens=10, output_tokens=5), _make_get_price() ) - cost_large = calculate_session_cost( + cost_large = _call( _make_context(input_tokens=1000, output_tokens=500), _make_get_price() ) - self.assertGreater(cost_large, cost_small) + self.assertGreater(cost_large.cost_opg, cost_small.cost_opg) def test_higher_token_price_yields_lower_cost(self): - with self._patch_definitions(), self._patch_model(): - cost_cheap = calculate_session_cost( - _make_context(), _make_get_price(Decimal("0.10")) - ) - cost_expensive = calculate_session_cost( - _make_context(), _make_get_price(Decimal("0.20")) - ) - self.assertGreater(cost_cheap, cost_expensive) + with self._patch_model(): + cost_cheap = _call(_make_context(), _make_get_price(Decimal("0.10"))) + cost_expensive = _call(_make_context(), _make_get_price(Decimal("0.20"))) + self.assertGreater( + cost_cheap.cost_opg, + cost_expensive.cost_opg, + ) def test_uses_current_price_on_each_call(self): """get_price is called fresh every invocation — price changes are picked up.""" get_price = MagicMock(side_effect=[Decimal("0.10"), Decimal("0.20")]) - with self._patch_definitions(), self._patch_model(): - cost_first = calculate_session_cost(_make_context(), get_price) - cost_second = calculate_session_cost(_make_context(), get_price) + with self._patch_model(): + cost_first = _call(_make_context(), get_price) + cost_second = _call(_make_context(), get_price) self.assertEqual(get_price.call_count, 2) # Price doubled → cost should halve (same USD spend, twice the token price). - self.assertGreater(cost_first, cost_second) + self.assertGreater(cost_first.cost_opg, cost_second.cost_opg) + + +class TestSessionCostWireRoundTrip(unittest.TestCase): + """SessionCost must round-trip through its JSON wire form so the chat + handler (writer) and x402's _session_cost_calculator (reader) agree.""" + + def test_round_trip_preserves_values(self): + import json + + cost = SessionCost( + # Above 2^53 — the whole reason wire form is strings, not ints. + cost_opg=12345678901234567890, + cost_usd=Decimal("0.00342100"), + opg_price_usd=Decimal("0.123456"), + ) + wire = json.loads(json.dumps(cost.model_dump(mode="json"))) + parsed = SessionCost.model_validate(wire) + self.assertEqual(parsed.cost_opg, cost.cost_opg) + self.assertEqual(parsed.cost_usd, cost.cost_usd) + self.assertEqual(parsed.opg_price_usd, cost.opg_price_usd) + + def test_validate_rejects_missing_block(self): + with self.assertRaises(Exception): + SessionCost.model_validate(None) if __name__ == "__main__": diff --git a/tee_gateway/util.py b/tee_gateway/util.py index 1c9047e..1264a1f 100644 --- a/tee_gateway/util.py +++ b/tee_gateway/util.py @@ -1,11 +1,6 @@ import datetime from tee_gateway import typing_utils -import logging -from decimal import Decimal, InvalidOperation, ROUND_CEILING -from typing import Any, Callable - -logger = logging.getLogger("llm_server.dynamic_pricing") def _deserialize(data, klass): @@ -151,169 +146,3 @@ def _deserialize_dict(data, boxed_type): :rtype: dict """ return {k: _deserialize(v, boxed_type) for k, v in data.items()} - - -from tee_gateway.definitions import ( # noqa: E402 - ASSET_DECIMALS_BY_ADDRESS, -) -from tee_gateway.model_registry import get_model_config # noqa: E402 - - -def _as_dict(value: Any) -> dict[str, Any] | None: - if value is None: - return None - if isinstance(value, dict): - return value - if hasattr(value, "model_dump"): - try: - dumped = value.model_dump(by_alias=True, exclude_none=True) - if isinstance(dumped, dict): - return dumped - except Exception: - pass - if hasattr(value, "to_dict"): - try: - dumped = value.to_dict() - if isinstance(dumped, dict): - return dumped - except Exception: - pass - return None - - -def _to_decimal(value: Any) -> Decimal | None: - if value is None: - return None - try: - return Decimal(str(value)) - except (InvalidOperation, ValueError, TypeError): - return None - - -def _normalize_model_name(model: str | None) -> str | None: - if not model: - return None - return str(model).strip().lower() - - -def _extract_usage_tokens( - response_json: dict[str, Any] | None, -) -> tuple[int, int]: - """Extract (input_tokens, output_tokens) from response JSON. - - Raises ValueError if usage data is missing or malformed — no silent fallback. - """ - if not isinstance(response_json, dict): - raise ValueError("response_json is not a dict; cannot extract usage tokens") - usage = response_json.get("usage") - if not isinstance(usage, dict): - raise ValueError( - "response_json has no 'usage' dict; cannot extract usage tokens" - ) - - prompt_tokens = usage.get("prompt_tokens", usage.get("input_tokens")) - completion_tokens = usage.get("completion_tokens", usage.get("output_tokens")) - if prompt_tokens is None or completion_tokens is None: - raise ValueError(f"usage dict is missing token counts: {usage!r}") - - try: - return max(0, int(prompt_tokens)), max(0, int(completion_tokens)) - except (TypeError, ValueError) as exc: - raise ValueError(f"Could not parse token counts from usage: {usage!r}") from exc - - -def _extract_model_from_context( - request_json: dict[str, Any] | None, - response_json: dict[str, Any] | None, -) -> str: - """Extract and normalize model name from request JSON. - - Uses only the request model name — the response model field is ignored - because providers may return a versioned alias that differs from the - user-facing name. Raises ValueError if the model name is absent. - """ - if not isinstance(request_json, dict): - raise ValueError("request_json is not a dict; cannot extract model name") - req_model = request_json.get("model") - if not req_model: - raise ValueError("request_json has no 'model' field") - normalized = _normalize_model_name(req_model) - if not normalized: - raise ValueError(f"model name normalizes to empty string: {req_model!r}") - return normalized - - -def _extract_asset_decimals_from_requirements(payment_requirements: Any) -> int: - req = _as_dict(payment_requirements) or {} - - asset = req.get("asset") - if not asset and isinstance(req.get("price"), dict): - asset = req["price"].get("asset") - - if not isinstance(asset, str) or not asset: - raise ValueError( - f"payment_requirements has no recognizable asset address; " - f"cannot determine token decimals: {req!r}" - ) - - asset_lower = asset.lower() - if asset_lower not in ASSET_DECIMALS_BY_ADDRESS: - raise ValueError( - f"Unknown asset address {asset!r}; not in ASSET_DECIMALS_BY_ADDRESS. " - f"Add it to definitions.py before accepting payments with this token." - ) - return ASSET_DECIMALS_BY_ADDRESS[asset_lower] - - -def calculate_session_cost( - context: dict[str, Any], get_price: Callable[[], Decimal] -) -> int: - """Calculate the x402 session cost in token smallest units for a completed request. - - ``get_price`` is called on every invocation to fetch the current OPG/USD - price — pass ``price_feed.get_price`` so the latest cached value is used. - Raises ``ValueError`` on any missing/invalid data. Predictable failures - (unavailable price, unknown model) are blocked before inference by the - pre-inference gate in ``__main__.py``; post-inference failures are logged - as CRITICAL by the caller and the client is not charged. - """ - request_json = context.get("request_json") - response_json = context.get("response_json") - - if not isinstance(request_json, dict) or not isinstance(response_json, dict): - raise ValueError( - "calculate_session_cost requires both request_json and response_json" - ) - - model = _extract_model_from_context(request_json, response_json) - cfg = get_model_config(model) - input_tokens, output_tokens = _extract_usage_tokens(response_json) - - total_usd = (Decimal(input_tokens) * cfg.input_price_usd) + ( - Decimal(output_tokens) * cfg.output_price_usd - ) - token_price_usd = get_price() - if token_price_usd <= 0: - raise ValueError(f"Token price is non-positive: {token_price_usd}") - - token_amount = total_usd / token_price_usd - decimals = _extract_asset_decimals_from_requirements( - context.get("payment_requirements") - ) - scale = Decimal(10) ** decimals - cost_smallest_units = int( - (token_amount * scale).to_integral_value(rounding=ROUND_CEILING) - ) - - logger.info( - "DYNAMIC_SESSION_COST model=%s input_tokens=%d output_tokens=%d " - "total_usd=%s token_price_usd=%s decimals=%d cost=%d", - model, - input_tokens, - output_tokens, - str(total_usd), - str(token_price_usd), - decimals, - cost_smallest_units, - ) - return max(0, cost_smallest_units) diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 0b0c57d..51d030e 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -16,7 +16,32 @@ _MODEL_LOOKUP, get_model_config, ) -from tee_gateway.util import calculate_session_cost +from tee_gateway.pricing import ( + _extract_asset_decimals_from_requirements, + calculate_session_cost as _calculate_session_cost_raw, +) + + +def calculate_session_cost(ctx, get_price): + """Adapter: legacy bundled-ctx call form -> new (request, response, + asset_decimals, get_price) signature, returning the OPG integer. + + request_json may be None (validated by the underlying function); guard so the + error path is preserved. + """ + request_json = ctx.get("request_json") + if not isinstance(request_json, dict): + # Match the underlying ValueError so .assertRaises(ValueError) still fires. + raise ValueError("request_json missing or not a dict") + return _calculate_session_cost_raw( + request_json=request_json, + response_json=ctx["response_json"], + asset_decimals=_extract_asset_decimals_from_requirements( + ctx["payment_requirements"] + ), + get_price=get_price, + ).cost_opg + # All pricing tests assume OPG = $1.00 so USD cost == OPG token amount. _OPG_PRICE_USD = Decimal("1") From 678ab0039ead97a285649098b7e314b9a06ce1f3 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Sat, 16 May 2026 12:56:54 -0400 Subject: [PATCH 25/39] size limit --- tee_gateway/controllers/ohttp_controller.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py index 1f7fbf8..62193b3 100644 --- a/tee_gateway/controllers/ohttp_controller.py +++ b/tee_gateway/controllers/ohttp_controller.py @@ -86,6 +86,13 @@ OHTTP_CHUNKED_RESPONSE_MEDIA_TYPE = "message/ohttp-chunked-res" _SSE_CONTENT_TYPE = "text/event-stream" +# Cap on the encapsulated request size. The inner payload is a chat-completion +# JSON body; even with long conversation history this comfortably fits in a few +# hundred KB. Rejecting larger bodies up-front prevents a malicious relay from +# forcing the enclave to allocate and attempt HPKE decapsulation on arbitrarily +# large blobs. +_MAX_ENCAPSULATED_REQUEST_BYTES = 512 * 1024 + # Fields that can re-identify a client and have no role in inference. We drop # them before forwarding to the inner handler — keeping them inside the # encrypted envelope would only protect them from the relay, not from us or @@ -111,9 +118,19 @@ def _should_forward_header(name: str) -> bool: def create_anonymous_chat_completion(): """POST /v1/ohttp — decrypt, sub-dispatch to /v1/chat/completions, re-encrypt.""" - raw_body: bytes = flask_request.get_data(cache=False) + declared_length = flask_request.content_length + if declared_length is not None and declared_length > _MAX_ENCAPSULATED_REQUEST_BYTES: + return _error(413, "encapsulated request too large") + + raw_body: bytes = flask_request.get_data( + cache=False, parse_form_data=False + ) if not raw_body: return _error(400, "empty body") + # Re-check after read in case Content-Length was absent or lied about the + # actual stream length (chunked transfer, malicious client). + if len(raw_body) > _MAX_ENCAPSULATED_REQUEST_BYTES: + return _error(413, "encapsulated request too large") tee = get_tee_keys() if tee.hpke_private_key is None: From c2be6e134f787e58aa1f8dac96217e6fe970b2bf Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Sat, 16 May 2026 12:57:08 -0400 Subject: [PATCH 26/39] lint --- tee_gateway/controllers/ohttp_controller.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py index 62193b3..dbc90e8 100644 --- a/tee_gateway/controllers/ohttp_controller.py +++ b/tee_gateway/controllers/ohttp_controller.py @@ -119,12 +119,13 @@ def _should_forward_header(name: str) -> bool: def create_anonymous_chat_completion(): """POST /v1/ohttp — decrypt, sub-dispatch to /v1/chat/completions, re-encrypt.""" declared_length = flask_request.content_length - if declared_length is not None and declared_length > _MAX_ENCAPSULATED_REQUEST_BYTES: + if ( + declared_length is not None + and declared_length > _MAX_ENCAPSULATED_REQUEST_BYTES + ): return _error(413, "encapsulated request too large") - raw_body: bytes = flask_request.get_data( - cache=False, parse_form_data=False - ) + raw_body: bytes = flask_request.get_data(cache=False, parse_form_data=False) if not raw_body: return _error(400, "empty body") # Re-check after read in case Content-Length was absent or lied about the From c81fe6ef5a8f67a1aed290e2c07e708bc3746a89 Mon Sep 17 00:00:00 2001 From: kukac Date: Sat, 16 May 2026 12:59:01 -0400 Subject: [PATCH 27/39] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- tee_gateway/controllers/ohttp_controller.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py index dbc90e8..8bc0c70 100644 --- a/tee_gateway/controllers/ohttp_controller.py +++ b/tee_gateway/controllers/ohttp_controller.py @@ -39,9 +39,9 @@ request. * The ENCLAVE decrypts the inner payload, forwards the request to its own chat endpoint with the relay's ``x-payment`` header, and returns. The - relay sees status and settlement headers (and, for non-stream, the - per-token usage headers above), but never sees the inner prompt or - completion. + relay sees status and settlement headers and, for non-stream, the outer + ``X-Inference-Cost-OPG`` / ``X-Inference-Cost-USD`` billing headers, but + never sees the inner prompt or completion. Privacy properties: * Relay (network position): terminates the client's TCP/TLS connection, From 46ffb568865327d1c09c01b32c702482a719bd1c Mon Sep 17 00:00:00 2001 From: kukac Date: Sat, 16 May 2026 12:59:17 -0400 Subject: [PATCH 28/39] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 741b7e4..cca5f5b 100644 --- a/README.md +++ b/README.md @@ -192,14 +192,14 @@ The `tee_*` fields provide cryptographic proof of the response: **Billing channel for the relay.** Both modes settle the actual cost via x402 against the relay's `X-Payment` (`upto` scheme); the gateway is the source of truth for the amount. -- `stream=false`: outer response *also* exposes per-token detail — `X-Usage-Prompt-Tokens`, `X-Usage-Completion-Tokens`, `X-Usage-Total-Tokens`, `X-Usage-Model` — for the relay's own bookkeeping. The sealed body carries the same `usage` block for the client. +- `stream=false`: outer response exposes billing/cost headers — `X-Inference-Cost-OPG`, `X-Inference-Cost-USD`, `X-Inference-Price-OPG-USD` — for the relay's own bookkeeping. Per-token `usage` detail is carried in the sealed body for the client, not in outer `X-Usage-*` headers. - `stream=true`: **no** per-token detail in outer headers (they're flushed before any body chunk, so we can't know token counts at header-write time) and the sealed chunks are opaque to the relay. The relay reads the actual settled amount from x402 — either by querying the facilitator with its `X-Upto-Session`, or via `X-Payment-Response` on its next call. The client still sees per-token detail in the final SSE event inside the decrypted stream. On non-2xx (e.g. 402 payment required) the body is forwarded plaintext so the relay can read x402 payment requirements and retry — those bodies never contain prompts or completions. **Trust split:** -- **Relay** terminates the client's TCP/TLS connection, so it does see the client's IP — that's unavoidable. What it doesn't see is content: only OHTTP ciphertext + its own wallet's `x-payment` material + (single-shot only) the token-usage outer headers it needs to bill. +- **Relay** terminates the client's TCP/TLS connection, so it does see the client's IP — that's unavoidable. What it doesn't see is content: only OHTTP ciphertext + its own wallet's `x-payment` material + the outer billing/cost headers used to settle and reconcile charges. - **Enclave** sees plaintext prompts/completions (it has to run the LLM call) but at the network layer only sees the relay's IP, never the client's. This is the unlinkability claim — the enclave can't tie a plaintext request to a specific end user. - **Client** decrypts and verifies the TEE signature embedded in the response body against the attested public key. From f211297242069c356789b20659365b1d0cb0e22d Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Sat, 16 May 2026 13:14:51 -0400 Subject: [PATCH 29/39] usage --- tee_gateway/controllers/chat_controller.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tee_gateway/controllers/chat_controller.py b/tee_gateway/controllers/chat_controller.py index 2fee235..5ef9597 100644 --- a/tee_gateway/controllers/chat_controller.py +++ b/tee_gateway/controllers/chat_controller.py @@ -264,6 +264,7 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest): f"Response Final\n\tTEE Signature: {signature}\n\tTEE request hash: {input_hash_hex}\n\tTEE output hash: {output_hash_hex}\n\tTEE timestamp: {timestamp}\n\tTEE ID: 0x{tee_keys.get_tee_id()}" ) + # TODO: If no usage is returned, we should compute it here. if usage: openai_response["usage"] = usage cost = compute_session_cost(request_dict, openai_response) From 399c0bba5dbd1e0eb13f11690899d84de8bb6522 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Sat, 16 May 2026 13:15:49 -0400 Subject: [PATCH 30/39] todo --- tee_gateway/controllers/chat_controller.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tee_gateway/controllers/chat_controller.py b/tee_gateway/controllers/chat_controller.py index 5ef9597..456e81b 100644 --- a/tee_gateway/controllers/chat_controller.py +++ b/tee_gateway/controllers/chat_controller.py @@ -589,6 +589,7 @@ def generate(): f"Response Final\n\tTEE Signature: {tee_signature}\n\tTEE request hash: {input_hash_hex}\n\tTEE output hash: {output_hash_hex}\n\tTEE timestamp: {timestamp}\n\tTEE ID: 0x{tee_keys.get_tee_id()}" ) + # TODO: If no usage is returned, we should compute it here. if final_usage: final_data["usage"] = { "prompt_tokens": final_usage.get("input_tokens", 0), From 7cf86da01b9b1fcb27e067642e0f75c77c002b7e Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Sat, 16 May 2026 13:34:04 -0400 Subject: [PATCH 31/39] simplify pricing --- tee_gateway/controllers/chat_controller.py | 7 +- .../controllers/completions_controller.py | 2 +- tee_gateway/definitions.py | 6 +- tee_gateway/price_feed/feed.py | 2 +- tee_gateway/pricing.py | 232 ++++-------------- tee_gateway/test/test_price_feed.py | 121 +++------ tests/test_pricing.py | 155 +++--------- 7 files changed, 121 insertions(+), 404 deletions(-) diff --git a/tee_gateway/controllers/chat_controller.py b/tee_gateway/controllers/chat_controller.py index 456e81b..4553264 100644 --- a/tee_gateway/controllers/chat_controller.py +++ b/tee_gateway/controllers/chat_controller.py @@ -225,8 +225,6 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest): for tc in response.tool_calls ] - usage = extract_usage(response) - # For tool-call responses, hash the serialized tool calls so the # signature covers which tools were invoked and with what arguments. if finish_reason == "tool_calls" and message_dict.get("tool_calls"): @@ -265,9 +263,10 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest): ) # TODO: If no usage is returned, we should compute it here. + usage = extract_usage(response) if usage: openai_response["usage"] = usage - cost = compute_session_cost(request_dict, openai_response) + cost = compute_session_cost(chat_request.model, usage) if cost is not None: openai_response["opengradient"] = cost @@ -596,7 +595,7 @@ def generate(): "completion_tokens": final_usage.get("output_tokens", 0), "total_tokens": final_usage.get("total_tokens", 0), } - cost = compute_session_cost(request_dict, final_data) + cost = compute_session_cost(chat_request.model, final_data["usage"]) if cost is not None: # final_data is hand-serialized to SSE via json.dumps below, # which doesn't go through Flask's JSONEncoder — so do the diff --git a/tee_gateway/controllers/completions_controller.py b/tee_gateway/controllers/completions_controller.py index 68b7b58..f492e0e 100644 --- a/tee_gateway/controllers/completions_controller.py +++ b/tee_gateway/controllers/completions_controller.py @@ -80,7 +80,7 @@ def create_completion(body): "tee_id": f"0x{tee_keys.get_tee_id()}", } if usage: - cost = compute_session_cost(request_dict, completion_response) + cost = compute_session_cost(body.model, usage) if cost is not None: completion_response["opengradient"] = cost return completion_response diff --git a/tee_gateway/definitions.py b/tee_gateway/definitions.py index ce33541..1cb77be 100644 --- a/tee_gateway/definitions.py +++ b/tee_gateway/definitions.py @@ -63,19 +63,19 @@ # # These are the *maximum* amounts shown during the x402 payment pre-check. # Actual per-request costs are calculated dynamically from real token usage -# by calculate_session_cost() in pricing.py. +# by compute_session_cost() in pricing.py. # --------------------------------------------------------------------------- # /v1/chat/completions — maximum OPG spend per session (18 decimals: 100000000000000000 = 0.1 OPG). # This is the upper-bound amount presented to the client during the x402 pre-check handshake. # The x402 "upto" scheme allows the actual charge to be any value up to this cap; -# the real per-request cost is settled dynamically by calculate_session_cost() in pricing.py +# the real per-request cost is settled dynamically by compute_session_cost() in pricing.py # based on actual token usage, so clients are never overcharged beyond what they consumed. CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND: str = "100000000000000000" # /v1/completions — maximum OPG spend per session (18 decimals: 100000000000000000 = 0.1 OPG). # This is the upper-bound amount presented to the client during the x402 pre-check handshake. # The x402 "upto" scheme allows the actual charge to be any value up to this cap; -# the real per-request cost is settled dynamically by calculate_session_cost() in pricing.py +# the real per-request cost is settled dynamically by compute_session_cost() in pricing.py # based on actual token usage, so clients are never overcharged beyond what they consumed. COMPLETIONS_OPG_SESSION_MAX_SPEND: str = "100000000000000000" diff --git a/tee_gateway/price_feed/feed.py b/tee_gateway/price_feed/feed.py index e32f94a..cc0025a 100644 --- a/tee_gateway/price_feed/feed.py +++ b/tee_gateway/price_feed/feed.py @@ -9,7 +9,7 @@ ----- Create an ``OPGPriceFeed`` instance in the application entry point, call ``start()``, then pass it explicitly to wherever the price is needed (e.g. -``calculate_session_cost(...)`` in ``pricing.py``). +``compute_session_cost(...)`` in ``pricing.py``). """ import logging diff --git a/tee_gateway/pricing.py b/tee_gateway/pricing.py index e332677..ce00a62 100644 --- a/tee_gateway/pricing.py +++ b/tee_gateway/pricing.py @@ -9,8 +9,7 @@ """ import logging -from decimal import Decimal, InvalidOperation, ROUND_CEILING -from typing import Any, Callable +from decimal import Decimal, ROUND_CEILING from pydantic import BaseModel, ConfigDict, field_serializer @@ -25,112 +24,6 @@ _OPG_DECIMALS = ASSET_DECIMALS_BY_ADDRESS[BASE_MAINNET_OPG_ADDRESS.lower()] -def _as_dict(value: Any) -> dict[str, Any] | None: - if value is None: - return None - if isinstance(value, dict): - return value - if hasattr(value, "model_dump"): - try: - dumped = value.model_dump(by_alias=True, exclude_none=True) - if isinstance(dumped, dict): - return dumped - except Exception: - pass - if hasattr(value, "to_dict"): - try: - dumped = value.to_dict() - if isinstance(dumped, dict): - return dumped - except Exception: - pass - return None - - -def _to_decimal(value: Any) -> Decimal | None: - if value is None: - return None - try: - return Decimal(str(value)) - except (InvalidOperation, ValueError, TypeError): - return None - - -def _normalize_model_name(model: str | None) -> str | None: - if not model: - return None - return str(model).strip().lower() - - -def _extract_usage_tokens( - response_json: dict[str, Any] | None, -) -> tuple[int, int]: - """Extract (input_tokens, output_tokens) from response JSON. - - Raises ValueError if usage data is missing or malformed — no silent fallback. - """ - if not isinstance(response_json, dict): - raise ValueError("response_json is not a dict; cannot extract usage tokens") - usage = response_json.get("usage") - if not isinstance(usage, dict): - raise ValueError( - "response_json has no 'usage' dict; cannot extract usage tokens" - ) - - prompt_tokens = usage.get("prompt_tokens", usage.get("input_tokens")) - completion_tokens = usage.get("completion_tokens", usage.get("output_tokens")) - if prompt_tokens is None or completion_tokens is None: - raise ValueError(f"usage dict is missing token counts: {usage!r}") - - try: - return max(0, int(prompt_tokens)), max(0, int(completion_tokens)) - except (TypeError, ValueError) as exc: - raise ValueError(f"Could not parse token counts from usage: {usage!r}") from exc - - -def _extract_model_from_context( - request_json: dict[str, Any] | None, - response_json: dict[str, Any] | None, -) -> str: - """Extract and normalize model name from request JSON. - - Uses only the request model name — the response model field is ignored - because providers may return a versioned alias that differs from the - user-facing name. Raises ValueError if the model name is absent. - """ - if not isinstance(request_json, dict): - raise ValueError("request_json is not a dict; cannot extract model name") - req_model = request_json.get("model") - if not req_model: - raise ValueError("request_json has no 'model' field") - normalized = _normalize_model_name(req_model) - if not normalized: - raise ValueError(f"model name normalizes to empty string: {req_model!r}") - return normalized - - -def _extract_asset_decimals_from_requirements(payment_requirements: Any) -> int: - req = _as_dict(payment_requirements) or {} - - asset = req.get("asset") - if not asset and isinstance(req.get("price"), dict): - asset = req["price"].get("asset") - - if not isinstance(asset, str) or not asset: - raise ValueError( - f"payment_requirements has no recognizable asset address; " - f"cannot determine token decimals: {req!r}" - ) - - asset_lower = asset.lower() - if asset_lower not in ASSET_DECIMALS_BY_ADDRESS: - raise ValueError( - f"Unknown asset address {asset!r}; not in ASSET_DECIMALS_BY_ADDRESS. " - f"Add it to definitions.py before accepting payments with this token." - ) - return ASSET_DECIMALS_BY_ADDRESS[asset_lower] - - class SessionCost(BaseModel): """Settled per-request cost. ``cost_opg`` is what x402 actually charges; ``cost_usd`` and ``opg_price_usd`` are reported so clients and relays can @@ -156,95 +49,66 @@ def _serialize_decimal(self, value: Decimal) -> str: return format(value, "f") -def calculate_session_cost( - request_json: dict[str, Any], - response_json: dict[str, Any], - asset_decimals: int, - get_price: Callable[[], Decimal], -) -> SessionCost: +def compute_session_cost(model: str, usage: dict) -> SessionCost | None: """Compute the settled cost for a completed inference request. - ``get_price`` is called on every invocation to fetch the current OPG/USD - price — pass ``price_feed.get_price`` so the latest cached value is used. - Raises ``ValueError`` on any missing/invalid data. Predictable failures - (unavailable price, unknown model) are blocked before inference by the - pre-inference gate in ``__main__.py``; post-inference failures are logged - as CRITICAL by the caller and the client is not charged. + Returns the SessionCost pydantic model, or ``None`` on failure. Predictable + failures (unknown model, unavailable price) are blocked by the pre-inference + gate, so anything reaching here is a provider-side error (e.g. missing + usage) or a transient price-feed outage. Logging as CRITICAL matches + ``__main__._session_cost_calculator``'s contract: when this returns None, + x402's downstream reader will skip settlement and the client is not charged. Returns both the OPG integer (what x402 charges) and the equivalent USD — derived from the SAME rounded OPG value, not the raw USD math — so the two numbers always reconcile via ``cost_usd = cost_opg / 10^decimals * price``. """ - if not isinstance(request_json, dict) or not isinstance(response_json, dict): - raise ValueError( - "calculate_session_cost requires both request_json and response_json" - ) - - model = _extract_model_from_context(request_json, response_json) - cfg = get_model_config(model) - input_tokens, output_tokens = _extract_usage_tokens(response_json) - - raw_usd = (Decimal(input_tokens) * cfg.input_price_usd) + ( - Decimal(output_tokens) * cfg.output_price_usd - ) - token_price_usd = get_price() - if token_price_usd <= 0: - raise ValueError(f"Token price is non-positive: {token_price_usd}") - - scale = Decimal(10) ** asset_decimals - cost_smallest_units = max( - 0, - int( - ((raw_usd / token_price_usd) * scale).to_integral_value( - rounding=ROUND_CEILING - ) - ), - ) - # Reconcile USD from the rounded OPG integer so the two surfaced figures - # are exactly consistent (clients verify: usd == opg / 10^decimals * price). - settled_usd = (Decimal(cost_smallest_units) / scale) * token_price_usd - - logger.info( - "DYNAMIC_SESSION_COST model=%s input_tokens=%d output_tokens=%d " - "raw_usd=%s settled_usd=%s token_price_usd=%s decimals=%d cost=%d", - model, - input_tokens, - output_tokens, - str(raw_usd), - str(settled_usd), - str(token_price_usd), - asset_decimals, - cost_smallest_units, - ) - return SessionCost( - cost_opg=cost_smallest_units, - cost_usd=settled_usd, - opg_price_usd=token_price_usd, - ) - - -def compute_session_cost( - request_json: dict[str, Any], response_with_usage: dict[str, Any] -) -> SessionCost | None: - """Wrap calculate_session_cost for controllers: returns the SessionCost - pydantic model, or ``None`` on failure. Predictable failures (unknown - price/model) are blocked by the pre-inference gate, so anything reaching - here is a provider-side error (e.g. missing usage). Logging as CRITICAL - matches __main__._session_cost_calculator's contract: when this returns - None, x402's downstream reader will skip settlement and the client is not - charged. - """ # Imported lazily to keep this module free of process-level state at import # time (price_feed singleton is registered during app startup, after this # module has already been imported transitively via other paths). from tee_gateway.price_feed import get_price_feed try: - return calculate_session_cost( - request_json=request_json, - response_json=response_with_usage, - asset_decimals=_OPG_DECIMALS, - get_price=get_price_feed().get_price, + cfg = get_model_config(model.strip().lower()) + in_tok = max(0, int(usage["prompt_tokens"])) + out_tok = max(0, int(usage["completion_tokens"])) + + raw_usd = (Decimal(in_tok) * cfg.input_price_usd) + ( + Decimal(out_tok) * cfg.output_price_usd + ) + token_price_usd = get_price_feed().get_price() + if token_price_usd <= 0: + raise ValueError(f"Token price is non-positive: {token_price_usd}") + + scale = Decimal(10) ** _OPG_DECIMALS + cost_smallest_units = max( + 0, + int( + ((raw_usd / token_price_usd) * scale).to_integral_value( + rounding=ROUND_CEILING + ) + ), + ) + # Reconcile USD from the rounded OPG integer so the two surfaced figures + # are exactly consistent (clients verify: usd == opg / 10^decimals * price). + settled_usd = (Decimal(cost_smallest_units) / scale) * token_price_usd + + logger.info( + "DYNAMIC_SESSION_COST model=%s input_tokens=%d output_tokens=%d " + "raw_usd=%s settled_usd=%s token_price_usd=%s decimals=%d cost=%d", + model, + in_tok, + out_tok, + str(raw_usd), + str(settled_usd), + str(token_price_usd), + _OPG_DECIMALS, + cost_smallest_units, + ) + return SessionCost( + cost_opg=cost_smallest_units, + cost_usd=settled_usd, + opg_price_usd=token_price_usd, ) except Exception as exc: logger.critical( diff --git a/tee_gateway/test/test_price_feed.py b/tee_gateway/test/test_price_feed.py index 970aac8..17af102 100644 --- a/tee_gateway/test/test_price_feed.py +++ b/tee_gateway/test/test_price_feed.py @@ -1,5 +1,5 @@ """ -Unit tests for tee_gateway.price_feed and tee_gateway.pricing.calculate_session_cost. +Unit tests for tee_gateway.price_feed and tee_gateway.pricing.compute_session_cost. All external HTTP calls are mocked — no network access required. @@ -9,7 +9,7 @@ TestOPGPriceFeedRefresh — OPGPriceFeed._refresh_price() (retry, rate-limit, stats) TestOPGPriceFeedGetPrice — OPGPriceFeed.get_price() (stale warning, ValueError before fetch) TestOPGPriceFeedStatus — OPGPriceFeed.get_status() snapshots -TestCalculateSessionCost — calculate_session_cost(context, get_price) in pricing.py +TestCalculateSessionCost — compute_session_cost(model, usage) in pricing.py """ import time @@ -23,7 +23,7 @@ from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS from tee_gateway.price_feed import OPGPriceFeed from tee_gateway.price_feed.feed import fetch_opg_price -from tee_gateway.pricing import SessionCost, calculate_session_cost +from tee_gateway.pricing import SessionCost, compute_session_cost # --------------------------------------------------------------------------- # Helpers @@ -367,65 +367,33 @@ def test_status_accumulates_multiple_error_cycles(self, mock_fetch): # --------------------------------------------------------------------------- -# TestMakeCostCalculator +# TestCalculateSessionCost # --------------------------------------------------------------------------- -_ASSET_ADDR = "0xdeadbeef" -_ASSET_ADDR_LOWER = _ASSET_ADDR.lower() +# pricing.compute_session_cost reads OPG decimals from the asset registry — +# OPG is 18 decimals on Base mainnet. _ASSET_DECIMALS = 18 -def _make_payment_requirements(asset: str = _ASSET_ADDR) -> dict: - return {"asset": asset, "price": {"amount": "1000000000000000000", "asset": asset}} - - -def _make_context( - model: str = "gpt-4.1-mini", - input_tokens: int = 100, - output_tokens: int = 50, - price_usd: Decimal = Decimal("0.10"), - asset: str = _ASSET_ADDR, -) -> dict: - return { - "request_json": {"model": model}, - "response_json": { - "model": model, - "usage": { - "prompt_tokens": input_tokens, - "completion_tokens": output_tokens, - }, - }, - "payment_requirements": _make_payment_requirements(asset), - "method": "POST", - "path": "/v1/chat/completions", - "status_code": 200, - "is_streaming": False, - "request_body_bytes": b"", - "response_body_bytes": b"", - "default_cost": 10**18, - } +def _usage(input_tokens: int = 100, output_tokens: int = 50) -> dict: + return {"prompt_tokens": input_tokens, "completion_tokens": output_tokens} def _make_get_price(price_usd: Decimal = Decimal("0.10")) -> MagicMock: - mock = MagicMock(return_value=price_usd) - return mock + return MagicMock(return_value=price_usd) + +def _call(usage: dict, get_price, model: str = "gpt-4.1-mini"): + """Run compute_session_cost with a stubbed price feed.""" + from types import SimpleNamespace -def _call( - ctx: dict, - get_price, - asset_decimals: int = _ASSET_DECIMALS, -) -> SessionCost: - return calculate_session_cost( - request_json=ctx["request_json"], - response_json=ctx["response_json"], - asset_decimals=asset_decimals, - get_price=get_price, - ) + feed = SimpleNamespace(get_price=get_price) + with patch("tee_gateway.price_feed.get_price_feed", return_value=feed): + return compute_session_cost(model, usage) class TestCalculateSessionCost(unittest.TestCase): - """Tests for calculate_session_cost(request_json, response_json, asset_decimals, get_price).""" + """Tests for compute_session_cost(model, usage).""" def _patch_model( self, input_price: str = "0.000001", output_price: str = "0.000002" @@ -438,12 +406,12 @@ def _patch_model( def test_calls_get_price(self): get_price = _make_get_price() with self._patch_model(): - _call(_make_context(), get_price) + _call(_usage(), get_price) get_price.assert_called_once() def test_returns_session_cost(self): with self._patch_model(): - result = _call(_make_context(), _make_get_price()) + result = _call(_usage(), _make_get_price()) self.assertIsInstance(result, SessionCost) self.assertIsInstance(result.cost_opg, int) self.assertGreaterEqual(result.cost_opg, 0) @@ -453,69 +421,48 @@ def test_reported_usd_reconciles_with_opg(self): """cost_usd must equal cost_opg / 10^decimals * price exactly — otherwise clients verifying the conversion will reject the response.""" with self._patch_model(): - result = _call(_make_context(), _make_get_price(Decimal("0.10"))) + result = _call(_usage(), _make_get_price(Decimal("0.10"))) scale = Decimal(10) ** _ASSET_DECIMALS expected = (Decimal(result.cost_opg) / scale) * Decimal("0.10") self.assertEqual(result.cost_usd, expected) def test_zero_tokens_returns_zero(self): with self._patch_model(): - result = _call( - _make_context(input_tokens=0, output_tokens=0), _make_get_price() - ) + result = _call(_usage(0, 0), _make_get_price()) self.assertEqual(result.cost_opg, 0) self.assertEqual(result.cost_usd, Decimal(0)) - def test_raises_when_get_price_raises(self): + def test_returns_none_when_get_price_raises(self): get_price = MagicMock(side_effect=ValueError("price not available")) with self._patch_model(): - with self.assertRaises(ValueError): - _call(_make_context(), get_price) + self.assertIsNone(_call(_usage(), get_price)) - def test_raises_when_non_positive_price(self): + def test_returns_none_when_non_positive_price(self): with self._patch_model(): - with self.assertRaises(ValueError): - _call(_make_context(), _make_get_price(Decimal("0"))) + self.assertIsNone(_call(_usage(), _make_get_price(Decimal("0")))) - def test_raises_when_request_json_missing(self): - ctx = _make_context() - ctx["request_json"] = None + def test_returns_none_when_usage_missing_keys(self): with self._patch_model(): - with self.assertRaises(ValueError): - _call(ctx, _make_get_price()) - - def test_raises_when_usage_missing(self): - ctx = _make_context() - ctx["response_json"] = {"model": "gpt-4.1-mini"} - with self._patch_model(): - with self.assertRaises(ValueError): - _call(ctx, _make_get_price()) + self.assertIsNone(_call({"prompt_tokens": 100}, _make_get_price())) def test_cost_scales_with_token_count(self): with self._patch_model(): - cost_small = _call( - _make_context(input_tokens=10, output_tokens=5), _make_get_price() - ) - cost_large = _call( - _make_context(input_tokens=1000, output_tokens=500), _make_get_price() - ) + cost_small = _call(_usage(10, 5), _make_get_price()) + cost_large = _call(_usage(1000, 500), _make_get_price()) self.assertGreater(cost_large.cost_opg, cost_small.cost_opg) def test_higher_token_price_yields_lower_cost(self): with self._patch_model(): - cost_cheap = _call(_make_context(), _make_get_price(Decimal("0.10"))) - cost_expensive = _call(_make_context(), _make_get_price(Decimal("0.20"))) - self.assertGreater( - cost_cheap.cost_opg, - cost_expensive.cost_opg, - ) + cost_cheap = _call(_usage(), _make_get_price(Decimal("0.10"))) + cost_expensive = _call(_usage(), _make_get_price(Decimal("0.20"))) + self.assertGreater(cost_cheap.cost_opg, cost_expensive.cost_opg) def test_uses_current_price_on_each_call(self): """get_price is called fresh every invocation — price changes are picked up.""" get_price = MagicMock(side_effect=[Decimal("0.10"), Decimal("0.20")]) with self._patch_model(): - cost_first = _call(_make_context(), get_price) - cost_second = _call(_make_context(), get_price) + cost_first = _call(_usage(), get_price) + cost_second = _call(_usage(), get_price) self.assertEqual(get_price.call_count, 2) # Price doubled → cost should halve (same USD spend, twice the token price). self.assertGreater(cost_first.cost_opg, cost_second.cost_opg) diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 51d030e..9df1dbc 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -3,75 +3,40 @@ Tests verify that: - Every user-facing model name resolves to the correct ModelConfig - - calculate_session_cost produces the right amount in OPG token + - compute_session_cost produces the right amount in OPG token smallest-units for supported models - - Edge cases (no usage, unknown model, bad context) are handled correctly + - Edge cases (no usage, unknown model) are handled correctly """ import unittest from decimal import Decimal +from types import SimpleNamespace +from unittest.mock import patch -from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS from tee_gateway.model_registry import ( _MODEL_LOOKUP, get_model_config, ) -from tee_gateway.pricing import ( - _extract_asset_decimals_from_requirements, - calculate_session_cost as _calculate_session_cost_raw, -) - - -def calculate_session_cost(ctx, get_price): - """Adapter: legacy bundled-ctx call form -> new (request, response, - asset_decimals, get_price) signature, returning the OPG integer. - - request_json may be None (validated by the underlying function); guard so the - error path is preserved. - """ - request_json = ctx.get("request_json") - if not isinstance(request_json, dict): - # Match the underlying ValueError so .assertRaises(ValueError) still fires. - raise ValueError("request_json missing or not a dict") - return _calculate_session_cost_raw( - request_json=request_json, - response_json=ctx["response_json"], - asset_decimals=_extract_asset_decimals_from_requirements( - ctx["payment_requirements"] - ), - get_price=get_price, - ).cost_opg +from tee_gateway.pricing import compute_session_cost # All pricing tests assume OPG = $1.00 so USD cost == OPG token amount. _OPG_PRICE_USD = Decimal("1") -_get_price = lambda: _OPG_PRICE_USD # noqa: E731 -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _opg_requirements() -> dict: - """Fake PaymentRequirements dict for OPG (18 decimals).""" - return {"asset": BASE_MAINNET_OPG_ADDRESS, "amount": "50000000000000000"} - - -def _ctx(model: str, input_tokens: int, output_tokens: int, requirements=None) -> dict: - """Build a minimal calculator context.""" - return { - "request_json": {"model": model, "messages": []}, - "response_json": { - "model": model, - "usage": { - "prompt_tokens": input_tokens, - "completion_tokens": output_tokens, - "total_tokens": input_tokens + output_tokens, - }, - }, - "payment_requirements": requirements or _opg_requirements(), +def _calc_opg(model: str, input_tokens: int, output_tokens: int) -> int: + """Call compute_session_cost with the test price feed and return the OPG + integer. Returns -1 when the function returns None so tests can assert on + failure paths without raising.""" + usage = { + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, } + fake_feed = SimpleNamespace(get_price=lambda: _OPG_PRICE_USD) + with patch("tee_gateway.price_feed.get_price_feed", return_value=fake_feed): + result = compute_session_cost(model, usage) + return -1 if result is None else result.cost_opg def _expected_cost_opg(model: str, input_tokens: int, output_tokens: int) -> int: @@ -351,12 +316,10 @@ def test_unknown_sonnet_variant_raises(self): class TestCalculateSessionCostOPG(unittest.TestCase): - """calculate_session_cost with OPG (18 decimals).""" + """compute_session_cost with OPG (18 decimals).""" def _calc(self, model, input_tokens, output_tokens): - return calculate_session_cost( - _ctx(model, input_tokens, output_tokens, _opg_requirements()), _get_price - ) + return _calc_opg(model, input_tokens, output_tokens) # ── OpenAI ────────────────────────────────────────────────────────────── @@ -603,84 +566,28 @@ def test_grok_4_fast_cheaper_than_grok_4(self): class TestCalculateSessionCostEdgeCases(unittest.TestCase): - """Edge cases for calculate_session_cost.""" + """Edge cases for compute_session_cost.""" def test_zero_tokens_returns_zero(self): - cost = calculate_session_cost(_ctx("claude-sonnet-4-5", 0, 0), _get_price) - self.assertEqual(cost, 0) - - def test_missing_usage_raises(self): - ctx = { - "request_json": {"model": "claude-sonnet-4-5"}, - "response_json": {"model": "claude-sonnet-4-5"}, # no usage - "payment_requirements": _opg_requirements(), - } - with self.assertRaises(ValueError): - calculate_session_cost(ctx, _get_price) + self.assertEqual(_calc_opg("claude-sonnet-4-5", 0, 0), 0) - def test_unknown_asset_raises(self): - ctx = _ctx("claude-sonnet-4-5", 100, 100) - ctx["payment_requirements"] = {"asset": "0xdeadbeef", "amount": "1000"} - with self.assertRaises(ValueError): - calculate_session_cost(ctx, _get_price) - - def test_missing_asset_raises(self): - ctx = _ctx("claude-sonnet-4-5", 100, 100) - ctx["payment_requirements"] = {"amount": "1000"} # no asset - with self.assertRaises(ValueError): - calculate_session_cost(ctx, _get_price) - - def test_unknown_model_raises_value_error(self): - ctx = _ctx("gpt-4o", 100, 100) - with self.assertRaises(ValueError): - calculate_session_cost(ctx, _get_price) - - def test_missing_request_json_raises_value_error(self): - ctx = { - "request_json": None, - "response_json": { - "model": "claude-sonnet-4-5", - "usage": {"prompt_tokens": 100, "completion_tokens": 100}, - }, - "payment_requirements": _opg_requirements(), - } - with self.assertRaises(ValueError): - calculate_session_cost(ctx, _get_price) - - def test_model_from_request_takes_priority(self): - """request_json model name is used even if response_json has a different model.""" - ctx = { - "request_json": {"model": "claude-haiku-4-5"}, - "response_json": { - "model": "claude-sonnet-4-5", # response says Sonnet - "usage": {"prompt_tokens": 1000, "completion_tokens": 500}, - }, - "payment_requirements": _opg_requirements(), - } - cost = calculate_session_cost(ctx, _get_price) - # Should be priced as Haiku (from request), not Sonnet - haiku_cost = _expected_cost_opg("claude-haiku-4-5", 1000, 500) - self.assertEqual(cost, haiku_cost) + def test_unknown_model_returns_none(self): + # gpt-4o is not in the registry — get_model_config raises, caught and + # returned as None. + self.assertEqual(_calc_opg("gpt-4o", 100, 100), -1) def test_rounding_ceiling(self): """Fractional token costs are always rounded UP.""" - # 1 output token of Haiku: 0.000005 USD = 5e12 wei — exact, no rounding needed - cost = calculate_session_cost(_ctx("claude-haiku-4-5", 0, 1), _get_price) - self.assertEqual(cost, 5_000_000_000_000) - + # 1 output token of Haiku: 0.000005 USD = 5e12 wei — exact + self.assertEqual(_calc_opg("claude-haiku-4-5", 0, 1), 5_000_000_000_000) # 1 input token of Gemini Flash Lite: 0.0000001 USD = 1e11 wei — exact - cost = calculate_session_cost(_ctx("gemini-2.5-flash-lite", 1, 0), _get_price) - self.assertEqual(cost, 100_000_000_000) + self.assertEqual(_calc_opg("gemini-2.5-flash-lite", 1, 0), 100_000_000_000) def test_model_name_case_insensitive(self): - """Model names are normalized to lowercase before lookup.""" - cost_lower = calculate_session_cost( - _ctx("claude-sonnet-4-5", 100, 100), _get_price - ) - cost_upper = calculate_session_cost( - _ctx("CLAUDE-SONNET-4-5", 100, 100), _get_price + self.assertEqual( + _calc_opg("claude-sonnet-4-5", 100, 100), + _calc_opg("CLAUDE-SONNET-4-5", 100, 100), ) - self.assertEqual(cost_lower, cost_upper) if __name__ == "__main__": From 21306507fcbeb6db90878fe998c8454dfd2577d1 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Sat, 16 May 2026 13:54:58 -0400 Subject: [PATCH 32/39] cost --- tee_gateway/__main__.py | 13 +- tests/test_opengradient_field.py | 229 +++++++++++++++++++++++++++++++ 2 files changed, 241 insertions(+), 1 deletion(-) create mode 100644 tests/test_opengradient_field.py diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index 3edc0cf..d7237c4 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -176,7 +176,18 @@ def _session_cost_calculator(ctx: dict) -> int: response_json = ctx.get("response_json") if not isinstance(response_json, dict): raise ValueError("response_json missing or not a dict") - return SessionCost.model_validate(response_json.get("opengradient")).cost_opg + cost_block = response_json.get("opengradient") + if cost_block is None: + # Should never happen on a successful paid response — the controller + # always embeds this when it can compute it, and when it can't, the + # request typically errored out before reaching x402's close hook. + # If we see this, settlement silently skips and a real bug is hiding. + logger.critical( + "opengradient cost block missing on paid response — client will " + "NOT be charged. response_id=%s", + response_json.get("id"), + ) + return SessionCost.model_validate(cost_block).cost_opg # --------------------------------------------------------------------------- diff --git a/tests/test_opengradient_field.py b/tests/test_opengradient_field.py new file mode 100644 index 0000000..257700f --- /dev/null +++ b/tests/test_opengradient_field.py @@ -0,0 +1,229 @@ +"""Verify the `opengradient` cost block is embedded on responses. + +These tests are the only thing keeping `compute_session_cost`'s result from +silently going missing on a controller response — if that block is absent, +x402's `_session_cost_calculator` swallows the error and the client is never +charged. The runtime CRITICAL log is the safety net; this is the unit-test +catch. +""" + +import json +import unittest +from decimal import Decimal +from unittest.mock import MagicMock, patch + +from tee_gateway.models.create_chat_completion_request import ( + CreateChatCompletionRequest, +) +from tee_gateway.models import ChatCompletionRequestUserMessage +from tee_gateway.models.create_completion_request import CreateCompletionRequest +from tee_gateway.pricing import SessionCost + + +_FAKE_USAGE = {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} + + +def _fake_cost() -> SessionCost: + return SessionCost( + cost_opg=12345, + cost_usd=Decimal("0.001"), + opg_price_usd=Decimal("0.5"), + ) + + +def _fake_tee_keys() -> MagicMock: + keys = MagicMock() + keys.sign_data.return_value = "0xsig" + keys.get_tee_id.return_value = "deadbeef" + return keys + + +def _chat_request() -> CreateChatCompletionRequest: + return CreateChatCompletionRequest( + model="gpt-4.1-mini", + messages=[ChatCompletionRequestUserMessage(role="user", content="hi")], + stream=False, + ) + + +class TestChatNonStreamingOpengradient(unittest.TestCase): + def test_opengradient_block_embedded_when_cost_computed(self): + from tee_gateway.controllers import chat_controller + + fake_response = MagicMock() + fake_response.content = "hello" + fake_response.tool_calls = None + fake_model = MagicMock() + fake_model.invoke.return_value = fake_response + + with ( + patch.object( + chat_controller, "get_chat_model_cached", return_value=fake_model + ), + patch.object(chat_controller, "extract_usage", return_value=_FAKE_USAGE), + patch.object( + chat_controller, "compute_session_cost", return_value=_fake_cost() + ), + patch.object( + chat_controller, "get_tee_keys", return_value=_fake_tee_keys() + ), + patch.object( + chat_controller, + "compute_tee_msg_hash", + return_value=(b"h", "ih", "oh"), + ), + ): + resp = chat_controller._create_non_streaming_response(_chat_request()) + + self.assertIsInstance(resp, dict) + self.assertIn("opengradient", resp) + self.assertEqual(resp["opengradient"].cost_opg, 12345) + + def test_opengradient_block_absent_when_compute_returns_none(self): + from tee_gateway.controllers import chat_controller + + fake_response = MagicMock() + fake_response.content = "hello" + fake_response.tool_calls = None + fake_model = MagicMock() + fake_model.invoke.return_value = fake_response + + with ( + patch.object( + chat_controller, "get_chat_model_cached", return_value=fake_model + ), + patch.object(chat_controller, "extract_usage", return_value=_FAKE_USAGE), + patch.object(chat_controller, "compute_session_cost", return_value=None), + patch.object( + chat_controller, "get_tee_keys", return_value=_fake_tee_keys() + ), + patch.object( + chat_controller, + "compute_tee_msg_hash", + return_value=(b"h", "ih", "oh"), + ), + ): + resp = chat_controller._create_non_streaming_response(_chat_request()) + + self.assertNotIn("opengradient", resp) + + +class TestChatStreamingOpengradient(unittest.TestCase): + def test_final_sse_event_carries_opengradient(self): + from tee_gateway.controllers import chat_controller + + chunk = MagicMock() + chunk.content = "hello" + chunk.tool_call_chunks = [] + chunk.usage_metadata = { + "input_tokens": 10, + "output_tokens": 20, + "total_tokens": 30, + } + fake_model = MagicMock() + fake_model.stream.return_value = iter([chunk]) + + with ( + patch.object( + chat_controller, "get_chat_model_cached", return_value=fake_model + ), + patch.object( + chat_controller, "get_provider_from_model", return_value="openai" + ), + patch.object( + chat_controller, "compute_session_cost", return_value=_fake_cost() + ), + patch.object( + chat_controller, "get_tee_keys", return_value=_fake_tee_keys() + ), + patch.object( + chat_controller, + "compute_tee_msg_hash", + return_value=(b"h", "ih", "oh"), + ), + ): + req = _chat_request() + req.stream = True + response = chat_controller._create_streaming_response(req) + chunks = [ + c.decode() if isinstance(c, bytes) else c for c in response.response + ] + + # The final SSE data event before [DONE] carries the opengradient block. + data_events = [ + c for c in chunks if c.startswith("data: ") and "[DONE]" not in c + ] + final = json.loads(data_events[-1][len("data: ") :].strip()) + self.assertIn("opengradient", final) + self.assertEqual(final["opengradient"]["cost_opg"], "12345") + + +class TestCompletionsOpengradient(unittest.TestCase): + def test_opengradient_block_embedded_on_completion(self): + from tee_gateway.controllers import completions_controller + + fake_response = MagicMock() + fake_response.content = "world" + fake_model = MagicMock() + fake_model.invoke.return_value = fake_response + + fake_request = MagicMock() + fake_request.is_json = True + fake_request.get_json.return_value = { + "model": "gpt-4.1-mini", + "prompt": "hi", + } + + body = CreateCompletionRequest(model="gpt-4.1-mini", prompt="hi") + + with ( + patch("connexion.request", fake_request), + patch.object( + completions_controller, + "get_chat_model_cached", + return_value=fake_model, + ), + patch.object( + completions_controller, "extract_usage", return_value=_FAKE_USAGE + ), + patch.object( + completions_controller, + "compute_session_cost", + return_value=_fake_cost(), + ), + patch.object( + completions_controller, + "get_tee_keys", + return_value=_fake_tee_keys(), + ), + patch.object( + completions_controller, + "compute_tee_msg_hash", + return_value=(b"h", "ih", "oh"), + ), + ): + resp = completions_controller.create_completion(body) + + self.assertIsInstance(resp, dict) + self.assertIn("opengradient", resp) + self.assertEqual(resp["opengradient"].cost_opg, 12345) + + +class TestSessionCostCalculatorMissingBlock(unittest.TestCase): + def test_critical_log_when_opengradient_missing(self): + from tee_gateway import __main__ as gateway_main + + with self.assertLogs(gateway_main.logger, level="CRITICAL") as cm: + with self.assertRaises(Exception): + gateway_main._session_cost_calculator( + {"response_json": {"id": "chatcmpl-x"}} + ) + + self.assertTrue( + any("opengradient cost block missing" in msg for msg in cm.output), + f"expected CRITICAL about missing opengradient, got: {cm.output}", + ) + + +if __name__ == "__main__": + unittest.main() From 5f7942e29a6ea8d8fb7b135290cd959f01acfe28 Mon Sep 17 00:00:00 2001 From: kukac Date: Sat, 16 May 2026 13:55:54 -0400 Subject: [PATCH 33/39] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- tee_gateway/controllers/ohttp_controller.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py index 8bc0c70..62b82bf 100644 --- a/tee_gateway/controllers/ohttp_controller.py +++ b/tee_gateway/controllers/ohttp_controller.py @@ -195,9 +195,11 @@ def _build_outer_response( inner_content_type: str, ) -> Response: """Single-shot OHTTP response. Seals the body on 2xx (contains user - prompts/completions) and surfaces token usage as outer headers so the - relay can bill. Non-2xx bodies (x402 payment requirements, validation - errors) are forwarded as plaintext so the relay can act on them.""" + prompts/completions) and surfaces only extracted cost/price headers as + outer headers for relay billing (including `X-Inference-Price-OPG-USD`); + usage details remain sealed in the encapsulated body. Non-2xx bodies + (x402 payment requirements, validation errors) are forwarded as + plaintext so the relay can act on them.""" # Keep headers as a list of (name, value) tuples — WSGI gives us a list # specifically because HTTP allows multi-valued headers (RFC 7230 §3.2.2; # WWW-Authenticate in particular per RFC 7235 §4.1 can repeat, one per From 625354c8a15c15cdf97f7a9b8cd9cefb4608a77b Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Sat, 16 May 2026 14:00:32 -0400 Subject: [PATCH 34/39] controller test --- tee_gateway/test/test_ohttp_controller.py | 520 ++++++++++++++++++++++ 1 file changed, 520 insertions(+) create mode 100644 tee_gateway/test/test_ohttp_controller.py diff --git a/tee_gateway/test/test_ohttp_controller.py b/tee_gateway/test/test_ohttp_controller.py new file mode 100644 index 0000000..b7b4d7d --- /dev/null +++ b/tee_gateway/test/test_ohttp_controller.py @@ -0,0 +1,520 @@ +"""Unit tests for the /v1/ohttp handler in +``tee_gateway.controllers.ohttp_controller``. + +The controller is a relatively thin shell that sits between HPKE +decapsulation and an in-process WSGI sub-request to /v1/chat/completions. +These tests pin down the shell's behaviour without bringing up the real +chat backend by stubbing ``get_tee_keys`` (to plant a known HPKE keypair) +and ``current_app.wsgi_app`` (to fake the inner /v1/chat/completions +response). + +What we verify: + * 413 on an oversized encapsulated body + * 400 on a malformed encapsulation + * 400 when the decrypted inner payload is not valid JSON + * Non-2xx inner responses are forwarded as plaintext (no HPKE seal), + with the inner content-type and the x402-related headers preserved + * 2xx inner responses are sealed and the ``opengradient`` cost block is + projected onto outer X-Inference-Cost-* headers + * Streaming inner responses emit exactly one AAD=b"final" chunk and + close the inner WSGI iterator (settles x402) +""" + +from __future__ import annotations + +import json +import struct +from decimal import Decimal +from typing import Iterator + +import pytest +from flask import Flask + +from tee_gateway import ohttp +from tee_gateway.controllers import ohttp_controller + + +# --------------------------------------------------------------------------- +# helpers + + +def _encapsulate(plaintext: bytes): + """Build a real HPKE-encapsulated request and return ``(sk, wire, sender, enc)``. + + The sender context is returned so individual tests that need to decrypt + the outer response can re-derive the response key the same way an SDK + client would. + """ + sk, pk_raw = ohttp.generate_keypair() + hdr = bytes([ohttp.KEY_CONFIG_ID]) + struct.pack( + ">HHH", + ohttp.KEM_ID_X25519, + ohttp.KDF_ID_HKDF_SHA256, + ohttp.AEAD_ID_CHACHA20_POLY1305, + ) + info = b"message/bhttp request" + b"\x00" + hdr + pkr = ohttp._SUITE.kem.deserialize_public_key(pk_raw) + enc, sender = ohttp._SUITE.create_sender_context(pkr, info=info) + wire = hdr + enc + sender.seal(plaintext, aad=b"") + return sk, wire, sender, enc + + +class _FakeTee: + def __init__(self, sk): + self.hpke_private_key = sk + + +class _CloseTrackingIter: + """Iterable that records when ``close()`` is called. + + Mirrors how the real WSGI iterator from the chat handler behaves — + closing it is what triggers x402's post-response settlement, so we + need the controller to call it on both the streaming and non-streaming + paths. + """ + + def __init__(self, chunks): + self._chunks = list(chunks) + self.closed = False + + def __iter__(self) -> Iterator[bytes]: + for c in self._chunks: + yield c + + def close(self): + self.closed = True + + +def _make_app(inner_responder, sk): + """Build a tiny Flask app with the controller mounted at /v1/ohttp. + + ``inner_responder`` is invoked when the controller does its WSGI + sub-dispatch to ``/v1/chat/completions``; it must return + ``(status_line, headers, body_iter)``. Anything else passes through + to the real Flask routing so the outer POST still lands on our + handler. + """ + app = Flask(__name__) + app.add_url_rule( + "/v1/ohttp", + view_func=ohttp_controller.create_anonymous_chat_completion, + methods=["POST"], + ) + + original_wsgi = app.wsgi_app + captured = {"env": None, "called": 0, "iter": None} + + def fake_wsgi(env, start_response): + if env.get("PATH_INFO") == "/v1/chat/completions": + captured["env"] = env + captured["called"] += 1 + status, headers, body_iter = inner_responder() + captured["iter"] = body_iter + start_response(status, headers) + return body_iter + return original_wsgi(env, start_response) + + app.wsgi_app = fake_wsgi + + def fake_get_tee_keys(): + return _FakeTee(sk) + + return app, captured, fake_get_tee_keys + + +# --------------------------------------------------------------------------- +# 413 / 400 cases — no decap happens, so we don't even need a key + + +def test_oversized_body_returns_413(monkeypatch): + sk, _ = ohttp.generate_keypair() + # Body is well past _MAX_ENCAPSULATED_REQUEST_BYTES (512 KiB). Werkzeug + # will set Content-Length from the data length, so the up-front check + # fires before any HPKE work. + app, captured, fake_keys = _make_app(lambda: ("200 OK", [], iter([])), sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + too_big = b"x" * (ohttp_controller._MAX_ENCAPSULATED_REQUEST_BYTES + 1) + client = app.test_client() + resp = client.post("/v1/ohttp", data=too_big, content_type="message/ohttp-req") + + assert resp.status_code == 413 + assert captured["called"] == 0 + + +def test_empty_body_returns_400(monkeypatch): + sk, _ = ohttp.generate_keypair() + app, captured, fake_keys = _make_app(lambda: ("200 OK", [], iter([])), sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=b"", content_type="message/ohttp-req") + assert resp.status_code == 400 + assert captured["called"] == 0 + + +def test_malformed_encapsulation_returns_400(monkeypatch): + sk, _ = ohttp.generate_keypair() + app, captured, fake_keys = _make_app(lambda: ("200 OK", [], iter([])), sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + # Random short payload — passes the size gate but ohttp.decapsulate_request + # will reject it as malformed. The controller MUST normalise that into a + # generic 400 so it doesn't expose an oracle on which decap step failed. + client = app.test_client() + resp = client.post( + "/v1/ohttp", data=b"\x00" * 64, content_type="message/ohttp-req" + ) + assert resp.status_code == 400 + assert captured["called"] == 0 + + +def test_invalid_inner_json_returns_400(monkeypatch): + """Plaintext decapsulates fine but isn't valid JSON — controller must + surface that as a 400 rather than fall through to the chat handler.""" + sk, wire, _, _ = _encapsulate(b"not-json{{{") + app, captured, fake_keys = _make_app(lambda: ("200 OK", [], iter([])), sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=wire, content_type="message/ohttp-req") + assert resp.status_code == 400 + assert captured["called"] == 0 + + +def test_non_object_inner_payload_returns_400(monkeypatch): + sk, wire, _, _ = _encapsulate(b"[1, 2, 3]") + app, captured, fake_keys = _make_app(lambda: ("200 OK", [], iter([])), sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=wire, content_type="message/ohttp-req") + assert resp.status_code == 400 + assert captured["called"] == 0 + + +# --------------------------------------------------------------------------- +# Non-2xx inner: plaintext passthrough, forwarded headers + + +def test_non_2xx_inner_response_is_plaintext_passthrough(monkeypatch): + """An x402 402 (or any non-2xx) from the chat handler must NOT be + HPKE-sealed — the relay needs to read the payment challenge and act + on it. The inner content-type and x402 control headers must survive.""" + sk, wire, _, _ = _encapsulate(b'{"model":"gpt-4.1","messages":[]}') + + inner_body = b'{"error":"payment required"}' + + def inner(): + return ( + "402 Payment Required", + [ + ("Content-Type", "application/json"), + ("WWW-Authenticate", "x402 ..."), + ("X-Payment-Required", "true"), + ("X-Tee-Signature", "sig"), + ("Set-Cookie", "tracker=abc"), # must NOT be forwarded + ], + _CloseTrackingIter([inner_body]), + ) + + app, captured, fake_keys = _make_app(inner, sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post( + "/v1/ohttp", + data=wire, + content_type="message/ohttp-req", + headers={"X-Payment": "client-payment-blob"}, + ) + + assert resp.status_code == 402 + # Body must be plaintext, NOT message/ohttp-res. + assert resp.content_type.startswith("application/json") + assert resp.data == inner_body + # x402 / WWW-Authenticate forwarded; arbitrary headers not. + assert resp.headers.get("WWW-Authenticate") == "x402 ..." + assert resp.headers.get("X-Payment-Required") == "true" + assert resp.headers.get("X-Tee-Signature") == "sig" + assert "Set-Cookie" not in resp.headers + # The relay's X-Payment was forwarded into the inner env so the + # x402 middleware can verify it. + assert captured["env"]["HTTP_X_PAYMENT"] == "client-payment-blob" + # WSGI iterator was drained AND closed (drains x402 settlement). + assert captured["iter"].closed is True + + +# --------------------------------------------------------------------------- +# 2xx inner: response sealed, cost headers surfaced + + +def _client_decrypt_response(sealed: bytes, sender, enc: bytes) -> bytes: + """Recover the plaintext from a single-shot OHTTP response, the same + way an external client SDK would.""" + from cryptography.hazmat.primitives import hashes, hmac + from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 + from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand + + response_secret = sender.export(b"message/bhttp response", 32) + response_nonce = sealed[:32] + aead_ct = sealed[32:] + salt = enc + response_nonce + h = hmac.HMAC(salt, hashes.SHA256()) + h.update(response_secret) + prk = h.finalize() + key = HKDFExpand(algorithm=hashes.SHA256(), length=32, info=b"key").derive(prk) + nonce = HKDFExpand(algorithm=hashes.SHA256(), length=12, info=b"nonce").derive(prk) + return ChaCha20Poly1305(key).decrypt(nonce, aead_ct, b"") + + +def test_2xx_inner_response_is_sealed_with_cost_headers(monkeypatch): + sk, wire, sender, enc = _encapsulate(b'{"model":"gpt-4.1","messages":[]}') + + inner_payload = { + "id": "chatcmpl-xyz", + "choices": [{"message": {"role": "assistant", "content": "hi"}}], + "opengradient": { + "cost_opg": 1234567890, + "cost_usd": "0.001234", + "opg_price_usd": "1.234", + }, + } + inner_body = json.dumps(inner_payload).encode() + + def inner(): + return ( + "200 OK", + [ + ("Content-Type", "application/json"), + ("X-Tee-Signature", "sig"), + ("X-Payment-Response", "settled"), + ], + _CloseTrackingIter([inner_body]), + ) + + app, captured, fake_keys = _make_app(inner, sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=wire, content_type="message/ohttp-req") + + assert resp.status_code == 200 + assert resp.content_type.startswith(ohttp_controller.OHTTP_RESPONSE_MEDIA_TYPE) + # Cost is projected onto outer headers — relay-billable, no model name + # and no token counts leak. + assert resp.headers["X-Inference-Cost-OPG"] == "1234567890" + assert Decimal(resp.headers["X-Inference-Cost-USD"]) == Decimal("0.001234") + assert Decimal(resp.headers["X-Inference-Price-OPG-USD"]) == Decimal("1.234") + # x402 / TEE headers still pass through. + assert resp.headers.get("X-Tee-Signature") == "sig" + assert resp.headers.get("X-Payment-Response") == "settled" + + decrypted = _client_decrypt_response(resp.data, sender, enc) + assert json.loads(decrypted) == inner_payload + assert captured["iter"].closed is True + + +def test_2xx_inner_without_cost_block_omits_cost_headers(monkeypatch): + """Missing or unparseable opengradient block must not 500 — we just + skip the cost projection. Belt-and-braces guard around _extract_cost_headers.""" + sk, wire, sender, enc = _encapsulate(b'{"model":"gpt-4.1","messages":[]}') + + inner_body = b'{"id":"chatcmpl-xyz","choices":[]}' + + def inner(): + return ( + "200 OK", + [("Content-Type", "application/json")], + _CloseTrackingIter([inner_body]), + ) + + app, _, fake_keys = _make_app(inner, sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=wire, content_type="message/ohttp-req") + + assert resp.status_code == 200 + assert "X-Inference-Cost-OPG" not in resp.headers + assert "X-Inference-Cost-USD" not in resp.headers + + +# --------------------------------------------------------------------------- +# Streaming path + + +def test_streaming_inner_response_emits_one_final_chunk_and_closes(monkeypatch): + sk, wire, sender, enc = _encapsulate(b'{"model":"gpt-4.1","stream":true}') + + sse_chunks = [ + b"data: {\"a\":1}\n\n", + b"data: {\"b\":2}\n\n", + b"data: [DONE]\n\n", + ] + + def inner(): + return ( + "200 OK", + [("Content-Type", "text/event-stream")], + _CloseTrackingIter(sse_chunks), + ) + + app, captured, fake_keys = _make_app(inner, sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=wire, content_type="message/ohttp-req") + + assert resp.status_code == 200 + assert resp.content_type.startswith( + ohttp_controller.OHTTP_CHUNKED_RESPONSE_MEDIA_TYPE + ) + + # Decode the chunked-OHTTP wire frame the same way a client SDK would. + body = resp.get_data() + response_nonce = body[:32] + off = 32 + + response_secret = sender.export(b"message/bhttp chunked response", 32) + aead_key, aead_nonce = ohttp._derive_response_keys( + response_secret, enc, response_nonce + ) + from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 + + aead = ChaCha20Poly1305(aead_key) + + recovered = [] + final_count = 0 + counter = 0 + while off < len(body): + length, off = ohttp.decode_varint(body, off) + is_final = length == 0 + seg_len = len(body) - off if is_final else length + ct = body[off : off + seg_len] + off += seg_len + chunk_nonce = bytes( + a ^ b for a, b in zip(aead_nonce, counter.to_bytes(12, "big")) + ) + aad = b"final" if is_final else b"" + recovered.append(aead.decrypt(chunk_nonce, ct, aad)) + counter += 1 + if is_final: + final_count += 1 + break + + # Every SSE event must be recovered, and the last chunk MUST be the + # AAD=b"final" terminator — that's what protects clients from + # undetected truncation. There must be exactly one. + assert recovered == sse_chunks + assert final_count == 1 + # And the inner iterator must have been closed so x402 streaming + # settlement runs. + assert captured["iter"].closed is True + + +def test_streaming_path_uses_chunked_when_only_one_inner_chunk(monkeypatch): + """Even with a single SSE event, the controller must still emit + exactly one AAD=b"final" chunk (the pending/lookahead logic in + _build_streaming_response is the part being pinned down).""" + sk, wire, sender, enc = _encapsulate(b'{"model":"gpt-4.1","stream":true}') + + only = b"data: [DONE]\n\n" + + def inner(): + return ( + "200 OK", + [("Content-Type", "text/event-stream")], + _CloseTrackingIter([only]), + ) + + app, captured, fake_keys = _make_app(inner, sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=wire, content_type="message/ohttp-req") + + body = resp.get_data() + response_nonce = body[:32] + off = 32 + response_secret = sender.export(b"message/bhttp chunked response", 32) + aead_key, aead_nonce = ohttp._derive_response_keys( + response_secret, enc, response_nonce + ) + from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 + + aead = ChaCha20Poly1305(aead_key) + + length, off = ohttp.decode_varint(body, off) + # With exactly one inner chunk, the controller buffers it and emits + # it as the final marker — so the first (and only) framed chunk is + # the zero-length-prefixed final. + assert length == 0 + ct = body[off:] + chunk_nonce = bytes(a ^ b for a, b in zip(aead_nonce, (0).to_bytes(12, "big"))) + assert aead.decrypt(chunk_nonce, ct, b"final") == only + assert captured["iter"].closed is True + + +# --------------------------------------------------------------------------- +# 503 when HPKE key isn't initialized + + +def test_returns_503_when_hpke_key_missing(monkeypatch): + app, captured, _ = _make_app(lambda: ("200 OK", [], iter([])), sk=None) + + class _Tee: + hpke_private_key = None + + monkeypatch.setattr(ohttp_controller, "get_tee_keys", lambda: _Tee()) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=b"abc", content_type="message/ohttp-req") + assert resp.status_code == 503 + assert captured["called"] == 0 + + +# --------------------------------------------------------------------------- +# Identifying field scrubbing + + +def test_identifying_fields_are_scrubbed_before_inner_dispatch(monkeypatch): + payload = { + "model": "gpt-4.1", + "messages": [{"role": "user", "content": "hi"}], + "user": "alice@example.com", + "metadata": {"req": "123"}, + "x-request-id": "abc", + "request_id": "def", + } + sk, wire, _, _ = _encapsulate(json.dumps(payload).encode()) + + inner_body = b'{"id":"chatcmpl-1","choices":[]}' + + def inner(): + return ( + "200 OK", + [("Content-Type", "application/json")], + _CloseTrackingIter([inner_body]), + ) + + app, captured, fake_keys = _make_app(inner, sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=wire, content_type="message/ohttp-req") + assert resp.status_code == 200 + + # Reconstruct what the inner handler received. + inner_env = captured["env"] + inner_body_received = inner_env["wsgi.input"].read( + int(inner_env["CONTENT_LENGTH"]) + ) + forwarded = json.loads(inner_body_received) + assert forwarded == { + "model": "gpt-4.1", + "messages": [{"role": "user", "content": "hi"}], + } + # Authorization is overwritten with the OHTTP-fixed value so it can't + # re-identify the client. + assert inner_env["HTTP_AUTHORIZATION"] == "Bearer ohttp" From dd19905e5b7887228d53fdb40db2340486e43e34 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Sat, 16 May 2026 14:01:21 -0400 Subject: [PATCH 35/39] lint --- tee_gateway/test/test_ohttp_controller.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tee_gateway/test/test_ohttp_controller.py b/tee_gateway/test/test_ohttp_controller.py index b7b4d7d..8ef6298 100644 --- a/tee_gateway/test/test_ohttp_controller.py +++ b/tee_gateway/test/test_ohttp_controller.py @@ -27,7 +27,6 @@ from decimal import Decimal from typing import Iterator -import pytest from flask import Flask from tee_gateway import ohttp @@ -162,9 +161,7 @@ def test_malformed_encapsulation_returns_400(monkeypatch): # will reject it as malformed. The controller MUST normalise that into a # generic 400 so it doesn't expose an oracle on which decap step failed. client = app.test_client() - resp = client.post( - "/v1/ohttp", data=b"\x00" * 64, content_type="message/ohttp-req" - ) + resp = client.post("/v1/ohttp", data=b"\x00" * 64, content_type="message/ohttp-req") assert resp.status_code == 400 assert captured["called"] == 0 @@ -348,8 +345,8 @@ def test_streaming_inner_response_emits_one_final_chunk_and_closes(monkeypatch): sk, wire, sender, enc = _encapsulate(b'{"model":"gpt-4.1","stream":true}') sse_chunks = [ - b"data: {\"a\":1}\n\n", - b"data: {\"b\":2}\n\n", + b'data: {"a":1}\n\n', + b'data: {"b":2}\n\n', b"data: [DONE]\n\n", ] @@ -507,9 +504,7 @@ def inner(): # Reconstruct what the inner handler received. inner_env = captured["env"] - inner_body_received = inner_env["wsgi.input"].read( - int(inner_env["CONTENT_LENGTH"]) - ) + inner_body_received = inner_env["wsgi.input"].read(int(inner_env["CONTENT_LENGTH"])) forwarded = json.loads(inner_body_received) assert forwarded == { "model": "gpt-4.1", From 61471cde60f5610bdfe361965399e3ab9d026395 Mon Sep 17 00:00:00 2001 From: kukac Date: Sat, 16 May 2026 14:09:14 -0400 Subject: [PATCH 36/39] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1bef00c..0503808 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,7 +20,7 @@ jobs: # left test_chat_controller, test_completions_controller, and # test_ohttp out. test_price_feed_integration.py self-gates on # RUN_INTEGRATION_TESTS and skips cleanly without it. - run: uv run --group test pytest tee_gateway/test/ tests/test_pricing.py -v --import-mode=importlib + run: uv run --group test pytest tee_gateway/test/ tests/test_pricing.py tests/test_opengradient_field.py -v --import-mode=importlib # To also run integration tests (real CoinGecko network calls), add: # env: # RUN_INTEGRATION_TESTS: "1" From d44ab8a695c5850e6ee5e1808d3022485a24863c Mon Sep 17 00:00:00 2001 From: kukac Date: Sat, 16 May 2026 14:09:44 -0400 Subject: [PATCH 37/39] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- tests/test_pricing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 9df1dbc..99fa4b2 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -5,7 +5,7 @@ - Every user-facing model name resolves to the correct ModelConfig - compute_session_cost produces the right amount in OPG token smallest-units for supported models - - Edge cases (no usage, unknown model) are handled correctly + - Edge cases (unknown model) are handled correctly """ import unittest From 95a17eed9d837ae1fde2ebc324998cceab7ecb9f Mon Sep 17 00:00:00 2001 From: Aniket Dixit <47004499+dixitaniket@users.noreply.github.com> Date: Mon, 18 May 2026 22:53:36 +0530 Subject: [PATCH 38/39] OHTTP updates (#74) * updates test * updates * updates --- tee_gateway/__main__.py | 97 +++++++++++- tee_gateway/controllers/chat_controller.py | 2 +- .../controllers/completions_controller.py | 2 +- tee_gateway/controllers/ohttp_controller.py | 145 +++++++++++++++--- tee_gateway/test/test_ohttp_controller.py | 7 +- tests/test_opengradient_field.py | 4 +- 6 files changed, 227 insertions(+), 30 deletions(-) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index d7237c4..0d4483a 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -160,6 +160,66 @@ def _patched_read_body_bytes(environ): x402_flask._read_body_bytes = _patched_read_body_bytes +def _patched_stream_session_response( + self, + environ, + start_response, + context, + session_id, + payment_payload, + payment_requirements, +): + """Expose x402's per-request cost context to Flask route handlers. + + OHTTP requests arrive as ciphertext and return ciphertext, so the x402 + middleware cannot parse request_json/response_json from the outer HTTP + bodies. The OHTTP controller decrypts the inner request and plaintext + response inside the enclave; this patch gives it a request-local dict where + it can attach those inner JSON objects for dynamic settlement. + """ + self._start_reaper() + + request_body_bytes = x402_flask._read_body_bytes(environ) + request_json = x402_flask._try_parse_json(request_body_bytes) + parsed_request_json = ( + request_json if isinstance(request_json, (dict, list)) else None + ) + + x402_flask.g.payment_payload = payment_payload + x402_flask.g.payment_requirements = payment_requirements + x402_flask.g.x402_session_id = session_id + + cost_context = { + "method": context.method, + "path": context.path, + "request_body_bytes": request_body_bytes, + "request_json": parsed_request_json, + "payment_payload": payment_payload, + "payment_requirements": payment_requirements, + } + environ["x402.cost_context"] = cost_context + + status_capture = x402_flask.StatusCapture(start_response) + status_capture.add_header(x402_flask.UPTO_SESSION_HEADER, session_id) + + upstream_iter = self._original_wsgi(environ, status_capture) + + return x402_flask.StreamingSessionResponse( + upstream_iter, + middleware=self, + session_id=session_id, + cost_context=cost_context, + status_ref=status_capture, + ) + + +setattr( + x402_flask.PaymentMiddleware, + "_stream_session_response", + _patched_stream_session_response, +) + + def _session_cost_calculator(ctx: dict) -> int: # The chat/completions controllers compute cost in-band and embed it on # the response as a SessionCost model BEFORE returning. We parse it back @@ -173,7 +233,10 @@ def _session_cost_calculator(ctx: dict) -> int: # has already logged CRITICAL in that case. from .pricing import SessionCost - response_json = ctx.get("response_json") + if ctx.get("path") == "/v1/ohttp": + response_json = ctx.get("inner_response_json") + else: + response_json = ctx.get("response_json") if not isinstance(response_json, dict): raise ValueError("response_json missing or not a dict") cost_block = response_json.get("opengradient") @@ -257,8 +320,35 @@ def _init_payment_middleware(facilitator_url: str) -> None: mime_type="application/json", description="Completion", ), + "POST /v1/ohttp": RouteConfig( + accepts=[ + PaymentOption( + scheme="upto", + pay_to=EVM_PAYMENT_ADDRESS, + price=AssetAmount( + amount=CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, + asset=BASE_MAINNET_OPG_ADDRESS, + extra={ + "name": "OpenGradient", + "version": "1", + "assetTransferMethod": "permit2", + }, + ), + network=BASE_MAINNET_NETWORK, + ), + ], + extensions={ + **declare_erc20_approval_gas_sponsoring_extension(), + }, + mime_type="message/ohttp-req", + description="OHTTP-wrapped chat completion", + ), } + inner_wsgi_app = application.wsgi_app + flask_app = getattr(application, "app", application) + flask_app.config["OHTTP_INNER_WSGI_APP"] = inner_wsgi_app + # Return value intentionally discarded — PaymentMiddleware.__init__ self-wires # by setting application.wsgi_app = self._wsgi_middleware internally. payment_middleware( @@ -499,7 +589,7 @@ def create_app(): @application.before_request def _check_pricing_ready(): - if request.path not in ("/v1/chat/completions", "/v1/completions"): + if request.path not in ("/v1/chat/completions", "/v1/completions", "/v1/ohttp"): return try: _price_feed.get_price() @@ -507,6 +597,9 @@ def _check_pricing_ready(): logger.warning("Rejecting inference request — price feed unavailable: %s", exc) return jsonify({"error": f"Pricing unavailable: {exc}"}), 503 + if request.path == "/v1/ohttp": + return + body = request.get_json(silent=True, cache=True) or {} model = body.get("model") if model: diff --git a/tee_gateway/controllers/chat_controller.py b/tee_gateway/controllers/chat_controller.py index 4553264..93f6523 100644 --- a/tee_gateway/controllers/chat_controller.py +++ b/tee_gateway/controllers/chat_controller.py @@ -268,7 +268,7 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest): openai_response["usage"] = usage cost = compute_session_cost(chat_request.model, usage) if cost is not None: - openai_response["opengradient"] = cost + openai_response["opengradient"] = cost.model_dump(mode="json") # Validate schema (the extra tee_* fields are preserved by returning dict directly) CreateChatCompletionResponse.from_dict(openai_response) diff --git a/tee_gateway/controllers/completions_controller.py b/tee_gateway/controllers/completions_controller.py index f492e0e..b98c7b6 100644 --- a/tee_gateway/controllers/completions_controller.py +++ b/tee_gateway/controllers/completions_controller.py @@ -82,7 +82,7 @@ def create_completion(body): if usage: cost = compute_session_cost(body.model, usage) if cost is not None: - completion_response["opengradient"] = cost + completion_response["opengradient"] = cost.model_dump(mode="json") return completion_response except Exception as e: diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py index 62b82bf..50e5496 100644 --- a/tee_gateway/controllers/ohttp_controller.py +++ b/tee_gateway/controllers/ohttp_controller.py @@ -154,17 +154,12 @@ def create_anonymous_chat_completion(): return _error(400, "inner payload must be a JSON object") chat_body = _scrub(chat_body) + _set_inner_cost_context(flask_request, request_json=chat_body) body_bytes = json.dumps(chat_body, separators=(",", ":")).encode("utf-8") - # The relay pays — x-payment is a standard outer-request header, not - # inside the encrypted envelope. Pass it through to the inner endpoint - # so x402 verifies and settles exactly as it does for a normal call. - payment_header = flask_request.headers.get("X-Payment") - sub_status, sub_headers, sub_iter = _wsgi_subrequest( path="/v1/chat/completions", body_bytes=body_bytes, - payment_header=payment_header, ) inner_content_type = next( @@ -177,11 +172,21 @@ def create_anonymous_chat_completion(): ) if is_streaming: - return _build_streaming_response(decap, sub_status, sub_headers, sub_iter) + cost_context = flask_request.environ.get("x402.cost_context") + return _build_streaming_response( + decap, + sub_status, + sub_headers, + sub_iter, + cost_context if isinstance(cost_context, dict) else None, + ) - # Non-streaming: drain into bytes (this also triggers x402's - # post-response settlement via the WSGI iterator's close()). + # Non-streaming: drain into bytes, record the inner plaintext cost block + # for outer /v1/ohttp x402 settlement, then seal the response. body_bytes_out = _drain(sub_iter) + _set_inner_response_cost_context( + flask_request, body_bytes_out, status_code=sub_status + ) return _build_outer_response( decap, sub_status, sub_headers, body_bytes_out, inner_content_type ) @@ -234,6 +239,7 @@ def _build_streaming_response( status: int, headers: list[tuple[str, str]], sub_iter: Iterator[bytes], + cost_context: dict[str, Any] | None, ) -> Response: """Chunked OHTTP response (draft-ietf-ohai-chunked-ohttp-08). @@ -262,10 +268,12 @@ def _stream() -> Iterator[bytes]: yield encrypter.header() pending: bytes | None = None + plaintext_chunks: list[bytes] = [] try: for chunk in sub_iter: if not chunk: continue + plaintext_chunks.append(chunk) if pending is not None: yield encrypter.encrypt_chunk(pending, is_final=False) pending = chunk @@ -274,9 +282,11 @@ def _stream() -> Iterator[bytes]: # undetected truncation. yield encrypter.encrypt_chunk(pending or b"", is_final=True) finally: + _set_inner_stream_cost_context( + cost_context, plaintext_chunks, status_code=status + ) close = getattr(sub_iter, "close", None) if callable(close): - # Triggers x402's streaming-session settlement. close() return Response( @@ -329,17 +339,15 @@ def _extract_cost_headers(body_bytes: bytes) -> dict[str, str]: def _wsgi_subrequest( path: str, body_bytes: bytes, - payment_header: str | None, ) -> tuple[int, list[tuple[str, str]], Iterator[bytes]]: """Issue an in-process WSGI request through the app's full middleware stack. Returns ``(status_code, headers, body_iterator)``. The caller is - responsible for draining and closing the iterator (close() triggers - x402's post-response settlement). We invoke ``current_app.wsgi_app`` - directly so the x402 payment middleware (which wraps ``wsgi_app`` at - injection time) runs the same way it would for an external HTTP - request to the same path — including the pre-inference pricing gate, - payment verification, cost settlement and TEE response signing. + responsible for draining and closing the iterator. The outer /v1/ohttp + request is the x402-paid boundary, so this inner chat dispatch uses the + pre-x402 WSGI app saved at middleware installation time. That avoids + charging/verifying the same relay payment twice while still running + connexion routing, validation, TEE signing, and provider inference. """ outer_env = flask_request.environ sub_env: dict[str, Any] = { @@ -361,9 +369,6 @@ def _wsgi_subrequest( "wsgi.input": io.BytesIO(body_bytes), } ) - if payment_header: - sub_env["HTTP_X_PAYMENT"] = payment_header - # The OpenAPI spec declares a global ApiKeyAuth requirement and connexion # enforces it before our handler runs (returns 401 "No authorization # token provided"). The security function (security_controller.py) is an @@ -382,7 +387,8 @@ def _start_response(status: str, headers: list, exc_info: Any = None): captured["headers"] = headers return lambda _chunk: None - iterator = current_app.wsgi_app(sub_env, _start_response) + inner_wsgi = current_app.config.get("OHTTP_INNER_WSGI_APP") or current_app.wsgi_app + iterator = inner_wsgi(sub_env, _start_response) status_code = int(captured["status"].split(" ", 1)[0]) # Don't wrap in iter() — that would strip the iterable's close() method, # which the caller relies on to trigger x402's post-response settlement. @@ -411,3 +417,100 @@ def _error(status: int, message: str) -> tuple[dict, int]: also returned plaintext so the relay can surface them to the client — they never contain user prompts.""" return {"error": message}, status + + +def _set_inner_cost_context( + req_or_context, + *, + request_json: dict[str, Any] | None = None, + response_json: dict[str, Any] | None = None, + status_code: int | None = None, +) -> None: + if isinstance(req_or_context, dict): + cost_context = req_or_context + elif req_or_context is None: + return + else: + cost_context = req_or_context.environ.get("x402.cost_context") + if not isinstance(cost_context, dict): + return + if request_json is not None: + cost_context["inner_request_json"] = request_json + if response_json is not None: + cost_context["inner_response_json"] = response_json + if status_code is not None: + cost_context["inner_status_code"] = status_code + + +def _set_inner_response_cost_context( + req, + body_bytes: bytes, + *, + status_code: int, +) -> None: + try: + response_json = json.loads(body_bytes.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError): + response_json = None + _set_inner_cost_context( + req, + response_json=response_json if isinstance(response_json, dict) else None, + status_code=status_code, + ) + + +def _set_inner_stream_cost_context( + req, + chunks: list[bytes], + *, + status_code: int, +) -> None: + body = b"".join(chunks) + response_json = _parse_final_sse_json(body) + _set_inner_cost_context( + req, + response_json=response_json, + status_code=status_code, + ) + + +def _parse_final_sse_json(body: bytes) -> dict[str, Any] | None: + last_json: dict[str, Any] | None = None + for line in body.decode("utf-8", errors="replace").splitlines(): + line = line.strip() + if not line.startswith("data:"): + continue + payload = line[len("data:") :].strip() + if not payload or payload == "[DONE]": + continue + try: + parsed = json.loads(payload) + except json.JSONDecodeError: + continue + if isinstance(parsed, dict): + last_json = parsed + return last_json + + +def _sealed_error(req, decap: ohttp.DecapsulatedRequest, status: int, message: str): + body = {"error": message} + _set_inner_cost_context(req, response_json=body, status_code=status) + return _sealed_json_response(decap, status, body) + + +def _sealed_json_response( + decap: ohttp.DecapsulatedRequest, + status: int, + body_obj: Any, +) -> Response: + inner_json = json.dumps( + {"status": status, "body": body_obj}, + separators=(",", ":"), + ).encode("utf-8") + + sealed = ohttp.encapsulate_response(decap.response_key, decap.enc, inner_json) + return Response( + sealed, + status=200, + mimetype=OHTTP_RESPONSE_MEDIA_TYPE, + ) diff --git a/tee_gateway/test/test_ohttp_controller.py b/tee_gateway/test/test_ohttp_controller.py index 8ef6298..ebba831 100644 --- a/tee_gateway/test/test_ohttp_controller.py +++ b/tee_gateway/test/test_ohttp_controller.py @@ -235,9 +235,10 @@ def inner(): assert resp.headers.get("X-Payment-Required") == "true" assert resp.headers.get("X-Tee-Signature") == "sig" assert "Set-Cookie" not in resp.headers - # The relay's X-Payment was forwarded into the inner env so the - # x402 middleware can verify it. - assert captured["env"]["HTTP_X_PAYMENT"] == "client-payment-blob" + # The outer /v1/ohttp request is the paid x402 boundary. The decrypted + # in-process chat subrequest bypasses x402, so the payment blob must not + # be forwarded into the inner env. + assert "HTTP_X_PAYMENT" not in captured["env"] # WSGI iterator was drained AND closed (drains x402 settlement). assert captured["iter"].closed is True diff --git a/tests/test_opengradient_field.py b/tests/test_opengradient_field.py index 257700f..a26c88c 100644 --- a/tests/test_opengradient_field.py +++ b/tests/test_opengradient_field.py @@ -77,7 +77,7 @@ def test_opengradient_block_embedded_when_cost_computed(self): self.assertIsInstance(resp, dict) self.assertIn("opengradient", resp) - self.assertEqual(resp["opengradient"].cost_opg, 12345) + self.assertEqual(resp["opengradient"]["cost_opg"], "12345") def test_opengradient_block_absent_when_compute_returns_none(self): from tee_gateway.controllers import chat_controller @@ -206,7 +206,7 @@ def test_opengradient_block_embedded_on_completion(self): self.assertIsInstance(resp, dict) self.assertIn("opengradient", resp) - self.assertEqual(resp["opengradient"].cost_opg, 12345) + self.assertEqual(resp["opengradient"]["cost_opg"], "12345") class TestSessionCostCalculatorMissingBlock(unittest.TestCase): From 715b075cf8c2445c52e7eaf3719e6938b163e384 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Mon, 18 May 2026 15:34:58 -0400 Subject: [PATCH 39/39] readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index cca5f5b..da4e0b6 100644 --- a/README.md +++ b/README.md @@ -180,8 +180,8 @@ The `tee_*` fields provide cryptographic proof of the response: 1. Client fetches `/v1/ohttp/config` (HPKE pubkey, key_id, suite IDs) and verifies it against the Nitro attestation. 2. Client HPKE-encapsulates a normal chat-completion JSON body and POSTs the ciphertext to a **relay**. The client carries no payment material. -3. Relay forwards the ciphertext to `/v1/ohttp` and attaches its own `X-Payment: ` header. -4. Enclave decrypts → re-issues the request internally to `/v1/chat/completions` with the relay's `X-Payment` header → x402 verifies and settles → response is sealed back to the client. +3. Relay forwards the ciphertext to `/v1/ohttp` and attaches its own `X-Payment: ` header. **`/v1/ohttp` is the x402-paid boundary** — verification and settlement happen on this outer request, against the relay's payment. +4. Enclave decrypts → re-issues the request in-process to `/v1/chat/completions` against the pre-x402 WSGI app (so connexion routing, validation, TEE signing and the LLM call still run, but x402 does **not** fire a second time and the relay's `X-Payment` is **not** forwarded into the inner dispatch) → response is sealed back to the client. **Two response modes** (chosen by the inner `stream` flag):