1+ (ns eca.llm-providers.aws-bedrock
2+ " AWS Bedrock provider implementation using Converse/ConverseStream APIs.
3+
4+ AUTHENTICATION:
5+ This implementation uses Bearer token authentication, which requires
6+ an external proxy/gateway that handles AWS SigV4 signing.
7+
8+ Set BEDROCK_API_KEY environment variable or configure :key in config.clj
9+ with a token provided by your authentication proxy.
10+
11+ ENDPOINTS:
12+ - Standard: https://your-proxy.com/model/{modelId}/converse
13+ - Streaming: https://your-proxy.com/model/{modelId}/converse-stream
14+
15+ Configure the :url in your provider config to point to your proxy endpoint."
16+ (:require
17+ [cheshire.core :as json]
18+ [clojure.string :as str]
19+ [eca.logger :as logger]
20+ [hato.client :as http])
21+ (:import (java.io DataInputStream BufferedInputStream ByteArrayInputStream)))
22+
23+ ; ; --- Helper Functions ---
24+
25+ (defn resolve-model-id
26+ " Resolve model ID from configuration."
27+ [model-alias config]
28+ (let [keyword-alias (keyword model-alias)
29+ model-config (get-in config [:models keyword-alias])]
30+ (or (:modelName model-config)
31+ (name model-alias))))
32+
33+ (defn format-tool-spec [tool]
34+ (let [f (:function tool)]
35+ {:toolSpec {:name (:name f)
36+ :description (:description f)
37+ ; ; AWS requires inputSchema wrapped in "json" key
38+ :inputSchema {:json (:parameters f)}}}))
39+
40+ (defn format-tool-config [tools]
41+ (let [tools-seq (if (sequential? tools) tools [tools])]
42+ (when (seq tools-seq)
43+ {:tools (mapv format-tool-spec tools-seq)})))
44+
45+ (defn parse-tool-result [content tool-call-id is-error?]
46+ (let [inner-content (try
47+ (if is-error?
48+ [{:text (str content)}]
49+ [{:json (json/parse-string content true )}])
50+ (catch Exception _
51+ [{:text (str content)}]))]
52+ {:toolUseId tool-call-id
53+ :content inner-content
54+ :status (if is-error? " error" " success" )}))
55+
56+ (defn message->bedrock [msg]
57+ (case (:role msg)
58+ " tool"
59+ {:role " user"
60+ :content [(parse-tool-result (:content msg)
61+ (:tool_call_id msg)
62+ (:error msg))]}
63+
64+ " assistant"
65+ {:role " assistant"
66+ :content (if (:tool_calls msg)
67+ (mapv (fn [tc]
68+ {:toolUse {:toolUseId (:id tc)
69+ :name (get-in tc [:function :name ])
70+ :input (json/parse-string
71+ (get-in tc [:function :arguments ]) keyword)}})
72+ (:tool_calls msg))
73+ [{:text (:content msg)}])}
74+
75+ ; ; Default/User
76+ {:role " user"
77+ :content [{:text (:content msg)}]}))
78+
79+ (defn build-payload [messages options]
80+ (let [system-prompts (filter #(= (:role %) " system" ) messages)
81+ conversation (->> messages
82+ (remove #(= (:role %) " system" ))
83+ (mapv message->bedrock))
84+ system-blocks (mapv (fn [m] {:text (:content m)}) system-prompts)
85+
86+ ; ; Base inference config
87+ base-config {:maxTokens (or (:max_tokens options) (:maxTokens options) 1024 )
88+ :temperature (or (:temperature options) 0.7 )
89+ :topP (or (:top_p options) (:topP options) 1.0 )}
90+
91+ ; ; Additional model-specific fields (e.g., top_k for Claude)
92+ additional-fields (select-keys options [:top_k :topK ])]
93+
94+ (cond-> {:messages conversation
95+ :inferenceConfig (merge base-config
96+ (select-keys options [:stopSequences ]))}
97+ (seq system-blocks)
98+ (assoc :system system-blocks)
99+
100+ (:tools options)
101+ (assoc :toolConfig (format-tool-config (:tools options)))
102+
103+ ; ; Add additionalModelRequestFields if present
104+ (seq additional-fields)
105+ (assoc :additionalModelRequestFields
106+ (into {} (map (fn [[k v]] [(name k) v]) additional-fields))))))
107+
108+ (defn parse-bedrock-response [body]
109+ (let [response (json/parse-string body true )
110+ output-msg (get-in response [:output :message ])
111+ stop-reason (:stopReason response)
112+ content (:content output-msg)
113+ usage (:usage response)]
114+
115+ ; ; Log token usage if present
116+ (when usage
117+ (logger/debug " Token usage" {:input (:inputTokens usage)
118+ :output (:outputTokens usage)
119+ :total (:totalTokens usage)}))
120+
121+ (if (= stop-reason " tool_use" )
122+ (let [tool-blocks (filter :toolUse content)
123+ tool-calls (mapv (fn [b]
124+ (let [t (:toolUse b)]
125+ {:id (:toolUseId t)
126+ :type " function"
127+ :function {:name (:name t)
128+ :arguments (json/generate-string (:input t))}}))
129+ tool-blocks)]
130+ {:role " assistant" :content nil :tool_calls tool-calls})
131+
132+ (let [text (-> (filter :text content) first :text )]
133+ {:role " assistant" :content text}))))
134+
135+ ; ; --- Binary Stream Parser ---
136+
137+ (defn parse-event-stream
138+ " Parses AWS Event Stream (Binary format) from a raw InputStream.
139+
140+ AWS Event Stream Protocol:
141+ - Prelude: Total Length (4) + Headers Length (4)
142+ - Headers: Variable length
143+ - Headers CRC: 4 bytes
144+ - Payload: Variable length
145+ - Message CRC: 4 bytes"
146+ [^java.io.InputStream input-stream]
147+ (let [dis (DataInputStream. (BufferedInputStream. input-stream))]
148+ (lazy-seq
149+ (try
150+ ; ; 1. Read Prelude (8 bytes, Big Endian)
151+ (let [total-len (.readInt dis)
152+ headers-len (.readInt dis)]
153+
154+ ; ; 2. Read and skip headers
155+ (when (> headers-len 0 )
156+ (let [header-bytes (byte-array headers-len)]
157+ (.readFully dis header-bytes)))
158+
159+ ; ; 3. Skip headers CRC (4 bytes)
160+ (.skipBytes dis 4 )
161+
162+ ; ; 4. Calculate and read payload
163+ ; ; total-len = prelude(8) + headers + headers-crc(4) + payload + message-crc(4)
164+ (let [payload-len (- total-len 8 headers-len 4 4 )
165+ payload-bytes (byte-array payload-len)]
166+
167+ (when (> payload-len 0 )
168+ (.readFully dis payload-bytes))
169+
170+ ; ; 5. Skip message CRC (4 bytes)
171+ (.skipBytes dis 4 )
172+
173+ ; ; 6. Parse JSON payload if present
174+ (if (> payload-len 0 )
175+ (let [payload-str (String. payload-bytes " UTF-8" )
176+ event (json/parse-string payload-str true )]
177+ (cons event (parse-event-stream dis)))
178+ ; ; Empty payload (heartbeat), continue to next event
179+ (parse-event-stream dis))))
180+
181+ (catch java.io.EOFException _ nil )
182+ (catch Exception e
183+ (logger/debug " Stream parsing error" e)
184+ nil )))))
185+
186+ (defn extract-text-deltas
187+ " Takes the sequence of parsed JSON events and extracts text content.
188+ Handles empty events (heartbeats) gracefully."
189+ [events]
190+ (vec (keep (fn [event]
191+ (when-let [delta (get-in event [:contentBlockDelta :delta ])]
192+ (:text delta)))
193+ events)))
194+
195+ ; ; --- Endpoint Construction ---
196+
197+ (defn- build-endpoint
198+ " Constructs the API endpoint URL with model ID interpolation."
199+ [config model-id stream?]
200+ (let [raw-url (:url config)
201+ region (or (:region config) " us-east-1" )
202+ suffix (if stream? " converse-stream" " converse" )]
203+ (if raw-url
204+ ; ; Interpolate {modelId} in custom proxy URLs
205+ (str/replace raw-url " {modelId}" model-id)
206+ ; ; Construct standard AWS URL
207+ (format " https://bedrock-runtime.%s.amazonaws.com/model/%s/%s"
208+ region model-id suffix))))
209+
210+ ; ; --- Public API Functions ---
211+
212+ (defn chat! [config callbacks]
213+ (let [token (or (:key config) (System/getenv " BEDROCK_API_KEY" ))
214+ model-id (resolve-model-id (:model config) config)
215+ endpoint (build-endpoint config model-id false )
216+ timeout (or (:timeout config) 30000 )
217+ headers {" Authorization" (str " Bearer " token)
218+ " Content-Type" " application/json" }
219+ payload (build-payload (:user-messages config) (:extra-payload config))
220+
221+ {:keys [status body error]} (http/post endpoint
222+ {:headers headers
223+ :body (json/generate-string payload)
224+ :timeout timeout})]
225+ (if (and (not error) (= 200 status))
226+ (let [response (parse-bedrock-response body)
227+ {:keys [on-message-received on-error on-prepare-tool-call on-tools-called on-usage-updated]} callbacks]
228+ (if-let [tool-calls (:tool_calls response)]
229+ (do
230+ (on-prepare-tool-call tool-calls)
231+ {:tools-to-call tool-calls})
232+ (do
233+ (on-message-received {:type :text :text (:content response)})
234+ {:output-text (:content response)})))
235+ (do
236+ (logger/error " Bedrock API error" {:status status :error error :body body})
237+ (throw (ex-info " Bedrock API error" {:status status :body body}))))))
238+
239+ (defn stream-chat! [config callbacks]
240+ (let [token (or (:key config) (System/getenv " BEDROCK_API_KEY" ))
241+ model-id (resolve-model-id (:model config) config)
242+ endpoint (build-endpoint config model-id true )
243+ timeout (or (:timeout config) 30000 )
244+ headers {" Authorization" (str " Bearer " token)
245+ " Content-Type" " application/json" }
246+ payload (build-payload (:user-messages config) (:extra-payload config))
247+
248+ {:keys [status body error]} (http/post endpoint
249+ {:headers headers
250+ :body (json/generate-string payload)
251+ :timeout timeout})]
252+ (if (and (not error) (= 200 status))
253+ (let [{:keys [on-message-received on-error]} callbacks
254+ events (parse-event-stream body)
255+ texts (extract-text-deltas events)]
256+ (doseq [text texts]
257+ (on-message-received {:type :text :text text}))
258+ {:output-text (str/join " " texts)})
259+ (do
260+ (logger/error " Bedrock Stream API error" {:status status :error error})
261+ (throw (ex-info " Bedrock Stream API error" {:status status}))))))
0 commit comments