Skip to content

Commit a265d11

Browse files
Copilotlpcox
andauthored
perf: pre-compute allowedToolSets at init time for O(1) lookup in isToolAllowed and registerToolsFromBackend
Agent-Logs-Url: https://github.com/github/gh-aw-mcpg/sessions/00582754-0d76-4bb6-a950-0cd6ff534f1b Co-authored-by: lpcox <15877973+lpcox@users.noreply.github.com>
1 parent dc17bd8 commit a265d11

3 files changed

Lines changed: 50 additions & 42 deletions

File tree

internal/server/call_backend_tool_test.go

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -401,13 +401,14 @@ func TestIsToolAllowed(t *testing.T) {
401401

402402
for _, tc := range tests {
403403
t.Run(tc.name, func(t *testing.T) {
404-
us := &UnifiedServer{
405-
cfg: &config.Config{
406-
Servers: map[string]*config.ServerConfig{
407-
"s": {Tools: tc.tools},
408-
},
404+
cfg := &config.Config{
405+
Servers: map[string]*config.ServerConfig{
406+
"s": {Tools: tc.tools},
409407
},
410408
}
409+
us := &UnifiedServer{
410+
allowedToolSets: buildAllowedToolSets(cfg),
411+
}
411412
got := us.isToolAllowed("s", tc.toolName)
412413
assert.Equal(t, tc.wantAllow, got)
413414
})
@@ -416,17 +417,14 @@ func TestIsToolAllowed(t *testing.T) {
416417

417418
// TestIsToolAllowed_NilConfig verifies that a nil config allows all tools.
418419
func TestIsToolAllowed_NilConfig(t *testing.T) {
419-
us := &UnifiedServer{cfg: nil}
420+
us := &UnifiedServer{allowedToolSets: buildAllowedToolSets(nil)}
420421
assert.True(t, us.isToolAllowed("s", "anything"), "nil cfg should allow all tools")
421422
}
422423

423424
// TestIsToolAllowed_UnknownServer verifies that an unknown server ID allows all tools.
424425
func TestIsToolAllowed_UnknownServer(t *testing.T) {
425-
us := &UnifiedServer{
426-
cfg: &config.Config{
427-
Servers: map[string]*config.ServerConfig{},
428-
},
429-
}
426+
cfg := &config.Config{Servers: map[string]*config.ServerConfig{}}
427+
us := &UnifiedServer{allowedToolSets: buildAllowedToolSets(cfg)}
430428
assert.True(t, us.isToolAllowed("unknown", "tool"), "unknown server should allow all tools")
431429
}
432430

internal/server/tool_registry.go

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -185,25 +185,19 @@ func (us *UnifiedServer) registerToolsFromBackend(serverID string) error {
185185
// Filter tools by the server's allowed-tools list (if configured).
186186
// This prevents non-allowed tools from appearing in tools/list responses
187187
// and is defense-in-depth alongside the callBackendTool enforcement.
188-
if us.cfg != nil {
189-
if serverCfg, ok := us.cfg.Servers[serverID]; ok && len(serverCfg.Tools) > 0 {
190-
allowedSet := make(map[string]bool, len(serverCfg.Tools))
191-
for _, t := range serverCfg.Tools {
192-
allowedSet[t] = true
188+
if allowedSet, ok := us.allowedToolSets[serverID]; ok && len(allowedSet) > 0 {
189+
n := 0
190+
for _, tool := range listResult.Tools {
191+
if allowedSet[tool.Name] {
192+
listResult.Tools[n] = tool
193+
n++
193194
}
194-
n := 0
195-
for _, tool := range listResult.Tools {
196-
if allowedSet[tool.Name] {
197-
listResult.Tools[n] = tool
198-
n++
199-
}
200-
}
201-
if n < len(listResult.Tools) {
202-
logger.LogInfo("backend", "[allowed-tools] Filtered %d tools from %s: keeping %d of %d",
203-
len(listResult.Tools)-n, serverID, n, len(listResult.Tools))
204-
}
205-
listResult.Tools = listResult.Tools[:n]
206195
}
196+
if n < len(listResult.Tools) {
197+
logger.LogInfo("backend", "[allowed-tools] Filtered %d tools from %s: keeping %d of %d",
198+
len(listResult.Tools)-n, serverID, n, len(listResult.Tools))
199+
}
200+
listResult.Tools = listResult.Tools[:n]
207201
}
208202

209203
// Collect tools for logging

internal/server/unified.go

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ type UnifiedServer struct {
9090
payloadPathPrefix string // Path prefix to use when returning payloadPath to clients (allows remapping host paths to client/agent container paths)
9191
payloadSizeThreshold int // Size threshold (in bytes) for storing payloads to disk. Payloads larger than this are stored to disk, smaller ones are returned inline.
9292

93+
// allowedToolSets holds a pre-computed set of allowed tool names per server ID.
94+
// Built once during NewUnified from the config Tools lists. A missing or nil entry
95+
// means all tools are permitted for that server.
96+
allowedToolSets map[string]map[string]bool
97+
9398
// DIFC components
9499
guardRegistry *guard.Registry
95100
agentRegistry *difc.AgentRegistry
@@ -144,6 +149,7 @@ func NewUnified(ctx context.Context, cfg *config.Config) (*UnifiedServer, error)
144149
payloadDir: payloadDir,
145150
payloadPathPrefix: payloadPathPrefix,
146151
payloadSizeThreshold: payloadSizeThreshold,
152+
allowedToolSets: buildAllowedToolSets(cfg),
147153

148154
// Initialize DIFC components
149155
guardRegistry: guard.NewRegistry(),
@@ -371,25 +377,35 @@ func newErrorCallToolResult(err error) (*sdk.CallToolResult, interface{}, error)
371377
}, nil, err
372378
}
373379

380+
// buildAllowedToolSets converts the per-server Tools lists from the config into pre-computed
381+
// map[string]bool sets for O(1) lookup. Servers with no Tools list are not added to the map,
382+
// which signals that all tools are permitted.
383+
func buildAllowedToolSets(cfg *config.Config) map[string]map[string]bool {
384+
sets := make(map[string]map[string]bool)
385+
if cfg == nil {
386+
return sets
387+
}
388+
for serverID, serverCfg := range cfg.Servers {
389+
if len(serverCfg.Tools) > 0 {
390+
set := make(map[string]bool, len(serverCfg.Tools))
391+
for _, t := range serverCfg.Tools {
392+
set[t] = true
393+
}
394+
sets[serverID] = set
395+
}
396+
}
397+
return sets
398+
}
399+
374400
// isToolAllowed reports whether toolName is permitted by the server's configured
375401
// allowed-tools list. When no list is configured (empty), all tools are allowed.
402+
// Uses the pre-computed allowedToolSets map for O(1) lookup.
376403
func (us *UnifiedServer) isToolAllowed(serverID, toolName string) bool {
377-
if us.cfg == nil {
404+
set, ok := us.allowedToolSets[serverID]
405+
if !ok || set == nil {
378406
return true
379407
}
380-
serverCfg, ok := us.cfg.Servers[serverID]
381-
if !ok {
382-
return true
383-
}
384-
if len(serverCfg.Tools) == 0 {
385-
return true
386-
}
387-
for _, allowed := range serverCfg.Tools {
388-
if allowed == toolName {
389-
return true
390-
}
391-
}
392-
return false
408+
return set[toolName]
393409
}
394410

395411
// callBackendTool calls a tool on a backend server with DIFC enforcement

0 commit comments

Comments
 (0)