Skip to content

Commit ffff69f

Browse files
authored
Merge pull request #29 from github/adds-panic-handler
Extends the Function Stage to Handle Panics
2 parents 951775d + c0d12d7 commit ffff69f

4 files changed

Lines changed: 110 additions & 32 deletions

File tree

pipe/function.go

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,29 +32,41 @@ func Function(name string, f StageFunc) Stage {
3232
// goStage is a `Stage` that does its work by running an arbitrary
3333
// `stageFunc` in a goroutine.
3434
type goStage struct {
35-
name string
36-
f StageFunc
37-
done chan struct{}
38-
err error
35+
name string
36+
f StageFunc
37+
done chan struct{}
38+
err error
39+
panicHandler StagePanicHandler
3940
}
4041

4142
func (s *goStage) Name() string {
4243
return s.name
4344
}
4445

46+
func (s *goStage) SetPanicHandler(ph StagePanicHandler) {
47+
s.panicHandler = ph
48+
}
49+
4550
func (s *goStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) {
4651
r, w := io.Pipe()
52+
4753
go func() {
48-
s.err = s.f(ctx, env, stdin, w)
49-
if err := w.Close(); err != nil && s.err == nil {
50-
s.err = fmt.Errorf("error closing output pipe for stage %q: %w", s.Name(), err)
51-
}
52-
if stdin != nil {
53-
if err := stdin.Close(); err != nil && s.err == nil {
54-
s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err)
54+
defer func() {
55+
// Cleanup resources on exit
56+
if err := w.Close(); err != nil && s.err == nil {
57+
s.err = fmt.Errorf("error closing output pipe for stage %q: %w", s.Name(), err)
5558
}
56-
}
57-
close(s.done)
59+
if stdin != nil {
60+
if err := stdin.Close(); err != nil && s.err == nil {
61+
s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err)
62+
}
63+
}
64+
close(s.done)
65+
}()
66+
67+
defer s.recoverPanic()
68+
69+
s.err = s.f(ctx, env, stdin, w)
5870
}()
5971

6072
return r, nil
@@ -64,3 +76,13 @@ func (s *goStage) Wait() error {
6476
<-s.done
6577
return s.err
6678
}
79+
80+
func (s *goStage) recoverPanic() {
81+
if s.panicHandler == nil {
82+
return
83+
}
84+
85+
if p := recover(); p != nil {
86+
s.err = s.panicHandler(p)
87+
}
88+
}

pipe/panic.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package pipe
2+
3+
// StagePanicHandlerAware is an interface that Stages can implement to receive
4+
// a panic handler from the pipeline. This is particularly useful for stages
5+
// that execute work in a separate goroutine and need to manage panics occurring
6+
// within that goroutine.
7+
type StagePanicHandlerAware interface {
8+
SetPanicHandler(StagePanicHandler)
9+
}
10+
11+
// StagePanicHandler is a function that handles panics in a pipeline's stages.
12+
type StagePanicHandler func(p any) error

pipe/pipeline.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ type Pipeline struct {
6464
started uint32
6565

6666
eventHandler func(e *Event)
67+
panicHandler StagePanicHandler
6768
}
6869

6970
var emptyEventHandler = func(e *Event) {}
@@ -179,6 +180,20 @@ func WithEventHandler(handler func(e *Event)) Option {
179180
}
180181
}
181182

183+
// WithStagePanicHandler sets a panic handler for the stages within a pipeline.
184+
// When a pipeline stage panics, the provided handler will be invoked, allowing
185+
// the client to handle the panic in whatever way they see fit.
186+
//
187+
// Note:
188+
// - Only the Function stage supports this functionality.
189+
// - The client is responsible for deciding whether to recover from the panic or panicking again.
190+
// - If a panic handler is not set, the panic will be propagated normally.
191+
func WithStagePanicHandler(ph StagePanicHandler) Option {
192+
return func(p *Pipeline) {
193+
p.panicHandler = ph
194+
}
195+
}
196+
182197
func (p *Pipeline) hasStarted() bool {
183198
return atomic.LoadUint32(&p.started) != 0
184199
}
@@ -265,6 +280,10 @@ func (p *Pipeline) Start(ctx context.Context) error {
265280
}
266281

267282
for i, s := range p.stages {
283+
if phs, ok := s.(StagePanicHandlerAware); ok && p.panicHandler != nil {
284+
phs.SetPanicHandler(p.panicHandler)
285+
}
286+
268287
var err error
269288
stdout, err := s.Start(ctx, p.env, nextStdin)
270289
if err != nil {

pipe/pipeline_test.go

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -436,28 +436,53 @@ func TestFunction(t *testing.T) {
436436

437437
dir := t.TempDir()
438438

439-
p := pipe.New(pipe.WithDir(dir))
440-
p.Add(
441-
pipe.Print("hello world"),
442-
pipe.Function(
443-
"farewell",
444-
func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error {
445-
buf, err := io.ReadAll(stdin)
446-
if err != nil {
439+
t.Run("successful function", func(t *testing.T) {
440+
p := pipe.New(pipe.WithDir(dir))
441+
p.Add(
442+
pipe.Print("hello world"),
443+
pipe.Function(
444+
"farewell",
445+
func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error {
446+
buf, err := io.ReadAll(stdin)
447+
if err != nil {
448+
return err
449+
}
450+
if string(buf) != "hello world" {
451+
return fmt.Errorf("expected \"hello world\"; got %q", string(buf))
452+
}
453+
_, err = stdout.Write([]byte("goodbye, cruel world"))
447454
return err
448-
}
449-
if string(buf) != "hello world" {
450-
return fmt.Errorf("expected \"hello world\"; got %q", string(buf))
451-
}
452-
_, err = stdout.Write([]byte("goodbye, cruel world"))
455+
},
456+
),
457+
)
458+
459+
out, err := p.Output(ctx)
460+
assert.NoError(t, err)
461+
assert.EqualValues(t, "goodbye, cruel world", out)
462+
})
463+
464+
t.Run("panic with handler", func(t *testing.T) {
465+
p := pipe.New(
466+
pipe.WithDir(dir),
467+
pipe.WithStagePanicHandler(func(p any) error {
468+
err := fmt.Errorf("panic handled: %v", p)
453469
return err
454-
},
455-
),
456-
)
470+
}),
471+
)
472+
p.Add(
473+
pipe.Print("hello world"),
474+
pipe.Function(
475+
"farewell",
476+
func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error {
477+
panic("this is a panic")
478+
},
479+
),
480+
)
457481

458-
out, err := p.Output(ctx)
459-
assert.NoError(t, err)
460-
assert.EqualValues(t, "goodbye, cruel world", out)
482+
out, err := p.Output(ctx)
483+
assert.ErrorContains(t, err, "panic handled")
484+
assert.Empty(t, out)
485+
})
461486
}
462487

463488
func TestPipelineWithFunction(t *testing.T) {

0 commit comments

Comments
 (0)