diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 273d4c9..0503808 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,12 @@ jobs: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v5 - name: Run unit tests - run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py tee_gateway/test/test_price_feed.py tests/test_pricing.py -v --import-mode=importlib + # Discover the whole tee_gateway/test/ directory so new test files + # aren't silently excluded from CI — the previous explicit list had + # left test_chat_controller, test_completions_controller, and + # test_ohttp out. test_price_feed_integration.py self-gates on + # RUN_INTEGRATION_TESTS and skips cleanly without it. + run: uv run --group test pytest tee_gateway/test/ tests/test_pricing.py tests/test_opengradient_field.py -v --import-mode=importlib # To also run integration tests (real CoinGecko network calls), add: # env: # RUN_INTEGRATION_TESTS: "1" diff --git a/README.md b/README.md index 95148be..da4e0b6 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,8 @@ The `measurements.txt` checked into this repository reflects the OpenGradient-op | `/signing-key` | GET | TEE public key (PEM format) and tee_id | | `/v1/completions` | POST | Text completion (signed) | | `/v1/chat/completions` | POST | Chat completion (signed) | +| `/v1/ohttp` | POST | Anonymous chat completion (OHTTP-encapsulated, relay-paid) | +| `/v1/ohttp/config` | GET | HPKE key configuration (RFC 9458) for OHTTP clients | ### Request Format @@ -170,6 +172,39 @@ The `tee_*` fields provide cryptographic proof of the response: - **`tee_timestamp`** — Unix timestamp when the response was signed (proves freshness) - **`tee_id`** — keccak256 of the enclave's DER-encoded public key (stable identifier for this enclave instance) +## Anonymous Inference (Oblivious HTTP) + +`/v1/ohttp` is a thin wrapper around `/v1/chat/completions` that adds **client unlinkability** via [RFC 9458 OHTTP](https://www.rfc-editor.org/rfc/rfc9458) + [draft-ietf-ohai-chunked-ohttp-08](https://datatracker.ietf.org/doc/draft-ietf-ohai-chunked-ohttp/). HPKE ciphersuite is fixed: DHKEM(X25519,HKDF-SHA256) / HKDF-SHA256 / ChaCha20-Poly1305. + +**Flow:** + +1. Client fetches `/v1/ohttp/config` (HPKE pubkey, key_id, suite IDs) and verifies it against the Nitro attestation. +2. Client HPKE-encapsulates a normal chat-completion JSON body and POSTs the ciphertext to a **relay**. The client carries no payment material. +3. Relay forwards the ciphertext to `/v1/ohttp` and attaches its own `X-Payment: ` header. **`/v1/ohttp` is the x402-paid boundary** — verification and settlement happen on this outer request, against the relay's payment. +4. Enclave decrypts → re-issues the request in-process to `/v1/chat/completions` against the pre-x402 WSGI app (so connexion routing, validation, TEE signing and the LLM call still run, but x402 does **not** fire a second time and the relay's `X-Payment` is **not** forwarded into the inner dispatch) → response is sealed back to the client. + +**Two response modes** (chosen by the inner `stream` flag): + +| Mode | Outer content-type | Body | +|------|---|---| +| `stream=false` | `message/ohttp-res` | Single-shot sealed body (RFC 9458 §4.5) | +| `stream=true` | `message/ohttp-chunked-res` | `response_nonce \|\| (varint(len) \|\| sealed_ct)+ \|\| varint(0) \|\| sealed_final_ct` — one OHTTP chunk per SSE event, AAD=`b"final"` on the last chunk (chunked-ohttp draft §3) | + +**Billing channel for the relay.** Both modes settle the actual cost via x402 against the relay's `X-Payment` (`upto` scheme); the gateway is the source of truth for the amount. + +- `stream=false`: outer response exposes billing/cost headers — `X-Inference-Cost-OPG`, `X-Inference-Cost-USD`, `X-Inference-Price-OPG-USD` — for the relay's own bookkeeping. Per-token `usage` detail is carried in the sealed body for the client, not in outer `X-Usage-*` headers. +- `stream=true`: **no** per-token detail in outer headers (they're flushed before any body chunk, so we can't know token counts at header-write time) and the sealed chunks are opaque to the relay. The relay reads the actual settled amount from x402 — either by querying the facilitator with its `X-Upto-Session`, or via `X-Payment-Response` on its next call. The client still sees per-token detail in the final SSE event inside the decrypted stream. + +On non-2xx (e.g. 402 payment required) the body is forwarded plaintext so the relay can read x402 payment requirements and retry — those bodies never contain prompts or completions. + +**Trust split:** + +- **Relay** terminates the client's TCP/TLS connection, so it does see the client's IP — that's unavoidable. What it doesn't see is content: only OHTTP ciphertext + its own wallet's `x-payment` material + the outer billing/cost headers used to settle and reconcile charges. +- **Enclave** sees plaintext prompts/completions (it has to run the LLM call) but at the network layer only sees the relay's IP, never the client's. This is the unlinkability claim — the enclave can't tie a plaintext request to a specific end user. +- **Client** decrypts and verifies the TEE signature embedded in the response body against the attested public key. + +Unlinkability between a client identity and a plaintext request holds unless relay and enclave collude (the relay would have to share its client-IP log alongside the enclave's plaintext log). Streaming additionally leaks per-chunk timing and length — clients who can't accept that signal should use `stream=false`. + ## Verification ### 1. Verify Attestation diff --git a/pyproject.toml b/pyproject.toml index 80a9bc6..426ab33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "psutil>=7.2.1", "python-dotenv>=1.2.1", "eth-account>=0.13.0", + "pyhpke>=0.6.0", ] [dependency-groups] diff --git a/scripts/test_ohttp.py b/scripts/test_ohttp.py new file mode 100644 index 0000000..d5fe97b --- /dev/null +++ b/scripts/test_ohttp.py @@ -0,0 +1,396 @@ +""" +Smoke test: exercise the OHTTP anonymous-inference endpoints end-to-end against +a running gateway. Mirrors test_bytedance.py — point it at a local server and +verify a request round-trips through HPKE encap/decap + the chat backend. + +Usage: + # Non-streaming round-trip (default): + uv run python scripts/test_ohttp.py + uv run python scripts/test_ohttp.py --model gpt-4.1 --prompt "what model are you?" + + # Streaming via chunked OHTTP: + uv run python scripts/test_ohttp.py --stream --model claude-haiku-4-5 + + # Against a remote gateway (anything reachable via HTTP/S): + uv run python scripts/test_ohttp.py --url https://my-enclave.example + +The gateway must already have provider keys injected via POST /v1/keys for the +inner chat request to succeed. Run scripts/test_bytedance.py first if unsure — +it shares the same backend. +""" + +from __future__ import annotations + +import argparse +import json +import os +import struct +import sys +from pathlib import Path +from typing import IO, Iterator + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import requests # noqa: E402 +from cryptography.hazmat.primitives import hashes, hmac # noqa: E402 +from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 # noqa: E402 +from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand # noqa: E402 +from pyhpke import AEADId, CipherSuite, KDFId, KEMId # noqa: E402 + +# Fixed ciphersuite — the gateway only accepts this combination. +_SUITE = CipherSuite.new( + KEMId.DHKEM_X25519_HKDF_SHA256, + KDFId.HKDF_SHA256, + AEADId.CHACHA20_POLY1305, +) +_KEM_ID = 0x0020 +_KDF_ID = 0x0001 +_AEAD_ID = 0x0003 +_NK = 32 +_NN = 12 +_LABEL_REQ = b"message/bhttp request" +_LABEL_RESP = b"message/bhttp response" +_LABEL_RESP_CHUNKED = b"message/bhttp chunked response" + + +# --------------------------------------------------------------------------- +# HPKE encapsulation (client side) +# --------------------------------------------------------------------------- + + +def encapsulate_request( + public_key_raw: bytes, key_id: int, plaintext: bytes +) -> tuple[bytes, object, bytes]: + """Build an OHTTP-encapsulated request. Returns ``(wire, sender, enc)``. + + The sender context is kept by the caller so it can later export the + response secret (single-shot or chunked) and decrypt the reply.""" + hdr = bytes([key_id]) + struct.pack(">HHH", _KEM_ID, _KDF_ID, _AEAD_ID) + info = _LABEL_REQ + b"\x00" + hdr + pkr = _SUITE.kem.deserialize_public_key(public_key_raw) + enc, sender = _SUITE.create_sender_context(pkr, info=info) + ct = sender.seal(plaintext, aad=b"") + return hdr + enc + ct, sender, enc + + +# --------------------------------------------------------------------------- +# Response decryption (single-shot + chunked) +# --------------------------------------------------------------------------- + + +def _derive_response_keys( + response_secret: bytes, enc: bytes, response_nonce: bytes +) -> tuple[bytes, bytes]: + h = hmac.HMAC(enc + response_nonce, hashes.SHA256()) + h.update(response_secret) + prk = h.finalize() + aead_key = HKDFExpand(algorithm=hashes.SHA256(), length=_NK, info=b"key").derive( + prk + ) + aead_nonce = HKDFExpand( + algorithm=hashes.SHA256(), length=_NN, info=b"nonce" + ).derive(prk) + return aead_key, aead_nonce + + +def decrypt_single_shot(sealed: bytes, sender, enc: bytes) -> bytes: + response_secret = sender.export(_LABEL_RESP, max(_NN, _NK)) + nonce_len = max(_NN, _NK) + response_nonce = sealed[:nonce_len] + aead_ct = sealed[nonce_len:] + aead_key, aead_nonce = _derive_response_keys(response_secret, enc, response_nonce) + return ChaCha20Poly1305(aead_key).decrypt(aead_nonce, aead_ct, b"") + + +def _read_varint(stream: IO[bytes]) -> int | None: + """Read one QUIC varint from a byte stream. Returns None at clean EOF.""" + first = stream.read(1) + if not first: + return None + b = first[0] + nbytes = 1 << (b >> 6) # 1, 2, 4, or 8 + rest = stream.read(nbytes - 1) if nbytes > 1 else b"" + if len(rest) != nbytes - 1: + raise ValueError("truncated varint") + head = bytes([b & 0x3F]) + rest + if nbytes == 1: + return b & 0x3F + if nbytes == 2: + return (head[0] << 8) | head[1] + if nbytes == 4: + return struct.unpack(">I", head)[0] + return struct.unpack(">Q", head)[0] + + +def decrypt_chunked_stream( + raw_stream: IO[bytes], sender, enc: bytes +) -> Iterator[bytes]: + """Decrypt a chunked OHTTP response (draft-ietf-ohai-chunked-ohttp-08) + incrementally, yielding plaintext as each sealed chunk arrives. + + Wire: response_nonce || (varint(len) || ct)+ || varint(0) || final_ct + Per-chunk nonce = aead_nonce XOR encode_be(counter), AAD=""/b"final". + """ + response_secret = sender.export(_LABEL_RESP_CHUNKED, max(_NN, _NK)) + response_nonce = raw_stream.read(max(_NN, _NK)) + if len(response_nonce) != max(_NN, _NK): + raise ValueError("truncated response_nonce") + aead_key, aead_nonce = _derive_response_keys(response_secret, enc, response_nonce) + aead = ChaCha20Poly1305(aead_key) + counter = 0 + + while True: + length = _read_varint(raw_stream) + if length is None: + raise ValueError("stream ended without AAD=final chunk") + + chunk_nonce = bytes( + a ^ b for a, b in zip(aead_nonce, counter.to_bytes(_NN, "big")) + ) + counter += 1 + + if length == 0: + final_ct = raw_stream.read() # rest of stream + yield aead.decrypt(chunk_nonce, final_ct, b"final") + return + + ct = raw_stream.read(length) + if len(ct) != length: + raise ValueError("truncated chunk ciphertext") + yield aead.decrypt(chunk_nonce, ct, b"") + + +# --------------------------------------------------------------------------- +# Wire-payload dump (for visual confirmation that nothing plaintext leaves) +# --------------------------------------------------------------------------- + + +def _hexdump(data: bytes, max_bytes: int = 96) -> str: + """xxd-style hex dump, capped at ``max_bytes`` so output stays readable.""" + out: list[str] = [] + for i in range(0, min(len(data), max_bytes), 16): + row = data[i : i + 16] + hex_part = " ".join(f"{b:02x}" for b in row) + ascii_part = "".join(chr(b) if 32 <= b < 127 else "." for b in row) + out.append(f" {i:04x} {hex_part:<48} |{ascii_part}|") + if len(data) > max_bytes: + out.append(f" ... ({len(data) - max_bytes} more bytes of ciphertext)") + return "\n".join(out) + + +def dump_outgoing( + url: str, + headers: dict[str, str], + inner_plaintext: bytes, + wire: bytes, + enc: bytes, + key_id: int, +) -> None: + """Print everything that's about to go on the wire so you can eyeball + that the relay only sees opaque ciphertext.""" + print("\n================ OUTGOING REQUEST ================") + print(f"POST {url}") + for name, value in headers.items(): + print(f" {name}: {value}") + print( + f"\n inner plaintext ({len(inner_plaintext)} bytes, " + f"NEVER goes on the wire — sealed under HPKE):" + ) + try: + print( + " " + + json.dumps(json.loads(inner_plaintext), indent=4).replace("\n", "\n ") + ) + except json.JSONDecodeError: + print(f" {inner_plaintext!r}") + + # The wire body decomposes into: + # header (7 bytes): key_id || kem_id || kdf_id || aead_id + # enc (32 bytes): client's ephemeral X25519 public key + # ct (rest): AEAD ciphertext + 16-byte tag — opaque to the relay + print(f"\n encapsulated body ({len(wire)} bytes total):") + print( + f" OHTTP header = {wire[:7].hex()} " + f"(key_id=0x{key_id:02x}, suite=0x0020/0x0001/0x0003)" + ) + print(f" enc (ephemeral)= {enc.hex()}") + print( + f" ciphertext = {len(wire) - 7 - 32} bytes (HPKE-sealed, no plaintext leaks):" + ) + print(_hexdump(wire[7 + 32 :])) + print("==================================================\n") + + +# --------------------------------------------------------------------------- +# Driver +# --------------------------------------------------------------------------- + + +def fetch_config(base_url: str) -> tuple[bytes, int]: + """GET /v1/ohttp/config; return (public_key_raw_bytes, key_id).""" + r = requests.get(f"{base_url}/v1/ohttp/config", timeout=10) + r.raise_for_status() + cfg = r.json() + print( + f"HPKE config: key_id={cfg['key_id']} suite={cfg['kem_id']}/{cfg['kdf_id']}/{cfg['aead_id']}" + ) + print(f" public_key = {cfg['public_key']}") + print(f" key_config = {cfg['key_config']}") + pk_raw = bytes.fromhex(cfg["public_key"]) + if len(pk_raw) != 32: + raise ValueError(f"expected 32-byte X25519 pubkey, got {len(pk_raw)}") + return pk_raw, cfg["key_id"] + + +def verify_attestation_binding(base_url: str, hpke_pubkey_hex: str) -> None: + """Cross-check that the HPKE pubkey we just got matches what the + attestation document at /signing-key reports. A network attacker + between us and the enclave could otherwise swap in their own + pubkey and decrypt our prompt.""" + r = requests.get(f"{base_url}/signing-key", timeout=10) + r.raise_for_status() + doc = r.json() + attested = (doc.get("hpke") or {}).get("public_key") + if attested is None: + print( + "WARN: /signing-key did not include hpke.public_key; skipping binding check" + ) + return + if attested.lower() != hpke_pubkey_hex.lower(): + raise ValueError( + f"HPKE pubkey mismatch! config={hpke_pubkey_hex} attestation={attested}" + ) + print("HPKE pubkey matches the attestation document") + + +def run_non_streaming( + base_url: str, model: str, prompt: str, payment: str | None +) -> int: + pk_raw, key_id = fetch_config(base_url) + verify_attestation_binding(base_url, pk_raw.hex()) + + inner = { + "model": model, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 200, + } + inner_bytes = json.dumps(inner, separators=(",", ":")).encode("utf-8") + wire, sender, enc = encapsulate_request(pk_raw, key_id, inner_bytes) + + headers = {"Content-Type": "message/ohttp-req"} + if payment: + headers["X-Payment"] = payment + + dump_outgoing(f"{base_url}/v1/ohttp", headers, inner_bytes, wire, enc, key_id) + r = requests.post(f"{base_url}/v1/ohttp", data=wire, headers=headers, timeout=120) + print(f"HTTP {r.status_code} content-type={r.headers.get('content-type')}") + + if r.status_code >= 400 or "ohttp-res" not in (r.headers.get("content-type") or ""): + # Plaintext error pass-through (e.g. 402 with x402 payment requirements). + print("---- plaintext error body ----") + try: + print(json.dumps(r.json(), indent=2)) + except ValueError: + print(r.text) + return 1 + + print("---- forwarded headers ----") + for k, v in r.headers.items(): + if k.lower().startswith( + ("x-payment", "x-upto", "x-settlement", "x-tee", "x-usage") + ): + print(f" {k}: {v}") + + plaintext = decrypt_single_shot(r.content, sender, enc) + print("---- decrypted response ----") + try: + parsed = json.loads(plaintext) + choices = parsed.get("choices") or [] + if choices: + print(choices[0].get("message", {}).get("content", "")) + print("---- usage ----") + print(json.dumps(parsed.get("usage"), indent=2)) + for field in ("tee_signature", "tee_request_hash", "tee_output_hash"): + if field in parsed: + print(f" {field}: {parsed[field][:40]}...") + except json.JSONDecodeError: + print(plaintext.decode("utf-8", errors="replace")) + return 0 + + +def run_streaming(base_url: str, model: str, prompt: str, payment: str | None) -> int: + pk_raw, key_id = fetch_config(base_url) + verify_attestation_binding(base_url, pk_raw.hex()) + + inner = { + "model": model, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 200, + "stream": True, + } + inner_bytes = json.dumps(inner, separators=(",", ":")).encode("utf-8") + wire, sender, enc = encapsulate_request(pk_raw, key_id, inner_bytes) + + headers = {"Content-Type": "message/ohttp-req"} + if payment: + headers["X-Payment"] = payment + + dump_outgoing(f"{base_url}/v1/ohttp", headers, inner_bytes, wire, enc, key_id) + r = requests.post( + f"{base_url}/v1/ohttp", + data=wire, + headers=headers, + timeout=120, + stream=True, + ) + print(f"HTTP {r.status_code} content-type={r.headers.get('content-type')}") + + ct = r.headers.get("content-type", "") + if r.status_code >= 400 or "chunked-res" not in ct: + print("---- non-streaming response body ----") + print(r.text) + return 1 + + print("---- decrypted SSE events ----") + # urllib3's HTTPResponse — supports .read(n) without re-decoding chunked transfer. + for plaintext in decrypt_chunked_stream(r.raw, sender, enc): + text = plaintext.decode("utf-8", errors="replace") + sys.stdout.write(text) + sys.stdout.flush() + print("\n---- end of stream ----") + return 0 + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--url", default=os.environ.get("OHTTP_URL", "http://127.0.0.1:8000") + ) + parser.add_argument("--model", default="gpt-4.1") + parser.add_argument( + "--prompt", default="What model are you? Reply in one short sentence." + ) + parser.add_argument( + "--stream", action="store_true", help="Use chunked OHTTP streaming" + ) + parser.add_argument( + "--payment", + default=os.environ.get("X_PAYMENT"), + help="Optional x402 payment payload (base64). Send via the outer X-Payment header.", + ) + args = parser.parse_args() + + try: + if args.stream: + return run_streaming(args.url, args.model, args.prompt, args.payment) + return run_non_streaming(args.url, args.model, args.prompt, args.payment) + except requests.RequestException as exc: + print(f"\nERROR: HTTP failure — {exc}", file=sys.stderr) + return 2 + except Exception as exc: + print(f"\nERROR: {type(exc).__name__}: {exc}", file=sys.stderr) + return 2 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index 165d5d1..0d4483a 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -23,6 +23,10 @@ ) from tee_gateway.llm_backend import get_provider_config, set_provider_config from tee_gateway.heartbeat import create_heartbeat_service +from tee_gateway.controllers.ohttp_controller import ( + create_anonymous_chat_completion, + get_hpke_config, +) from x402.http import FacilitatorConfig, HTTPFacilitatorClientSync, PaymentOption from x402.http.middleware.flask import payment_middleware @@ -37,9 +41,8 @@ from x402.session import SessionStore import x402.http.middleware.flask as x402_flask -from .util import calculate_session_cost from .model_registry import get_model_config -from .price_feed import OPGPriceFeed +from .price_feed import OPGPriceFeed, set_price_feed from .definitions import ( EVM_PAYMENT_ADDRESS, BASE_MAINNET_NETWORK, @@ -117,6 +120,7 @@ def _shutdown_heartbeat(): # --------------------------------------------------------------------------- _price_feed = OPGPriceFeed() _price_feed.start() +set_price_feed(_price_feed) _started_at = time.time() @@ -156,23 +160,97 @@ def _patched_read_body_bytes(environ): x402_flask._read_body_bytes = _patched_read_body_bytes +def _patched_stream_session_response( + self, + environ, + start_response, + context, + session_id, + payment_payload, + payment_requirements, +): + """Expose x402's per-request cost context to Flask route handlers. + + OHTTP requests arrive as ciphertext and return ciphertext, so the x402 + middleware cannot parse request_json/response_json from the outer HTTP + bodies. The OHTTP controller decrypts the inner request and plaintext + response inside the enclave; this patch gives it a request-local dict where + it can attach those inner JSON objects for dynamic settlement. + """ + self._start_reaper() + + request_body_bytes = x402_flask._read_body_bytes(environ) + request_json = x402_flask._try_parse_json(request_body_bytes) + parsed_request_json = ( + request_json if isinstance(request_json, (dict, list)) else None + ) + + x402_flask.g.payment_payload = payment_payload + x402_flask.g.payment_requirements = payment_requirements + x402_flask.g.x402_session_id = session_id + + cost_context = { + "method": context.method, + "path": context.path, + "request_body_bytes": request_body_bytes, + "request_json": parsed_request_json, + "payment_payload": payment_payload, + "payment_requirements": payment_requirements, + } + environ["x402.cost_context"] = cost_context + + status_capture = x402_flask.StatusCapture(start_response) + status_capture.add_header(x402_flask.UPTO_SESSION_HEADER, session_id) + + upstream_iter = self._original_wsgi(environ, status_capture) + + return x402_flask.StreamingSessionResponse( + upstream_iter, + middleware=self, + session_id=session_id, + cost_context=cost_context, + status_ref=status_capture, + ) + + +setattr( + x402_flask.PaymentMiddleware, + "_stream_session_response", + _patched_stream_session_response, +) + + def _session_cost_calculator(ctx: dict) -> int: - # Post-inference cost calculation — response already sent to client. - # Predictable failures (unknown price, unknown model) are blocked by the - # pre-inference gate; any exception here indicates a provider-side error - # (e.g. missing usage field in the LLM response). The x402 middleware - # swallows the exception in close(), so the client is not charged. - # Log CRITICAL so provider errors are never silently missed. - try: - return calculate_session_cost(ctx, _price_feed.get_price) - except Exception as exc: + # The chat/completions controllers compute cost in-band and embed it on + # the response as a SessionCost model BEFORE returning. We parse it back + # here — single source of truth for cost lives in the controller, not + # split between the controller (which serves clients) and x402's callback + # (which charges them). + # + # If the controller couldn't compute cost (e.g. missing usage from the + # provider), the block is absent — SessionCost.model_validate raises and + # x402's close() swallows it so the client is not charged. The controller + # has already logged CRITICAL in that case. + from .pricing import SessionCost + + if ctx.get("path") == "/v1/ohttp": + response_json = ctx.get("inner_response_json") + else: + response_json = ctx.get("response_json") + if not isinstance(response_json, dict): + raise ValueError("response_json missing or not a dict") + cost_block = response_json.get("opengradient") + if cost_block is None: + # Should never happen on a successful paid response — the controller + # always embeds this when it can compute it, and when it can't, the + # request typically errored out before reaching x402's close hook. + # If we see this, settlement silently skips and a real bug is hiding. logger.critical( - "Post-inference cost calculation failed (provider error) — " - "client was NOT charged: %s", - exc, - exc_info=True, + "opengradient cost block missing on paid response — client will " + "NOT be charged. response_id=%s", + response_json.get("id"), ) - raise + return SessionCost.model_validate(cost_block).cost_opg # --------------------------------------------------------------------------- @@ -242,8 +320,35 @@ def _init_payment_middleware(facilitator_url: str) -> None: mime_type="application/json", description="Completion", ), + "POST /v1/ohttp": RouteConfig( + accepts=[ + PaymentOption( + scheme="upto", + pay_to=EVM_PAYMENT_ADDRESS, + price=AssetAmount( + amount=CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, + asset=BASE_MAINNET_OPG_ADDRESS, + extra={ + "name": "OpenGradient", + "version": "1", + "assetTransferMethod": "permit2", + }, + ), + network=BASE_MAINNET_NETWORK, + ), + ], + extensions={ + **declare_erc20_approval_gas_sponsoring_extension(), + }, + mime_type="message/ohttp-req", + description="OHTTP-wrapped chat completion", + ), } + inner_wsgi_app = application.wsgi_app + flask_app = getattr(application, "app", application) + flask_app.config["OHTTP_INNER_WSGI_APP"] = inner_wsgi_app + # Return value intentionally discarded — PaymentMiddleware.__init__ self-wires # by setting application.wsgi_app = self._wsgi_middleware internally. payment_middleware( @@ -437,6 +542,20 @@ def create_app(): "/heartbeat/status", "heartbeat-status", heartbeat_status, methods=["GET"] ) + # Anonymous inference (OHTTP-wrapped chat completions). Deliberately + # mounted via add_url_rule rather than the OpenAPI spec because the body + # is raw binary and connexion's request-validation pipeline would reject + # it as malformed JSON. + app.app.add_url_rule( + "/v1/ohttp", + "anonymous-chat", + create_anonymous_chat_completion, + methods=["POST"], + ) + app.app.add_url_rule( + "/v1/ohttp/config", "ohttp-config", get_hpke_config, methods=["GET"] + ) + # Initialize TEE here so it runs under both Gunicorn and direct execution. # This is the single TEEKeyManager instance — the same key both registers # with nitriding and signs all LLM responses. @@ -470,7 +589,7 @@ def create_app(): @application.before_request def _check_pricing_ready(): - if request.path not in ("/v1/chat/completions", "/v1/completions"): + if request.path not in ("/v1/chat/completions", "/v1/completions", "/v1/ohttp"): return try: _price_feed.get_price() @@ -478,6 +597,9 @@ def _check_pricing_ready(): logger.warning("Rejecting inference request — price feed unavailable: %s", exc) return jsonify({"error": f"Pricing unavailable: {exc}"}), 503 + if request.path == "/v1/ohttp": + return + body = request.get_json(silent=True, cache=True) or {} model = body.get("model") if model: diff --git a/tee_gateway/controllers/chat_controller.py b/tee_gateway/controllers/chat_controller.py index e6d9337..93f6523 100644 --- a/tee_gateway/controllers/chat_controller.py +++ b/tee_gateway/controllers/chat_controller.py @@ -30,6 +30,7 @@ convert_messages, extract_usage, ) +from tee_gateway.pricing import compute_session_cost logger = logging.getLogger(__name__) @@ -224,8 +225,6 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest): for tc in response.tool_calls ] - usage = extract_usage(response) - # For tool-call responses, hash the serialized tool calls so the # signature covers which tools were invoked and with what arguments. if finish_reason == "tool_calls" and message_dict.get("tool_calls"): @@ -263,8 +262,13 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest): f"Response Final\n\tTEE Signature: {signature}\n\tTEE request hash: {input_hash_hex}\n\tTEE output hash: {output_hash_hex}\n\tTEE timestamp: {timestamp}\n\tTEE ID: 0x{tee_keys.get_tee_id()}" ) + # TODO: If no usage is returned, we should compute it here. + usage = extract_usage(response) if usage: openai_response["usage"] = usage + cost = compute_session_cost(chat_request.model, usage) + if cost is not None: + openai_response["opengradient"] = cost.model_dump(mode="json") # Validate schema (the extra tee_* fields are preserved by returning dict directly) CreateChatCompletionResponse.from_dict(openai_response) @@ -584,12 +588,19 @@ def generate(): f"Response Final\n\tTEE Signature: {tee_signature}\n\tTEE request hash: {input_hash_hex}\n\tTEE output hash: {output_hash_hex}\n\tTEE timestamp: {timestamp}\n\tTEE ID: 0x{tee_keys.get_tee_id()}" ) + # TODO: If no usage is returned, we should compute it here. if final_usage: final_data["usage"] = { "prompt_tokens": final_usage.get("input_tokens", 0), "completion_tokens": final_usage.get("output_tokens", 0), "total_tokens": final_usage.get("total_tokens", 0), } + cost = compute_session_cost(chat_request.model, final_data["usage"]) + if cost is not None: + # final_data is hand-serialized to SSE via json.dumps below, + # which doesn't go through Flask's JSONEncoder — so do the + # serialization ourselves here. + final_data["opengradient"] = cost.model_dump(mode="json") logger.info( f"Stream completed — usage: {final_data['usage']}, " f"finish: {finish_reason}, " diff --git a/tee_gateway/controllers/completions_controller.py b/tee_gateway/controllers/completions_controller.py index 78b9314..b98c7b6 100644 --- a/tee_gateway/controllers/completions_controller.py +++ b/tee_gateway/controllers/completions_controller.py @@ -4,12 +4,15 @@ import logging import connexion +from typing import Any + from langchain_core.messages import HumanMessage from tee_gateway.models.create_completion_request import CreateCompletionRequest from tee_gateway.tee_manager import get_tee_keys, compute_tee_msg_hash from tee_gateway.llm_backend import get_chat_model_cached, extract_usage +from tee_gateway.pricing import compute_session_cost logger = logging.getLogger(__name__) @@ -57,7 +60,7 @@ def create_completion(body): tee_keys = get_tee_keys() signature = tee_keys.sign_data(msg_hash) - return { + completion_response: dict[str, Any] = { "id": f"cmpl-{uuid.uuid4()}", "object": "text_completion", "created": timestamp, @@ -76,6 +79,11 @@ def create_completion(body): "tee_timestamp": timestamp, "tee_id": f"0x{tee_keys.get_tee_id()}", } + if usage: + cost = compute_session_cost(body.model, usage) + if cost is not None: + completion_response["opengradient"] = cost.model_dump(mode="json") + return completion_response except Exception as e: logger.error(f"Completion error: {str(e)}", exc_info=True) diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py new file mode 100644 index 0000000..50e5496 --- /dev/null +++ b/tee_gateway/controllers/ohttp_controller.py @@ -0,0 +1,516 @@ +""" +Oblivious HTTP endpoint for anonymous inference (relay-pays model). + +This handler is a thin shell: it HPKE-decapsulates the inner request, re-issues +it as an in-process WSGI sub-request against the enclave's own +``/v1/chat/completions``, then encapsulates the response. All x402 payment, +LangChain routing, cost settlement and TEE response signing reuse the public +chat code paths — there is no duplicated routing or pricing logic here. + +Two response modes are supported, dispatched by the inner ``stream`` flag: + * stream=false → single-shot OHTTP response (RFC 9458 §4.5), + content-type ``message/ohttp-res``. + * stream=true → chunked OHTTP response (draft-ietf-ohai-chunked-ohttp-08), + content-type ``message/ohttp-chunked-res``. Each SSE event from the + inner /v1/chat/completions stream becomes one sealed OHTTP chunk; the + final chunk uses AAD=b"final" so truncation is detectable. + +Billing channel for the relay (both modes settle the real cost via x402 +against the relay's x-payment; the gateway is the source of truth for the +amount): + * stream=false: outer headers expose ONLY the settled cost — + X-Inference-Cost-OPG (smallest units, the integer x402 actually charged) + and X-Inference-Cost-USD (the equivalent USD at the price used). No + model name, no token counts: those would be a fingerprint of the + inner request and have no role in billing. + * stream=true: outer response headers are flushed before any body chunk + is yielded, so cost isn't known at header-write time. The relay reads + the settled amount from x402: either by querying the facilitator with + its X-Upto-Session, or via X-Payment-Response on its next request. The + client still sees cost in the final SSE event inside the encrypted + stream (the ``opengradient`` block written by the chat controller). + +Trust / payment model: + * The CLIENT encrypts only an LLM chat-completion request. It does not see, + sign, or carry x402 payment material. + * A RELAY sits between the client and the enclave. The relay holds the + x402 wallet and forwards the (still-encrypted) inner request to the + enclave, attaching its own ``x-payment`` header on the OUTER HTTP + request. + * The ENCLAVE decrypts the inner payload, forwards the request to its own + chat endpoint with the relay's ``x-payment`` header, and returns. The + relay sees status and settlement headers and, for non-stream, the outer + ``X-Inference-Cost-OPG`` / ``X-Inference-Cost-USD`` billing headers, but + never sees the inner prompt or completion. + +Privacy properties: + * Relay (network position): terminates the client's TCP/TLS connection, + so the relay DOES see the client's IP at the network layer — that's + unavoidable. What the relay does NOT see is the request/response + content: it observes only the OHTTP-encapsulated ciphertext, its own + wallet's x-payment material, and (single-shot only) the settled-cost + outer headers it needs to bill its own customer. Model name and token + counts are deliberately NOT surfaced — they'd fingerprint the inner + request without adding anything billing-relevant. + * Enclave (compute position): sees the plaintext prompt and completion + (they are decrypted inside the enclave to run the LLM call), but at + the network layer it only sees the RELAY's IP — never the client's. + That's the unlinkability property: the enclave cannot tie a request's + plaintext back to a specific end user. + * Unlinkability between a specific client identity and a specific + plaintext request holds unless the relay and the enclave collude + (the relay would have to disclose its client-IP log alongside the + enclave's per-request plaintext log). + * Streaming leaks per-chunk timing and length (the relay sees the + cadence of varint-framed sealed chunks). This is an inherent cost of + server-sent events — clients who can't accept that signal should + use stream=false. +""" + +from __future__ import annotations + +import io +import json +import logging +from typing import Any, Iterator + +from flask import Response, current_app, request as flask_request + +from tee_gateway import ohttp +from tee_gateway.tee_manager import get_tee_keys +from tee_gateway.pricing import SessionCost + +logger = logging.getLogger(__name__) + +OHTTP_RESPONSE_MEDIA_TYPE = "message/ohttp-res" +OHTTP_CHUNKED_RESPONSE_MEDIA_TYPE = "message/ohttp-chunked-res" +_SSE_CONTENT_TYPE = "text/event-stream" + +# Cap on the encapsulated request size. The inner payload is a chat-completion +# JSON body; even with long conversation history this comfortably fits in a few +# hundred KB. Rejecting larger bodies up-front prevents a malicious relay from +# forcing the enclave to allocate and attempt HPKE decapsulation on arbitrarily +# large blobs. +_MAX_ENCAPSULATED_REQUEST_BYTES = 512 * 1024 + +# Fields that can re-identify a client and have no role in inference. We drop +# them before forwarding to the inner handler — keeping them inside the +# encrypted envelope would only protect them from the relay, not from us or +# the upstream LLM provider. +_IDENTIFYING_FIELDS = ("user", "metadata", "x-request-id", "request_id") + +# Response headers we propagate from the inner /v1/chat/completions response +# back through the relay to the client. +_FORWARDED_HEADER_PREFIXES = ("x-payment", "x-upto", "x-settlement", "x-tee") +_FORWARDED_HEADER_NAMES = ("www-authenticate",) + + +def _scrub(payload: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in payload.items() if k not in _IDENTIFYING_FIELDS} + + +def _should_forward_header(name: str) -> bool: + lower = name.lower() + return lower in _FORWARDED_HEADER_NAMES or any( + lower.startswith(p) for p in _FORWARDED_HEADER_PREFIXES + ) + + +def create_anonymous_chat_completion(): + """POST /v1/ohttp — decrypt, sub-dispatch to /v1/chat/completions, re-encrypt.""" + declared_length = flask_request.content_length + if ( + declared_length is not None + and declared_length > _MAX_ENCAPSULATED_REQUEST_BYTES + ): + return _error(413, "encapsulated request too large") + + raw_body: bytes = flask_request.get_data(cache=False, parse_form_data=False) + if not raw_body: + return _error(400, "empty body") + # Re-check after read in case Content-Length was absent or lied about the + # actual stream length (chunked transfer, malicious client). + if len(raw_body) > _MAX_ENCAPSULATED_REQUEST_BYTES: + return _error(413, "encapsulated request too large") + + tee = get_tee_keys() + if tee.hpke_private_key is None: + return _error(503, "anonymous inference not initialized") + + try: + decap = ohttp.decapsulate_request(tee.hpke_private_key, raw_body) + except Exception as exc: + # Don't leak which step failed — clients can retry with a fresh + # encapsulation, all observable failures look identical. + logger.warning("OHTTP decapsulation failed: %s", type(exc).__name__) + return _error(400, "malformed encapsulated request") + + try: + chat_body = json.loads(decap.plaintext.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError): + return _error(400, "inner payload is not valid JSON") + + if not isinstance(chat_body, dict): + return _error(400, "inner payload must be a JSON object") + + chat_body = _scrub(chat_body) + _set_inner_cost_context(flask_request, request_json=chat_body) + body_bytes = json.dumps(chat_body, separators=(",", ":")).encode("utf-8") + + sub_status, sub_headers, sub_iter = _wsgi_subrequest( + path="/v1/chat/completions", + body_bytes=body_bytes, + ) + + inner_content_type = next( + (v for k, v in sub_headers if k.lower() == "content-type"), + "application/json", + ) + is_streaming = ( + 200 <= sub_status < 300 + and inner_content_type.split(";", 1)[0].strip().lower() == _SSE_CONTENT_TYPE + ) + + if is_streaming: + cost_context = flask_request.environ.get("x402.cost_context") + return _build_streaming_response( + decap, + sub_status, + sub_headers, + sub_iter, + cost_context if isinstance(cost_context, dict) else None, + ) + + # Non-streaming: drain into bytes, record the inner plaintext cost block + # for outer /v1/ohttp x402 settlement, then seal the response. + body_bytes_out = _drain(sub_iter) + _set_inner_response_cost_context( + flask_request, body_bytes_out, status_code=sub_status + ) + return _build_outer_response( + decap, sub_status, sub_headers, body_bytes_out, inner_content_type + ) + + +def _build_outer_response( + decap: ohttp.DecapsulatedRequest, + status: int, + headers: list[tuple[str, str]], + body_bytes: bytes, + inner_content_type: str, +) -> Response: + """Single-shot OHTTP response. Seals the body on 2xx (contains user + prompts/completions) and surfaces only extracted cost/price headers as + outer headers for relay billing (including `X-Inference-Price-OPG-USD`); + usage details remain sealed in the encapsulated body. Non-2xx bodies + (x402 payment requirements, validation errors) are forwarded as + plaintext so the relay can act on them.""" + # Keep headers as a list of (name, value) tuples — WSGI gives us a list + # specifically because HTTP allows multi-valued headers (RFC 7230 §3.2.2; + # WWW-Authenticate in particular per RFC 7235 §4.1 can repeat, one per + # challenge scheme). Collapsing to a dict would drop duplicates silently + # and could lose an x402 challenge if future versions emit more than one. + forwarded: list[tuple[str, str]] = [ + (name, value) for name, value in headers if _should_forward_header(name) + ] + + if not (200 <= status < 300): + return Response( + body_bytes, + status=status, + headers=forwarded, + content_type=inner_content_type, + ) + + # Cost headers come from a JSON parse — single-valued by construction — + # so a dict is fine internally; just project to tuples at merge time. + forwarded.extend(_extract_cost_headers(body_bytes).items()) + sealed = ohttp.encapsulate_response(decap.response_key, decap.enc, body_bytes) + return Response( + sealed, + status=status, + headers=forwarded, + mimetype=OHTTP_RESPONSE_MEDIA_TYPE, + ) + + +def _build_streaming_response( + decap: ohttp.DecapsulatedRequest, + status: int, + headers: list[tuple[str, str]], + sub_iter: Iterator[bytes], + cost_context: dict[str, Any] | None, +) -> Response: + """Chunked OHTTP response (draft-ietf-ohai-chunked-ohttp-08). + + Each SSE event from the inner /v1/chat/completions stream is sealed as + one OHTTP chunk and yielded immediately. The final chunk uses + AAD=b"final" with a zero-length varint prefix; emitting it requires + look-ahead by one chunk so we know which one is last, hence the + ``pending`` buffer below. + + Cost can't be exposed as outer headers (those are already flushed + before the body); the relay bills via x402 settlement metadata + (X-Upto-Session header, set up-front). The client reads cost from the + ``opengradient`` block on the final SSE event inside the decrypted + stream. + """ + # See _build_outer_response: keep as a list so duplicate HTTP header + # values (e.g. multiple WWW-Authenticate challenges) survive forwarding. + forwarded: list[tuple[str, str]] = [ + (name, value) for name, value in headers if _should_forward_header(name) + ] + + def _stream() -> Iterator[bytes]: + encrypter = ohttp.ChunkedResponseEncrypter( + decap.response_key_chunked, decap.enc + ) + yield encrypter.header() + + pending: bytes | None = None + plaintext_chunks: list[bytes] = [] + try: + for chunk in sub_iter: + if not chunk: + continue + plaintext_chunks.append(chunk) + if pending is not None: + yield encrypter.encrypt_chunk(pending, is_final=False) + pending = chunk + # Always emit exactly one final chunk so the AAD=b"final" + # marker is present — that's what protects clients from + # undetected truncation. + yield encrypter.encrypt_chunk(pending or b"", is_final=True) + finally: + _set_inner_stream_cost_context( + cost_context, plaintext_chunks, status_code=status + ) + close = getattr(sub_iter, "close", None) + if callable(close): + close() + + return Response( + _stream(), + status=status, + headers=forwarded, + mimetype=OHTTP_CHUNKED_RESPONSE_MEDIA_TYPE, + ) + + +def _drain(sub_iter: Iterator[bytes]) -> bytes: + chunks: list[bytes] = [] + try: + for chunk in sub_iter: + if chunk: + chunks.append(chunk) + finally: + close = getattr(sub_iter, "close", None) + if callable(close): + close() + return b"".join(chunks) + + +def _extract_cost_headers(body_bytes: bytes) -> dict[str, str]: + """Project the response's ``opengradient`` cost block onto outer headers + for the relay's billing pipeline. Cost is the only metadata the relay + needs — model name and token counts stay sealed (they'd fingerprint the + inner request without being billing-relevant). The price is surfaced too + so the relay (and downstream auditors) can verify + ``cost_usd == cost_opg / 10^decimals * opg_price_usd`` without trusting us. + """ + try: + body = json.loads(body_bytes.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError): + return {} + if not isinstance(body, dict): + return {} + try: + cost = SessionCost.model_validate(body.get("opengradient")) + except Exception: + return {} + wire = cost.model_dump(mode="json") + return { + "X-Inference-Cost-OPG": wire["cost_opg"], + "X-Inference-Cost-USD": wire["cost_usd"], + "X-Inference-Price-OPG-USD": wire["opg_price_usd"], + } + + +def _wsgi_subrequest( + path: str, + body_bytes: bytes, +) -> tuple[int, list[tuple[str, str]], Iterator[bytes]]: + """Issue an in-process WSGI request through the app's full middleware stack. + + Returns ``(status_code, headers, body_iterator)``. The caller is + responsible for draining and closing the iterator. The outer /v1/ohttp + request is the x402-paid boundary, so this inner chat dispatch uses the + pre-x402 WSGI app saved at middleware installation time. That avoids + charging/verifying the same relay payment twice while still running + connexion routing, validation, TEE signing, and provider inference. + """ + outer_env = flask_request.environ + sub_env: dict[str, Any] = { + k: v + for k, v in outer_env.items() + if k.startswith("wsgi.") + or k in ("SERVER_NAME", "SERVER_PORT", "SERVER_PROTOCOL", "HTTP_HOST") + } + sub_env.update( + { + "REQUEST_METHOD": "POST", + "PATH_INFO": path, + "RAW_URI": path, + "REQUEST_URI": path, + "SCRIPT_NAME": "", + "QUERY_STRING": "", + "CONTENT_TYPE": "application/json", + "CONTENT_LENGTH": str(len(body_bytes)), + "wsgi.input": io.BytesIO(body_bytes), + } + ) + # The OpenAPI spec declares a global ApiKeyAuth requirement and connexion + # enforces it before our handler runs (returns 401 "No authorization + # token provided"). The security function (security_controller.py) is an + # intentional passthrough — x402 is the real access control — so any + # value satisfies the schema check. We deliberately do NOT forward the + # outer Authorization header: anything the relay attached there could + # re-identify the client (API keys, JWT subjects, bearer tokens) and + # defeat the whole point of OHTTP. A fixed constant keeps every OHTTP + # request indistinguishable to the chat backend at this layer. + sub_env["HTTP_AUTHORIZATION"] = "Bearer ohttp" + + captured: dict[str, Any] = {"status": "500 Internal Server Error", "headers": []} + + def _start_response(status: str, headers: list, exc_info: Any = None): + captured["status"] = status + captured["headers"] = headers + return lambda _chunk: None + + inner_wsgi = current_app.config.get("OHTTP_INNER_WSGI_APP") or current_app.wsgi_app + iterator = inner_wsgi(sub_env, _start_response) + status_code = int(captured["status"].split(" ", 1)[0]) + # Don't wrap in iter() — that would strip the iterable's close() method, + # which the caller relies on to trigger x402's post-response settlement. + return status_code, captured["headers"], iterator # type: ignore[return-value] + + +def get_hpke_config(): + """GET /v1/ohttp/config — return the HPKE key configuration. + + Returns both an OHTTP-compliant binary key_config (base64) and the + individual fields for clients that prefer to parse JSON. The same data is + embedded inside the attestation document at /signing-key for clients that + want to verify the binding to the enclave's PCRs in one step. + """ + try: + tee = get_tee_keys() + return tee.get_hpke_config(), 200 + except Exception as exc: + logger.error("HPKE config error: %s", exc, exc_info=True) + return {"error": "Failed to retrieve HPKE config"}, 500 + + +def _error(status: int, message: str) -> tuple[dict, int]: + """Plaintext error for cases where we never have a recipient context + (empty body, malformed encapsulation). Post-decap input errors are + also returned plaintext so the relay can surface them to the client — + they never contain user prompts.""" + return {"error": message}, status + + +def _set_inner_cost_context( + req_or_context, + *, + request_json: dict[str, Any] | None = None, + response_json: dict[str, Any] | None = None, + status_code: int | None = None, +) -> None: + if isinstance(req_or_context, dict): + cost_context = req_or_context + elif req_or_context is None: + return + else: + cost_context = req_or_context.environ.get("x402.cost_context") + if not isinstance(cost_context, dict): + return + if request_json is not None: + cost_context["inner_request_json"] = request_json + if response_json is not None: + cost_context["inner_response_json"] = response_json + if status_code is not None: + cost_context["inner_status_code"] = status_code + + +def _set_inner_response_cost_context( + req, + body_bytes: bytes, + *, + status_code: int, +) -> None: + try: + response_json = json.loads(body_bytes.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError): + response_json = None + _set_inner_cost_context( + req, + response_json=response_json if isinstance(response_json, dict) else None, + status_code=status_code, + ) + + +def _set_inner_stream_cost_context( + req, + chunks: list[bytes], + *, + status_code: int, +) -> None: + body = b"".join(chunks) + response_json = _parse_final_sse_json(body) + _set_inner_cost_context( + req, + response_json=response_json, + status_code=status_code, + ) + + +def _parse_final_sse_json(body: bytes) -> dict[str, Any] | None: + last_json: dict[str, Any] | None = None + for line in body.decode("utf-8", errors="replace").splitlines(): + line = line.strip() + if not line.startswith("data:"): + continue + payload = line[len("data:") :].strip() + if not payload or payload == "[DONE]": + continue + try: + parsed = json.loads(payload) + except json.JSONDecodeError: + continue + if isinstance(parsed, dict): + last_json = parsed + return last_json + + +def _sealed_error(req, decap: ohttp.DecapsulatedRequest, status: int, message: str): + body = {"error": message} + _set_inner_cost_context(req, response_json=body, status_code=status) + return _sealed_json_response(decap, status, body) + + +def _sealed_json_response( + decap: ohttp.DecapsulatedRequest, + status: int, + body_obj: Any, +) -> Response: + inner_json = json.dumps( + {"status": status, "body": body_obj}, + separators=(",", ":"), + ).encode("utf-8") + + sealed = ohttp.encapsulate_response(decap.response_key, decap.enc, inner_json) + return Response( + sealed, + status=200, + mimetype=OHTTP_RESPONSE_MEDIA_TYPE, + ) diff --git a/tee_gateway/definitions.py b/tee_gateway/definitions.py index 932f9eb..1cb77be 100644 --- a/tee_gateway/definitions.py +++ b/tee_gateway/definitions.py @@ -63,19 +63,19 @@ # # These are the *maximum* amounts shown during the x402 payment pre-check. # Actual per-request costs are calculated dynamically from real token usage -# by dynamic_session_cost_calculator() in util.py. +# by compute_session_cost() in pricing.py. # --------------------------------------------------------------------------- # /v1/chat/completions — maximum OPG spend per session (18 decimals: 100000000000000000 = 0.1 OPG). # This is the upper-bound amount presented to the client during the x402 pre-check handshake. # The x402 "upto" scheme allows the actual charge to be any value up to this cap; -# the real per-request cost is settled dynamically by dynamic_session_cost_calculator() in util.py +# the real per-request cost is settled dynamically by compute_session_cost() in pricing.py # based on actual token usage, so clients are never overcharged beyond what they consumed. CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND: str = "100000000000000000" # /v1/completions — maximum OPG spend per session (18 decimals: 100000000000000000 = 0.1 OPG). # This is the upper-bound amount presented to the client during the x402 pre-check handshake. # The x402 "upto" scheme allows the actual charge to be any value up to this cap; -# the real per-request cost is settled dynamically by dynamic_session_cost_calculator() in util.py +# the real per-request cost is settled dynamically by compute_session_cost() in pricing.py # based on actual token usage, so clients are never overcharged beyond what they consumed. COMPLETIONS_OPG_SESSION_MAX_SPEND: str = "100000000000000000" diff --git a/tee_gateway/encoder.py b/tee_gateway/encoder.py index 05a073e..0963b50 100644 --- a/tee_gateway/encoder.py +++ b/tee_gateway/encoder.py @@ -1,4 +1,5 @@ from connexion.apps.flask_app import FlaskJSONEncoder +from pydantic import BaseModel from tee_gateway.models.base_model import Model @@ -7,6 +8,10 @@ class JSONEncoder(FlaskJSONEncoder): include_nulls = False def default(self, o): + if isinstance(o, BaseModel): + # mode="json" runs the model's field_serializers — e.g. SessionCost + # emits int/Decimal as JS-safe strings declared on the model. + return o.model_dump(mode="json") if isinstance(o, Model): dikt = {} for attr in o.openapi_types: diff --git a/tee_gateway/ohttp.py b/tee_gateway/ohttp.py new file mode 100644 index 0000000..0f9fa7f --- /dev/null +++ b/tee_gateway/ohttp.py @@ -0,0 +1,298 @@ +""" +Oblivious HTTP encapsulation for anonymous inference. + +Implements request/response encapsulation per RFC 9458 (Oblivious HTTP) +with a fixed HPKE ciphersuite: + - KEM: DHKEM(X25519, HKDF-SHA256) (0x0020) + - KDF: HKDF-SHA256 (0x0001) + - AEAD: ChaCha20-Poly1305 (0x0003) + +Also implements the chunked-response extension from +draft-ietf-ohai-chunked-ohttp-08 for streaming responses (SSE inference): + - Same HPKE context, separate export label "message/bhttp chunked response". + - Wire: response_nonce || (varint(sealed_len) || sealed_ct)+ + || varint(0) || sealed_final_ct + - AEAD AAD: "" for non-final chunks, "final" for the last chunk. + - Per-chunk nonce: aead_nonce XOR encode_be(counter), counter from 0. + +The inner payload is application/json (or text/event-stream for chunked) — +we do not BHTTP-wrap the inference request, since the enclave is the terminal +endpoint and not a generic HTTP proxy. This is a documented divergence from +strict RFC 9458; the cryptographic construction is identical. + +Trust model: the relay sees ciphertext + client IP; the enclave sees plaintext ++ relay IP. Unlinkability holds unless relay and enclave collude. +""" + +from __future__ import annotations + +import os +import struct +from dataclasses import dataclass + +from cryptography.hazmat.primitives import hashes, hmac +from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 +from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand +from pyhpke import AEADId, CipherSuite, KDFId, KEMId +from pyhpke.kem_key_interface import KEMKeyInterface + + +# RFC 9180 / 9458 algorithm identifiers +KEM_ID_X25519 = 0x0020 +KDF_ID_HKDF_SHA256 = 0x0001 +AEAD_ID_CHACHA20_POLY1305 = 0x0003 + +# Single, stable key configuration ID. Bump when the keypair or suite changes +# so clients can refuse stale configs. +KEY_CONFIG_ID = 0x01 + +# AEAD parameters for ChaCha20-Poly1305 +_NK = 32 # key length +_NN = 12 # nonce length + +# Per RFC 9458 §4.1/4.5 and draft-ietf-ohai-chunked-ohttp §3.1 — "info" labels. +_LABEL_REQUEST = b"message/bhttp request" +_LABEL_RESPONSE = b"message/bhttp response" +_LABEL_CHUNKED_RESPONSE = b"message/bhttp chunked response" + +_SUITE = CipherSuite.new( + KEMId.DHKEM_X25519_HKDF_SHA256, + KDFId.HKDF_SHA256, + AEADId.CHACHA20_POLY1305, +) + + +def encode_varint(value: int) -> bytes: + """QUIC variable-length integer encoding (RFC 9000 §16). + + The top two bits of the first byte encode the length (00=1B, 01=2B, + 10=4B, 11=8B); the remaining bits hold the big-endian value. Used by + draft-ietf-ohai-chunked-ohttp to frame each response chunk on the wire. + """ + if value < 0: + raise ValueError("varint must be non-negative") + if value < (1 << 6): + return bytes([value]) + if value < (1 << 14): + return bytes([0x40 | (value >> 8), value & 0xFF]) + if value < (1 << 30): + return struct.pack(">I", 0x80000000 | value) + if value < (1 << 62): + return struct.pack(">Q", 0xC000000000000000 | value) + raise ValueError("varint value exceeds 2^62-1") + + +def decode_varint(buf: bytes, offset: int = 0) -> tuple[int, int]: + """Parse one QUIC varint from ``buf`` starting at ``offset``. Returns + ``(value, new_offset)``. Used by clients and tests.""" + if offset >= len(buf): + raise ValueError("varint truncated") + first = buf[offset] + length_bits = first >> 6 + if length_bits == 0: + return first, offset + 1 + if length_bits == 1: + if offset + 2 > len(buf): + raise ValueError("varint truncated") + return ((first & 0x3F) << 8) | buf[offset + 1], offset + 2 + if length_bits == 2: + if offset + 4 > len(buf): + raise ValueError("varint truncated") + head = bytes([first & 0x3F]) + buf[offset + 1 : offset + 4] + return struct.unpack(">I", head)[0], offset + 4 + if offset + 8 > len(buf): + raise ValueError("varint truncated") + head = bytes([first & 0x3F]) + buf[offset + 1 : offset + 8] + return struct.unpack(">Q", head)[0], offset + 8 + + +def _header_bytes(key_id: int = KEY_CONFIG_ID) -> bytes: + return bytes([key_id]) + struct.pack( + ">HHH", + KEM_ID_X25519, + KDF_ID_HKDF_SHA256, + AEAD_ID_CHACHA20_POLY1305, + ) + + +def key_config(public_key_raw: bytes, key_id: int = KEY_CONFIG_ID) -> bytes: + """Build an OHTTP key configuration blob (RFC 9458 §3). + + Format: + key_id(1) || kem_id(2) || public_key(Npk=32) || + symmetric_algorithms_length(2) || (kdf_id(2) || aead_id(2))+ + """ + if len(public_key_raw) != 32: + raise ValueError("X25519 public key must be 32 bytes") + symmetric = struct.pack(">HH", KDF_ID_HKDF_SHA256, AEAD_ID_CHACHA20_POLY1305) + return ( + bytes([key_id]) + + struct.pack(">H", KEM_ID_X25519) + + public_key_raw + + struct.pack(">H", len(symmetric)) + + symmetric + ) + + +@dataclass +class DecapsulatedRequest: + """Result of decapsulating an OHTTP-wrapped request. + + Two response secrets are exported up-front so the caller can switch + between single-shot (``response_key``) and chunked + (``response_key_chunked``) encapsulation after inspecting the inner + body — the recipient context can only be created once per request. + """ + + plaintext: bytes + response_key: bytes # exported with label "message/bhttp response" + response_key_chunked: bytes # exported with label "message/bhttp chunked response" + enc: bytes # client's ephemeral public key, used as salt for response keying + + +def decapsulate_request( + private_key: KEMKeyInterface, encapsulated_request: bytes +) -> DecapsulatedRequest: + """Decrypt an HPKE-wrapped request inside the enclave. + + Raises ``ValueError`` on any malformed or unauthenticated input — + structural errors (wrong length, unsupported ciphersuite) and + cryptographic errors (AEAD tag failure, bad ephemeral key, etc.) are + all surfaced as ``ValueError`` with a generic message. The original + pyhpke/cryptography exception is deliberately NOT chained, because + those error strings can encode oracle information about which check + failed (tag verification vs. KDF vs. length). + """ + # Header(7) + ephemeral pubkey(32) + AEAD tag(16) is the absolute + # minimum. Reject shorter inputs up front so every short-input failure + # mode is a ValueError rather than whatever pyhpke would raise on a + # truncated open(). + if len(encapsulated_request) < 7 + 32 + 16: + raise ValueError("encapsulated request too short") + + key_id = encapsulated_request[0] + kem_id, kdf_id, aead_id = struct.unpack(">HHH", encapsulated_request[1:7]) + if (key_id, kem_id, kdf_id, aead_id) != ( + KEY_CONFIG_ID, + KEM_ID_X25519, + KDF_ID_HKDF_SHA256, + AEAD_ID_CHACHA20_POLY1305, + ): + raise ValueError("unsupported HPKE configuration") + + enc = encapsulated_request[7 : 7 + 32] + aead_ct = encapsulated_request[7 + 32 :] + + info = _LABEL_REQUEST + b"\x00" + _header_bytes(key_id) + try: + recipient = _SUITE.create_recipient_context(enc, private_key, info=info) + plaintext = recipient.open(aead_ct, aad=b"") + except Exception: + # Map AEAD / HPKE failures (invalid tag, bad ephemeral pubkey, etc.) + # to ValueError to match the documented contract. ``from None`` + # suppresses the chain so pyhpke / cryptography error strings — which + # can distinguish which check failed — never reach a caller's log. + raise ValueError("HPKE decapsulation failed") from None + + # Two exports, one per response mode. RFC 9458 §4.5 and the chunked + # draft §3.1 specify max(Nn, Nk) as the export length. HKDF-Expand is + # deterministic and can't fail on valid inputs, so we don't wrap it. + export_len = max(_NN, _NK) + response_secret = recipient.export(_LABEL_RESPONSE, export_len) + response_secret_chunked = recipient.export(_LABEL_CHUNKED_RESPONSE, export_len) + return DecapsulatedRequest( + plaintext=plaintext, + response_key=response_secret, + response_key_chunked=response_secret_chunked, + enc=enc, + ) + + +def encapsulate_response(response_secret: bytes, enc: bytes, plaintext: bytes) -> bytes: + """Seal a response under the per-request derived key (RFC 9458 §4.5). + + Wire format: response_nonce(max(Nn, Nk)=Nk=32) || AEAD ciphertext + """ + response_nonce = os.urandom(max(_NN, _NK)) + aead_key, aead_nonce = _derive_response_keys(response_secret, enc, response_nonce) + ct = ChaCha20Poly1305(aead_key).encrypt(aead_nonce, plaintext, b"") + return response_nonce + ct + + +def _derive_response_keys( + response_secret: bytes, enc: bytes, response_nonce: bytes +) -> tuple[bytes, bytes]: + """HKDF-Extract(enc || response_nonce, response_secret) then Expand for + ``aead_key`` (Nk bytes, info=b"key") and ``aead_nonce`` (Nn bytes, + info=b"nonce"). Shared by single-shot and chunked response paths.""" + h = hmac.HMAC(enc + response_nonce, hashes.SHA256()) + h.update(response_secret) + prk = h.finalize() + aead_key = HKDFExpand(algorithm=hashes.SHA256(), length=_NK, info=b"key").derive( + prk + ) + aead_nonce = HKDFExpand( + algorithm=hashes.SHA256(), length=_NN, info=b"nonce" + ).derive(prk) + return aead_key, aead_nonce + + +class ChunkedResponseEncrypter: + """Stream a chunked OHTTP response per draft-ietf-ohai-chunked-ohttp-08. + + Usage: + enc = ChunkedResponseEncrypter(decap.response_key_chunked, decap.enc) + yield enc.header() # response_nonce + for plaintext in non_final_chunks: + yield enc.encrypt_chunk(plaintext, is_final=False) + yield enc.encrypt_chunk(last_plaintext, is_final=True) + + The final chunk uses AAD=b"final" with a zero-length varint prefix — + that pair is what prevents an attacker from truncating the stream + undetectably, so callers MUST always emit exactly one is_final=True + chunk to terminate the response (even if its plaintext is empty). + """ + + def __init__(self, response_secret: bytes, enc: bytes): + self._response_nonce = os.urandom(max(_NN, _NK)) + self._aead_key, self._aead_nonce = _derive_response_keys( + response_secret, enc, self._response_nonce + ) + self._aead = ChaCha20Poly1305(self._aead_key) + self._counter = 0 + self._finalized = False + + def header(self) -> bytes: + """Wire bytes that prefix the chunk stream.""" + return self._response_nonce + + def encrypt_chunk(self, plaintext: bytes, is_final: bool) -> bytes: + if self._finalized: + raise RuntimeError("ChunkedResponseEncrypter already finalized") + ctr_bytes = self._counter.to_bytes(_NN, "big") + chunk_nonce = bytes(a ^ b for a, b in zip(self._aead_nonce, ctr_bytes)) + aad = b"final" if is_final else b"" + sealed = self._aead.encrypt(chunk_nonce, plaintext, aad) + self._counter += 1 + length_prefix = encode_varint(0) if is_final else encode_varint(len(sealed)) + if is_final: + self._finalized = True + return length_prefix + sealed + + +def generate_keypair() -> tuple[KEMKeyInterface, bytes]: + """Generate a fresh X25519 keypair for HPKE. + + The HPKE keypair is intentionally independent of the RSA TEE signing + key: deriving one from the other would create a single point of + compromise (a leak of the RSA private key would also leak the OHTTP + private key, and vice versa). Both public keys are still covered by + the same nitriding attestation transcript, so verifiers get binding + without sharing key material. + + pyhpke 0.6 derives the keypair from random IKM via + ``kem.derive_key_pair(ikm)``; we feed it ``os.urandom(32)`` so each + enclave boot produces an independent keypair. + """ + pair = _SUITE.kem.derive_key_pair(os.urandom(32)) + return pair.private_key, pair.public_key.to_public_bytes() diff --git a/tee_gateway/price_feed/__init__.py b/tee_gateway/price_feed/__init__.py index 1349825..aba9902 100644 --- a/tee_gateway/price_feed/__init__.py +++ b/tee_gateway/price_feed/__init__.py @@ -4,4 +4,25 @@ __all__ = [ "OPGPriceFeed", "PriceFeedConfig", + "get_price_feed", + "set_price_feed", ] + + +_price_feed: OPGPriceFeed | None = None + + +def set_price_feed(feed: OPGPriceFeed) -> None: + """Register the process-wide OPG price feed. Called once from app startup.""" + global _price_feed + _price_feed = feed + + +def get_price_feed() -> OPGPriceFeed: + """Return the registered process-wide price feed. Raises if not initialized.""" + if _price_feed is None: + raise RuntimeError( + "OPG price feed not initialized — set_price_feed() must be called " + "during app startup before pricing-dependent code runs." + ) + return _price_feed diff --git a/tee_gateway/price_feed/feed.py b/tee_gateway/price_feed/feed.py index bbfb86b..cc0025a 100644 --- a/tee_gateway/price_feed/feed.py +++ b/tee_gateway/price_feed/feed.py @@ -9,7 +9,7 @@ ----- Create an ``OPGPriceFeed`` instance in the application entry point, call ``start()``, then pass it explicitly to wherever the price is needed (e.g. -``calculate_session_cost(...)`` in ``util.py``). +``compute_session_cost(...)`` in ``pricing.py``). """ import logging diff --git a/tee_gateway/pricing.py b/tee_gateway/pricing.py new file mode 100644 index 0000000..ce00a62 --- /dev/null +++ b/tee_gateway/pricing.py @@ -0,0 +1,120 @@ +"""Per-request session-cost calculation for x402 settlement. + +Converts realized token usage from an LLM response into an OPG smallest-units +integer (what x402 actually charges) plus the equivalent USD figure and the +OPG/USD price used for the conversion. All three are bundled in +:class:`SessionCost`, embedded on the response under the ``opengradient`` key, +and read back by both the x402 settlement calculator and the OHTTP outer-header +extractor. This module is the single source of truth for that math. +""" + +import logging +from decimal import Decimal, ROUND_CEILING + +from pydantic import BaseModel, ConfigDict, field_serializer + +from tee_gateway.definitions import ( + ASSET_DECIMALS_BY_ADDRESS, + BASE_MAINNET_OPG_ADDRESS, +) +from tee_gateway.model_registry import get_model_config + +logger = logging.getLogger("llm_server.dynamic_pricing") + +_OPG_DECIMALS = ASSET_DECIMALS_BY_ADDRESS[BASE_MAINNET_OPG_ADDRESS.lower()] + + +class SessionCost(BaseModel): + """Settled per-request cost. ``cost_opg`` is what x402 actually charges; + ``cost_usd`` and ``opg_price_usd`` are reported so clients and relays can + audit the conversion without re-fetching the price feed. + + All three fields serialize to JSON strings: the OPG integer can exceed JS + safe-int (2^53) for any non-trivial cost at 18 decimals, and the Decimals + would lose precision through a float round-trip. + """ + + model_config = ConfigDict(frozen=True) + + cost_opg: int + cost_usd: Decimal + opg_price_usd: Decimal + + @field_serializer("cost_opg") + def _serialize_opg(self, value: int) -> str: + return str(value) + + @field_serializer("cost_usd", "opg_price_usd") + def _serialize_decimal(self, value: Decimal) -> str: + return format(value, "f") + + +def compute_session_cost(model: str, usage: dict) -> SessionCost | None: + """Compute the settled cost for a completed inference request. + + Returns the SessionCost pydantic model, or ``None`` on failure. Predictable + failures (unknown model, unavailable price) are blocked by the pre-inference + gate, so anything reaching here is a provider-side error (e.g. missing + usage) or a transient price-feed outage. Logging as CRITICAL matches + ``__main__._session_cost_calculator``'s contract: when this returns None, + x402's downstream reader will skip settlement and the client is not charged. + + Returns both the OPG integer (what x402 charges) and the equivalent USD — + derived from the SAME rounded OPG value, not the raw USD math — so the two + numbers always reconcile via ``cost_usd = cost_opg / 10^decimals * price``. + """ + # Imported lazily to keep this module free of process-level state at import + # time (price_feed singleton is registered during app startup, after this + # module has already been imported transitively via other paths). + from tee_gateway.price_feed import get_price_feed + + try: + cfg = get_model_config(model.strip().lower()) + in_tok = max(0, int(usage["prompt_tokens"])) + out_tok = max(0, int(usage["completion_tokens"])) + + raw_usd = (Decimal(in_tok) * cfg.input_price_usd) + ( + Decimal(out_tok) * cfg.output_price_usd + ) + token_price_usd = get_price_feed().get_price() + if token_price_usd <= 0: + raise ValueError(f"Token price is non-positive: {token_price_usd}") + + scale = Decimal(10) ** _OPG_DECIMALS + cost_smallest_units = max( + 0, + int( + ((raw_usd / token_price_usd) * scale).to_integral_value( + rounding=ROUND_CEILING + ) + ), + ) + # Reconcile USD from the rounded OPG integer so the two surfaced figures + # are exactly consistent (clients verify: usd == opg / 10^decimals * price). + settled_usd = (Decimal(cost_smallest_units) / scale) * token_price_usd + + logger.info( + "DYNAMIC_SESSION_COST model=%s input_tokens=%d output_tokens=%d " + "raw_usd=%s settled_usd=%s token_price_usd=%s decimals=%d cost=%d", + model, + in_tok, + out_tok, + str(raw_usd), + str(settled_usd), + str(token_price_usd), + _OPG_DECIMALS, + cost_smallest_units, + ) + return SessionCost( + cost_opg=cost_smallest_units, + cost_usd=settled_usd, + opg_price_usd=token_price_usd, + ) + except Exception as exc: + logger.critical( + "Post-inference cost calculation failed (provider error) — " + "client will NOT be charged: %s", + exc, + exc_info=True, + ) + return None diff --git a/tee_gateway/tee_manager.py b/tee_gateway/tee_manager.py index 3d98b12..98deaa6 100644 --- a/tee_gateway/tee_manager.py +++ b/tee_gateway/tee_manager.py @@ -17,6 +17,8 @@ from eth_account import Account from eth_hash.auto import keccak +from tee_gateway import ohttp + logger = logging.getLogger(__name__) NITRIDING_BASE_URL = "http://127.0.0.1:8080" @@ -36,6 +38,11 @@ def __init__(self, register=True): self.public_key_pem = None self.tee_id = None self.wallet_address = None + # HPKE keypair for OHTTP-style anonymous inference. Generated in the + # same enclave boot so the X25519 public key is covered by the same + # attestation that covers the RSA signing key. + self.hpke_private_key = None + self.hpke_public_key_raw: bytes | None = None self._generate_keys() if register: self.register_with_nitriding() @@ -70,19 +77,58 @@ def _generate_keys(self): wallet_account = Account.from_key(wallet_key_bytes) self.wallet_address = wallet_account.address + # HPKE X25519 keypair — independent random material from the RSA + # signing key. Both public keys are bound to the enclave by the + # nitriding attestation transcript (register_with_nitriding below), + # so verifiers still get a single attested fingerprint covering both, + # without sharing private-key material between them. + self.hpke_private_key, self.hpke_public_key_raw = ohttp.generate_keypair() + logger.info("TEE key pair generated successfully") logger.info(f"tee_id: 0x{self.tee_id}") logger.info(f"wallet_address: {self.wallet_address}") + logger.info( + f"hpke_public_key (X25519, raw, hex): {self.hpke_public_key_raw.hex()}" + ) def register_with_nitriding(self): - """Register public key hash with nitriding.""" + """Register public key hash with nitriding. + + The hash covers both the RSA signing key (DER-encoded SPKI) and the + raw X25519 HPKE public key. Including both in a single attested digest + means a verifier who validates the attestation document automatically + gets binding for the HPKE config used for anonymous inference — no + separate trust anchor required. + """ + # Defensive check: the v2 transcript labels both keys, so refusing to + # register is the only safe behavior when one is missing. Falling back + # to b"" would produce a digest that's nominally v2 but only covers + # RSA, and a verifier trusting the label would accept an enclave whose + # HPKE key was never attested. Raise outside the broad try/except + # below so a real misconfiguration isn't masked as a non-TEE + # environment. + if not self.hpke_public_key_raw or len(self.hpke_public_key_raw) != 32: + raise RuntimeError( + "Refusing to register with nitriding: HPKE X25519 public key " + "is missing or wrong length; the v2 attestation transcript " + "requires both RSA and HPKE keys." + ) + try: public_key_der = self.public_key.public_bytes( encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo, ) - key_hash = hashlib.sha256(public_key_der).digest() + # Domain-separated transcript so a future addition of more keys + # can't be confused with the existing layout. + transcript = ( + b"og-tee-keys|v2|rsa-spki=" + + public_key_der + + b"|hpke-x25519=" + + self.hpke_public_key_raw + ) + key_hash = hashlib.sha256(transcript).digest() key_hash_b64 = base64.b64encode(key_hash).decode("utf-8") logger.info(f"Public key DER length: {len(public_key_der)} bytes") @@ -149,12 +195,34 @@ def get_wallet_address(self) -> str: """Return the TEE-generated Ethereum wallet address (checksum).""" return self.wallet_address + def get_hpke_config(self) -> dict: + """Return the HPKE key configuration for anonymous inference. + + ``key_config`` is the RFC 9458 §3 binary key-config blob, base64-encoded. + Clients should treat this as authoritative only when fetched alongside + the Nitro attestation document (which commits to the same key hash via + nitriding registration). + """ + if self.hpke_public_key_raw is None: + raise RuntimeError("HPKE keypair not initialized") + return { + "key_id": ohttp.KEY_CONFIG_ID, + "kem_id": ohttp.KEM_ID_X25519, + "kdf_id": ohttp.KDF_ID_HKDF_SHA256, + "aead_id": ohttp.AEAD_ID_CHACHA20_POLY1305, + "public_key": self.hpke_public_key_raw.hex(), + "key_config": base64.b64encode( + ohttp.key_config(self.hpke_public_key_raw) + ).decode("ascii"), + } + def get_attestation_document(self) -> dict: """Return TEE attestation document.""" return { "public_key": self.public_key_pem, "tee_id": f"0x{self.tee_id}", "wallet_address": self.wallet_address, + "hpke": self.get_hpke_config() if self.hpke_public_key_raw else None, "timestamp": datetime.now(UTC).isoformat(), "enclave_info": { "platform": "aws-nitro", diff --git a/tee_gateway/test/test_ohttp.py b/tee_gateway/test/test_ohttp.py new file mode 100644 index 0000000..fa7c26f --- /dev/null +++ b/tee_gateway/test/test_ohttp.py @@ -0,0 +1,219 @@ +"""Tests for the OHTTP encapsulation module.""" + +from __future__ import annotations + +import json + +import pytest + +from tee_gateway import ohttp + + +def test_round_trip_request_and_response(): + sk, pk_raw = ohttp.generate_keypair() + + plaintext = json.dumps({"model": "gpt-4.1", "n": 1}).encode() + # Encapsulate using the same code paths a client would, since pyhpke is + # symmetric — we wire the request manually. + config = ohttp.key_config(pk_raw) + assert config[0] == ohttp.KEY_CONFIG_ID + + # Build a wire payload exactly as the SDK does. + import struct + + hdr = bytes([ohttp.KEY_CONFIG_ID]) + struct.pack( + ">HHH", + ohttp.KEM_ID_X25519, + ohttp.KDF_ID_HKDF_SHA256, + ohttp.AEAD_ID_CHACHA20_POLY1305, + ) + info = b"message/bhttp request" + b"\x00" + hdr + pkr = ohttp._SUITE.kem.deserialize_public_key(pk_raw) + enc, sender = ohttp._SUITE.create_sender_context(pkr, info=info) + ct = sender.seal(plaintext, aad=b"") + wire = hdr + enc + ct + + decap = ohttp.decapsulate_request(sk, wire) + assert decap.plaintext == plaintext + + response_secret = sender.export(b"message/bhttp response", 32) + assert decap.response_key == response_secret + assert decap.enc == enc + + response = b'{"ok":true}' + sealed = ohttp.encapsulate_response(decap.response_key, decap.enc, response) + + from cryptography.hazmat.primitives import hashes, hmac + from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 + from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand + + response_nonce = sealed[:32] + aead_ct = sealed[32:] + salt = enc + response_nonce + h = hmac.HMAC(salt, hashes.SHA256()) + h.update(response_secret) + prk = h.finalize() + key = HKDFExpand(algorithm=hashes.SHA256(), length=32, info=b"key").derive(prk) + nonce = HKDFExpand(algorithm=hashes.SHA256(), length=12, info=b"nonce").derive(prk) + assert ChaCha20Poly1305(key).decrypt(nonce, aead_ct, b"") == response + + +def test_rejects_wrong_suite(): + sk, pk_raw = ohttp.generate_keypair() + # Build a wire with the wrong AEAD ID + import struct + + hdr = bytes([ohttp.KEY_CONFIG_ID]) + struct.pack( + ">HHH", + ohttp.KEM_ID_X25519, + ohttp.KDF_ID_HKDF_SHA256, + 0x0001, # AES-128-GCM + ) + fake_wire = hdr + b"\x00" * 32 + b"\x00" * 16 + with pytest.raises(ValueError, match="unsupported"): + ohttp.decapsulate_request(sk, fake_wire) + + +def test_rejects_short_input(): + sk, _ = ohttp.generate_keypair() + with pytest.raises(ValueError, match="too short"): + ohttp.decapsulate_request(sk, b"\x01") + + +def test_generate_keypair_is_independent(): + """Each invocation must produce an independent keypair — the HPKE key + is intentionally not derived from any shared seed.""" + _, pk_a = ohttp.generate_keypair() + _, pk_b = ohttp.generate_keypair() + assert pk_a != pk_b + + +def test_varint_round_trip(): + """QUIC varint encoder/decoder matches across all 4 length classes.""" + for value in (0, 63, 64, 16383, 16384, 1073741823, 1073741824, (1 << 62) - 1): + encoded = ohttp.encode_varint(value) + decoded, off = ohttp.decode_varint(encoded) + assert decoded == value, value + assert off == len(encoded) + + with pytest.raises(ValueError): + ohttp.encode_varint(-1) + with pytest.raises(ValueError): + ohttp.encode_varint(1 << 62) + + +def test_chunked_response_round_trip(): + """Client encrypts a chunked response stream; client-side decrypter + must recover every chunk and detect the AAD=b'final' terminator.""" + from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 + + sk, pk_raw = ohttp.generate_keypair() + import struct as _struct + + hdr = bytes([ohttp.KEY_CONFIG_ID]) + _struct.pack( + ">HHH", + ohttp.KEM_ID_X25519, + ohttp.KDF_ID_HKDF_SHA256, + ohttp.AEAD_ID_CHACHA20_POLY1305, + ) + info = b"message/bhttp request" + b"\x00" + hdr + pkr = ohttp._SUITE.kem.deserialize_public_key(pk_raw) + enc, sender = ohttp._SUITE.create_sender_context(pkr, info=info) + wire = hdr + enc + sender.seal(b'{"stream": true}', aad=b"") + + decap = ohttp.decapsulate_request(sk, wire) + assert decap.response_key_chunked == sender.export( + b"message/bhttp chunked response", 32 + ) + + # Server side: stream three chunks plus an empty final marker. + plaintexts = [b"data: {chunk1}\n\n", b"data: {chunk2}\n\n", b"data: [DONE]\n\n"] + encrypter = ohttp.ChunkedResponseEncrypter(decap.response_key_chunked, decap.enc) + wire_chunks = [encrypter.header()] + for pt in plaintexts[:-1]: + wire_chunks.append(encrypter.encrypt_chunk(pt, is_final=False)) + wire_chunks.append(encrypter.encrypt_chunk(plaintexts[-1], is_final=True)) + stream = b"".join(wire_chunks) + + # Client side: re-derive keys, walk the varint-framed chunks. + response_nonce = stream[:32] + off = 32 + response_secret = sender.export(b"message/bhttp chunked response", 32) + aead_key, aead_nonce = ohttp._derive_response_keys( + response_secret, enc, response_nonce + ) + aead = ChaCha20Poly1305(aead_key) + + recovered: list[bytes] = [] + counter = 0 + while off < len(stream): + length, off = ohttp.decode_varint(stream, off) + is_final = length == 0 + # On the final chunk the prefix is zero; the actual sealed length is + # plaintext_len + 16 (Poly1305 tag). The chunk consumes the rest. + seg_len = len(stream) - off if is_final else length + ct = stream[off : off + seg_len] + off += seg_len + chunk_nonce = bytes( + a ^ b for a, b in zip(aead_nonce, counter.to_bytes(12, "big")) + ) + aad = b"final" if is_final else b"" + recovered.append(aead.decrypt(chunk_nonce, ct, aad)) + counter += 1 + if is_final: + break + + assert recovered == plaintexts + # The decrypter MUST reject the same stream with the final-AAD swapped + # — protects against undetected truncation at the boundary. + with pytest.raises(Exception): + aead.decrypt( + bytes(a ^ b for a, b in zip(aead_nonce, (counter - 1).to_bytes(12, "big"))), + stream[-len(ct) :], + b"", + ) + + +def test_chunked_encrypter_rejects_double_finalize(): + sk, pk_raw = ohttp.generate_keypair() + import struct as _struct + + hdr = bytes([ohttp.KEY_CONFIG_ID]) + _struct.pack( + ">HHH", + ohttp.KEM_ID_X25519, + ohttp.KDF_ID_HKDF_SHA256, + ohttp.AEAD_ID_CHACHA20_POLY1305, + ) + info = b"message/bhttp request" + b"\x00" + hdr + pkr = ohttp._SUITE.kem.deserialize_public_key(pk_raw) + enc, sender = ohttp._SUITE.create_sender_context(pkr, info=info) + decap = ohttp.decapsulate_request(sk, hdr + enc + sender.seal(b"hi", aad=b"")) + + encrypter = ohttp.ChunkedResponseEncrypter(decap.response_key_chunked, decap.enc) + encrypter.header() + encrypter.encrypt_chunk(b"only", is_final=True) + with pytest.raises(RuntimeError, match="already finalized"): + encrypter.encrypt_chunk(b"extra", is_final=False) + + +def test_rejects_tampered_ciphertext(): + sk, pk_raw = ohttp.generate_keypair() + import struct + + hdr = bytes([ohttp.KEY_CONFIG_ID]) + struct.pack( + ">HHH", + ohttp.KEM_ID_X25519, + ohttp.KDF_ID_HKDF_SHA256, + ohttp.AEAD_ID_CHACHA20_POLY1305, + ) + info = b"message/bhttp request" + b"\x00" + hdr + pkr = ohttp._SUITE.kem.deserialize_public_key(pk_raw) + enc, sender = ohttp._SUITE.create_sender_context(pkr, info=info) + ct = sender.seal(b"hello", aad=b"") + wire = bytearray(hdr + enc + ct) + wire[-1] ^= 0xFF + # decapsulate_request normalises every crypto failure to ValueError with + # a generic message — pyhpke / cryptography exception types must NOT leak + # through, since their strings encode oracle info about which check failed. + with pytest.raises(ValueError, match="HPKE decapsulation failed"): + ohttp.decapsulate_request(sk, bytes(wire)) diff --git a/tee_gateway/test/test_ohttp_controller.py b/tee_gateway/test/test_ohttp_controller.py new file mode 100644 index 0000000..ebba831 --- /dev/null +++ b/tee_gateway/test/test_ohttp_controller.py @@ -0,0 +1,516 @@ +"""Unit tests for the /v1/ohttp handler in +``tee_gateway.controllers.ohttp_controller``. + +The controller is a relatively thin shell that sits between HPKE +decapsulation and an in-process WSGI sub-request to /v1/chat/completions. +These tests pin down the shell's behaviour without bringing up the real +chat backend by stubbing ``get_tee_keys`` (to plant a known HPKE keypair) +and ``current_app.wsgi_app`` (to fake the inner /v1/chat/completions +response). + +What we verify: + * 413 on an oversized encapsulated body + * 400 on a malformed encapsulation + * 400 when the decrypted inner payload is not valid JSON + * Non-2xx inner responses are forwarded as plaintext (no HPKE seal), + with the inner content-type and the x402-related headers preserved + * 2xx inner responses are sealed and the ``opengradient`` cost block is + projected onto outer X-Inference-Cost-* headers + * Streaming inner responses emit exactly one AAD=b"final" chunk and + close the inner WSGI iterator (settles x402) +""" + +from __future__ import annotations + +import json +import struct +from decimal import Decimal +from typing import Iterator + +from flask import Flask + +from tee_gateway import ohttp +from tee_gateway.controllers import ohttp_controller + + +# --------------------------------------------------------------------------- +# helpers + + +def _encapsulate(plaintext: bytes): + """Build a real HPKE-encapsulated request and return ``(sk, wire, sender, enc)``. + + The sender context is returned so individual tests that need to decrypt + the outer response can re-derive the response key the same way an SDK + client would. + """ + sk, pk_raw = ohttp.generate_keypair() + hdr = bytes([ohttp.KEY_CONFIG_ID]) + struct.pack( + ">HHH", + ohttp.KEM_ID_X25519, + ohttp.KDF_ID_HKDF_SHA256, + ohttp.AEAD_ID_CHACHA20_POLY1305, + ) + info = b"message/bhttp request" + b"\x00" + hdr + pkr = ohttp._SUITE.kem.deserialize_public_key(pk_raw) + enc, sender = ohttp._SUITE.create_sender_context(pkr, info=info) + wire = hdr + enc + sender.seal(plaintext, aad=b"") + return sk, wire, sender, enc + + +class _FakeTee: + def __init__(self, sk): + self.hpke_private_key = sk + + +class _CloseTrackingIter: + """Iterable that records when ``close()`` is called. + + Mirrors how the real WSGI iterator from the chat handler behaves — + closing it is what triggers x402's post-response settlement, so we + need the controller to call it on both the streaming and non-streaming + paths. + """ + + def __init__(self, chunks): + self._chunks = list(chunks) + self.closed = False + + def __iter__(self) -> Iterator[bytes]: + for c in self._chunks: + yield c + + def close(self): + self.closed = True + + +def _make_app(inner_responder, sk): + """Build a tiny Flask app with the controller mounted at /v1/ohttp. + + ``inner_responder`` is invoked when the controller does its WSGI + sub-dispatch to ``/v1/chat/completions``; it must return + ``(status_line, headers, body_iter)``. Anything else passes through + to the real Flask routing so the outer POST still lands on our + handler. + """ + app = Flask(__name__) + app.add_url_rule( + "/v1/ohttp", + view_func=ohttp_controller.create_anonymous_chat_completion, + methods=["POST"], + ) + + original_wsgi = app.wsgi_app + captured = {"env": None, "called": 0, "iter": None} + + def fake_wsgi(env, start_response): + if env.get("PATH_INFO") == "/v1/chat/completions": + captured["env"] = env + captured["called"] += 1 + status, headers, body_iter = inner_responder() + captured["iter"] = body_iter + start_response(status, headers) + return body_iter + return original_wsgi(env, start_response) + + app.wsgi_app = fake_wsgi + + def fake_get_tee_keys(): + return _FakeTee(sk) + + return app, captured, fake_get_tee_keys + + +# --------------------------------------------------------------------------- +# 413 / 400 cases — no decap happens, so we don't even need a key + + +def test_oversized_body_returns_413(monkeypatch): + sk, _ = ohttp.generate_keypair() + # Body is well past _MAX_ENCAPSULATED_REQUEST_BYTES (512 KiB). Werkzeug + # will set Content-Length from the data length, so the up-front check + # fires before any HPKE work. + app, captured, fake_keys = _make_app(lambda: ("200 OK", [], iter([])), sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + too_big = b"x" * (ohttp_controller._MAX_ENCAPSULATED_REQUEST_BYTES + 1) + client = app.test_client() + resp = client.post("/v1/ohttp", data=too_big, content_type="message/ohttp-req") + + assert resp.status_code == 413 + assert captured["called"] == 0 + + +def test_empty_body_returns_400(monkeypatch): + sk, _ = ohttp.generate_keypair() + app, captured, fake_keys = _make_app(lambda: ("200 OK", [], iter([])), sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=b"", content_type="message/ohttp-req") + assert resp.status_code == 400 + assert captured["called"] == 0 + + +def test_malformed_encapsulation_returns_400(monkeypatch): + sk, _ = ohttp.generate_keypair() + app, captured, fake_keys = _make_app(lambda: ("200 OK", [], iter([])), sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + # Random short payload — passes the size gate but ohttp.decapsulate_request + # will reject it as malformed. The controller MUST normalise that into a + # generic 400 so it doesn't expose an oracle on which decap step failed. + client = app.test_client() + resp = client.post("/v1/ohttp", data=b"\x00" * 64, content_type="message/ohttp-req") + assert resp.status_code == 400 + assert captured["called"] == 0 + + +def test_invalid_inner_json_returns_400(monkeypatch): + """Plaintext decapsulates fine but isn't valid JSON — controller must + surface that as a 400 rather than fall through to the chat handler.""" + sk, wire, _, _ = _encapsulate(b"not-json{{{") + app, captured, fake_keys = _make_app(lambda: ("200 OK", [], iter([])), sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=wire, content_type="message/ohttp-req") + assert resp.status_code == 400 + assert captured["called"] == 0 + + +def test_non_object_inner_payload_returns_400(monkeypatch): + sk, wire, _, _ = _encapsulate(b"[1, 2, 3]") + app, captured, fake_keys = _make_app(lambda: ("200 OK", [], iter([])), sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=wire, content_type="message/ohttp-req") + assert resp.status_code == 400 + assert captured["called"] == 0 + + +# --------------------------------------------------------------------------- +# Non-2xx inner: plaintext passthrough, forwarded headers + + +def test_non_2xx_inner_response_is_plaintext_passthrough(monkeypatch): + """An x402 402 (or any non-2xx) from the chat handler must NOT be + HPKE-sealed — the relay needs to read the payment challenge and act + on it. The inner content-type and x402 control headers must survive.""" + sk, wire, _, _ = _encapsulate(b'{"model":"gpt-4.1","messages":[]}') + + inner_body = b'{"error":"payment required"}' + + def inner(): + return ( + "402 Payment Required", + [ + ("Content-Type", "application/json"), + ("WWW-Authenticate", "x402 ..."), + ("X-Payment-Required", "true"), + ("X-Tee-Signature", "sig"), + ("Set-Cookie", "tracker=abc"), # must NOT be forwarded + ], + _CloseTrackingIter([inner_body]), + ) + + app, captured, fake_keys = _make_app(inner, sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post( + "/v1/ohttp", + data=wire, + content_type="message/ohttp-req", + headers={"X-Payment": "client-payment-blob"}, + ) + + assert resp.status_code == 402 + # Body must be plaintext, NOT message/ohttp-res. + assert resp.content_type.startswith("application/json") + assert resp.data == inner_body + # x402 / WWW-Authenticate forwarded; arbitrary headers not. + assert resp.headers.get("WWW-Authenticate") == "x402 ..." + assert resp.headers.get("X-Payment-Required") == "true" + assert resp.headers.get("X-Tee-Signature") == "sig" + assert "Set-Cookie" not in resp.headers + # The outer /v1/ohttp request is the paid x402 boundary. The decrypted + # in-process chat subrequest bypasses x402, so the payment blob must not + # be forwarded into the inner env. + assert "HTTP_X_PAYMENT" not in captured["env"] + # WSGI iterator was drained AND closed (drains x402 settlement). + assert captured["iter"].closed is True + + +# --------------------------------------------------------------------------- +# 2xx inner: response sealed, cost headers surfaced + + +def _client_decrypt_response(sealed: bytes, sender, enc: bytes) -> bytes: + """Recover the plaintext from a single-shot OHTTP response, the same + way an external client SDK would.""" + from cryptography.hazmat.primitives import hashes, hmac + from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 + from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand + + response_secret = sender.export(b"message/bhttp response", 32) + response_nonce = sealed[:32] + aead_ct = sealed[32:] + salt = enc + response_nonce + h = hmac.HMAC(salt, hashes.SHA256()) + h.update(response_secret) + prk = h.finalize() + key = HKDFExpand(algorithm=hashes.SHA256(), length=32, info=b"key").derive(prk) + nonce = HKDFExpand(algorithm=hashes.SHA256(), length=12, info=b"nonce").derive(prk) + return ChaCha20Poly1305(key).decrypt(nonce, aead_ct, b"") + + +def test_2xx_inner_response_is_sealed_with_cost_headers(monkeypatch): + sk, wire, sender, enc = _encapsulate(b'{"model":"gpt-4.1","messages":[]}') + + inner_payload = { + "id": "chatcmpl-xyz", + "choices": [{"message": {"role": "assistant", "content": "hi"}}], + "opengradient": { + "cost_opg": 1234567890, + "cost_usd": "0.001234", + "opg_price_usd": "1.234", + }, + } + inner_body = json.dumps(inner_payload).encode() + + def inner(): + return ( + "200 OK", + [ + ("Content-Type", "application/json"), + ("X-Tee-Signature", "sig"), + ("X-Payment-Response", "settled"), + ], + _CloseTrackingIter([inner_body]), + ) + + app, captured, fake_keys = _make_app(inner, sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=wire, content_type="message/ohttp-req") + + assert resp.status_code == 200 + assert resp.content_type.startswith(ohttp_controller.OHTTP_RESPONSE_MEDIA_TYPE) + # Cost is projected onto outer headers — relay-billable, no model name + # and no token counts leak. + assert resp.headers["X-Inference-Cost-OPG"] == "1234567890" + assert Decimal(resp.headers["X-Inference-Cost-USD"]) == Decimal("0.001234") + assert Decimal(resp.headers["X-Inference-Price-OPG-USD"]) == Decimal("1.234") + # x402 / TEE headers still pass through. + assert resp.headers.get("X-Tee-Signature") == "sig" + assert resp.headers.get("X-Payment-Response") == "settled" + + decrypted = _client_decrypt_response(resp.data, sender, enc) + assert json.loads(decrypted) == inner_payload + assert captured["iter"].closed is True + + +def test_2xx_inner_without_cost_block_omits_cost_headers(monkeypatch): + """Missing or unparseable opengradient block must not 500 — we just + skip the cost projection. Belt-and-braces guard around _extract_cost_headers.""" + sk, wire, sender, enc = _encapsulate(b'{"model":"gpt-4.1","messages":[]}') + + inner_body = b'{"id":"chatcmpl-xyz","choices":[]}' + + def inner(): + return ( + "200 OK", + [("Content-Type", "application/json")], + _CloseTrackingIter([inner_body]), + ) + + app, _, fake_keys = _make_app(inner, sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=wire, content_type="message/ohttp-req") + + assert resp.status_code == 200 + assert "X-Inference-Cost-OPG" not in resp.headers + assert "X-Inference-Cost-USD" not in resp.headers + + +# --------------------------------------------------------------------------- +# Streaming path + + +def test_streaming_inner_response_emits_one_final_chunk_and_closes(monkeypatch): + sk, wire, sender, enc = _encapsulate(b'{"model":"gpt-4.1","stream":true}') + + sse_chunks = [ + b'data: {"a":1}\n\n', + b'data: {"b":2}\n\n', + b"data: [DONE]\n\n", + ] + + def inner(): + return ( + "200 OK", + [("Content-Type", "text/event-stream")], + _CloseTrackingIter(sse_chunks), + ) + + app, captured, fake_keys = _make_app(inner, sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=wire, content_type="message/ohttp-req") + + assert resp.status_code == 200 + assert resp.content_type.startswith( + ohttp_controller.OHTTP_CHUNKED_RESPONSE_MEDIA_TYPE + ) + + # Decode the chunked-OHTTP wire frame the same way a client SDK would. + body = resp.get_data() + response_nonce = body[:32] + off = 32 + + response_secret = sender.export(b"message/bhttp chunked response", 32) + aead_key, aead_nonce = ohttp._derive_response_keys( + response_secret, enc, response_nonce + ) + from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 + + aead = ChaCha20Poly1305(aead_key) + + recovered = [] + final_count = 0 + counter = 0 + while off < len(body): + length, off = ohttp.decode_varint(body, off) + is_final = length == 0 + seg_len = len(body) - off if is_final else length + ct = body[off : off + seg_len] + off += seg_len + chunk_nonce = bytes( + a ^ b for a, b in zip(aead_nonce, counter.to_bytes(12, "big")) + ) + aad = b"final" if is_final else b"" + recovered.append(aead.decrypt(chunk_nonce, ct, aad)) + counter += 1 + if is_final: + final_count += 1 + break + + # Every SSE event must be recovered, and the last chunk MUST be the + # AAD=b"final" terminator — that's what protects clients from + # undetected truncation. There must be exactly one. + assert recovered == sse_chunks + assert final_count == 1 + # And the inner iterator must have been closed so x402 streaming + # settlement runs. + assert captured["iter"].closed is True + + +def test_streaming_path_uses_chunked_when_only_one_inner_chunk(monkeypatch): + """Even with a single SSE event, the controller must still emit + exactly one AAD=b"final" chunk (the pending/lookahead logic in + _build_streaming_response is the part being pinned down).""" + sk, wire, sender, enc = _encapsulate(b'{"model":"gpt-4.1","stream":true}') + + only = b"data: [DONE]\n\n" + + def inner(): + return ( + "200 OK", + [("Content-Type", "text/event-stream")], + _CloseTrackingIter([only]), + ) + + app, captured, fake_keys = _make_app(inner, sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=wire, content_type="message/ohttp-req") + + body = resp.get_data() + response_nonce = body[:32] + off = 32 + response_secret = sender.export(b"message/bhttp chunked response", 32) + aead_key, aead_nonce = ohttp._derive_response_keys( + response_secret, enc, response_nonce + ) + from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 + + aead = ChaCha20Poly1305(aead_key) + + length, off = ohttp.decode_varint(body, off) + # With exactly one inner chunk, the controller buffers it and emits + # it as the final marker — so the first (and only) framed chunk is + # the zero-length-prefixed final. + assert length == 0 + ct = body[off:] + chunk_nonce = bytes(a ^ b for a, b in zip(aead_nonce, (0).to_bytes(12, "big"))) + assert aead.decrypt(chunk_nonce, ct, b"final") == only + assert captured["iter"].closed is True + + +# --------------------------------------------------------------------------- +# 503 when HPKE key isn't initialized + + +def test_returns_503_when_hpke_key_missing(monkeypatch): + app, captured, _ = _make_app(lambda: ("200 OK", [], iter([])), sk=None) + + class _Tee: + hpke_private_key = None + + monkeypatch.setattr(ohttp_controller, "get_tee_keys", lambda: _Tee()) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=b"abc", content_type="message/ohttp-req") + assert resp.status_code == 503 + assert captured["called"] == 0 + + +# --------------------------------------------------------------------------- +# Identifying field scrubbing + + +def test_identifying_fields_are_scrubbed_before_inner_dispatch(monkeypatch): + payload = { + "model": "gpt-4.1", + "messages": [{"role": "user", "content": "hi"}], + "user": "alice@example.com", + "metadata": {"req": "123"}, + "x-request-id": "abc", + "request_id": "def", + } + sk, wire, _, _ = _encapsulate(json.dumps(payload).encode()) + + inner_body = b'{"id":"chatcmpl-1","choices":[]}' + + def inner(): + return ( + "200 OK", + [("Content-Type", "application/json")], + _CloseTrackingIter([inner_body]), + ) + + app, captured, fake_keys = _make_app(inner, sk) + monkeypatch.setattr(ohttp_controller, "get_tee_keys", fake_keys) + + client = app.test_client() + resp = client.post("/v1/ohttp", data=wire, content_type="message/ohttp-req") + assert resp.status_code == 200 + + # Reconstruct what the inner handler received. + inner_env = captured["env"] + inner_body_received = inner_env["wsgi.input"].read(int(inner_env["CONTENT_LENGTH"])) + forwarded = json.loads(inner_body_received) + assert forwarded == { + "model": "gpt-4.1", + "messages": [{"role": "user", "content": "hi"}], + } + # Authorization is overwritten with the OHTTP-fixed value so it can't + # re-identify the client. + assert inner_env["HTTP_AUTHORIZATION"] == "Bearer ohttp" diff --git a/tee_gateway/test/test_price_feed.py b/tee_gateway/test/test_price_feed.py index fe9bb8a..17af102 100644 --- a/tee_gateway/test/test_price_feed.py +++ b/tee_gateway/test/test_price_feed.py @@ -1,5 +1,5 @@ """ -Unit tests for tee_gateway.price_feed and tee_gateway.util.calculate_session_cost. +Unit tests for tee_gateway.price_feed and tee_gateway.pricing.compute_session_cost. All external HTTP calls are mocked — no network access required. @@ -9,7 +9,7 @@ TestOPGPriceFeedRefresh — OPGPriceFeed._refresh_price() (retry, rate-limit, stats) TestOPGPriceFeedGetPrice — OPGPriceFeed.get_price() (stale warning, ValueError before fetch) TestOPGPriceFeedStatus — OPGPriceFeed.get_status() snapshots -TestCalculateSessionCost — calculate_session_cost(context, get_price) in util.py +TestCalculateSessionCost — compute_session_cost(model, usage) in pricing.py """ import time @@ -23,7 +23,7 @@ from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS from tee_gateway.price_feed import OPGPriceFeed from tee_gateway.price_feed.feed import fetch_opg_price -from tee_gateway.util import calculate_session_cost +from tee_gateway.pricing import SessionCost, compute_session_cost # --------------------------------------------------------------------------- # Helpers @@ -367,58 +367,33 @@ def test_status_accumulates_multiple_error_cycles(self, mock_fetch): # --------------------------------------------------------------------------- -# TestMakeCostCalculator +# TestCalculateSessionCost # --------------------------------------------------------------------------- -_ASSET_ADDR = "0xdeadbeef" -_ASSET_ADDR_LOWER = _ASSET_ADDR.lower() +# pricing.compute_session_cost reads OPG decimals from the asset registry — +# OPG is 18 decimals on Base mainnet. _ASSET_DECIMALS = 18 -def _make_payment_requirements(asset: str = _ASSET_ADDR) -> dict: - return {"asset": asset, "price": {"amount": "1000000000000000000", "asset": asset}} - - -def _make_context( - model: str = "gpt-4.1-mini", - input_tokens: int = 100, - output_tokens: int = 50, - price_usd: Decimal = Decimal("0.10"), - asset: str = _ASSET_ADDR, -) -> dict: - return { - "request_json": {"model": model}, - "response_json": { - "model": model, - "usage": { - "prompt_tokens": input_tokens, - "completion_tokens": output_tokens, - }, - }, - "payment_requirements": _make_payment_requirements(asset), - "method": "POST", - "path": "/v1/chat/completions", - "status_code": 200, - "is_streaming": False, - "request_body_bytes": b"", - "response_body_bytes": b"", - "default_cost": 10**18, - } +def _usage(input_tokens: int = 100, output_tokens: int = 50) -> dict: + return {"prompt_tokens": input_tokens, "completion_tokens": output_tokens} def _make_get_price(price_usd: Decimal = Decimal("0.10")) -> MagicMock: - mock = MagicMock(return_value=price_usd) - return mock + return MagicMock(return_value=price_usd) -class TestCalculateSessionCost(unittest.TestCase): - """Tests for calculate_session_cost(context, get_price).""" +def _call(usage: dict, get_price, model: str = "gpt-4.1-mini"): + """Run compute_session_cost with a stubbed price feed.""" + from types import SimpleNamespace + + feed = SimpleNamespace(get_price=get_price) + with patch("tee_gateway.price_feed.get_price_feed", return_value=feed): + return compute_session_cost(model, usage) - def _patch_definitions(self): - return patch( - "tee_gateway.util.ASSET_DECIMALS_BY_ADDRESS", - {_ASSET_ADDR_LOWER: _ASSET_DECIMALS}, - ) + +class TestCalculateSessionCost(unittest.TestCase): + """Tests for compute_session_cost(model, usage).""" def _patch_model( self, input_price: str = "0.000001", output_price: str = "0.000002" @@ -426,90 +401,95 @@ def _patch_model( cfg = MagicMock() cfg.input_price_usd = Decimal(input_price) cfg.output_price_usd = Decimal(output_price) - return patch("tee_gateway.util.get_model_config", return_value=cfg) + return patch("tee_gateway.pricing.get_model_config", return_value=cfg) def test_calls_get_price(self): get_price = _make_get_price() - with self._patch_definitions(), self._patch_model(): - calculate_session_cost(_make_context(), get_price) + with self._patch_model(): + _call(_usage(), get_price) get_price.assert_called_once() - def test_returns_positive_int(self): - with self._patch_definitions(), self._patch_model(): - result = calculate_session_cost(_make_context(), _make_get_price()) - self.assertIsInstance(result, int) - self.assertGreaterEqual(result, 0) + def test_returns_session_cost(self): + with self._patch_model(): + result = _call(_usage(), _make_get_price()) + self.assertIsInstance(result, SessionCost) + self.assertIsInstance(result.cost_opg, int) + self.assertGreaterEqual(result.cost_opg, 0) + self.assertEqual(result.opg_price_usd, Decimal("0.10")) + + def test_reported_usd_reconciles_with_opg(self): + """cost_usd must equal cost_opg / 10^decimals * price exactly — otherwise + clients verifying the conversion will reject the response.""" + with self._patch_model(): + result = _call(_usage(), _make_get_price(Decimal("0.10"))) + scale = Decimal(10) ** _ASSET_DECIMALS + expected = (Decimal(result.cost_opg) / scale) * Decimal("0.10") + self.assertEqual(result.cost_usd, expected) def test_zero_tokens_returns_zero(self): - with self._patch_definitions(), self._patch_model(): - result = calculate_session_cost( - _make_context(input_tokens=0, output_tokens=0), _make_get_price() - ) - self.assertEqual(result, 0) + with self._patch_model(): + result = _call(_usage(0, 0), _make_get_price()) + self.assertEqual(result.cost_opg, 0) + self.assertEqual(result.cost_usd, Decimal(0)) - def test_raises_when_get_price_raises(self): + def test_returns_none_when_get_price_raises(self): get_price = MagicMock(side_effect=ValueError("price not available")) - with self._patch_definitions(), self._patch_model(): - with self.assertRaises(ValueError): - calculate_session_cost(_make_context(), get_price) - - def test_raises_when_non_positive_price(self): - with self._patch_definitions(), self._patch_model(): - with self.assertRaises(ValueError): - calculate_session_cost(_make_context(), _make_get_price(Decimal("0"))) - - def test_raises_when_request_json_missing(self): - ctx = _make_context() - ctx["request_json"] = None - with self._patch_definitions(), self._patch_model(): - with self.assertRaises(ValueError): - calculate_session_cost(ctx, _make_get_price()) - - def test_raises_when_usage_missing(self): - ctx = _make_context() - ctx["response_json"] = {"model": "gpt-4.1-mini"} - with self._patch_definitions(), self._patch_model(): - with self.assertRaises(ValueError): - calculate_session_cost(ctx, _make_get_price()) - - def test_raises_when_asset_unknown(self): - ctx = _make_context(asset="0xunknown") - with ( - patch("tee_gateway.util.ASSET_DECIMALS_BY_ADDRESS", {}), - self._patch_model(), - ): - with self.assertRaises(ValueError): - calculate_session_cost(ctx, _make_get_price()) + with self._patch_model(): + self.assertIsNone(_call(_usage(), get_price)) + + def test_returns_none_when_non_positive_price(self): + with self._patch_model(): + self.assertIsNone(_call(_usage(), _make_get_price(Decimal("0")))) + + def test_returns_none_when_usage_missing_keys(self): + with self._patch_model(): + self.assertIsNone(_call({"prompt_tokens": 100}, _make_get_price())) def test_cost_scales_with_token_count(self): - with self._patch_definitions(), self._patch_model(): - cost_small = calculate_session_cost( - _make_context(input_tokens=10, output_tokens=5), _make_get_price() - ) - cost_large = calculate_session_cost( - _make_context(input_tokens=1000, output_tokens=500), _make_get_price() - ) - self.assertGreater(cost_large, cost_small) + with self._patch_model(): + cost_small = _call(_usage(10, 5), _make_get_price()) + cost_large = _call(_usage(1000, 500), _make_get_price()) + self.assertGreater(cost_large.cost_opg, cost_small.cost_opg) def test_higher_token_price_yields_lower_cost(self): - with self._patch_definitions(), self._patch_model(): - cost_cheap = calculate_session_cost( - _make_context(), _make_get_price(Decimal("0.10")) - ) - cost_expensive = calculate_session_cost( - _make_context(), _make_get_price(Decimal("0.20")) - ) - self.assertGreater(cost_cheap, cost_expensive) + with self._patch_model(): + cost_cheap = _call(_usage(), _make_get_price(Decimal("0.10"))) + cost_expensive = _call(_usage(), _make_get_price(Decimal("0.20"))) + self.assertGreater(cost_cheap.cost_opg, cost_expensive.cost_opg) def test_uses_current_price_on_each_call(self): """get_price is called fresh every invocation — price changes are picked up.""" get_price = MagicMock(side_effect=[Decimal("0.10"), Decimal("0.20")]) - with self._patch_definitions(), self._patch_model(): - cost_first = calculate_session_cost(_make_context(), get_price) - cost_second = calculate_session_cost(_make_context(), get_price) + with self._patch_model(): + cost_first = _call(_usage(), get_price) + cost_second = _call(_usage(), get_price) self.assertEqual(get_price.call_count, 2) # Price doubled → cost should halve (same USD spend, twice the token price). - self.assertGreater(cost_first, cost_second) + self.assertGreater(cost_first.cost_opg, cost_second.cost_opg) + + +class TestSessionCostWireRoundTrip(unittest.TestCase): + """SessionCost must round-trip through its JSON wire form so the chat + handler (writer) and x402's _session_cost_calculator (reader) agree.""" + + def test_round_trip_preserves_values(self): + import json + + cost = SessionCost( + # Above 2^53 — the whole reason wire form is strings, not ints. + cost_opg=12345678901234567890, + cost_usd=Decimal("0.00342100"), + opg_price_usd=Decimal("0.123456"), + ) + wire = json.loads(json.dumps(cost.model_dump(mode="json"))) + parsed = SessionCost.model_validate(wire) + self.assertEqual(parsed.cost_opg, cost.cost_opg) + self.assertEqual(parsed.cost_usd, cost.cost_usd) + self.assertEqual(parsed.opg_price_usd, cost.opg_price_usd) + + def test_validate_rejects_missing_block(self): + with self.assertRaises(Exception): + SessionCost.model_validate(None) if __name__ == "__main__": diff --git a/tee_gateway/util.py b/tee_gateway/util.py index 1c9047e..1264a1f 100644 --- a/tee_gateway/util.py +++ b/tee_gateway/util.py @@ -1,11 +1,6 @@ import datetime from tee_gateway import typing_utils -import logging -from decimal import Decimal, InvalidOperation, ROUND_CEILING -from typing import Any, Callable - -logger = logging.getLogger("llm_server.dynamic_pricing") def _deserialize(data, klass): @@ -151,169 +146,3 @@ def _deserialize_dict(data, boxed_type): :rtype: dict """ return {k: _deserialize(v, boxed_type) for k, v in data.items()} - - -from tee_gateway.definitions import ( # noqa: E402 - ASSET_DECIMALS_BY_ADDRESS, -) -from tee_gateway.model_registry import get_model_config # noqa: E402 - - -def _as_dict(value: Any) -> dict[str, Any] | None: - if value is None: - return None - if isinstance(value, dict): - return value - if hasattr(value, "model_dump"): - try: - dumped = value.model_dump(by_alias=True, exclude_none=True) - if isinstance(dumped, dict): - return dumped - except Exception: - pass - if hasattr(value, "to_dict"): - try: - dumped = value.to_dict() - if isinstance(dumped, dict): - return dumped - except Exception: - pass - return None - - -def _to_decimal(value: Any) -> Decimal | None: - if value is None: - return None - try: - return Decimal(str(value)) - except (InvalidOperation, ValueError, TypeError): - return None - - -def _normalize_model_name(model: str | None) -> str | None: - if not model: - return None - return str(model).strip().lower() - - -def _extract_usage_tokens( - response_json: dict[str, Any] | None, -) -> tuple[int, int]: - """Extract (input_tokens, output_tokens) from response JSON. - - Raises ValueError if usage data is missing or malformed — no silent fallback. - """ - if not isinstance(response_json, dict): - raise ValueError("response_json is not a dict; cannot extract usage tokens") - usage = response_json.get("usage") - if not isinstance(usage, dict): - raise ValueError( - "response_json has no 'usage' dict; cannot extract usage tokens" - ) - - prompt_tokens = usage.get("prompt_tokens", usage.get("input_tokens")) - completion_tokens = usage.get("completion_tokens", usage.get("output_tokens")) - if prompt_tokens is None or completion_tokens is None: - raise ValueError(f"usage dict is missing token counts: {usage!r}") - - try: - return max(0, int(prompt_tokens)), max(0, int(completion_tokens)) - except (TypeError, ValueError) as exc: - raise ValueError(f"Could not parse token counts from usage: {usage!r}") from exc - - -def _extract_model_from_context( - request_json: dict[str, Any] | None, - response_json: dict[str, Any] | None, -) -> str: - """Extract and normalize model name from request JSON. - - Uses only the request model name — the response model field is ignored - because providers may return a versioned alias that differs from the - user-facing name. Raises ValueError if the model name is absent. - """ - if not isinstance(request_json, dict): - raise ValueError("request_json is not a dict; cannot extract model name") - req_model = request_json.get("model") - if not req_model: - raise ValueError("request_json has no 'model' field") - normalized = _normalize_model_name(req_model) - if not normalized: - raise ValueError(f"model name normalizes to empty string: {req_model!r}") - return normalized - - -def _extract_asset_decimals_from_requirements(payment_requirements: Any) -> int: - req = _as_dict(payment_requirements) or {} - - asset = req.get("asset") - if not asset and isinstance(req.get("price"), dict): - asset = req["price"].get("asset") - - if not isinstance(asset, str) or not asset: - raise ValueError( - f"payment_requirements has no recognizable asset address; " - f"cannot determine token decimals: {req!r}" - ) - - asset_lower = asset.lower() - if asset_lower not in ASSET_DECIMALS_BY_ADDRESS: - raise ValueError( - f"Unknown asset address {asset!r}; not in ASSET_DECIMALS_BY_ADDRESS. " - f"Add it to definitions.py before accepting payments with this token." - ) - return ASSET_DECIMALS_BY_ADDRESS[asset_lower] - - -def calculate_session_cost( - context: dict[str, Any], get_price: Callable[[], Decimal] -) -> int: - """Calculate the x402 session cost in token smallest units for a completed request. - - ``get_price`` is called on every invocation to fetch the current OPG/USD - price — pass ``price_feed.get_price`` so the latest cached value is used. - Raises ``ValueError`` on any missing/invalid data. Predictable failures - (unavailable price, unknown model) are blocked before inference by the - pre-inference gate in ``__main__.py``; post-inference failures are logged - as CRITICAL by the caller and the client is not charged. - """ - request_json = context.get("request_json") - response_json = context.get("response_json") - - if not isinstance(request_json, dict) or not isinstance(response_json, dict): - raise ValueError( - "calculate_session_cost requires both request_json and response_json" - ) - - model = _extract_model_from_context(request_json, response_json) - cfg = get_model_config(model) - input_tokens, output_tokens = _extract_usage_tokens(response_json) - - total_usd = (Decimal(input_tokens) * cfg.input_price_usd) + ( - Decimal(output_tokens) * cfg.output_price_usd - ) - token_price_usd = get_price() - if token_price_usd <= 0: - raise ValueError(f"Token price is non-positive: {token_price_usd}") - - token_amount = total_usd / token_price_usd - decimals = _extract_asset_decimals_from_requirements( - context.get("payment_requirements") - ) - scale = Decimal(10) ** decimals - cost_smallest_units = int( - (token_amount * scale).to_integral_value(rounding=ROUND_CEILING) - ) - - logger.info( - "DYNAMIC_SESSION_COST model=%s input_tokens=%d output_tokens=%d " - "total_usd=%s token_price_usd=%s decimals=%d cost=%d", - model, - input_tokens, - output_tokens, - str(total_usd), - str(token_price_usd), - decimals, - cost_smallest_units, - ) - return max(0, cost_smallest_units) diff --git a/tests/test_opengradient_field.py b/tests/test_opengradient_field.py new file mode 100644 index 0000000..a26c88c --- /dev/null +++ b/tests/test_opengradient_field.py @@ -0,0 +1,229 @@ +"""Verify the `opengradient` cost block is embedded on responses. + +These tests are the only thing keeping `compute_session_cost`'s result from +silently going missing on a controller response — if that block is absent, +x402's `_session_cost_calculator` swallows the error and the client is never +charged. The runtime CRITICAL log is the safety net; this is the unit-test +catch. +""" + +import json +import unittest +from decimal import Decimal +from unittest.mock import MagicMock, patch + +from tee_gateway.models.create_chat_completion_request import ( + CreateChatCompletionRequest, +) +from tee_gateway.models import ChatCompletionRequestUserMessage +from tee_gateway.models.create_completion_request import CreateCompletionRequest +from tee_gateway.pricing import SessionCost + + +_FAKE_USAGE = {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} + + +def _fake_cost() -> SessionCost: + return SessionCost( + cost_opg=12345, + cost_usd=Decimal("0.001"), + opg_price_usd=Decimal("0.5"), + ) + + +def _fake_tee_keys() -> MagicMock: + keys = MagicMock() + keys.sign_data.return_value = "0xsig" + keys.get_tee_id.return_value = "deadbeef" + return keys + + +def _chat_request() -> CreateChatCompletionRequest: + return CreateChatCompletionRequest( + model="gpt-4.1-mini", + messages=[ChatCompletionRequestUserMessage(role="user", content="hi")], + stream=False, + ) + + +class TestChatNonStreamingOpengradient(unittest.TestCase): + def test_opengradient_block_embedded_when_cost_computed(self): + from tee_gateway.controllers import chat_controller + + fake_response = MagicMock() + fake_response.content = "hello" + fake_response.tool_calls = None + fake_model = MagicMock() + fake_model.invoke.return_value = fake_response + + with ( + patch.object( + chat_controller, "get_chat_model_cached", return_value=fake_model + ), + patch.object(chat_controller, "extract_usage", return_value=_FAKE_USAGE), + patch.object( + chat_controller, "compute_session_cost", return_value=_fake_cost() + ), + patch.object( + chat_controller, "get_tee_keys", return_value=_fake_tee_keys() + ), + patch.object( + chat_controller, + "compute_tee_msg_hash", + return_value=(b"h", "ih", "oh"), + ), + ): + resp = chat_controller._create_non_streaming_response(_chat_request()) + + self.assertIsInstance(resp, dict) + self.assertIn("opengradient", resp) + self.assertEqual(resp["opengradient"]["cost_opg"], "12345") + + def test_opengradient_block_absent_when_compute_returns_none(self): + from tee_gateway.controllers import chat_controller + + fake_response = MagicMock() + fake_response.content = "hello" + fake_response.tool_calls = None + fake_model = MagicMock() + fake_model.invoke.return_value = fake_response + + with ( + patch.object( + chat_controller, "get_chat_model_cached", return_value=fake_model + ), + patch.object(chat_controller, "extract_usage", return_value=_FAKE_USAGE), + patch.object(chat_controller, "compute_session_cost", return_value=None), + patch.object( + chat_controller, "get_tee_keys", return_value=_fake_tee_keys() + ), + patch.object( + chat_controller, + "compute_tee_msg_hash", + return_value=(b"h", "ih", "oh"), + ), + ): + resp = chat_controller._create_non_streaming_response(_chat_request()) + + self.assertNotIn("opengradient", resp) + + +class TestChatStreamingOpengradient(unittest.TestCase): + def test_final_sse_event_carries_opengradient(self): + from tee_gateway.controllers import chat_controller + + chunk = MagicMock() + chunk.content = "hello" + chunk.tool_call_chunks = [] + chunk.usage_metadata = { + "input_tokens": 10, + "output_tokens": 20, + "total_tokens": 30, + } + fake_model = MagicMock() + fake_model.stream.return_value = iter([chunk]) + + with ( + patch.object( + chat_controller, "get_chat_model_cached", return_value=fake_model + ), + patch.object( + chat_controller, "get_provider_from_model", return_value="openai" + ), + patch.object( + chat_controller, "compute_session_cost", return_value=_fake_cost() + ), + patch.object( + chat_controller, "get_tee_keys", return_value=_fake_tee_keys() + ), + patch.object( + chat_controller, + "compute_tee_msg_hash", + return_value=(b"h", "ih", "oh"), + ), + ): + req = _chat_request() + req.stream = True + response = chat_controller._create_streaming_response(req) + chunks = [ + c.decode() if isinstance(c, bytes) else c for c in response.response + ] + + # The final SSE data event before [DONE] carries the opengradient block. + data_events = [ + c for c in chunks if c.startswith("data: ") and "[DONE]" not in c + ] + final = json.loads(data_events[-1][len("data: ") :].strip()) + self.assertIn("opengradient", final) + self.assertEqual(final["opengradient"]["cost_opg"], "12345") + + +class TestCompletionsOpengradient(unittest.TestCase): + def test_opengradient_block_embedded_on_completion(self): + from tee_gateway.controllers import completions_controller + + fake_response = MagicMock() + fake_response.content = "world" + fake_model = MagicMock() + fake_model.invoke.return_value = fake_response + + fake_request = MagicMock() + fake_request.is_json = True + fake_request.get_json.return_value = { + "model": "gpt-4.1-mini", + "prompt": "hi", + } + + body = CreateCompletionRequest(model="gpt-4.1-mini", prompt="hi") + + with ( + patch("connexion.request", fake_request), + patch.object( + completions_controller, + "get_chat_model_cached", + return_value=fake_model, + ), + patch.object( + completions_controller, "extract_usage", return_value=_FAKE_USAGE + ), + patch.object( + completions_controller, + "compute_session_cost", + return_value=_fake_cost(), + ), + patch.object( + completions_controller, + "get_tee_keys", + return_value=_fake_tee_keys(), + ), + patch.object( + completions_controller, + "compute_tee_msg_hash", + return_value=(b"h", "ih", "oh"), + ), + ): + resp = completions_controller.create_completion(body) + + self.assertIsInstance(resp, dict) + self.assertIn("opengradient", resp) + self.assertEqual(resp["opengradient"]["cost_opg"], "12345") + + +class TestSessionCostCalculatorMissingBlock(unittest.TestCase): + def test_critical_log_when_opengradient_missing(self): + from tee_gateway import __main__ as gateway_main + + with self.assertLogs(gateway_main.logger, level="CRITICAL") as cm: + with self.assertRaises(Exception): + gateway_main._session_cost_calculator( + {"response_json": {"id": "chatcmpl-x"}} + ) + + self.assertTrue( + any("opengradient cost block missing" in msg for msg in cm.output), + f"expected CRITICAL about missing opengradient, got: {cm.output}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 0b0c57d..99fa4b2 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -3,50 +3,40 @@ Tests verify that: - Every user-facing model name resolves to the correct ModelConfig - - calculate_session_cost produces the right amount in OPG token + - compute_session_cost produces the right amount in OPG token smallest-units for supported models - - Edge cases (no usage, unknown model, bad context) are handled correctly + - Edge cases (unknown model) are handled correctly """ import unittest from decimal import Decimal +from types import SimpleNamespace +from unittest.mock import patch -from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS from tee_gateway.model_registry import ( _MODEL_LOOKUP, get_model_config, ) -from tee_gateway.util import calculate_session_cost +from tee_gateway.pricing import compute_session_cost + # All pricing tests assume OPG = $1.00 so USD cost == OPG token amount. _OPG_PRICE_USD = Decimal("1") -_get_price = lambda: _OPG_PRICE_USD # noqa: E731 - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- -def _opg_requirements() -> dict: - """Fake PaymentRequirements dict for OPG (18 decimals).""" - return {"asset": BASE_MAINNET_OPG_ADDRESS, "amount": "50000000000000000"} - - -def _ctx(model: str, input_tokens: int, output_tokens: int, requirements=None) -> dict: - """Build a minimal calculator context.""" - return { - "request_json": {"model": model, "messages": []}, - "response_json": { - "model": model, - "usage": { - "prompt_tokens": input_tokens, - "completion_tokens": output_tokens, - "total_tokens": input_tokens + output_tokens, - }, - }, - "payment_requirements": requirements or _opg_requirements(), +def _calc_opg(model: str, input_tokens: int, output_tokens: int) -> int: + """Call compute_session_cost with the test price feed and return the OPG + integer. Returns -1 when the function returns None so tests can assert on + failure paths without raising.""" + usage = { + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, } + fake_feed = SimpleNamespace(get_price=lambda: _OPG_PRICE_USD) + with patch("tee_gateway.price_feed.get_price_feed", return_value=fake_feed): + result = compute_session_cost(model, usage) + return -1 if result is None else result.cost_opg def _expected_cost_opg(model: str, input_tokens: int, output_tokens: int) -> int: @@ -326,12 +316,10 @@ def test_unknown_sonnet_variant_raises(self): class TestCalculateSessionCostOPG(unittest.TestCase): - """calculate_session_cost with OPG (18 decimals).""" + """compute_session_cost with OPG (18 decimals).""" def _calc(self, model, input_tokens, output_tokens): - return calculate_session_cost( - _ctx(model, input_tokens, output_tokens, _opg_requirements()), _get_price - ) + return _calc_opg(model, input_tokens, output_tokens) # ── OpenAI ────────────────────────────────────────────────────────────── @@ -578,84 +566,28 @@ def test_grok_4_fast_cheaper_than_grok_4(self): class TestCalculateSessionCostEdgeCases(unittest.TestCase): - """Edge cases for calculate_session_cost.""" + """Edge cases for compute_session_cost.""" def test_zero_tokens_returns_zero(self): - cost = calculate_session_cost(_ctx("claude-sonnet-4-5", 0, 0), _get_price) - self.assertEqual(cost, 0) - - def test_missing_usage_raises(self): - ctx = { - "request_json": {"model": "claude-sonnet-4-5"}, - "response_json": {"model": "claude-sonnet-4-5"}, # no usage - "payment_requirements": _opg_requirements(), - } - with self.assertRaises(ValueError): - calculate_session_cost(ctx, _get_price) + self.assertEqual(_calc_opg("claude-sonnet-4-5", 0, 0), 0) - def test_unknown_asset_raises(self): - ctx = _ctx("claude-sonnet-4-5", 100, 100) - ctx["payment_requirements"] = {"asset": "0xdeadbeef", "amount": "1000"} - with self.assertRaises(ValueError): - calculate_session_cost(ctx, _get_price) - - def test_missing_asset_raises(self): - ctx = _ctx("claude-sonnet-4-5", 100, 100) - ctx["payment_requirements"] = {"amount": "1000"} # no asset - with self.assertRaises(ValueError): - calculate_session_cost(ctx, _get_price) - - def test_unknown_model_raises_value_error(self): - ctx = _ctx("gpt-4o", 100, 100) - with self.assertRaises(ValueError): - calculate_session_cost(ctx, _get_price) - - def test_missing_request_json_raises_value_error(self): - ctx = { - "request_json": None, - "response_json": { - "model": "claude-sonnet-4-5", - "usage": {"prompt_tokens": 100, "completion_tokens": 100}, - }, - "payment_requirements": _opg_requirements(), - } - with self.assertRaises(ValueError): - calculate_session_cost(ctx, _get_price) - - def test_model_from_request_takes_priority(self): - """request_json model name is used even if response_json has a different model.""" - ctx = { - "request_json": {"model": "claude-haiku-4-5"}, - "response_json": { - "model": "claude-sonnet-4-5", # response says Sonnet - "usage": {"prompt_tokens": 1000, "completion_tokens": 500}, - }, - "payment_requirements": _opg_requirements(), - } - cost = calculate_session_cost(ctx, _get_price) - # Should be priced as Haiku (from request), not Sonnet - haiku_cost = _expected_cost_opg("claude-haiku-4-5", 1000, 500) - self.assertEqual(cost, haiku_cost) + def test_unknown_model_returns_none(self): + # gpt-4o is not in the registry — get_model_config raises, caught and + # returned as None. + self.assertEqual(_calc_opg("gpt-4o", 100, 100), -1) def test_rounding_ceiling(self): """Fractional token costs are always rounded UP.""" - # 1 output token of Haiku: 0.000005 USD = 5e12 wei — exact, no rounding needed - cost = calculate_session_cost(_ctx("claude-haiku-4-5", 0, 1), _get_price) - self.assertEqual(cost, 5_000_000_000_000) - + # 1 output token of Haiku: 0.000005 USD = 5e12 wei — exact + self.assertEqual(_calc_opg("claude-haiku-4-5", 0, 1), 5_000_000_000_000) # 1 input token of Gemini Flash Lite: 0.0000001 USD = 1e11 wei — exact - cost = calculate_session_cost(_ctx("gemini-2.5-flash-lite", 1, 0), _get_price) - self.assertEqual(cost, 100_000_000_000) + self.assertEqual(_calc_opg("gemini-2.5-flash-lite", 1, 0), 100_000_000_000) def test_model_name_case_insensitive(self): - """Model names are normalized to lowercase before lookup.""" - cost_lower = calculate_session_cost( - _ctx("claude-sonnet-4-5", 100, 100), _get_price - ) - cost_upper = calculate_session_cost( - _ctx("CLAUDE-SONNET-4-5", 100, 100), _get_price + self.assertEqual( + _calc_opg("claude-sonnet-4-5", 100, 100), + _calc_opg("CLAUDE-SONNET-4-5", 100, 100), ) - self.assertEqual(cost_lower, cost_upper) if __name__ == "__main__": diff --git a/uv.lock b/uv.lock index 4d8b6f1..d51f8a1 100644 --- a/uv.lock +++ b/uv.lock @@ -1540,6 +1540,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyhpke" +version = "0.6.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/27/37/1acb2cee5afd3dcf45b425b0d984a9cba8917fd935106ef278b42062ecfa/pyhpke-0.6.4.tar.gz", hash = "sha256:1402c6c41a0605941d2d2a589774d346c0e7a0dc7f745e84c6f0a06c2fd335c9", size = 1638147, upload-time = "2025-12-21T10:38:07.556Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/f6/ff7df9e21b38ec1c827efd90c28b3bc76eddbfdf5a44aaf2fadb59a17cb9/pyhpke-0.6.4-py3-none-any.whl", hash = "sha256:abd0b2fec1424858399ffbed0d236fb7e9740dece9907f59ca40bd567d7fef78", size = 23792, upload-time = "2025-12-21T10:38:06.172Z" }, +] + [[package]] name = "pyparsing" version = "3.3.2" @@ -1849,6 +1861,7 @@ dependencies = [ { name = "openai" }, { name = "psutil" }, { name = "pydantic" }, + { name = "pyhpke" }, { name = "python-dateutil" }, { name = "python-dotenv" }, { name = "requests" }, @@ -1895,6 +1908,7 @@ requires-dist = [ { name = "openai", specifier = ">=2.15.0" }, { name = "psutil", specifier = ">=7.2.1" }, { name = "pydantic", specifier = ">=2.12.5" }, + { name = "pyhpke", specifier = ">=0.6.0" }, { name = "python-dateutil", specifier = ">=2.6.0" }, { name = "python-dotenv", specifier = ">=1.2.1" }, { name = "requests", specifier = ">=2.32.5" },