Skip to content

Commit 6b0ff32

Browse files
JAORMXclaude
andauthored
Wire server discovery protocol into thv serve (#4319)
* Wire server discovery protocol into thv serve The discovery package (pkg/server/discovery/) was already implemented and tested but had zero imports in the codebase. This wires it into the serve command so clients (CLI, Studio) can auto-discover a running server without hardcoded ports or environment variables. On startup, thv serve now generates a cryptographic nonce, writes a discovery file to $XDG_STATE_HOME/toolhive/server/server.json with the actual listen URL (supporting port 0 and Unix sockets), and returns the nonce via the X-Toolhive-Nonce health check header. On shutdown the file is removed. The skills client now tries discovery before falling back to the TOOLHIVE_API_URL env var or the default localhost:8080, with loopback and socket-path validation on discovered URLs. Additional fixes: SIGTERM handling in the serve command, a 30-second shutdown timeout (was unbounded), symlink rejection on the discovery file read path, directory permission tightening after MkdirAll, and constant-time nonce comparison. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Juan Antonio Osorio <ozz@stacklok.com> * Address review feedback on server discovery - Wrap writeDiscoveryFile check-then-write in WithFileLock to prevent TOCTOU race when two servers start simultaneously - Log FindProcess errors at Debug level instead of silently discarding - Consolidate ListenURL tests into a table-driven test - Rename healtcheck_test.go to healthcheck_test.go (fix typo) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Juan Antonio Osorio <ozz@stacklok.com> * Create discovery directory before acquiring lock file The discovery lock file is created in the same directory as server.json, but the directory may not exist on a fresh system. MkdirAll was called inside the lock callback (via WriteServerInfo), but the lock acquisition itself needs the directory to already exist. Create the directory before calling WithFileLock so the lock file can be written. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Address review feedback on discovery wiring - Extract shared HTTPClientForURL in the discovery package to deduplicate transport setup between health.go and the skills client - Propagate context.Context through NewDefaultClient and resolveViaDiscovery instead of using context.Background() - Add comment explaining intentional opts-shadowing order so caller-supplied options can override discovery defaults - Use url.JoinPath in buildHealthClient instead of string concatenation Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Signed-off-by: Juan Antonio Osorio <ozz@stacklok.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9c8f5ec commit 6b0ff32

21 files changed

Lines changed: 1339 additions & 110 deletions

cmd/thv/app/skill_build.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func skillBuildCmdFunc(cmd *cobra.Command, args []string) error {
3939
return fmt.Errorf("failed to resolve path: %w", err)
4040
}
4141

42-
c := newSkillClient()
42+
c := newSkillClient(cmd.Context())
4343

4444
result, err := c.Build(cmd.Context(), skills.BuildOptions{
4545
Path: absPath,

cmd/thv/app/skill_helpers.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package app
55

66
import (
7+
"context"
78
"errors"
89
"fmt"
910

@@ -14,8 +15,9 @@ import (
1415
)
1516

1617
// newSkillClient creates a new Skills API HTTP client using default settings.
17-
func newSkillClient() *skillclient.Client {
18-
return skillclient.NewDefaultClient()
18+
// The context is used for server discovery; it is not stored.
19+
func newSkillClient(ctx context.Context) *skillclient.Client {
20+
return skillclient.NewDefaultClient(ctx)
1921
}
2022

2123
// completeSkillNames provides shell completion for installed skill names.
@@ -24,7 +26,7 @@ func completeSkillNames(cmd *cobra.Command, args []string, _ string) ([]string,
2426
return nil, cobra.ShellCompDirectiveNoFileComp
2527
}
2628

27-
c := newSkillClient()
29+
c := newSkillClient(cmd.Context())
2830
installed, err := c.List(cmd.Context(), skills.ListOptions{})
2931
if err != nil {
3032
return nil, cobra.ShellCompDirectiveError

cmd/thv/app/skill_info.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ func init() {
4343
}
4444

4545
func skillInfoCmdFunc(cmd *cobra.Command, args []string) error {
46-
c := newSkillClient()
46+
c := newSkillClient(cmd.Context())
4747

4848
info, err := c.Info(cmd.Context(), skills.InfoOptions{
4949
Name: args[0],

cmd/thv/app/skill_install.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func init() {
4242
}
4343

4444
func skillInstallCmdFunc(cmd *cobra.Command, args []string) error {
45-
c := newSkillClient()
45+
c := newSkillClient(cmd.Context())
4646

4747
_, err := c.Install(cmd.Context(), skills.InstallOptions{
4848
Name: args[0],

cmd/thv/app/skill_list.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func init() {
4747
}
4848

4949
func skillListCmdFunc(cmd *cobra.Command, _ []string) error {
50-
c := newSkillClient()
50+
c := newSkillClient(cmd.Context())
5151

5252
installed, err := c.List(cmd.Context(), skills.ListOptions{
5353
Scope: skills.Scope(skillListScope),

cmd/thv/app/skill_push.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func init() {
2222
}
2323

2424
func skillPushCmdFunc(cmd *cobra.Command, args []string) error {
25-
c := newSkillClient()
25+
c := newSkillClient(cmd.Context())
2626

2727
err := c.Push(cmd.Context(), skills.PushOptions{
2828
Reference: args[0],

cmd/thv/app/skill_uninstall.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func init() {
3939
}
4040

4141
func skillUninstallCmdFunc(cmd *cobra.Command, args []string) error {
42-
c := newSkillClient()
42+
c := newSkillClient(cmd.Context())
4343

4444
err := c.Uninstall(cmd.Context(), skills.UninstallOptions{
4545
Name: args[0],

cmd/thv/app/skill_validate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func skillValidateCmdFunc(cmd *cobra.Command, args []string) error {
3737
return fmt.Errorf("failed to resolve path: %w", err)
3838
}
3939

40-
c := newSkillClient()
40+
c := newSkillClient(cmd.Context())
4141

4242
result, err := c.Validate(cmd.Context(), absPath)
4343
if err != nil {

pkg/api/server.go

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ package api
1717

1818
import (
1919
"context"
20+
"crypto/rand"
21+
"encoding/hex"
2022
"errors"
2123
"fmt"
2224
"io"
@@ -39,9 +41,11 @@ import (
3941
"github.com/stacklok/toolhive/pkg/config"
4042
"github.com/stacklok/toolhive/pkg/container"
4143
"github.com/stacklok/toolhive/pkg/container/runtime"
44+
"github.com/stacklok/toolhive/pkg/fileutils"
4245
"github.com/stacklok/toolhive/pkg/groups"
4346
"github.com/stacklok/toolhive/pkg/recovery"
4447
"github.com/stacklok/toolhive/pkg/registry"
48+
"github.com/stacklok/toolhive/pkg/server/discovery"
4549
"github.com/stacklok/toolhive/pkg/skills"
4650
"github.com/stacklok/toolhive/pkg/skills/gitresolver"
4751
"github.com/stacklok/toolhive/pkg/skills/skillsvc"
@@ -55,6 +59,7 @@ const (
5559
middlewareTimeout = 60 * time.Second
5660
readHeaderTimeout = 10 * time.Second
5761
shutdownTimeout = 30 * time.Second
62+
nonceBytes = 16
5863
socketPermissions = 0660 // Socket file permissions (owner/group read-write)
5964
maxRequestBodySize = 1 << 20 // 1MB - Maximum request body size
6065
)
@@ -65,6 +70,7 @@ type ServerBuilder struct {
6570
isUnixSocket bool
6671
debugMode bool
6772
enableDocs bool
73+
nonce string
6874
oidcConfig *auth.TokenValidatorConfig
6975
middlewares []func(http.Handler) http.Handler
7076
customRoutes map[string]http.Handler
@@ -108,6 +114,14 @@ func (b *ServerBuilder) WithDocs(enableDocs bool) *ServerBuilder {
108114
return b
109115
}
110116

117+
// WithNonce sets the server instance nonce used for discovery verification.
118+
// When non-empty, the server writes a discovery file on startup and returns
119+
// the nonce in the X-Toolhive-Nonce health check header.
120+
func (b *ServerBuilder) WithNonce(nonce string) *ServerBuilder {
121+
b.nonce = nonce
122+
return b
123+
}
124+
111125
// WithOIDCConfig sets the OIDC configuration
112126
func (b *ServerBuilder) WithOIDCConfig(oidcConfig *auth.TokenValidatorConfig) *ServerBuilder {
113127
b.oidcConfig = oidcConfig
@@ -297,7 +311,7 @@ func (b *ServerBuilder) setupDefaultRoutes(r *chi.Mux) {
297311

298312
// All other routes get standard timeout
299313
standardRouters := map[string]http.Handler{
300-
"/health": v1.HealthcheckRouter(b.containerRuntime),
314+
"/health": v1.HealthcheckRouter(b.containerRuntime, b.nonce),
301315
"/api/v1beta/version": v1.VersionRouter(),
302316
"/api/v1beta/registry": v1.RegistryRouter(true),
303317
"/api/v1beta/discovery": v1.DiscoveryRouter(),
@@ -504,6 +518,7 @@ type Server struct {
504518
address string
505519
isUnixSocket bool
506520
addrType string
521+
nonce string
507522
storeCloser io.Closer
508523
}
509524

@@ -532,14 +547,29 @@ func NewServer(ctx context.Context, builder *ServerBuilder) (*Server, error) {
532547
address: builder.address,
533548
isUnixSocket: builder.isUnixSocket,
534549
addrType: addrType,
550+
nonce: builder.nonce,
535551
storeCloser: builder.skillStoreCloser,
536552
}, nil
537553
}
538554

555+
// ListenURL returns the URL where the server is listening, using the actual
556+
// bound address from the listener (important when binding to port 0).
557+
func (s *Server) ListenURL() string {
558+
if s.isUnixSocket {
559+
return fmt.Sprintf("unix://%s", s.address)
560+
}
561+
return fmt.Sprintf("http://%s", s.listener.Addr().String())
562+
}
563+
539564
// Start starts the server and blocks until the context is cancelled
540565
func (s *Server) Start(ctx context.Context) error {
541566
slog.Info("starting server", "type", s.addrType, "address", s.address)
542567

568+
// Write server discovery file so clients can find this instance.
569+
if err := s.writeDiscoveryFile(ctx); err != nil {
570+
return err
571+
}
572+
543573
// Start server in a goroutine
544574
serverErr := make(chan error, 1)
545575
go func() {
@@ -562,6 +592,61 @@ func (s *Server) Start(ctx context.Context) error {
562592
}
563593
}
564594

595+
// writeDiscoveryFile writes the server discovery file if a nonce is configured.
596+
// It checks for an existing healthy server first to prevent silent orphaning.
597+
// The entire check-then-write sequence is wrapped in a file lock to prevent
598+
// TOCTOU races when two servers start simultaneously.
599+
func (s *Server) writeDiscoveryFile(ctx context.Context) error {
600+
if s.nonce == "" {
601+
return nil
602+
}
603+
604+
// Ensure the discovery directory exists before acquiring the lock,
605+
// since the lock file is created in the same directory.
606+
discoveryPath := discovery.FilePath()
607+
if err := os.MkdirAll(filepath.Dir(discoveryPath), 0700); err != nil {
608+
return fmt.Errorf("failed to create discovery directory: %w", err)
609+
}
610+
611+
return fileutils.WithFileLock(discoveryPath, func() error {
612+
// Guard against overwriting another server's discovery file.
613+
result, err := discovery.Discover(ctx)
614+
if err != nil {
615+
slog.Debug("discovery check failed, proceeding with startup", "error", err)
616+
} else {
617+
switch result.State {
618+
case discovery.StateRunning:
619+
return fmt.Errorf("another ToolHive server is already running at %s (PID %d)", result.Info.URL, result.Info.PID)
620+
case discovery.StateStale:
621+
slog.Debug("cleaning up stale discovery file", "pid", result.Info.PID)
622+
if err := discovery.CleanupStale(); err != nil {
623+
slog.Warn("failed to clean up stale discovery file", "error", err)
624+
}
625+
case discovery.StateUnhealthy:
626+
// The process is alive but not responding to health checks.
627+
// This can happen after a crash-restart where the old process
628+
// is hung. We intentionally overwrite the discovery file so
629+
// this new server becomes discoverable.
630+
slog.Warn("existing server is unhealthy, overwriting discovery file", "pid", result.Info.PID)
631+
case discovery.StateNotFound:
632+
// No existing server, proceed normally.
633+
}
634+
}
635+
636+
info := &discovery.ServerInfo{
637+
URL: s.ListenURL(),
638+
PID: os.Getpid(),
639+
Nonce: s.nonce,
640+
StartedAt: time.Now().UTC(),
641+
}
642+
if err := discovery.WriteServerInfo(info); err != nil {
643+
return fmt.Errorf("failed to write discovery file: %w", err)
644+
}
645+
slog.Debug("wrote discovery file", "url", info.URL, "pid", info.PID)
646+
return nil
647+
})
648+
}
649+
565650
// shutdown gracefully shuts down the server
566651
func (s *Server) shutdown() error {
567652
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
@@ -579,6 +664,11 @@ func (s *Server) shutdown() error {
579664

580665
// cleanup performs cleanup operations
581666
func (s *Server) cleanup() {
667+
if s.nonce != "" {
668+
if err := discovery.RemoveServerInfo(); err != nil {
669+
slog.Warn("failed to remove discovery file", "error", err)
670+
}
671+
}
582672
if s.storeCloser != nil {
583673
if err := s.storeCloser.Close(); err != nil {
584674
slog.Warn("failed to close skill store", "error", err)
@@ -653,6 +743,16 @@ func (a *clientPathAdapter) ListSkillSupportingClients() []string {
653743
return result
654744
}
655745

746+
// generateNonce creates a cryptographically random nonce for server instance
747+
// identification. It returns a 32-character hex string (16 random bytes).
748+
func generateNonce() (string, error) {
749+
b := make([]byte, nonceBytes)
750+
if _, err := rand.Read(b); err != nil {
751+
return "", fmt.Errorf("failed to generate server nonce: %w", err)
752+
}
753+
return hex.EncodeToString(b), nil
754+
}
755+
656756
// Serve starts the server on the given address and serves the API.
657757
// It is assumed that the caller sets up appropriate signal handling.
658758
// If isUnixSocket is true, address is treated as a UNIX socket path.
@@ -666,11 +766,17 @@ func Serve(
666766
oidcConfig *auth.TokenValidatorConfig,
667767
middlewares ...func(http.Handler) http.Handler,
668768
) error {
769+
nonce, err := generateNonce()
770+
if err != nil {
771+
return err
772+
}
773+
669774
builder := NewServerBuilder().
670775
WithAddress(address).
671776
WithUnixSocket(isUnixSocket).
672777
WithDebugMode(debugMode).
673778
WithDocs(enableDocs).
779+
WithNonce(nonce).
674780
WithOIDCConfig(oidcConfig).
675781
WithMiddleware(middlewares...)
676782

pkg/api/server_test.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package api
5+
6+
import (
7+
"fmt"
8+
"net"
9+
"regexp"
10+
"testing"
11+
12+
"github.com/stretchr/testify/assert"
13+
"github.com/stretchr/testify/require"
14+
)
15+
16+
func TestGenerateNonce(t *testing.T) {
17+
t.Parallel()
18+
19+
t.Run("returns valid 32-char hex string", func(t *testing.T) {
20+
t.Parallel()
21+
22+
nonce, err := generateNonce()
23+
require.NoError(t, err)
24+
25+
assert.Len(t, nonce, 32)
26+
assert.Regexp(t, regexp.MustCompile(`^[0-9a-f]{32}$`), nonce)
27+
})
28+
29+
t.Run("returns unique values on successive calls", func(t *testing.T) {
30+
t.Parallel()
31+
32+
nonce1, err := generateNonce()
33+
require.NoError(t, err)
34+
35+
nonce2, err := generateNonce()
36+
require.NoError(t, err)
37+
38+
assert.NotEqual(t, nonce1, nonce2)
39+
})
40+
}
41+
42+
func TestListenURL(t *testing.T) {
43+
t.Parallel()
44+
45+
tests := []struct {
46+
name string
47+
server func(t *testing.T) *Server
48+
expected func(s *Server) string
49+
}{
50+
{
51+
name: "TCP returns http URL with actual port",
52+
server: func(t *testing.T) *Server {
53+
t.Helper()
54+
listener, err := net.Listen("tcp", "127.0.0.1:0")
55+
require.NoError(t, err)
56+
t.Cleanup(func() { listener.Close() })
57+
return &Server{
58+
listener: listener,
59+
isUnixSocket: false,
60+
address: "127.0.0.1:0",
61+
}
62+
},
63+
expected: func(s *Server) string {
64+
return fmt.Sprintf("http://%s", s.listener.Addr().String())
65+
},
66+
},
67+
{
68+
name: "Unix socket returns unix URL",
69+
server: func(_ *testing.T) *Server {
70+
return &Server{
71+
isUnixSocket: true,
72+
address: "/tmp/test.sock",
73+
}
74+
},
75+
expected: func(_ *Server) string {
76+
return "unix:///tmp/test.sock"
77+
},
78+
},
79+
}
80+
81+
for _, tt := range tests {
82+
t.Run(tt.name, func(t *testing.T) {
83+
t.Parallel()
84+
s := tt.server(t)
85+
assert.Equal(t, tt.expected(s), s.ListenURL())
86+
})
87+
}
88+
}

0 commit comments

Comments
 (0)