Skip to content

Commit

Permalink
fix: slacktest GetSeenOutboundMessages race condition
Browse files Browse the repository at this point in the history
This patch addresses issue slack-go#1361.

The storage backing GetSeenOutboundMessages is updated in a goroutine
that may or may not be executed in time for assertions against this
method. This patch updates the storage to be updated synchronously in
the handler, while maintaining the asynchronous queue behavior for
websockets handlers.

The test case has been updated to remove the sleep statement that likely
worked around this very issue.

I opted to change the locking behavior to be more closely related with
the messageCollection type. There is a smaller version of this fix that
move the lock/update/unlock block into the callsite of each queue
update, if that is preferable.
  • Loading branch information
askreet committed Dec 13, 2024
1 parent e764011 commit eef3a5d
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 49 deletions.
10 changes: 3 additions & 7 deletions slacktest/funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@ func (sts *Server) queueForWebsocket(s, hubname string) {
channel, err := getHubForServer(hubname)
if err != nil {
log.Printf("Unable to get server's channels: %s", err.Error())
} else {
channel.sent <- s
}
sts.seenOutboundMessages.Lock()
sts.seenOutboundMessages.messages = append(sts.seenOutboundMessages.messages, s)
sts.seenOutboundMessages.Unlock()
channel.sent <- s
}

func handlePendingMessages(c *websocket.Conn, hubname string) {
Expand All @@ -43,9 +41,7 @@ func (sts *Server) postProcessMessage(m, hubname string) {
log.Printf("Unable to get server's channels: %s", err.Error())
return
}
sts.seenInboundMessages.Lock()
sts.seenInboundMessages.messages = append(sts.seenInboundMessages.messages, m)
sts.seenInboundMessages.Unlock()
sts.seenInboundMessages.observe(m)
// send to firehose
channel.seen <- m
}
Expand Down
53 changes: 27 additions & 26 deletions slacktest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,53 +106,45 @@ func (sts *Server) GetGroups() []slack.Group {

// GetSeenInboundMessages returns all messages seen via websocket excluding pings
func (sts *Server) GetSeenInboundMessages() []string {
sts.seenInboundMessages.RLock()
m := sts.seenInboundMessages.messages
sts.seenInboundMessages.RUnlock()
return m
return sts.seenInboundMessages.get()
}

// GetSeenOutboundMessages returns all messages seen via websocket excluding pings
func (sts *Server) GetSeenOutboundMessages() []string {
sts.seenOutboundMessages.RLock()
m := sts.seenOutboundMessages.messages
sts.seenOutboundMessages.RUnlock()
return m
return sts.seenOutboundMessages.get()
}

// SawOutgoingMessage checks if a message was sent to connected websocket clients
func (sts *Server) SawOutgoingMessage(msg string) bool {
sts.seenOutboundMessages.RLock()
defer sts.seenOutboundMessages.RUnlock()
for _, m := range sts.seenOutboundMessages.messages {
for _, m := range sts.seenOutboundMessages.get() {
evt := &slack.MessageEvent{}
jErr := json.Unmarshal([]byte(m), evt)
if jErr != nil {
err := json.Unmarshal([]byte(m), evt)
if err != nil {
continue
}

if evt.Text == msg {
return true
}
}

return false
}

// SawMessage checks if an incoming message was seen
func (sts *Server) SawMessage(msg string) bool {
sts.seenInboundMessages.RLock()
defer sts.seenInboundMessages.RUnlock()
for _, m := range sts.seenInboundMessages.messages {
for _, m := range sts.seenInboundMessages.get() {
evt := &slack.MessageEvent{}
jErr := json.Unmarshal([]byte(m), evt)
if jErr != nil {
err := json.Unmarshal([]byte(m), evt)
if err != nil {
// This event isn't a message event so we'll skip it
continue
}
if evt.Text == msg {
return true
}
}

return false
}

Expand Down Expand Up @@ -184,11 +176,14 @@ func (sts *Server) SendMessageToBot(channel, msg string) {
m.User = defaultNonBotUserID
m.Text = fmt.Sprintf("<@%s> %s", sts.BotID, msg)
m.Timestamp = fmt.Sprintf("%d", time.Now().Unix())
j, jErr := json.Marshal(m)
if jErr != nil {
log.Printf("Unable to marshal message for bot: %s", jErr.Error())

j, err := json.Marshal(m)
if err != nil {
log.Printf("Unable to marshal message for bot: %s", err.Error())
return
}

sts.seenOutboundMessages.observe(string(j))
go sts.queueForWebsocket(string(j), sts.ServerAddr)
}

Expand All @@ -200,11 +195,14 @@ func (sts *Server) SendDirectMessageToBot(msg string) {
m.User = defaultNonBotUserID
m.Text = msg
m.Timestamp = fmt.Sprintf("%d", time.Now().Unix())
j, jErr := json.Marshal(m)
if jErr != nil {
log.Printf("Unable to marshal private message for bot: %s", jErr.Error())

j, err := json.Marshal(m)
if err != nil {
log.Printf("Unable to marshal private message for bot: %s", err.Error())
return
}

sts.seenOutboundMessages.observe(string(j))
go sts.queueForWebsocket(string(j), sts.ServerAddr)
}

Expand All @@ -216,18 +214,21 @@ func (sts *Server) SendMessageToChannel(channel, msg string) {
m.Text = msg
m.User = defaultNonBotUserID
m.Timestamp = fmt.Sprintf("%d", time.Now().Unix())

j, jErr := json.Marshal(m)
if jErr != nil {
log.Printf("Unable to marshal message for channel: %s", jErr.Error())
return
}
stringMsg := string(j)
go sts.queueForWebsocket(stringMsg, sts.ServerAddr)

sts.seenOutboundMessages.observe(string(j))
go sts.queueForWebsocket(string(j), sts.ServerAddr)
}

// SendToWebsocket send `s` as is to connected clients.
// This is useful for sending your own custom json to the websocket
func (sts *Server) SendToWebsocket(s string) {
sts.seenOutboundMessages.observe(s)
go sts.queueForWebsocket(s, sts.ServerAddr)
}

Expand Down
24 changes: 12 additions & 12 deletions slacktest/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestCustomNewServer(t *testing.T) {

func TestServerSendMessageToChannel(t *testing.T) {
s := NewTestServer()
go s.Start()
s.Start()
s.SendMessageToChannel("C123456789", "some text")
time.Sleep(2 * time.Second)
assert.True(t, s.SawOutgoingMessage("some text"))
Expand All @@ -36,7 +36,7 @@ func TestServerSendMessageToChannel(t *testing.T) {

func TestServerSendMessageToBot(t *testing.T) {
s := NewTestServer()
go s.Start()
s.Start()
s.SendMessageToBot("C123456789", "some text")
expectedMsg := fmt.Sprintf("<@%s> %s", s.BotID, "some text")
time.Sleep(2 * time.Second)
Expand All @@ -46,7 +46,7 @@ func TestServerSendMessageToBot(t *testing.T) {

func TestBotDirectMessageBotHandler(t *testing.T) {
s := NewTestServer()
go s.Start()
s.Start()
s.SendDirectMessageToBot("some text")
expectedMsg := "some text"
time.Sleep(2 * time.Second)
Expand All @@ -55,14 +55,14 @@ func TestBotDirectMessageBotHandler(t *testing.T) {
}

func TestGetSeenOutboundMessages(t *testing.T) {
maxWait := 5 * time.Second
s := NewTestServer()
go s.Start()
s.Start()

s.SendMessageToChannel("foo", "should see this message")
time.Sleep(maxWait)

seenOutbound := s.GetSeenOutboundMessages()
assert.True(t, len(seenOutbound) > 0)
assert.Len(t, seenOutbound, 1)

hadMessage := false
for _, msg := range seenOutbound {
var m = slack.Message{}
Expand All @@ -79,7 +79,7 @@ func TestGetSeenOutboundMessages(t *testing.T) {
func TestGetSeenInboundMessages(t *testing.T) {
maxWait := 5 * time.Second
s := NewTestServer()
go s.Start()
s.Start()

api := slack.New("ABCDEFG", slack.OptionAPIURL(s.GetAPIURL()))
rtm := api.NewRTM()
Expand Down Expand Up @@ -108,7 +108,7 @@ func TestGetSeenInboundMessages(t *testing.T) {
func TestSendChannelInvite(t *testing.T) {
maxWait := 5 * time.Second
s := NewTestServer()
go s.Start()
s.Start()
rtm := s.GetTestRTMInstance()
go rtm.ManageConnection()
evChan := make(chan (slack.Channel), 1)
Expand Down Expand Up @@ -137,7 +137,7 @@ func TestSendChannelInvite(t *testing.T) {
func TestSendGroupInvite(t *testing.T) {
maxWait := 5 * time.Second
s := NewTestServer()
go s.Start()
s.Start()
rtm := s.GetTestRTMInstance()
go rtm.ManageConnection()
evChan := make(chan (slack.Channel), 1)
Expand Down Expand Up @@ -165,12 +165,12 @@ func TestSendGroupInvite(t *testing.T) {

func TestServerSawMessage(t *testing.T) {
s := NewTestServer()
go s.Start()
s.Start()
assert.False(t, s.SawMessage("foo"), "should not have seen any message")
}

func TestServerSawOutgoingMessage(t *testing.T) {
s := NewTestServer()
go s.Start()
s.Start()
assert.False(t, s.SawOutgoingMessage("foo"), "should not have seen any message")
}
22 changes: 18 additions & 4 deletions slacktest/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,29 @@ type hub struct {
}

type messageChannels struct {
seen chan (string)
sent chan (string)
posted chan (slack.Message)
seen chan string
sent chan string
posted chan slack.Message
}
type messageCollection struct {
sync.RWMutex
messages []string
}

func (mc *messageCollection) observe(msg string) {
mc.Lock()
defer mc.Unlock()
mc.messages = append(mc.messages, msg)
}

func (mc *messageCollection) get() []string {
mc.RLock()
defer mc.RUnlock()

m := mc.messages
return m
}

type serverChannels struct {
sync.RWMutex
channels []slack.Channel
Expand All @@ -68,7 +82,7 @@ type Server struct {
BotName string
BotID string
ServerAddr string
SeenFeed chan (string)
SeenFeed chan string
channels *serverChannels
groups *serverGroups
seenInboundMessages *messageCollection
Expand Down

0 comments on commit eef3a5d

Please sign in to comment.