Skip to content

Commit fc4528b

Browse files
committed
protect against time of check vs use race condition
1 parent d5bb33e commit fc4528b

4 files changed

Lines changed: 195 additions & 16 deletions

File tree

cmd/utils.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,19 @@ func loadStack(cfg *config.Config, branch string) (*loadStackResult, error) {
138138
}
139139

140140
// handleSaveError translates a stack.Save error into the appropriate user
141-
// message and exit error. Lock contention returns ErrLockFailed (exit 8);
142-
// other write failures return ErrSilent (exit 1).
141+
// message and exit error. Lock contention and stale-file detection both
142+
// return ErrLockFailed (exit 8); other write failures return ErrSilent (exit 1).
143143
func handleSaveError(cfg *config.Config, err error) error {
144144
var lockErr *stack.LockError
145145
if errors.As(err, &lockErr) {
146146
cfg.Errorf("another process is currently editing the stack — try again later")
147147
return ErrLockFailed
148148
}
149+
var staleErr *stack.StaleError
150+
if errors.As(err, &staleErr) {
151+
cfg.Errorf("stack file was modified by another process — please re-run the command")
152+
return ErrLockFailed
153+
}
149154
cfg.Errorf("failed to save stack state: %s", err)
150155
return ErrSilent
151156
}

