Skip to content

Commit 806ff7c

Browse files
committed
SendMessage status validation
1 parent 782cba9 commit 806ff7c

2 files changed

Lines changed: 72 additions & 33 deletions

File tree

lib/screentracker/conversation.go

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,12 @@ type ConversationConfig struct {
3333
// Function to format the messages received from the agent
3434
// userInput is the last user message
3535
FormatMessage func(message string, userInput string) string
36-
// SkipWritingMessage skips the writing of a message to the agent
36+
// SkipWritingMessage skips the writing of a message to the agent.
3737
// This is used in tests
3838
SkipWritingMessage bool
39+
// SkipSendMessageStatusCheck skips the check for whether the message can be sent.
40+
// This is used in tests
41+
SkipSendMessageStatusCheck bool
3942
}
4043

4144
type ConversationRole string
@@ -287,14 +290,26 @@ func (c *Conversation) writeMessageWithConfirmation(ctx context.Context, message
287290
return nil
288291
}
289292

293+
var MessageValidationErrorWhitespace = xerrors.New("message must be trimmed of leading and trailing whitespace")
294+
var MessageValidationErrorEmpty = xerrors.New("message must not be empty")
295+
var MessageValidationErrorChanging = xerrors.New("message can only be sent when the agent is waiting for user input")
296+
290297
func (c *Conversation) SendMessage(messageParts ...MessagePart) error {
291298
c.lock.Lock()
292299
defer c.lock.Unlock()
293300

301+
if !c.cfg.SkipSendMessageStatusCheck && c.statusInner() != ConversationStatusStable {
302+
return MessageValidationErrorChanging
303+
}
304+
294305
message := PartsToString(messageParts...)
295306
if message != msgfmt.TrimWhitespace(message) {
296307
// msgfmt formatting functions assume this
297-
return xerrors.Errorf("message must be trimmed of leading and trailing whitespace")
308+
return MessageValidationErrorWhitespace
309+
}
310+
if message == "" {
311+
// writeMessageWithConfirmation requires a non-empty message
312+
return MessageValidationErrorEmpty
298313
}
299314

300315
screenBeforeMessage := c.cfg.AgentIO.ReadScreen()
@@ -315,10 +330,8 @@ func (c *Conversation) SendMessage(messageParts ...MessagePart) error {
315330
return nil
316331
}
317332

318-
func (c *Conversation) Status() ConversationStatus {
319-
c.lock.Lock()
320-
defer c.lock.Unlock()
321-
333+
// Assumes that the caller holds the lock
334+
func (c *Conversation) statusInner() ConversationStatus {
322335
// sanity checks
323336
if c.snapshotBuffer.Capacity() != c.stableSnapshotsThreshold {
324337
panic(fmt.Sprintf("snapshot buffer capacity %d is not equal to snapshot threshold %d. can't check stability", c.snapshotBuffer.Capacity(), c.stableSnapshotsThreshold))
@@ -347,6 +360,13 @@ func (c *Conversation) Status() ConversationStatus {
347360
return ConversationStatusStable
348361
}
349362

363+
func (c *Conversation) Status() ConversationStatus {
364+
c.lock.Lock()
365+
defer c.lock.Unlock()
366+
367+
return c.statusInner()
368+
}
369+
350370
func (c *Conversation) Messages() []ConversationMessage {
351371
c.lock.Lock()
352372
defer c.lock.Unlock()

lib/screentracker/conversation_test.go

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -135,22 +135,22 @@ func TestMessages(t *testing.T) {
135135
sendMsg := func(c *st.Conversation, msg string) error {
136136
return c.SendMessage(st.MessagePartText{Content: msg})
137137
}
138-
newConversation := func(cfg st.ConversationConfig) *st.Conversation {
139-
if cfg.GetTime == nil {
140-
cfg.GetTime = func() time.Time { return now }
138+
newConversation := func(opts ...func(*st.ConversationConfig)) *st.Conversation {
139+
cfg := st.ConversationConfig{
140+
GetTime: func() time.Time { return now },
141+
SnapshotInterval: 1 * time.Second,
142+
ScreenStabilityLength: 2 * time.Second,
143+
SkipWritingMessage: true,
144+
SkipSendMessageStatusCheck: true,
141145
}
142-
if cfg.SnapshotInterval == 0 {
143-
cfg.SnapshotInterval = 1 * time.Second
144-
}
145-
if cfg.ScreenStabilityLength == 0 {
146-
cfg.ScreenStabilityLength = 2 * time.Second
146+
for _, opt := range opts {
147+
opt(&cfg)
147148
}
148-
cfg.SkipWritingMessage = true
149149
return st.NewConversation(context.Background(), cfg)
150150
}
151151

152152
t.Run("messages are copied", func(t *testing.T) {
153-
c := newConversation(st.ConversationConfig{})
153+
c := newConversation()
154154
messages := c.Messages()
155155
assert.Equal(t, []st.ConversationMessage{
156156
agentMsg(0, ""),
@@ -164,11 +164,10 @@ func TestMessages(t *testing.T) {
164164
})
165165

166166
t.Run("whitespace-padding", func(t *testing.T) {
167-
c := newConversation(st.ConversationConfig{})
167+
c := newConversation()
168168
for _, msg := range []string{"123 ", " 123", "123\t\t", "\n123", "123\n\t", " \t123\n\t"} {
169169
err := c.SendMessage(st.MessagePartText{Content: msg})
170-
assert.Error(t, err)
171-
assert.Contains(t, err.Error(), "message must be trimmed of leading and trailing whitespace")
170+
assert.Error(t, err, st.MessageValidationErrorWhitespace)
172171
}
173172
})
174173

@@ -178,8 +177,8 @@ func TestMessages(t *testing.T) {
178177
}{
179178
Time: now,
180179
}
181-
c := newConversation(st.ConversationConfig{
182-
GetTime: func() time.Time { return nowWrapper.Time },
180+
c := newConversation(func(cfg *st.ConversationConfig) {
181+
cfg.GetTime = func() time.Time { return nowWrapper.Time }
183182
})
184183

185184
c.AddSnapshot("1")
@@ -194,8 +193,8 @@ func TestMessages(t *testing.T) {
194193

195194
t.Run("tracking messages", func(t *testing.T) {
196195
agent := &testAgent{}
197-
c := newConversation(st.ConversationConfig{
198-
AgentIO: agent,
196+
c := newConversation(func(cfg *st.ConversationConfig) {
197+
cfg.AgentIO = agent
199198
})
200199
// agent message is recorded when the first snapshot is added
201200
c.AddSnapshot("1")
@@ -260,8 +259,8 @@ func TestMessages(t *testing.T) {
260259

261260
t.Run("tracking messages overlap", func(t *testing.T) {
262261
agent := &testAgent{}
263-
c := newConversation(st.ConversationConfig{
264-
AgentIO: agent,
262+
c := newConversation(func(cfg *st.ConversationConfig) {
263+
cfg.AgentIO = agent
265264
})
266265

267266
// common overlap between screens is removed after a user message
@@ -289,11 +288,11 @@ func TestMessages(t *testing.T) {
289288

290289
t.Run("format-message", func(t *testing.T) {
291290
agent := &testAgent{}
292-
c := newConversation(st.ConversationConfig{
293-
AgentIO: agent,
294-
FormatMessage: func(message string, userInput string) string {
291+
c := newConversation(func(cfg *st.ConversationConfig) {
292+
cfg.AgentIO = agent
293+
cfg.FormatMessage = func(message string, userInput string) string {
295294
return message + " " + userInput
296-
},
295+
}
297296
})
298297
agent.screen = "1"
299298
assert.NoError(t, sendMsg(c, "2"))
@@ -312,11 +311,11 @@ func TestMessages(t *testing.T) {
312311

313312
t.Run("format-message", func(t *testing.T) {
314313
agent := &testAgent{}
315-
c := newConversation(st.ConversationConfig{
316-
AgentIO: agent,
317-
FormatMessage: func(message string, userInput string) string {
314+
c := newConversation(func(cfg *st.ConversationConfig) {
315+
cfg.AgentIO = agent
316+
cfg.FormatMessage = func(message string, userInput string) string {
318317
return "formatted"
319-
},
318+
}
320319
})
321320
assert.Equal(t, []st.ConversationMessage{
322321
{
@@ -326,7 +325,27 @@ func TestMessages(t *testing.T) {
326325
Time: now,
327326
},
328327
}, c.Messages())
328+
})
329+
330+
t.Run("send-message-status-check", func(t *testing.T) {
331+
c := newConversation(func(cfg *st.ConversationConfig) {
332+
cfg.SkipSendMessageStatusCheck = false
333+
cfg.SnapshotInterval = 1 * time.Second
334+
cfg.ScreenStabilityLength = 2 * time.Second
335+
cfg.AgentIO = &testAgent{}
336+
})
337+
assert.Error(t, sendMsg(c, "1"), st.MessageValidationErrorChanging)
338+
for range 3 {
339+
c.AddSnapshot("1")
340+
}
341+
assert.NoError(t, sendMsg(c, "4"))
342+
c.AddSnapshot("2")
343+
assert.Error(t, sendMsg(c, "5"), st.MessageValidationErrorChanging)
344+
})
329345

346+
t.Run("send-message-empty-message", func(t *testing.T) {
347+
c := newConversation()
348+
assert.Error(t, sendMsg(c, ""), st.MessageValidationErrorEmpty)
330349
})
331350
}
332351

0 commit comments

Comments
 (0)