Skip to content

Commit 1e211df

Browse files
authored
fix: BED-4600 - Add Request Timeout (#188)
* BED-4600: Add page request timeout to avoid indefinitely hanging on a failed call
1 parent f757f5f commit 1e211df

2 files changed

Lines changed: 154 additions & 7 deletions

File tree

client/client.go

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"fmt"
2626
"net/http"
2727
"net/url"
28+
"time"
2829

2930
"github.com/bloodhoundad/azurehound/v2/client/config"
3031
"github.com/bloodhoundad/azurehound/v2/client/query"
@@ -34,6 +35,9 @@ import (
3435
"github.com/bloodhoundad/azurehound/v2/pipeline"
3536
)
3637

38+
// This prevents a hung connection from blocking the entire collection pipeline indefinitely.
39+
var pageRequestTimeout = 2 * time.Minute
40+
3741
func NewClient(config config.Config) (AzureClient, error) {
3842
if msgraph, err := rest.NewRestClient(config.GraphUrl(), config); err != nil {
3943
return nil, err
@@ -118,8 +122,13 @@ func getAzureObjectList[T any](client rest.RestClient, ctx context.Context, path
118122
err error
119123
)
120124

125+
// Create a per-page timeout so a single hung API response cannot block
126+
// the entire collection pipeline indefinitely.
127+
pageCtx, pageCancel := context.WithTimeout(ctx, pageRequestTimeout)
128+
121129
if nextLink != "" {
122130
if nextUrl, err := url.Parse(nextLink); err != nil {
131+
pageCancel()
123132
errResult.Error = err
124133
_ = pipeline.Send(ctx.Done(), out, errResult)
125134
return
@@ -128,33 +137,40 @@ func getAzureObjectList[T any](client rest.RestClient, ctx context.Context, path
128137
if params != nil {
129138
paramsMap = params.AsMap()
130139
}
131-
if req, err := rest.NewRequest(ctx, "GET", nextUrl, nil, paramsMap, nil); err != nil {
140+
if req, err := rest.NewRequest(pageCtx, "GET", nextUrl, nil, paramsMap, nil); err != nil {
141+
pageCancel()
132142
errResult.Error = err
133143
_ = pipeline.Send(ctx.Done(), out, errResult)
134144
return
135145
} else if res, err = client.Send(req); err != nil {
146+
pageCancel()
136147
errResult.Error = err
137148
_ = pipeline.Send(ctx.Done(), out, errResult)
138149
return
139150
}
140151
}
141152
} else {
142-
if res, err = client.Get(ctx, path, params, nil); err != nil {
153+
if res, err = client.Get(pageCtx, path, params, nil); err != nil {
154+
pageCancel()
143155
errResult.Error = err
144156
_ = pipeline.Send(ctx.Done(), out, errResult)
145157
return
146158
}
147159
}
148160

149161
if err := rest.Decode(res.Body, &list); err != nil {
162+
pageCancel()
150163
errResult.Error = err
151164
_ = pipeline.Send(ctx.Done(), out, errResult)
152165
return
153-
} else {
154-
for _, u := range list.Value {
155-
if ok := pipeline.Send(ctx.Done(), out, AzureResult[T]{Ok: u}); !ok {
156-
return
157-
}
166+
}
167+
168+
// Page fetch complete; release the timeout context
169+
pageCancel()
170+
171+
for _, u := range list.Value {
172+
if ok := pipeline.Send(ctx.Done(), out, AzureResult[T]{Ok: u}); !ok {
173+
return
158174
}
159175
}
160176

client/client_test.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
package client
2+
3+
import (
4+
"context"
5+
"io"
6+
"net/http"
7+
"strings"
8+
"testing"
9+
"time"
10+
11+
"github.com/bloodhoundad/azurehound/v2/client/query"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
// fakeRestClient is a minimal test double for rest.RestClient that allows
16+
// controlling the Get response per test case.
17+
type fakeRestClient struct {
18+
getFunc func(ctx context.Context, path string, params query.Params, headers map[string]string) (*http.Response, error)
19+
}
20+
21+
func (s *fakeRestClient) Get(ctx context.Context, path string, params query.Params, headers map[string]string) (*http.Response, error) {
22+
return s.getFunc(ctx, path, params, headers)
23+
}
24+
25+
func (s *fakeRestClient) Delete(context.Context, string, interface{}, query.Params, map[string]string) (*http.Response, error) {
26+
return nil, nil
27+
}
28+
func (s *fakeRestClient) Patch(context.Context, string, interface{}, query.Params, map[string]string) (*http.Response, error) {
29+
return nil, nil
30+
}
31+
func (s *fakeRestClient) Post(context.Context, string, interface{}, query.Params, map[string]string) (*http.Response, error) {
32+
return nil, nil
33+
}
34+
func (s *fakeRestClient) Put(context.Context, string, interface{}, query.Params, map[string]string) (*http.Response, error) {
35+
return nil, nil
36+
}
37+
func (s *fakeRestClient) Send(req *http.Request) (*http.Response, error) { return nil, nil }
38+
func (s *fakeRestClient) AddAuthenticationToRequest(req *http.Request) (*http.Request, error) {
39+
return req, nil
40+
}
41+
func (s *fakeRestClient) CloseIdleConnections() {}
42+
43+
func TestGetAzureObjectList_SuccessfulResponse(t *testing.T) {
44+
body := `{"value": [{"id": "1"}, {"id": "2"}]}`
45+
client := &fakeRestClient{
46+
getFunc: func(ctx context.Context, path string, params query.Params, headers map[string]string) (*http.Response, error) {
47+
return &http.Response{
48+
StatusCode: http.StatusOK,
49+
Body: io.NopCloser(strings.NewReader(body)),
50+
}, nil
51+
},
52+
}
53+
54+
out := make(chan AzureResult[map[string]string])
55+
go getAzureObjectList(client, context.Background(), "/test/path", nil, out)
56+
57+
var results []map[string]string
58+
for result := range out {
59+
require.NoError(t, result.Error)
60+
results = append(results, result.Ok)
61+
}
62+
63+
require.Len(t, results, 2)
64+
require.Equal(t, "1", results[0]["id"])
65+
require.Equal(t, "2", results[1]["id"])
66+
}
67+
68+
func TestGetAzureObjectList_HungResponseTimesOut(t *testing.T) {
69+
// Shorten the timeout so the test completes quickly
70+
original := pageRequestTimeout
71+
pageRequestTimeout = 500 * time.Millisecond
72+
defer func() { pageRequestTimeout = original }()
73+
74+
client := &fakeRestClient{
75+
getFunc: func(ctx context.Context, path string, params query.Params, headers map[string]string) (*http.Response, error) {
76+
// Verify the context has a deadline (set by pageRequestTimeout)
77+
_, hasDeadline := ctx.Deadline()
78+
require.True(t, hasDeadline, "expected context passed to Get to have a deadline from pageRequestTimeout")
79+
80+
// Simulate a hung connection: block until the context expires
81+
<-ctx.Done()
82+
return nil, ctx.Err()
83+
},
84+
}
85+
86+
out := make(chan AzureResult[map[string]string])
87+
go getAzureObjectList(client, context.Background(), "/test/path", nil, out)
88+
89+
// The channel should produce an error and close well within a few seconds
90+
select {
91+
case result, ok := <-out:
92+
require.True(t, ok, "expected a value on out, channel closed")
93+
require.Error(t, result.Error, "expected an error result from timed-out request")
94+
require.ErrorIs(t, result.Error, context.DeadlineExceeded)
95+
case <-time.After(5 * time.Second):
96+
t.Fatal("getAzureObjectList did not return within expected timeout; pipeline is hung")
97+
}
98+
99+
// drain and ensure channel closes
100+
for range out {
101+
}
102+
}
103+
104+
func TestGetAzureObjectList_ParentContextCanceled(t *testing.T) {
105+
ctx, cancel := context.WithCancel(context.Background())
106+
107+
client := &fakeRestClient{
108+
getFunc: func(ctx context.Context, path string, params query.Params, headers map[string]string) (*http.Response, error) {
109+
<-ctx.Done()
110+
return nil, ctx.Err()
111+
},
112+
}
113+
114+
out := make(chan AzureResult[map[string]string])
115+
go getAzureObjectList(client, ctx, "/test/path", nil, out)
116+
117+
// Cancel the parent context after a short delay
118+
time.AfterFunc(100*time.Millisecond, cancel)
119+
120+
select {
121+
case result, ok := <-out:
122+
if ok {
123+
require.Error(t, result.Error)
124+
}
125+
case <-time.After(5 * time.Second):
126+
t.Fatal("getAzureObjectList did not respect parent context cancellation")
127+
}
128+
129+
for range out {
130+
}
131+
}

0 commit comments

Comments
 (0)