internal/stack/lock.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,20 @@ type LockError struct {
1919
func (e *LockError) Error() string { return e.Err.Error() }
2020
func (e *LockError) Unwrap() error { return e.Err }
2121

22+
// StaleError is returned when the stack file was modified on disk since it
23+
// was loaded. This indicates another process wrote to the file concurrently.
24+
// Callers can check for this with errors.As.
25+
type StaleError struct {
26+
Err error
27+
}
28+
29+
func (e *StaleError) Error() string { return e.Err.Error() }
30+
func (e *StaleError) Unwrap() error { return e.Err }
31+
2232
// LockTimeout is how long Lock() will wait for the exclusive lock before
2333
// giving up. With the lock held only during file writes (milliseconds),
2434
// this timeout primarily guards against a hung process holding the lock.
25-
const LockTimeout = 5 * time.Second
35+
var LockTimeout = 5 * time.Second
2636

2737
// lockRetryInterval is the sleep between non-blocking lock attempts.
2838
const lockRetryInterval = 100 * time.Millisecond

internal/stack/lock_test.go

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package stack
22

33
import (
4+
"errors"
45
"fmt"
56
"os"
67
"path/filepath"
@@ -77,7 +78,7 @@ func TestLock_SerializesConcurrentAccess(t *testing.T) {
7778
require.NoError(t, Save(dir, sf))
7879

7980
// Run 10 concurrent goroutines, each adding a stack under lock.
80-
// Uses Lock + Load + SaveLocked for atomic read-modify-write.
81+
// Uses Lock + Load + writeStackFile for atomic read-modify-write.
8182
errCh := make(chan error, 10)
8283
var wg sync.WaitGroup
8384
for i := 0; i < 10; i++ {
@@ -99,8 +100,8 @@ func TestLock_SerializesConcurrentAccess(t *testing.T) {
99100
}
100101

101102
loaded.AddStack(makeStack("main", "branch"))
102-
if err := SaveLocked(dir, loaded); err != nil {
103-
errCh <- fmt.Errorf("goroutine %d SaveLocked: %w", idx, err)
103+
if err := writeStackFile(dir, loaded); err != nil {
104+
errCh <- fmt.Errorf("goroutine %d writeStackFile: %w", idx, err)
104105
}
105106
}(i)
106107
}
@@ -132,3 +133,112 @@ func TestLock_FileLeftOnDisk(t *testing.T) {
132133
require.NoError(t, err, "should be able to re-lock after unlock")
133134
lock2.Unlock()
134135
}
136+
137+
func TestLock_TimesOut(t *testing.T) {
138+
dir := t.TempDir()
139+
140+
// Hold the lock so the second attempt can never acquire it.
141+
lock1, err := Lock(dir)
142+
require.NoError(t, err)
143+
defer lock1.Unlock()
144+
145+
// Save original timeout and set a short one for the test.
146+
origTimeout := LockTimeout
147+
LockTimeout = 200 * time.Millisecond
148+
defer func() { LockTimeout = origTimeout }()
149+
150+
start := time.Now()
151+
lock2, err := Lock(dir)
152+
elapsed := time.Since(start)
153+
154+
assert.Nil(t, lock2, "should not acquire lock")
155+
require.Error(t, err)
156+
157+
var lockErr *LockError
158+
require.True(t, errors.As(err, &lockErr), "error should be *LockError, got %T", err)
159+
assert.Contains(t, lockErr.Error(), "timed out")
160+
161+
// Should have waited roughly LockTimeout before giving up.
162+
assert.GreaterOrEqual(t, elapsed, 150*time.Millisecond, "should wait near the timeout")
163+
}
164+
165+
func TestSave_DetectsStaleFile(t *testing.T) {
166+
dir := t.TempDir()
167+
168+
// Write an initial stack file.
169+
sf := &StackFile{SchemaVersion: 1, Stacks: []Stack{}}
170+
require.NoError(t, Save(dir, sf))
171+
172+
// Load — captures the on-disk checksum.
173+
loaded, err := Load(dir)
174+
require.NoError(t, err)
175+
176+
// Simulate another process: load, modify, save.
177+
other, err := Load(dir)
178+
require.NoError(t, err)
179+
other.AddStack(makeStack("main", "sneaky"))
180+
require.NoError(t, Save(dir, other))
181+
182+
// Our loaded copy tries to save — should detect staleness.
183+
loaded.AddStack(makeStack("main", "my-branch"))
184+
err = Save(dir, loaded)
185+
require.Error(t, err)
186+
187+
var staleErr *StaleError
188+
require.True(t, errors.As(err, &staleErr), "error should be *StaleError, got %T", err)
189+
assert.Contains(t, staleErr.Error(), "modified by another process")
190+
}
191+
192+
func TestSave_AllowsWriteWhenFileUnchanged(t *testing.T) {
193+
dir := t.TempDir()
194+
195+
// Write, load, modify, save — no concurrent changes.
196+
sf := &StackFile{SchemaVersion: 1, Stacks: []Stack{}}
197+
require.NoError(t, Save(dir, sf))
198+
199+
loaded, err := Load(dir)
200+
require.NoError(t, err)
201+
202+
loaded.AddStack(makeStack("main", "feature"))
203+
require.NoError(t, Save(dir, loaded))
204+
205+
// Verify the write actually persisted.
206+
final, err := Load(dir)
207+
require.NoError(t, err)
208+
assert.Len(t, final.Stacks, 1)
209+
}
210+
211+
func TestSave_AllowsFirstWrite(t *testing.T) {
212+
dir := t.TempDir()
213+
214+
// File doesn't exist — Load returns nil checksum, Save should succeed.
215+
sf, err := Load(dir)
216+
require.NoError(t, err)
217+
assert.Empty(t, sf.Stacks)
218+
219+
sf.AddStack(makeStack("main", "first"))
220+
require.NoError(t, Save(dir, sf), "first save to a new file should succeed")
221+
222+
final, err := Load(dir)
223+
require.NoError(t, err)
224+
assert.Len(t, final.Stacks, 1)
225+
}
226+
227+
func TestSave_DoubleSaveSucceeds(t *testing.T) {
228+
dir := t.TempDir()
229+
230+
sf, err := Load(dir)
231+
require.NoError(t, err)
232+
233+
sf.AddStack(makeStack("main", "first"))
234+
require.NoError(t, Save(dir, sf), "first save should succeed")
235+
236+
// A second Save on the same instance must not spuriously fail —
237+
// writeStackFile refreshes loadChecksum after writing.
238+
sf.AddStack(makeStack("main", "second"))
239+
require.NoError(t, Save(dir, sf), "second save on same instance should succeed")
240+
241+
final, err := Load(dir)
242+
require.NoError(t, err)
243+
assert.Len(t, final.Stacks, 2)
244+
}

internal/stack/stack.go

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package stack
22

33
import (
4+
"bytes"
5+
"crypto/sha256"
46
"encoding/json"
57
"errors"
68
"fmt"
@@ -168,6 +170,11 @@ type StackFile struct {
168170
SchemaVersion int `json:"schemaVersion"`
169171
Repository string `json:"repository"`
170172
Stacks []Stack `json:"stacks"`
173+
174+
// loadChecksum is the SHA-256 of the raw file bytes at Load time.
175+
// Save uses it to detect concurrent modifications (optimistic concurrency).
176+
// nil means the file did not exist when loaded.
177+
loadChecksum []byte
171178
}
172179

173180
// FindAllStacksForBranch returns all stacks that contain the given branch.
@@ -233,11 +240,14 @@ func stackFilePath(gitDir string) string {
233240

234241
// Load reads the stack file from the given git directory.
235242
// Returns an empty StackFile if the file does not exist.
243+
// The returned StackFile records a checksum of the on-disk content so that
244+
// Save can detect concurrent modifications.
236245
func Load(gitDir string) (*StackFile, error) {
237246
path := stackFilePath(gitDir)
238247
data, err := os.ReadFile(path)
239248
if err != nil {
240249
if errors.Is(err, os.ErrNotExist) {
250+
// loadChecksum stays nil — sentinel for "file absent at load time".
241251
return &StackFile{
242252
SchemaVersion: schemaVersion,
243253
Stacks: []Stack{},
@@ -255,24 +265,32 @@ func Load(gitDir string) (*StackFile, error) {
255265
return nil, fmt.Errorf("stack file has schema version %d, but this version of gh-stack only supports up to version %d — please upgrade gh-stack", sf.SchemaVersion, schemaVersion)
256266
}
257267

268+
sum := sha256.Sum256(data)
269+
sf.loadChecksum = sum[:]
258270
return &sf, nil
259271
}
260272

261-
// Save acquires an exclusive lock on the stack file, writes sf as JSON, and
262-
// releases the lock. The lock is held only for the duration of the write.
263-
// Returns *LockError if the lock times out due to contention.
273+
// Save acquires an exclusive lock on the stack file, verifies the file hasn't
274+
// been modified since Load (optimistic concurrency), writes sf as JSON, and
275+
// releases the lock. The lock is held only for the read-compare-write window.
276+
// Returns *LockError if the lock times out, or *StaleError if another process
277+
// modified the file since it was loaded.
264278
func Save(gitDir string, sf *StackFile) error {
265279
lock, err := Lock(gitDir)
266280
if err != nil {
267281
return err // *LockError for contention, plain error for I/O failures
268282
}
269283
defer lock.Unlock()
284+
285+
if err := checkStale(gitDir, sf); err != nil {
286+
return err
287+
}
270288
return writeStackFile(gitDir, sf)
271289
}
272290

273291
// SaveNonBlocking attempts to save without blocking. If another process holds
274-
// the lock, the save is silently skipped. Use this for best-effort metadata
275-
// persistence (e.g. syncing PR state in view).
292+
// the lock or the file was modified since Load, the save is silently skipped.
293+
// Use this for best-effort metadata persistence (e.g. syncing PR state in view).
276294
func SaveNonBlocking(gitDir string, sf *StackFile) {
277295
path := filepath.Join(gitDir, lockFileName)
278296
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0644)
@@ -285,14 +303,46 @@ func SaveNonBlocking(gitDir string, sf *StackFile) {
285303
}
286304
lock := &FileLock{f: f}
287305
defer lock.Unlock()
306+
307+
if checkStale(gitDir, sf) != nil {
308+
return
309+
}
288310
_ = writeStackFile(gitDir, sf)
289311
}
290312

291-
// SaveLocked writes the stack file without acquiring the lock. The caller
292-
// must already hold the lock (via Lock) to protect the write. Use this when
293-
// you need an atomic Load-Modify-Save sequence.
294-
func SaveLocked(gitDir string, sf *StackFile) error {
295-
return writeStackFile(gitDir, sf)
313+
// checkStale compares the current on-disk content against the checksum
314+
// captured at Load time. Returns *StaleError if the file was modified
315+
// by another process. The caller must hold the lock.
316+
func checkStale(gitDir string, sf *StackFile) error {
317+
path := stackFilePath(gitDir)
318+
data, err := os.ReadFile(path)
319+
320+
if errors.Is(err, os.ErrNotExist) {
321+
// File absent on disk.
322+
if sf.loadChecksum == nil {
323+
return nil // was absent at Load time too — no conflict
324+
}
325+
// File existed at Load but is now gone. Allow the write to
326+
// recreate it rather than erroring; this is not a lost-update.
327+
return nil
328+
}
329+
if err != nil {
330+
return fmt.Errorf("reading stack file for staleness check: %w", err)
331+
}
332+
333+
// File exists on disk.
334+
if sf.loadChecksum == nil {
335+
// File was absent at Load but another process created it.
336+
return &StaleError{Err: fmt.Errorf(
337+
"stack file was created by another process since it was loaded")}
338+
}
339+
340+
sum := sha256.Sum256(data)
341+
if !bytes.Equal(sf.loadChecksum, sum[:]) {
342+
return &StaleError{Err: fmt.Errorf(
343+
"stack file was modified by another process since it was loaded")}
344+
}
345+
return nil
296346
}
297347

298348
func writeStackFile(gitDir string, sf *StackFile) error {
@@ -308,5 +358,9 @@ func writeStackFile(gitDir string, sf *StackFile) error {
308358
if err := os.WriteFile(path, data, 0644); err != nil {
309359
return fmt.Errorf("writing stack file: %w", err)
310360
}
361+
// Refresh checksum so a second Save on the same StackFile doesn't
362+
// spuriously fail the staleness check.
363+
sum := sha256.Sum256(data)
364+
sf.loadChecksum = sum[:]
311365
return nil
312366
}

0 commit comments

Comments
 (0)