Skip to content

Commit f650a38

Browse files
committed
Tweaks for auth middleware tests: Focused + parallel + conventions
Follows up #390 with a couple tweaks for the tests: * Make auth middleware tests more focused so they only test `authMiddleware` rather than the full server HTTP stack. * Remove use of `t.Setenv` so we can run all tests in parallel. * Use `pathPrefix` in middleware so that only the specific configured prefix is accepted rather than any prefix that might be present. * Use `CamelCaseTestName` convention rather than "test name like this". * Use `testBundle` convention for `setup`.
1 parent 9ae67f5 commit f650a38

3 files changed

Lines changed: 61 additions & 48 deletions

File tree

cmd/riverui/auth_middleware.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@ import (
77
)
88

99
type authMiddleware struct {
10-
username string
11-
password string
10+
password string
11+
pathPrefix string // HTTP path prefix
12+
username string
1213
}
1314

1415
func (m *authMiddleware) Middleware(next http.Handler) http.Handler {
1516
return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
16-
if isReqAuthorized(req, m.username, m.password) {
17+
if m.isReqAuthorized(req) {
1718
next.ServeHTTP(res, req)
1819
return
1920
}
@@ -23,13 +24,13 @@ func (m *authMiddleware) Middleware(next http.Handler) http.Handler {
2324
})
2425
}
2526

