mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-10-29 15:42:24 -05:00
[bugfix] fix possible mutex lockup during streaming code (#2633)
* rewrite Stream{} to use much less mutex locking, update related code
* use new context for the stream context
* ensure stream gets closed on return of writeTo / readFrom WSConn()
* ensure stream write timeout gets cancelled
* remove embedded context type from Stream{}, reformat log messages for consistency
* use c.Request.Context() for context passed into Stream().Open()
* only return 1 boolean, fix tests to expect multiple stream types in messages
* changes to ping logic
* further improved ping logic
* don't export unused function types, update message sending to only include relevant stream type
* ensure stream gets closed 🤦
* update to error log on failed json marshal (instead of panic)
* inverse websocket read error checking to _ignore_ expected close errors
This commit is contained in:
parent
8cafa6b74b
commit
291e180990
14 changed files with 535 additions and 451 deletions
|
|
@ -18,38 +18,16 @@
|
|||
package stream
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"context"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||
)
|
||||
|
||||
// Delete streams the delete of the given statusID to *ALL* open streams.
|
||||
func (p *Processor) Delete(statusID string) error {
|
||||
errs := []string{}
|
||||
|
||||
// get all account IDs with open streams
|
||||
accountIDs := []string{}
|
||||
p.streamMap.Range(func(k interface{}, _ interface{}) bool {
|
||||
key, ok := k.(string)
|
||||
if !ok {
|
||||
panic("streamMap key was not a string (account id)")
|
||||
}
|
||||
|
||||
accountIDs = append(accountIDs, key)
|
||||
return true
|
||||
func (p *Processor) Delete(ctx context.Context, statusID string) {
|
||||
p.streams.PostAll(ctx, stream.Message{
|
||||
Payload: statusID,
|
||||
Event: stream.EventTypeDelete,
|
||||
Stream: stream.AllStatusTimelines,
|
||||
})
|
||||
|
||||
// stream the delete to every account
|
||||
for _, accountID := range accountIDs {
|
||||
if err := p.toAccount(statusID, stream.EventTypeDelete, stream.AllStatusTimelines, accountID); err != nil {
|
||||
errs = append(errs, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) != 0 {
|
||||
return fmt.Errorf("one or more errors streaming status delete: %s", strings.Join(errs, ";"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,20 +18,29 @@
|
|||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"codeberg.org/gruf/go-byteutil"
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||
)
|
||||
|
||||
// Notify streams the given notification to any open, appropriate streams belonging to the given account.
|
||||
func (p *Processor) Notify(n *apimodel.Notification, account *gtsmodel.Account) error {
|
||||
bytes, err := json.Marshal(n)
|
||||
func (p *Processor) Notify(ctx context.Context, account *gtsmodel.Account, notif *apimodel.Notification) {
|
||||
b, err := json.Marshal(notif)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling notification to json: %s", err)
|
||||
log.Errorf(ctx, "error marshaling json: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
return p.toAccount(string(bytes), stream.EventTypeNotification, []string{stream.TimelineNotifications, stream.TimelineHome}, account.ID)
|
||||
p.streams.Post(ctx, account.ID, stream.Message{
|
||||
Payload: byteutil.B2S(b),
|
||||
Event: stream.EventTypeNotification,
|
||||
Stream: []string{
|
||||
stream.TimelineNotifications,
|
||||
stream.TimelineHome,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,10 +49,11 @@ func (suite *NotificationTestSuite) TestStreamNotification() {
|
|||
Account: followAccountAPIModel,
|
||||
}
|
||||
|
||||
err = suite.streamProcessor.Notify(notification, account)
|
||||
suite.NoError(err)
|
||||
suite.streamProcessor.Notify(context.Background(), account, notification)
|
||||
|
||||
msg, ok := openStream.Recv(context.Background())
|
||||
suite.True(ok)
|
||||
|
||||
msg := <-openStream.Messages
|
||||
dst := new(bytes.Buffer)
|
||||
err = json.Indent(dst, []byte(msg.Payload), "", " ")
|
||||
suite.NoError(err)
|
||||
|
|
|
|||
|
|
@ -19,13 +19,10 @@ package stream
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"codeberg.org/gruf/go-kv"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||
)
|
||||
|
|
@ -37,97 +34,5 @@ func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamT
|
|||
{"streamType", streamType},
|
||||
}...)
|
||||
l.Debug("received open stream request")
|
||||
|
||||
var (
|
||||
streamID string
|
||||
err error
|
||||
)
|
||||
|
||||
// Each stream needs a unique ID so we know to close it.
|
||||
streamID, err = id.NewRandomULID()
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %w", err))
|
||||
}
|
||||
|
||||
// Each stream can be subscibed to multiple types.
|
||||
// Record them in a set, and include the initial one
|
||||
// if it was given to us.
|
||||
streamTypes := map[string]any{}
|
||||
if streamType != "" {
|
||||
streamTypes[streamType] = true
|
||||
}
|
||||
|
||||
newStream := &stream.Stream{
|
||||
ID: streamID,
|
||||
StreamTypes: streamTypes,
|
||||
Messages: make(chan *stream.Message, 100),
|
||||
Hangup: make(chan interface{}, 1),
|
||||
Connected: true,
|
||||
}
|
||||
go p.waitToCloseStream(account, newStream)
|
||||
|
||||
v, ok := p.streamMap.Load(account.ID)
|
||||
if ok {
|
||||
// There is an entry in the streamMap
|
||||
// for this account. Parse it out.
|
||||
streamsForAccount, ok := v.(*stream.StreamsForAccount)
|
||||
if !ok {
|
||||
return nil, gtserror.NewErrorInternalError(errors.New("stream map error"))
|
||||
}
|
||||
|
||||
// Append new stream to existing entry.
|
||||
streamsForAccount.Lock()
|
||||
streamsForAccount.Streams = append(streamsForAccount.Streams, newStream)
|
||||
streamsForAccount.Unlock()
|
||||
} else {
|
||||
// There is no entry in the streamMap for
|
||||
// this account yet. Create one and store it.
|
||||
p.streamMap.Store(account.ID, &stream.StreamsForAccount{
|
||||
Streams: []*stream.Stream{
|
||||
newStream,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return newStream, nil
|
||||
}
|
||||
|
||||
// waitToCloseStream waits until the hangup channel is closed for the given stream.
|
||||
// It then iterates through the map of streams stored by the processor, removes the stream from it,
|
||||
// and then closes the messages channel of the stream to indicate that the channel should no longer be read from.
|
||||
func (p *Processor) waitToCloseStream(account *gtsmodel.Account, thisStream *stream.Stream) {
|
||||
<-thisStream.Hangup // wait for a hangup message
|
||||
|
||||
// lock the stream to prevent more messages being put in it while we work
|
||||
thisStream.Lock()
|
||||
defer thisStream.Unlock()
|
||||
|
||||
// indicate the stream is no longer connected
|
||||
thisStream.Connected = false
|
||||
|
||||
// load and parse the entry for this account from the stream map
|
||||
v, ok := p.streamMap.Load(account.ID)
|
||||
if !ok || v == nil {
|
||||
return
|
||||
}
|
||||
streamsForAccount, ok := v.(*stream.StreamsForAccount)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// lock the streams for account while we remove this stream from its slice
|
||||
streamsForAccount.Lock()
|
||||
defer streamsForAccount.Unlock()
|
||||
|
||||
// put everything into modified streams *except* the stream we're removing
|
||||
modifiedStreams := []*stream.Stream{}
|
||||
for _, s := range streamsForAccount.Streams {
|
||||
if s.ID != thisStream.ID {
|
||||
modifiedStreams = append(modifiedStreams, s)
|
||||
}
|
||||
}
|
||||
streamsForAccount.Streams = modifiedStreams
|
||||
|
||||
// finally close the messages channel so no more messages can be read from it
|
||||
close(thisStream.Messages)
|
||||
return p.streams.Open(account.ID, streamType), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,21 +18,26 @@
|
|||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"codeberg.org/gruf/go-byteutil"
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||
)
|
||||
|
||||
// StatusUpdate streams the given edited status to any open, appropriate
|
||||
// streams belonging to the given account.
|
||||
func (p *Processor) StatusUpdate(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error {
|
||||
bytes, err := json.Marshal(s)
|
||||
// StatusUpdate streams the given edited status to any open, appropriate streams belonging to the given account.
|
||||
func (p *Processor) StatusUpdate(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) {
|
||||
b, err := json.Marshal(status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling status to json: %s", err)
|
||||
log.Errorf(ctx, "error marshaling json: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
return p.toAccount(string(bytes), stream.EventTypeStatusUpdate, streamTypes, account.ID)
|
||||
p.streams.Post(ctx, account.ID, stream.Message{
|
||||
Payload: byteutil.B2S(b),
|
||||
Event: stream.EventTypeStatusUpdate,
|
||||
Stream: []string{streamType},
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,10 +42,11 @@ func (suite *StatusUpdateTestSuite) TestStreamNotification() {
|
|||
apiStatus, err := typeutils.NewConverter(&suite.state).StatusToAPIStatus(context.Background(), editedStatus, account)
|
||||
suite.NoError(err)
|
||||
|
||||
err = suite.streamProcessor.StatusUpdate(apiStatus, account, []string{stream.TimelineHome})
|
||||
suite.NoError(err)
|
||||
suite.streamProcessor.StatusUpdate(context.Background(), account, apiStatus, stream.TimelineHome)
|
||||
|
||||
msg, ok := openStream.Recv(context.Background())
|
||||
suite.True(ok)
|
||||
|
||||
msg := <-openStream.Messages
|
||||
dst := new(bytes.Buffer)
|
||||
err = json.Indent(dst, []byte(msg.Payload), "", " ")
|
||||
suite.NoError(err)
|
||||
|
|
|
|||
|
|
@ -18,8 +18,6 @@
|
|||
package stream
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||
|
|
@ -28,53 +26,13 @@ import (
|
|||
type Processor struct {
|
||||
state *state.State
|
||||
oauthServer oauth.Server
|
||||
streamMap *sync.Map
|
||||
streams stream.Streams
|
||||
}
|
||||
|
||||
func New(state *state.State, oauthServer oauth.Server) Processor {
|
||||
return Processor{
|
||||
state: state,
|
||||
oauthServer: oauthServer,
|
||||
streamMap: &sync.Map{},
|
||||
streams: stream.Streams{},
|
||||
}
|
||||
}
|
||||
|
||||
// toAccount streams the given payload with the given event type to any streams currently open for the given account ID.
|
||||
func (p *Processor) toAccount(payload string, event string, streamTypes []string, accountID string) error {
|
||||
// Load all streams open for this account.
|
||||
v, ok := p.streamMap.Load(accountID)
|
||||
if !ok {
|
||||
return nil // No entry = nothing to stream.
|
||||
}
|
||||
streamsForAccount := v.(*stream.StreamsForAccount)
|
||||
|
||||
streamsForAccount.Lock()
|
||||
defer streamsForAccount.Unlock()
|
||||
|
||||
for _, s := range streamsForAccount.Streams {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
if !s.Connected {
|
||||
continue
|
||||
}
|
||||
|
||||
typeLoop:
|
||||
for _, streamType := range streamTypes {
|
||||
if _, found := s.StreamTypes[streamType]; found {
|
||||
s.Messages <- &stream.Message{
|
||||
Stream: []string{streamType},
|
||||
Event: string(event),
|
||||
Payload: payload,
|
||||
}
|
||||
|
||||
// Break out to the outer loop,
|
||||
// to avoid sending duplicates of
|
||||
// the same event to the same stream.
|
||||
break typeLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,20 +18,26 @@
|
|||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"codeberg.org/gruf/go-byteutil"
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||
)
|
||||
|
||||
// Update streams the given update to any open, appropriate streams belonging to the given account.
|
||||
func (p *Processor) Update(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error {
|
||||
bytes, err := json.Marshal(s)
|
||||
func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) {
|
||||
b, err := json.Marshal(status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling status to json: %s", err)
|
||||
log.Errorf(ctx, "error marshaling json: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
return p.toAccount(string(bytes), stream.EventTypeUpdate, streamTypes, account.ID)
|
||||
p.streams.Post(ctx, account.ID, stream.Message{
|
||||
Payload: byteutil.B2S(b),
|
||||
Event: stream.EventTypeUpdate,
|
||||
Stream: []string{streamType},
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue