Skip to content

Commit 7c355a9

Browse files
Enforce Cedar policies on optimizer find_tool and call_tool
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9f7d1a8 commit 7c355a9

19 files changed

Lines changed: 1121 additions & 132 deletions

cmd/vmcp/app/commands.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,8 @@ func runServe(cmd *cobra.Command, _ []string) error {
456456

457457
slog.Info(fmt.Sprintf("Setting up incoming authentication (type: %s)", cfg.IncomingAuth.Type))
458458

459-
authMiddleware, authzMiddleware, authInfoHandler, err := factory.NewIncomingAuthMiddleware(ctx, cfg.IncomingAuth)
459+
authMiddleware, cedarAuthorizer, authInfoHandler, err :=
460+
factory.NewIncomingAuthMiddleware(ctx, cfg.IncomingAuth)
460461
if err != nil {
461462
return fmt.Errorf("failed to create authentication middleware: %w", err)
462463
}
@@ -539,7 +540,7 @@ func runServe(cmd *cobra.Command, _ []string) error {
539540
Host: host,
540541
Port: port,
541542
AuthMiddleware: authMiddleware,
542-
AuthzMiddleware: authzMiddleware,
543+
CedarAuthorizer: cedarAuthorizer,
543544
AuthInfoHandler: authInfoHandler,
544545
TelemetryProvider: telemetryProvider,
545546
AuditConfig: cfg.Audit,

pkg/authz/config.go

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,40 @@ var LoadConfig = authorizers.LoadConfig
2626
// NewConfig is an alias for authorizers.NewConfig for backward compatibility.
2727
var NewConfig = authorizers.NewConfig
2828

29-
// CreateMiddlewareFromConfig creates an HTTP middleware from the configuration.
30-
func CreateMiddlewareFromConfig(c *Config, serverName string) (types.MiddlewareFunction, error) {
31-
// Get the factory for this config type
29+
// CreateAuthorizerFromConfig creates an Authorizer from the configuration.
30+
func CreateAuthorizerFromConfig(c *Config, serverName string) (authorizers.Authorizer, error) {
3231
factory := authorizers.GetFactory(string(c.Type))
3332
if factory == nil {
3433
return nil, fmt.Errorf("unsupported configuration type: %s", c.Type)
3534
}
3635

37-
// Create the authorizer using the factory, passing the full raw config
3836
authz, err := factory.CreateAuthorizer(c.RawConfig(), serverName)
3937
if err != nil {
4038
return nil, fmt.Errorf("failed to create %s authorizer: %w", c.Type, err)
4139
}
4240

43-
// Return the middleware
44-
return func(handler http.Handler) http.Handler { return Middleware(authz, handler) }, nil
41+
return authz, nil
42+
}
43+
44+
// CreateMiddlewareFromAuthorizer wraps an existing Authorizer as HTTP middleware.
45+
// The passThroughTools parameter is optional; tool names in this set bypass the
46+
// response filter's policy check in tools/list responses. This is used when the
47+
// optimizer is enabled: its meta-tools (find_tool, call_tool) would otherwise be
48+
// rejected by Cedar default-deny since no policy references them by name.
49+
// Authorization for the underlying backend tools is handled inside the optimizer
50+
// decorator, so letting the meta-tools pass through is safe.
51+
func CreateMiddlewareFromAuthorizer(a authorizers.Authorizer, passThroughTools map[string]struct{}) types.MiddlewareFunction {
52+
return func(handler http.Handler) http.Handler { return Middleware(a, handler, passThroughTools) }
53+
}
54+
55+
// CreateMiddlewareFromConfig creates an HTTP middleware from the configuration.
56+
func CreateMiddlewareFromConfig(c *Config, serverName string) (types.MiddlewareFunction, error) {
57+
authz, err := CreateAuthorizerFromConfig(c, serverName)
58+
if err != nil {
59+
return nil, err
60+
}
61+
62+
return CreateMiddlewareFromAuthorizer(authz, nil), nil
4563
}
4664

4765
// GetMiddlewareFromFile loads the authorization configuration from a file and creates an HTTP middleware.

pkg/authz/config_test.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,3 +438,93 @@ func TestCreateMiddlewareFromConfigErrors(t *testing.T) {
438438
assert.Contains(t, err.Error(), "unsupported configuration type")
439439
})
440440
}
441+
442+
func TestCreateAuthorizerFromConfig(t *testing.T) {
443+
t.Parallel()
444+
445+
t.Run("valid config returns authorizer", func(t *testing.T) {
446+
t.Parallel()
447+
448+
config := mustNewConfig(t, cedar.Config{
449+
Version: "1.0",
450+
Type: cedar.ConfigType,
451+
Options: &cedar.ConfigOptions{
452+
Policies: []string{`permit(principal, action == Action::"call_tool", resource == Tool::"weather");`},
453+
EntitiesJSON: "[]",
454+
},
455+
})
456+
457+
a, err := CreateAuthorizerFromConfig(config, "testserver")
458+
require.NoError(t, err)
459+
require.NotNil(t, a)
460+
461+
// Verify the authorizer works by checking a tool call
462+
ctx := auth.WithIdentity(t.Context(), &auth.Identity{
463+
PrincipalInfo: auth.PrincipalInfo{Subject: "user1", Claims: map[string]interface{}{"sub": "user1"}},
464+
})
465+
authorized, err := a.AuthorizeWithJWTClaims(ctx, "tool", "call", "weather", nil)
466+
require.NoError(t, err)
467+
assert.True(t, authorized)
468+
})
469+
470+
t.Run("unsupported config type returns error", func(t *testing.T) {
471+
t.Parallel()
472+
473+
config := &Config{
474+
Version: "1.0",
475+
Type: "unsupported-type",
476+
}
477+
478+
a, err := CreateAuthorizerFromConfig(config, "testserver")
479+
assert.Error(t, err)
480+
assert.Nil(t, a)
481+
assert.Contains(t, err.Error(), "unsupported configuration type")
482+
})
483+
}
484+
485+
func TestCreateMiddlewareFromAuthorizer(t *testing.T) {
486+
t.Parallel()
487+
488+
config := mustNewConfig(t, cedar.Config{
489+
Version: "1.0",
490+
Type: cedar.ConfigType,
491+
Options: &cedar.ConfigOptions{
492+
Policies: []string{`permit(principal, action == Action::"call_tool", resource == Tool::"weather");`},
493+
EntitiesJSON: "[]",
494+
},
495+
})
496+
497+
a, err := CreateAuthorizerFromConfig(config, "testserver")
498+
require.NoError(t, err)
499+
500+
mw := CreateMiddlewareFromAuthorizer(a, nil)
501+
require.NotNil(t, mw)
502+
503+
// Verify middleware wraps a handler correctly
504+
handlerCalled := false
505+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
506+
handlerCalled = true
507+
w.WriteHeader(http.StatusOK)
508+
})
509+
510+
handler := mcpparser.ParsingMiddleware(mw(testHandler))
511+
512+
// Build a ping request (always allowed)
513+
body, err := json.Marshal(map[string]interface{}{
514+
"jsonrpc": "2.0", "id": 1, "method": "ping", "params": map[string]interface{}{},
515+
})
516+
require.NoError(t, err)
517+
518+
req, err := http.NewRequest(http.MethodPost, "/messages", bytes.NewBuffer(body))
519+
require.NoError(t, err)
520+
req.Header.Set("Content-Type", "application/json")
521+
522+
identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user1", Claims: map[string]interface{}{"sub": "user1"}}}
523+
req = req.WithContext(auth.WithIdentity(req.Context(), identity))
524+
525+
rr := httptest.NewRecorder()
526+
handler.ServeHTTP(rr, req)
527+
528+
assert.True(t, handlerCalled)
529+
assert.Equal(t, http.StatusOK, rr.Code)
530+
}

