mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 02:02:25 -05:00 
			
		
		
		
	[feature] Support multiple subscriptions on single websocket connection (#1489)
- Allow Oauth authentication on websocket endpoint - Make streamType query parameter optional - Read websocket commands from client and update subscriptions
This commit is contained in:
		
					parent
					
						
							
								cb2f84e551
							
						
					
				
			
			
				commit
				
					
						e323a930bf
					
				
			
		
					 4 changed files with 74 additions and 26 deletions
				
			
		|  | @ -20,14 +20,16 @@ package streaming | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" |  | ||||||
| 	"fmt" |  | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"codeberg.org/gruf/go-kv" | 	"codeberg.org/gruf/go-kv" | ||||||
| 	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" | 	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | ||||||
|  | 	streampkg "github.com/superseriousbusiness/gotosocial/internal/stream" | ||||||
|  | 	"golang.org/x/exp/slices" | ||||||
| 
 | 
 | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/gorilla/websocket" | 	"github.com/gorilla/websocket" | ||||||
|  | @ -134,32 +136,37 @@ import ( | ||||||
| //		'400': | //		'400': | ||||||
| //			description: bad request | //			description: bad request | ||||||
| func (m *Module) StreamGETHandler(c *gin.Context) { | func (m *Module) StreamGETHandler(c *gin.Context) { | ||||||
| 	streamType := c.Query(StreamQueryKey) |  | ||||||
| 	if streamType == "" { |  | ||||||
| 		err := fmt.Errorf("no stream type provided under query key %s", StreamQueryKey) |  | ||||||
| 		apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	var token string |  | ||||||
| 
 | 
 | ||||||
| 	// First we check for a query param provided access token | 	// First we check for a query param provided access token | ||||||
| 	if token = c.Query(AccessTokenQueryKey); token == "" { | 	token := c.Query(AccessTokenQueryKey) | ||||||
|  | 	if token == "" { | ||||||
| 		// Else we check the HTTP header provided token | 		// Else we check the HTTP header provided token | ||||||
| 		if token = c.GetHeader(AccessTokenHeader); token == "" { | 		token = c.GetHeader(AccessTokenHeader) | ||||||
| 			const errStr = "no access token provided" |  | ||||||
| 			err := gtserror.NewErrorUnauthorized(errors.New(errStr), errStr) |  | ||||||
| 			apiutil.ErrorHandler(c, err, m.processor.InstanceGetV1) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	account, errWithCode := m.processor.Stream().Authorize(c.Request.Context(), token) | 	var account *gtsmodel.Account | ||||||
|  | 	if token != "" { | ||||||
|  | 		// Check the explicit token | ||||||
|  | 		var errWithCode gtserror.WithCode | ||||||
|  | 		account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token) | ||||||
| 		if errWithCode != nil { | 		if errWithCode != nil { | ||||||
| 			apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | 			apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  | 	} else { | ||||||
|  | 		// If no explicit token was provided, try regular oauth | ||||||
|  | 		auth, errStr := oauth.Authed(c, true, true, true, true) | ||||||
|  | 		if errStr != nil { | ||||||
|  | 			err := gtserror.NewErrorUnauthorized(errStr, errStr.Error()) | ||||||
|  | 			apiutil.ErrorHandler(c, err, m.processor.InstanceGetV1) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		account = auth.Account | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Get the initial stream type, if there is one. | ||||||
|  | 	// streamType will be an empty string if one wasn't supplied. Open() will deal with this | ||||||
|  | 	streamType := c.Query(StreamQueryKey) | ||||||
| 	stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType) | 	stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType) | ||||||
| 	if errWithCode != nil { | 	if errWithCode != nil { | ||||||
| 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
|  | @ -219,8 +226,9 @@ func (m *Module) StreamGETHandler(c *gin.Context) { | ||||||
| 				// We have to listen for received websocket messages in | 				// We have to listen for received websocket messages in | ||||||
| 				// order to trigger the underlying wsConn.PingHandler(). | 				// order to trigger the underlying wsConn.PingHandler(). | ||||||
| 				// | 				// | ||||||
| 				// So we wait on received messages but only act on errors. | 				// Read JSON objects from the client and act on them | ||||||
| 				_, _, err := wsConn.ReadMessage() | 				var msg map[string]string | ||||||
|  | 				err := wsConn.ReadJSON(&msg) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					if ctx.Err() == nil { | 					if ctx.Err() == nil { | ||||||
| 						// Only log error if the connection was not closed | 						// Only log error if the connection was not closed | ||||||
|  | @ -229,6 +237,33 @@ func (m *Module) StreamGETHandler(c *gin.Context) { | ||||||
| 					} | 					} | ||||||
| 					return | 					return | ||||||
| 				} | 				} | ||||||
|  | 				l.Tracef("received message from websocket: %v", msg) | ||||||
|  | 
 | ||||||
|  | 				// If the message contains 'stream' and 'type' fields, we can | ||||||
|  | 				// update the set of timelines that are subscribed for events. | ||||||
|  | 				// everything else is ignored. | ||||||
|  | 				action := msg["type"] | ||||||
|  | 				streamType := msg["stream"] | ||||||
|  | 
 | ||||||
|  | 				// Ignore if the streamType is unknown (or missing), so a bad | ||||||
|  | 				// client can't cause extra memory allocations | ||||||
|  | 				if !slices.Contains(streampkg.AllStatusTimelines, streamType) { | ||||||
|  | 					l.Warnf("Unknown 'stream' field: %v", msg) | ||||||
|  | 					continue | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				switch action { | ||||||
|  | 				case "subscribe": | ||||||
|  | 					stream.Lock() | ||||||
|  | 					stream.Timelines[streamType] = true | ||||||
|  | 					stream.Unlock() | ||||||
|  | 				case "unsubscribe": | ||||||
|  | 					stream.Lock() | ||||||
|  | 					delete(stream.Timelines, streamType) | ||||||
|  | 					stream.Unlock() | ||||||
|  | 				default: | ||||||
|  | 					l.Warnf("Invalid 'type' field: %v", msg) | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 		}() | 		}() | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -45,9 +45,17 @@ func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamT | ||||||
| 		return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %s", err)) | 		return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %s", err)) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Each stream can be subscibed to multiple timelines. | ||||||
|  | 	// Record them in a set, and include the initial one | ||||||
|  | 	// if it was given to us | ||||||
|  | 	timelines := map[string]bool{} | ||||||
|  | 	if streamTimeline != "" { | ||||||
|  | 		timelines[streamTimeline] = true | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	thisStream := &stream.Stream{ | 	thisStream := &stream.Stream{ | ||||||
| 		ID:        streamID, | 		ID:        streamID, | ||||||
| 		Timeline:  streamTimeline, | 		Timelines: timelines, | ||||||
| 		Messages:  make(chan *stream.Message, 100), | 		Messages:  make(chan *stream.Message, 100), | ||||||
| 		Hangup:    make(chan interface{}, 1), | 		Hangup:    make(chan interface{}, 1), | ||||||
| 		Connected: true, | 		Connected: true, | ||||||
|  |  | ||||||
|  | @ -63,12 +63,15 @@ func (p *Processor) toAccount(payload string, event string, timelines []string, | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		for _, t := range timelines { | 		for _, t := range timelines { | ||||||
| 			if s.Timeline == string(t) { | 			if _, found := s.Timelines[t]; found { | ||||||
| 				s.Messages <- &stream.Message{ | 				s.Messages <- &stream.Message{ | ||||||
| 					Stream:  []string{string(t)}, | 					Stream:  []string{string(t)}, | ||||||
| 					Event:   string(event), | 					Event:   string(event), | ||||||
| 					Payload: payload, | 					Payload: payload, | ||||||
| 				} | 				} | ||||||
|  | 				// break out to the outer loop, to avoid sending duplicates | ||||||
|  | 				// of the same event to the same stream | ||||||
|  | 				break | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -63,8 +63,10 @@ type StreamsForAccount struct { | ||||||
| type Stream struct { | type Stream struct { | ||||||
| 	// ID of this stream, generated during creation. | 	// ID of this stream, generated during creation. | ||||||
| 	ID string | 	ID string | ||||||
| 	// Timeline of this stream: user/public/etc | 	// A set of timelines of this stream: user/public/etc | ||||||
| 	Timeline string | 	// a matching key means the timeline is subscribed. The value | ||||||
|  | 	// is ignored | ||||||
|  | 	Timelines map[string]bool | ||||||
| 	// Channel of messages for the client to read from | 	// Channel of messages for the client to read from | ||||||
| 	Messages chan *Message | 	Messages chan *Message | ||||||
| 	// Channel to close when the client drops away | 	// Channel to close when the client drops away | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue