Skip to content

Commit 0440fb9

Browse files
committed
feat: support multi provider
1 parent 7093b73 commit 0440fb9

10 files changed

Lines changed: 197 additions & 171 deletions

File tree

cmd/codebot/main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"os"
77

88
"github.com/voocel/codebot/internal/bootstrap"
9+
"github.com/voocel/codebot/internal/config"
910
"github.com/voocel/codebot/internal/ui"
1011
)
1112

@@ -41,7 +42,7 @@ func main() {
4142
return
4243
}
4344

44-
modelName := rt.Settings.DefaultModel
45+
modelName := config.FormatModelID(rt.Settings.Provider, rt.Settings.Model)
4546
if rt.Session != nil && rt.Session.ModelName() != "" {
4647
modelName = rt.Session.ModelName()
4748
}

internal/agent/compaction.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,11 @@ func (s *Session) compactWithReason(reason string) error {
3232
s.mu.Lock()
3333
prov := s.provider
3434
model := s.modelName
35-
apiKey := s.apiKey
36-
baseURL := s.baseURL
3735
ctxWindow := s.settings.ContextWindow
3836
store := s.store
3937
s.mu.Unlock()
4038

39+
apiKey, baseURL := s.resolveCredentials(prov)
4140
compactModel, err := s.createModel(prov, model, apiKey, baseURL)
4241
if err != nil {
4342
return fmt.Errorf("create compaction model: %w", err)

internal/agent/session.go

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ type Session struct {
5151

5252
provider string
5353
modelName string
54-
apiKey string
55-
baseURL string
54+
providers map[string]config.ProviderConfig
5655
cwd string
5756

5857
createModel ModelFactory
@@ -96,10 +95,9 @@ func NewSession(cfg SessionConfig) *Session {
9695
mgr: cfg.Manager,
9796
registry: cfg.Registry,
9897
settings: cfg.Settings,
99-
provider: cfg.Settings.DefaultProvider,
100-
modelName: cfg.Settings.DefaultModel,
101-
apiKey: cfg.Settings.APIKey,
102-
baseURL: cfg.Settings.BaseURL,
98+
provider: cfg.Settings.Provider,
99+
modelName: cfg.Settings.Model,
100+
providers: cfg.Settings.Providers,
103101
cwd: cfg.Cwd,
104102
createModel: modelFactory,
105103
lazyPersist: cfg.LazyPersist,
@@ -176,8 +174,8 @@ func (s *Session) ClearConversation() {
176174
// --------------------------------------------------------------------------
177175

178176
// SetModel switches the LLM model and persists the change.
179-
func (s *Session) SetModel(prov, model, apiKey string) error {
180-
_, baseURL := s.resolveCredentials(prov)
177+
func (s *Session) SetModel(prov, model string) error {
178+
apiKey, baseURL := s.resolveCredentials(prov)
181179
s.mu.Lock()
182180
store := s.store
183181
s.mu.Unlock()
@@ -197,8 +195,6 @@ func (s *Session) SetModel(prov, model, apiKey string) error {
197195
s.mu.Lock()
198196
s.provider = prov
199197
s.modelName = model
200-
s.apiKey = apiKey
201-
s.baseURL = baseURL
202198
s.mu.Unlock()
203199

204200
s.emit(SessionEvent{
@@ -224,9 +220,7 @@ func (s *Session) ResolveAndSetModel(pattern string) (string, error) {
224220
return "", err
225221
}
226222

227-
apiKey, _ := s.resolveCredentials(entry.Provider)
228-
229-
if err := s.SetModel(entry.Provider, entry.ID, apiKey); err != nil {
223+
if err := s.SetModel(entry.Provider, entry.ID); err != nil {
230224
return "", err
231225
}
232226
if thinkingLevel != "" {
@@ -282,25 +276,26 @@ func (s *Session) Provider() string {
282276
return s.provider
283277
}
284278

285-
// APIKey returns the current API key.
279+
// APIKey returns the API key for the current provider.
286280
func (s *Session) APIKey() string {
287-
s.mu.Lock()
288-
defer s.mu.Unlock()
289-
return s.apiKey
281+
apiKey, _ := s.resolveCredentials(s.Provider())
282+
return apiKey
290283
}
291284

292-
// BaseURL returns the current base URL.
285+
// BaseURL returns the base URL for the current provider.
293286
func (s *Session) BaseURL() string {
294-
s.mu.Lock()
295-
defer s.mu.Unlock()
296-
return s.baseURL
287+
_, baseURL := s.resolveCredentials(s.Provider())
288+
return baseURL
297289
}
298290

299-
// resolveCredentials returns the API key and base URL from settings.
291+
// resolveCredentials returns the API key and base URL for a provider.
300292
func (s *Session) resolveCredentials(prov string) (apiKey, baseURL string) {
301293
s.mu.Lock()
302294
defer s.mu.Unlock()
303-
return s.settings.APIKey, s.settings.BaseURL
295+
if pc, ok := s.providers[prov]; ok {
296+
return pc.APIKey, pc.BaseURL
297+
}
298+
return "", ""
304299
}
305300

306301
// --------------------------------------------------------------------------
@@ -406,8 +401,6 @@ func (s *Session) SwitchSession(id string) error {
406401
s.store = newStore
407402
s.provider = targetProvider
408403
s.modelName = targetModel
409-
s.apiKey = targetKey
410-
s.baseURL = targetBase
411404
s.autoNamed = newStore.Header().Name != ""
412405
if snapshot.Thinking != "" {
413406
clamped := snapshot.Thinking

internal/agent/session_test.go

Lines changed: 50 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,14 @@ func TestSwitchSessionKeepsCurrentStateOnModelRestoreFailure(t *testing.T) {
7777
Store: current,
7878
Manager: mgr,
7979
Settings: config.Resolved{
80-
DefaultProvider: "openai",
81-
DefaultModel: "good-model",
82-
APIKey: "k",
83-
ContextWindow: 128000,
84-
AutoCompaction: false,
85-
MaxTurns: 30,
80+
Provider: "openai",
81+
Model: "good-model",
82+
Providers: map[string]config.ProviderConfig{
83+
"openai": {APIKey: "k"},
84+
},
85+
ContextWindow: 128000,
86+
AutoCompaction: false,
87+
MaxTurns: 30,
8688
},
8789
Cwd: dir,
8890
CreateModel: func(_ string, model string, _ string, _ string) (agentcore.ChatModel, error) {
@@ -140,12 +142,14 @@ func TestSetModelKeepsStateWhenPersistFails(t *testing.T) {
140142
Store: store,
141143
Manager: mgr,
142144
Settings: config.Resolved{
143-
DefaultProvider: "openai",
144-
DefaultModel: "good-model",
145-
APIKey: "k",
146-
ContextWindow: 128000,
147-
AutoCompaction: false,
148-
MaxTurns: 30,
145+
Provider: "openai",
146+
Model: "good-model",
147+
Providers: map[string]config.ProviderConfig{
148+
"openai": {APIKey: "k"},
149+
},
150+
ContextWindow: 128000,
151+
AutoCompaction: false,
152+
MaxTurns: 30,
149153
},
150154
Cwd: dir,
151155
CreateModel: func(_ string, _ string, _ string, _ string) (agentcore.ChatModel, error) {
@@ -165,7 +169,7 @@ func TestSetModelKeepsStateWhenPersistFails(t *testing.T) {
165169
oldProvider := s.Provider()
166170
oldModel := s.ModelName()
167171

168-
err = s.SetModel("openai", "new-model", "k")
172+
err = s.SetModel("openai", "new-model")
169173
if err == nil {
170174
t.Fatalf("expected set model failure")
171175
}
@@ -280,49 +284,45 @@ func TestRestoreAllToolsRebuildsPrompt(t *testing.T) {
280284
}
281285
}
282286

283-
func TestResolveCredentialsDefaultProvider(t *testing.T) {
287+
func TestResolveCredentialsPerProvider(t *testing.T) {
284288
t.Parallel()
285289
ag := agentcore.NewAgent(agentcore.WithModel(&stubChatModel{}))
286290
s := NewSession(SessionConfig{
287291
Agent: ag,
288292
Settings: config.Resolved{
289-
DefaultProvider: "openai",
290-
APIKey: "openai-key",
291-
BaseURL: "https://openai.example.com",
293+
Provider: "openai",
294+
Model: "gpt-5",
295+
Providers: map[string]config.ProviderConfig{
296+
"openai": {APIKey: "openai-key", BaseURL: "https://openai.example.com"},
297+
"anthropic": {APIKey: "ant-key"},
298+
},
292299
},
293300
Cwd: t.TempDir(),
294301
})
295302
t.Cleanup(s.Close)
296303

304+
// Default provider
297305
apiKey, baseURL := s.resolveCredentials("openai")
298306
if apiKey != "openai-key" {
299307
t.Fatalf("expected openai-key, got %s", apiKey)
300308
}
301309
if baseURL != "https://openai.example.com" {
302310
t.Fatalf("expected https://openai.example.com, got %s", baseURL)
303311
}
304-
}
305312

306-
func TestResolveCredentialsCrossProviderUsesMainCredentials(t *testing.T) {
307-
t.Parallel()
308-
ag := agentcore.NewAgent(agentcore.WithModel(&stubChatModel{}))
309-
s := NewSession(SessionConfig{
310-
Agent: ag,
311-
Settings: config.Resolved{
312-
DefaultProvider: "openai",
313-
APIKey: "openai-key",
314-
BaseURL: "https://openai.example.com",
315-
},
316-
Cwd: t.TempDir(),
317-
})
318-
t.Cleanup(s.Close)
319-
320-
apiKey, baseURL := s.resolveCredentials("anthropic")
321-
if apiKey != "openai-key" {
322-
t.Fatalf("expected openai-key, got %s", apiKey)
313+
// Cross-provider resolves its own credentials
314+
apiKey, baseURL = s.resolveCredentials("anthropic")
315+
if apiKey != "ant-key" {
316+
t.Fatalf("expected ant-key, got %s", apiKey)
323317
}
324-
if baseURL != "https://openai.example.com" {
325-
t.Fatalf("expected https://openai.example.com, got %s", baseURL)
318+
if baseURL != "" {
319+
t.Fatalf("expected empty baseURL for anthropic, got %s", baseURL)
320+
}
321+
322+
// Unknown provider returns empty
323+
apiKey, baseURL = s.resolveCredentials("unknown")
324+
if apiKey != "" || baseURL != "" {
325+
t.Fatalf("expected empty for unknown provider, got %s/%s", apiKey, baseURL)
326326
}
327327
}
328328

@@ -355,12 +355,14 @@ func TestSwitchSessionCrossProviderCredentials(t *testing.T) {
355355
Store: current,
356356
Manager: mgr,
357357
Settings: config.Resolved{
358-
DefaultProvider: "openai",
359-
DefaultModel: "gpt-5",
360-
APIKey: "openai-key",
361-
BaseURL: "https://openai.example.com",
362-
ContextWindow: 128000,
363-
MaxTurns: 30,
358+
Provider: "openai",
359+
Model: "gpt-5",
360+
Providers: map[string]config.ProviderConfig{
361+
"openai": {APIKey: "openai-key", BaseURL: "https://openai.example.com"},
362+
"anthropic": {APIKey: "ant-key"},
363+
},
364+
ContextWindow: 128000,
365+
MaxTurns: 30,
364366
},
365367
Cwd: dir,
366368
CreateModel: func(_, _ string, apiKey, baseURL string) (agentcore.ChatModel, error) {
@@ -375,18 +377,12 @@ func TestSwitchSessionCrossProviderCredentials(t *testing.T) {
375377
t.Fatalf("switch session: %v", err)
376378
}
377379

378-
// Cross-provider uses the same credentials from settings.
379-
if capturedKey != "openai-key" {
380-
t.Fatalf("expected CreateModel to receive openai-key, got %s", capturedKey)
381-
}
382-
if capturedBase != "https://openai.example.com" {
383-
t.Fatalf("expected https://openai.example.com, got %s", capturedBase)
384-
}
385-
if s.APIKey() != "openai-key" {
386-
t.Fatalf("expected session apiKey=openai-key, got %s", s.APIKey())
380+
// Cross-provider uses anthropic's own credentials.
381+
if capturedKey != "ant-key" {
382+
t.Fatalf("expected CreateModel to receive ant-key, got %s", capturedKey)
387383
}
388-
if s.BaseURL() != "https://openai.example.com" {
389-
t.Fatalf("expected session baseURL=https://openai.example.com, got %s", s.BaseURL())
384+
if capturedBase != "" {
385+
t.Fatalf("expected empty baseURL for anthropic, got %s", capturedBase)
390386
}
391387
}
392388

internal/bootstrap/boot.go

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,11 @@ func Boot(opts Options) (*Runtime, error) {
8787
createModel = provider.CreateModel
8888
}
8989

90-
// Interactive setup when API key is missing.
91-
if settings.APIKey == "" {
90+
// Interactive setup when API key is missing for the active provider.
91+
apiKey, _ := settings.ProviderCredentials(settings.Provider)
92+
if apiKey == "" {
9293
if opts.NonTTYMode {
93-
return nil, fmt.Errorf("api key not set, configure api_key in %s",
94+
return nil, fmt.Errorf("api key not set, configure providers in %s",
9495
config.SettingsPath(cwd))
9596
}
9697
err := config.RunSetup(cwd, settings, func(prov string) []config.ModelOption {
@@ -127,16 +128,15 @@ func Boot(opts Options) (*Runtime, error) {
127128
}
128129
}
129130

130-
activeProvider := settings.DefaultProvider
131+
activeProvider := settings.Provider
131132
if snapshot.Provider != "" {
132133
activeProvider = snapshot.Provider
133134
}
134-
activeModel := settings.DefaultModel
135+
activeModel := settings.Model
135136
if snapshot.Model != "" {
136137
activeModel = snapshot.Model
137138
}
138-
activeAPIKey := settings.APIKey
139-
activeBaseURL := settings.BaseURL
139+
activeAPIKey, activeBaseURL := settings.ProviderCredentials(activeProvider)
140140

141141
chatModel, err := createModel(activeProvider, activeModel, activeAPIKey, activeBaseURL)
142142
if err != nil {
@@ -170,8 +170,7 @@ func Boot(opts Options) (*Runtime, error) {
170170
AllTools: builtTools,
171171
CreateModel: createModel,
172172
Provider: activeProvider,
173-
APIKey: activeAPIKey,
174-
BaseURL: activeBaseURL,
173+
Providers: settings.Providers,
175174
SmallModel: settings.SmallModel,
176175
})
177176
builtTools = append(builtTools, subagentTool)
@@ -243,10 +242,8 @@ func Boot(opts Options) (*Runtime, error) {
243242
ag.SetThinkingLevel(agentcore.ThinkingLevel(settings.ThinkingLevel))
244243
}
245244

246-
settings.DefaultProvider = activeProvider
247-
settings.DefaultModel = activeModel
248-
settings.APIKey = activeAPIKey
249-
settings.BaseURL = activeBaseURL
245+
settings.Provider = activeProvider
246+
settings.Model = activeModel
250247

251248
sess := agent.NewSession(agent.SessionConfig{
252249
Agent: ag,

0 commit comments

Comments
 (0)