pkg/authz/integration_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ func TestIntegrationListFiltering(t *testing.T) {
257257
})
258258

259259
// Apply the middleware chain: MCP parsing first, then authorization
260-
middleware := mcpparser.ParsingMiddleware(Middleware(authorizer, mockHandler))
260+
middleware := mcpparser.ParsingMiddleware(Middleware(authorizer, mockHandler, nil))
261261

262262
// Execute the request through the middleware
263263
middleware.ServeHTTP(rr, req)
@@ -426,7 +426,7 @@ func TestIntegrationNonListOperations(t *testing.T) {
426426
})
427427

428428
// Apply the middleware chain: MCP parsing first, then authorization
429-
middleware := mcpparser.ParsingMiddleware(Middleware(authorizer, mockHandler))
429+
middleware := mcpparser.ParsingMiddleware(Middleware(authorizer, mockHandler, nil))
430430

431431
// Execute the request through the middleware
432432
middleware.ServeHTTP(rr, req)

pkg/authz/middleware.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ func handleUnauthorized(w http.ResponseWriter, msgID interface{}, err error) {
171171
// The authorizer parameter should implement the authorizers.Authorizer interface,
172172
// which can be created using authz.CreateMiddlewareFromConfig() or directly
173173
// from an authorizer package (e.g., cedar.NewCedarAuthorizer()).
174-
func Middleware(a authorizers.Authorizer, next http.Handler) http.Handler {
174+
func Middleware(a authorizers.Authorizer, next http.Handler, passThroughTools map[string]struct{}) http.Handler {
175175
// Cache is shared across requests for the same proxy.
176176
// Populated from tools/list responses, read during tools/call.
177177
annotationCache := NewAnnotationCache()
@@ -218,7 +218,7 @@ func Middleware(a authorizers.Authorizer, next http.Handler) http.Handler {
218218
if featureOp.Operation == authorizers.MCPOperationList {
219219

220220
// Create a response filtering writer to intercept and filter the response
221-
filteringWriter := NewResponseFilteringWriter(w, a, r, parsedRequest.Method, annotationCache)
221+
filteringWriter := NewResponseFilteringWriter(w, a, r, parsedRequest.Method, annotationCache, passThroughTools)
222222

223223
// Call the next handler with the filtering writer
224224
next.ServeHTTP(filteringWriter, r)

pkg/authz/middleware_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ func TestMiddleware(t *testing.T) {
398398
})
399399

400400
// Apply the middleware chain: MCP parsing first, then authorization
401-
middleware := mcpparser.ParsingMiddleware(Middleware(authorizer, handler))
401+
middleware := mcpparser.ParsingMiddleware(Middleware(authorizer, handler, nil))
402402

403403
// Serve the request
404404
middleware.ServeHTTP(rr, req)
@@ -430,7 +430,7 @@ func TestMiddlewareWithGETRequest(t *testing.T) {
430430
})
431431

432432
// Apply the middleware chain: MCP parsing first, then authorization
433-
middleware := mcpparser.ParsingMiddleware(Middleware(authorizer, handler))
433+
middleware := mcpparser.ParsingMiddleware(Middleware(authorizer, handler, nil))
434434

435435
// Create a GET request
436436
req, err := http.NewRequest(http.MethodGet, "/messages", nil)
@@ -807,7 +807,7 @@ func TestMiddlewareToolsListTestkit(t *testing.T) {
807807
})
808808
},
809809
mcpparser.ParsingMiddleware,
810-
func(h http.Handler) http.Handler { return Middleware(authorizer, h) },
810+
func(h http.Handler) http.Handler { return Middleware(authorizer, h, nil) },
811811
))
812812
server, client, err := testkit.NewStreamableTestServer(opts...)
813813
require.NoError(t, err)
@@ -977,7 +977,7 @@ func TestMiddlewareToolsCallTestkit(t *testing.T) {
977977
})
978978
},
979979
mcpparser.ParsingMiddleware,
980-
func(h http.Handler) http.Handler { return Middleware(authorizer, h) },
980+
func(h http.Handler) http.Handler { return Middleware(authorizer, h, nil) },
981981
))
982982
server, client, err := testkit.NewStreamableTestServer(opts...)
983983
require.NoError(t, err)