26-
func isReqAuthorized(req *http.Request, username, password string) bool {
27-
reqUsername, reqPassword, ok := req.BasicAuth()
28-
29-
isHealthCheck := strings.Contains(req.URL.Path, "/api/health-checks/")
30-
isValidAuth := ok &&
31-
subtle.ConstantTimeCompare([]byte(reqUsername), []byte(username)) == 1 &&
32-
subtle.ConstantTimeCompare([]byte(reqPassword), []byte(password)) == 1
27+
func (m *authMiddleware) isReqAuthorized(req *http.Request) bool {
28+
if strings.HasPrefix(req.URL.Path, m.pathPrefix+"/api/health-checks/") {
29+
return true
30+
}
3331

34-
return isHealthCheck || isValidAuth
32+
reqUsername, reqPassword, ok := req.BasicAuth()
33+
return ok &&
34+
subtle.ConstantTimeCompare([]byte(reqUsername), []byte(m.username)) == 1 &&
35+
subtle.ConstantTimeCompare([]byte(reqPassword), []byte(m.password)) == 1
3536
}
Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,91 @@
11
package main
22

33
import (
4-
"cmp"
54
"context"
65
"net/http"
76
"net/http/httptest"
8-
"os"
97
"testing"
108

119
"github.com/stretchr/testify/require"
1210

13-
"github.com/riverqueue/river/rivershared/riversharedtest"
11+
"github.com/riverqueue/apiframe/apimiddleware"
1412
)
1513

1614
func TestAuthMiddleware(t *testing.T) {
17-
var (
18-
ctx = context.Background()
19-
databaseURL = cmp.Or(os.Getenv("TEST_DATABASE_URL"), "postgres://localhost/river_test")
20-
basicAuthUser = "test_auth_user"
15+
t.Parallel()
16+
17+
const (
2118
basicAuthPassword = "test_auth_pass"
19+
basicAuthUsername = "test_auth_user"
2220
)
2321

24-
t.Setenv("DEV", "true")
25-
t.Setenv("DATABASE_URL", databaseURL)
26-
t.Setenv("RIVER_BASIC_AUTH_USER", basicAuthUser)
27-
t.Setenv("RIVER_BASIC_AUTH_PASS", basicAuthPassword)
22+
ctx := context.Background()
23+
24+
type testBundle struct {
25+
handler http.Handler
26+
}
2827

29-
setup := func(t *testing.T, prefix string) http.Handler {
28+
setup := func(t *testing.T) (*authMiddleware, *testBundle) {
3029
t.Helper()
31-
initRes, err := initServer(ctx, riversharedtest.Logger(t), prefix)
32-
require.NoError(t, err)
33-
t.Cleanup(initRes.dbPool.Close)
3430

35-
return initRes.httpServer.Handler
31+
authMiddleware := &authMiddleware{username: basicAuthUsername, password: basicAuthPassword}
32+
33+
return authMiddleware, &testBundle{
34+
handler: apimiddleware.NewMiddlewareStack(
35+
authMiddleware,
36+
).Mount(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
37+
w.WriteHeader(http.StatusOK)
38+
})),
39+
}
3640
}
3741

38-
t.Run("Unauthorized", func(t *testing.T) { //nolint:paralleltest
39-
handler := setup(t, "/")
40-
req := httptest.NewRequest(http.MethodGet, "/api/jobs", nil)
41-
recorder := httptest.NewRecorder()
42+
t.Run("Unauthorized", func(t *testing.T) {
43+
t.Parallel()
44+
45+
_, bundle := setup(t)
4246

43-
handler.ServeHTTP(recorder, req)
47+
req := httptest.NewRequestWithContext(ctx, http.MethodGet, "/api/jobs", nil)
4448

49+
recorder := httptest.NewRecorder()
50+
bundle.handler.ServeHTTP(recorder, req)
4551
require.Equal(t, http.StatusUnauthorized, recorder.Code)
4652
})
4753

4854
t.Run("Authorized", func(t *testing.T) { //nolint:paralleltest
49-
handler := setup(t, "/")
50-
req := httptest.NewRequest(http.MethodGet, "/api/jobs", nil)
51-
req.SetBasicAuth(basicAuthUser, basicAuthPassword)
55+
t.Parallel()
5256

53-
recorder := httptest.NewRecorder()
57+
_, bundle := setup(t)
5458

55-
handler.ServeHTTP(recorder, req)
59+
req := httptest.NewRequestWithContext(ctx, http.MethodGet, "/api/jobs", nil)
60+
req.SetBasicAuth(basicAuthUsername, basicAuthPassword)
5661

62+
recorder := httptest.NewRecorder()
63+
bundle.handler.ServeHTTP(recorder, req)
5764
require.Equal(t, http.StatusOK, recorder.Code)
5865
})
5966

60-
t.Run("Healthcheck exemption", func(t *testing.T) { //nolint:paralleltest
61-
handler := setup(t, "/")
62-
req := httptest.NewRequest(http.MethodGet, "/api/health-checks/complete", nil)
63-
recorder := httptest.NewRecorder()
67+
t.Run("HealthCheckExemption", func(t *testing.T) { //nolint:paralleltest
68+
t.Parallel()
6469

65-
handler.ServeHTTP(recorder, req)
70+
_, bundle := setup(t)
6671

72+
req := httptest.NewRequestWithContext(ctx, http.MethodGet, "/api/health-checks/complete", nil)
73+
74+
recorder := httptest.NewRecorder()
75+
bundle.handler.ServeHTTP(recorder, req)
6776
require.Equal(t, http.StatusOK, recorder.Code)
6877
})
6978

70-
t.Run("Healthcheck exemption with prefix", func(t *testing.T) { //nolint:paralleltest
71-
handler := setup(t, "/test-prefix")
72-
req := httptest.NewRequest(http.MethodGet, "/test-prefix/api/health-checks/complete", nil)
73-
recorder := httptest.NewRecorder()
79+
t.Run("HealthCheckExemptionWithPrefix", func(t *testing.T) { //nolint:paralleltest
80+
t.Parallel()
81+
82+
middleware, bundle := setup(t)
83+
middleware.pathPrefix = "/test-prefix"
7484

75-
handler.ServeHTTP(recorder, req)
85+
req := httptest.NewRequestWithContext(ctx, http.MethodGet, "/test-prefix/api/health-checks/complete", nil)
7686

87+
recorder := httptest.NewRecorder()
88+
bundle.handler.ServeHTTP(recorder, req)
7789
require.Equal(t, http.StatusOK, recorder.Code)
7890
})
7991
}

cmd/riverui/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ func initServer(ctx context.Context, logger *slog.Logger, pathPrefix string) (*i
163163
apimiddleware.MiddlewareFunc(logHandler),
164164
)
165165
if basicAuthUsername != "" && basicAuthPassword != "" {
166-
middlewareStack.Use(&authMiddleware{username: basicAuthUsername, password: basicAuthPassword})
166+
middlewareStack.Use(&authMiddleware{username: basicAuthUsername, password: basicAuthPassword, pathPrefix: pathPrefix})
167167
}
168168

169169
return &initServerResult{

0 commit comments

Comments
 (0)