|
1 | 1 | package main |
2 | 2 |
|
3 | 3 | import ( |
4 | | - "cmp" |
5 | 4 | "context" |
6 | 5 | "net/http" |
7 | 6 | "net/http/httptest" |
8 | | - "os" |
9 | 7 | "testing" |
10 | 8 |
|
11 | 9 | "github.com/stretchr/testify/require" |
12 | 10 |
|
13 | | - "github.com/riverqueue/river/rivershared/riversharedtest" |
| 11 | + "github.com/riverqueue/apiframe/apimiddleware" |
14 | 12 | ) |
15 | 13 |
|
16 | 14 | func TestAuthMiddleware(t *testing.T) { |
17 | | - var ( |
18 | | - ctx = context.Background() |
19 | | - databaseURL = cmp.Or(os.Getenv("TEST_DATABASE_URL"), "postgres://localhost/river_test") |
20 | | - basicAuthUser = "test_auth_user" |
| 15 | + t.Parallel() |
| 16 | + |
| 17 | + const ( |
21 | 18 | basicAuthPassword = "test_auth_pass" |
| 19 | + basicAuthUsername = "test_auth_user" |
22 | 20 | ) |
23 | 21 |
|
24 | | - t.Setenv("DEV", "true") |
25 | | - t.Setenv("DATABASE_URL", databaseURL) |
26 | | - t.Setenv("RIVER_BASIC_AUTH_USER", basicAuthUser) |
27 | | - t.Setenv("RIVER_BASIC_AUTH_PASS", basicAuthPassword) |
| 22 | + ctx := context.Background() |
| 23 | + |
| 24 | + type testBundle struct { |
| 25 | + handler http.Handler |
| 26 | + } |
28 | 27 |
|
29 | | - setup := func(t *testing.T, prefix string) http.Handler { |
| 28 | + setup := func(t *testing.T) (*authMiddleware, *testBundle) { |
30 | 29 | t.Helper() |
31 | | - initRes, err := initServer(ctx, riversharedtest.Logger(t), prefix) |
32 | | - require.NoError(t, err) |
33 | | - t.Cleanup(initRes.dbPool.Close) |
34 | 30 |
|
35 | | - return initRes.httpServer.Handler |
| 31 | + authMiddleware := &authMiddleware{username: basicAuthUsername, password: basicAuthPassword} |
| 32 | + |
| 33 | + return authMiddleware, &testBundle{ |
| 34 | + handler: apimiddleware.NewMiddlewareStack( |
| 35 | + authMiddleware, |
| 36 | + ).Mount(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 37 | + w.WriteHeader(http.StatusOK) |
| 38 | + })), |
| 39 | + } |
36 | 40 | } |
37 | 41 |
|
38 | | - t.Run("Unauthorized", func(t *testing.T) { //nolint:paralleltest |
39 | | - handler := setup(t, "/") |
40 | | - req := httptest.NewRequest(http.MethodGet, "/api/jobs", nil) |
41 | | - recorder := httptest.NewRecorder() |
| 42 | + t.Run("Unauthorized", func(t *testing.T) { |
| 43 | + t.Parallel() |
| 44 | + |
| 45 | + _, bundle := setup(t) |
42 | 46 |
|
43 | | - handler.ServeHTTP(recorder, req) |
| 47 | + req := httptest.NewRequestWithContext(ctx, http.MethodGet, "/api/jobs", nil) |
44 | 48 |
|
| 49 | + recorder := httptest.NewRecorder() |
| 50 | + bundle.handler.ServeHTTP(recorder, req) |
45 | 51 | require.Equal(t, http.StatusUnauthorized, recorder.Code) |
46 | 52 | }) |
47 | 53 |
|
48 | 54 | t.Run("Authorized", func(t *testing.T) { //nolint:paralleltest |
49 | | - handler := setup(t, "/") |
50 | | - req := httptest.NewRequest(http.MethodGet, "/api/jobs", nil) |
51 | | - req.SetBasicAuth(basicAuthUser, basicAuthPassword) |
| 55 | + t.Parallel() |
52 | 56 |
|
53 | | - recorder := httptest.NewRecorder() |
| 57 | + _, bundle := setup(t) |
54 | 58 |
|
55 | | - handler.ServeHTTP(recorder, req) |
| 59 | + req := httptest.NewRequestWithContext(ctx, http.MethodGet, "/api/jobs", nil) |
| 60 | + req.SetBasicAuth(basicAuthUsername, basicAuthPassword) |
56 | 61 |
|
| 62 | + recorder := httptest.NewRecorder() |
| 63 | + bundle.handler.ServeHTTP(recorder, req) |
57 | 64 | require.Equal(t, http.StatusOK, recorder.Code) |
58 | 65 | }) |
59 | 66 |
|
60 | | - t.Run("Healthcheck exemption", func(t *testing.T) { //nolint:paralleltest |
61 | | - handler := setup(t, "/") |
62 | | - req := httptest.NewRequest(http.MethodGet, "/api/health-checks/complete", nil) |
63 | | - recorder := httptest.NewRecorder() |
| 67 | + t.Run("HealthCheckExemption", func(t *testing.T) { //nolint:paralleltest |
| 68 | + t.Parallel() |
64 | 69 |
|
65 | | - handler.ServeHTTP(recorder, req) |
| 70 | + _, bundle := setup(t) |
66 | 71 |
|
| 72 | + req := httptest.NewRequestWithContext(ctx, http.MethodGet, "/api/health-checks/complete", nil) |
| 73 | + |
| 74 | + recorder := httptest.NewRecorder() |
| 75 | + bundle.handler.ServeHTTP(recorder, req) |
67 | 76 | require.Equal(t, http.StatusOK, recorder.Code) |
68 | 77 | }) |
69 | 78 |
|
70 | | - t.Run("Healthcheck exemption with prefix", func(t *testing.T) { //nolint:paralleltest |
71 | | - handler := setup(t, "/test-prefix") |
72 | | - req := httptest.NewRequest(http.MethodGet, "/test-prefix/api/health-checks/complete", nil) |
73 | | - recorder := httptest.NewRecorder() |
| 79 | + t.Run("HealthCheckExemptionWithPrefix", func(t *testing.T) { //nolint:paralleltest |
| 80 | + t.Parallel() |
| 81 | + |
| 82 | + middleware, bundle := setup(t) |
| 83 | + middleware.pathPrefix = "/test-prefix" |
74 | 84 |
|
75 | | - handler.ServeHTTP(recorder, req) |
| 85 | + req := httptest.NewRequestWithContext(ctx, http.MethodGet, "/test-prefix/api/health-checks/complete", nil) |
76 | 86 |
|
| 87 | + recorder := httptest.NewRecorder() |
| 88 | + bundle.handler.ServeHTTP(recorder, req) |
77 | 89 | require.Equal(t, http.StatusOK, recorder.Code) |
78 | 90 | }) |
79 | 91 | } |
0 commit comments