pkg/authz/response_filter.go

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -24,28 +24,33 @@ var errBug = errors.New("there's a bug")
2424
// ResponseFilteringWriter wraps an http.ResponseWriter to intercept and filter responses
2525
type ResponseFilteringWriter struct {
2626
http.ResponseWriter
27-
authorizer authorizers.Authorizer
28-
request *http.Request
29-
method string
30-
buffer *bytes.Buffer
31-
statusCode int
32-
annotationCache *AnnotationCache
27+
authorizer authorizers.Authorizer
28+
request *http.Request
29+
method string
30+
buffer *bytes.Buffer
31+
statusCode int
32+
annotationCache *AnnotationCache
33+
passThroughTools map[string]struct{}
3334
}
3435

3536
// NewResponseFilteringWriter creates a new response filtering writer.
3637
// The annotationCache parameter is optional; pass nil to disable annotation caching.
38+
// The passThroughTools parameter is optional; tools whose names appear in this set
39+
// bypass policy filtering because authorization is enforced elsewhere (e.g., inside
40+
// the optimizer decorator for find_tool/call_tool).
3741
func NewResponseFilteringWriter(
3842
w http.ResponseWriter, authorizer authorizers.Authorizer, r *http.Request, method string,
39-
annotationCache *AnnotationCache,
43+
annotationCache *AnnotationCache, passThroughTools map[string]struct{},
4044
) *ResponseFilteringWriter {
4145
return &ResponseFilteringWriter{
42-
ResponseWriter: w,
43-
authorizer: authorizer,
44-
request: r,
45-
method: method,
46-
buffer: &bytes.Buffer{},
47-
statusCode: http.StatusOK,
48-
annotationCache: annotationCache,
46+
ResponseWriter: w,
47+
authorizer: authorizer,
48+
request: r,
49+
method: method,
50+
buffer: &bytes.Buffer{},
51+
statusCode: http.StatusOK,
52+
annotationCache: annotationCache,
53+
passThroughTools: passThroughTools,
4954
}
5055
}
5156

@@ -283,39 +288,32 @@ func (rfw *ResponseFilteringWriter) filterToolsResponse(response *jsonrpc2.Respo
283288
// subsequent tools/call requests can look up annotations.
284289
rfw.annotationCache.SetFromToolsList(listResult.Tools)
285290

286-
// Note: instantiating the list ensures that no null value is sent over the wire.
287-
// This is basically defensive programming, but for clients.
288-
filteredTools := []mcp.Tool{}
289-
for i, tool := range listResult.Tools {
290-
// Inject this tool's annotations into the context so Cedar policies
291-
// that use when clauses on resource attributes (e.g. resource.readOnlyHint)
292-
// can evaluate correctly. Without this, the authorization check runs
293-
// against a context with no annotations and all when clauses fail.
294-
ctx := rfw.request.Context()
295-
ann := &listResult.Tools[i].Annotations
296-
if hasAnyHint(ann) {
297-
ctx = authorizers.WithToolAnnotations(ctx, convertMCPAnnotation(ann))
298-
}
299-
300-
// Check if the user is authorized to call this tool
301-
authorized, err := rfw.authorizer.AuthorizeWithJWTClaims(
302-
ctx,
303-
authorizers.MCPFeatureTool,
304-
authorizers.MCPOperationCall,
305-
tool.Name,
306-
nil, // No arguments for the authorization check
307-
)
308-
if err != nil {
309-
slog.Warn("Authorization check failed for tool, skipping",
310-
"tool", tool.Name, "error", err)
311-
continue
312-
}
313-
314-
if authorized {
315-
filteredTools = append(filteredTools, tool)
291+
// When the optimizer is enabled, its meta-tools (find_tool, call_tool) appear
292+
// in tools/list instead of real backend tools. These meta-tools won't match
293+
// any operator-written Cedar policy (which references real tool names), so
294+
// default-deny would filter them out — leaving the client with zero tools.
295+
// Authorization for the underlying backend tools is enforced inside the
296+
// optimizer decorator itself (find_tool filters results, call_tool gates
297+
// invocations), so the meta-tools can safely pass through the response filter.
298+
// See: https://github.com/stacklok/toolhive/issues/4373
299+
passThrough := []mcp.Tool{}
300+
regular := []mcp.Tool{}
301+
for _, t := range listResult.Tools {
302+
if _, ok := rfw.passThroughTools[t.Name]; ok {
303+
passThrough = append(passThrough, t)
304+
} else {
305+
regular = append(regular, t)
316306
}
317307
}
318308

309+
// FilterToolsByPolicy checks each tool against the caller's Cedar policies
310+
// (injecting annotations into context for when-clause evaluation) and returns
311+
// only tools the caller is authorized to call.
312+
policyFiltered := FilterToolsByPolicy(rfw.request.Context(), rfw.authorizer, regular)
313+
filteredTools := make([]mcp.Tool, 0, len(passThrough)+len(policyFiltered))
314+
filteredTools = append(filteredTools, passThrough...)
315+
filteredTools = append(filteredTools, policyFiltered...)
316+
319317
// Create a new result with filtered tools
320318
filteredResult := mcp.ListToolsResult{
321319
PaginatedResult: listResult.PaginatedResult,

0 commit comments

Comments
 (0)