mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 08:22:27 -05:00 
			
		
		
		
	[chore/bugfix] Break Websockets logic into smaller read/write functions, don't log expected errors (#1932)
* [chore/bugfix] Break Websockets logic into smaller read/write functions, don't log expected errors * tweak * tidy up, use control message
This commit is contained in:
		
					parent
					
						
							
								ba0bc06b8c
							
						
					
				
			
			
				commit
				
					
						3d16962173
					
				
			
		
					 2 changed files with 237 additions and 133 deletions
				
			
		|  | @ -149,60 +149,78 @@ import ( | ||||||
| //		'400': | //		'400': | ||||||
| //			description: bad request | //			description: bad request | ||||||
| func (m *Module) StreamGETHandler(c *gin.Context) { | func (m *Module) StreamGETHandler(c *gin.Context) { | ||||||
|  | 	var ( | ||||||
|  | 		account     *gtsmodel.Account | ||||||
|  | 		errWithCode gtserror.WithCode | ||||||
|  | 	) | ||||||
| 
 | 
 | ||||||
| 	// First we check for a query param provided access token | 	// Try query param access token. | ||||||
| 	token := c.Query(AccessTokenQueryKey) | 	token := c.Query(AccessTokenQueryKey) | ||||||
| 	if token == "" { | 	if token == "" { | ||||||
| 		// Else we check the HTTP header provided token | 		// Try fallback HTTP header provided token. | ||||||
| 		token = c.GetHeader(AccessTokenHeader) | 		token = c.GetHeader(AccessTokenHeader) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var account *gtsmodel.Account |  | ||||||
| 	if token != "" { | 	if token != "" { | ||||||
| 		// Check the explicit token | 		// Token was provided, use it to authorize stream. | ||||||
| 		var errWithCode gtserror.WithCode |  | ||||||
| 		account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token) | 		account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token) | ||||||
| 		if errWithCode != nil { |  | ||||||
| 			apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	} else { | 	} else { | ||||||
| 		// If no explicit token was provided, try regular oauth | 		// No explicit token was provided: | ||||||
| 		auth, errStr := oauth.Authed(c, true, true, true, true) | 		// try regular oauth as a last resort. | ||||||
| 		if errStr != nil { | 		account, errWithCode = func() (*gtsmodel.Account, gtserror.WithCode) { | ||||||
| 			err := gtserror.NewErrorUnauthorized(errStr, errStr.Error()) | 			authed, err := oauth.Authed(c, true, true, true, true) | ||||||
| 			apiutil.ErrorHandler(c, err, m.processor.InstanceGetV1) | 			if err != nil { | ||||||
| 			return | 				return nil, gtserror.NewErrorUnauthorized(err, err.Error()) | ||||||
| 		} | 			} | ||||||
| 		account = auth.Account | 
 | ||||||
|  | 			return authed.Account, nil | ||||||
|  | 		}() | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Get the initial stream type, if there is one. | 	if errWithCode != nil { | ||||||
| 	// By appending other query params to the streamType, | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
| 	// we can allow for streaming for specific list IDs | 		return | ||||||
| 	// or hashtags. | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get the initial requested stream type, if there is one. | ||||||
| 	streamType := c.Query(StreamQueryKey) | 	streamType := c.Query(StreamQueryKey) | ||||||
|  | 
 | ||||||
|  | 	// By appending other query params to the streamType, we | ||||||
|  | 	// can allow streaming for specific list IDs or hashtags. | ||||||
|  | 	// The streamType in this case will end up looking like | ||||||
|  | 	// `hashtag:example` or `list:01H3YF48G8B7KTPQFS8D2QBVG8`. | ||||||
| 	if list := c.Query(StreamListKey); list != "" { | 	if list := c.Query(StreamListKey); list != "" { | ||||||
| 		streamType += ":" + list | 		streamType += ":" + list | ||||||
| 	} else if tag := c.Query(StreamTagKey); tag != "" { | 	} else if tag := c.Query(StreamTagKey); tag != "" { | ||||||
| 		streamType += ":" + tag | 		streamType += ":" + tag | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType) | 	// Open a stream with the processor; this lets processor | ||||||
|  | 	// functions pass messages into a channel, which we can | ||||||
|  | 	// then read from and put into a websockets connection. | ||||||
|  | 	stream, errWithCode := m.processor.Stream().Open( | ||||||
|  | 		c.Request.Context(), | ||||||
|  | 		account, | ||||||
|  | 		streamType, | ||||||
|  | 	) | ||||||
| 	if errWithCode != nil { | 	if errWithCode != nil { | ||||||
| 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	l := log.WithContext(c.Request.Context()). | 	l := log. | ||||||
|  | 		WithContext(c.Request.Context()). | ||||||
| 		WithFields(kv.Fields{ | 		WithFields(kv.Fields{ | ||||||
| 			{"account", account.Username}, | 			{"username", account.Username}, | ||||||
| 			{"streamID", stream.ID}, | 			{"streamID", stream.ID}, | ||||||
| 			{"streamType", streamType}, |  | ||||||
| 		}...) | 		}...) | ||||||
| 
 | 
 | ||||||
| 	// Upgrade the incoming HTTP request, which hijacks the underlying | 	// Upgrade the incoming HTTP request. This hijacks the | ||||||
| 	// connection and reuses it for the websocket (non-http) protocol. | 	// underlying connection and reuses it for the websocket | ||||||
|  | 	// (non-http) protocol. | ||||||
|  | 	// | ||||||
|  | 	// If the upgrade fails, then Upgrade replies to the client | ||||||
|  | 	// with an HTTP error response. | ||||||
| 	wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil) | 	wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		l.Errorf("error upgrading websocket connection: %v", err) | 		l.Errorf("error upgrading websocket connection: %v", err) | ||||||
|  | @ -210,125 +228,208 @@ func (m *Module) StreamGETHandler(c *gin.Context) { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	l.Info("opened websocket connection") | ||||||
|  | 
 | ||||||
|  | 	// We perform the main websocket rw loops in a separate | ||||||
|  | 	// goroutine in order to let the upgrade handler return. | ||||||
|  | 	// This prevents the upgrade handler from holding open any | ||||||
|  | 	// throttle / rate-limit request tokens which could become | ||||||
|  | 	// problematic on instances with multiple users. | ||||||
|  | 	go m.handleWSConn(account.Username, wsConn, stream) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // handleWSConn handles a two-way websocket streaming connection. | ||||||
|  | // It will both read messages from the connection, and push messages | ||||||
|  | // into the connection. If any errors are encountered while reading | ||||||
|  | // or writing (including expected errors like clients leaving), the | ||||||
|  | // connection will be closed. | ||||||
|  | func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *streampkg.Stream) { | ||||||
|  | 	// Create new context for the lifetime of this connection. | ||||||
|  | 	ctx, cancel := context.WithCancel(context.Background()) | ||||||
|  | 
 | ||||||
|  | 	l := log. | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithFields(kv.Fields{ | ||||||
|  | 			{"username", username}, | ||||||
|  | 			{"streamID", stream.ID}, | ||||||
|  | 		}...) | ||||||
|  | 
 | ||||||
|  | 	// Create ticker to send keepalive pings | ||||||
|  | 	pinger := time.NewTicker(m.dTicker) | ||||||
|  | 
 | ||||||
|  | 	// Read messages coming from the Websocket client connection into the server. | ||||||
| 	go func() { | 	go func() { | ||||||
| 		// We perform the main websocket send loop in a separate | 		defer cancel() | ||||||
| 		// goroutine in order to let the upgrade handler return. | 		m.readFromWSConn(ctx, username, wsConn, stream) | ||||||
| 		// This prevents the upgrade handler from holding open any | 	}() | ||||||
| 		// throttle / rate-limit request tokens which could become |  | ||||||
| 		// problematic on instances with multiple users. |  | ||||||
| 		l.Info("opened websocket connection") |  | ||||||
| 		defer l.Info("closed websocket connection") |  | ||||||
| 
 | 
 | ||||||
| 		// Create new context for lifetime of the connection | 	// Write messages coming from the processor into the Websocket client connection. | ||||||
| 		ctx, cncl := context.WithCancel(context.Background()) | 	go func() { | ||||||
|  | 		defer cancel() | ||||||
|  | 		m.writeToWSConn(ctx, username, wsConn, stream, pinger) | ||||||
|  | 	}() | ||||||
| 
 | 
 | ||||||
| 		// Create ticker to send alive pings | 	// Wait for either the read or write functions to close, to indicate | ||||||
| 		pinger := time.NewTicker(m.dTicker) | 	// that the client has left, or something else has gone wrong. | ||||||
|  | 	<-ctx.Done() | ||||||
| 
 | 
 | ||||||
| 		defer func() { | 	// Tidy up underlying websocket connection. | ||||||
| 			// Signal done | 	if err := wsConn.Close(); err != nil { | ||||||
| 			cncl() | 		l.Errorf("error closing websocket connection: %v", err) | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 			// Close websocket conn | 	// Close processor channel so the processor knows | ||||||
| 			_ = wsConn.Close() | 	// not to send any more messages to this stream. | ||||||
|  | 	close(stream.Hangup) | ||||||
| 
 | 
 | ||||||
| 			// Close processor stream | 	// Stop ping ticker (tiny resource saving). | ||||||
| 			close(stream.Hangup) | 	pinger.Stop() | ||||||
| 
 | 
 | ||||||
| 			// Stop ping ticker | 	l.Info("closed websocket connection") | ||||||
| 			pinger.Stop() | } | ||||||
| 		}() |  | ||||||
| 
 | 
 | ||||||
| 		go func() { | // readFromWSConn reads control messages coming in from the given | ||||||
| 			// Signal done | // websockets connection, and modifies the subscription StreamTypes | ||||||
| 			defer cncl() | // of the given stream accordingly after acquiring a lock on it. | ||||||
|  | // | ||||||
|  | // This is a blocking function; will return only on read error or | ||||||
|  | // if the given context is canceled. | ||||||
|  | func (m *Module) readFromWSConn( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	username string, | ||||||
|  | 	wsConn *websocket.Conn, | ||||||
|  | 	stream *streampkg.Stream, | ||||||
|  | ) { | ||||||
|  | 	l := log. | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithFields(kv.Fields{ | ||||||
|  | 			{"username", username}, | ||||||
|  | 			{"streamID", stream.ID}, | ||||||
|  | 		}...) | ||||||
| 
 | 
 | ||||||
| 			for { | readLoop: | ||||||
| 				// We have to listen for received websocket messages in | 	for { | ||||||
| 				// order to trigger the underlying wsConn.PingHandler(). | 		select { | ||||||
| 				// | 		case <-ctx.Done(): | ||||||
| 				// Read JSON objects from the client and act on them | 			// Connection closed. | ||||||
| 				var msg map[string]string | 			break readLoop | ||||||
| 				err := wsConn.ReadJSON(&msg) |  | ||||||
| 				if err != nil { |  | ||||||
| 					if ctx.Err() == nil { |  | ||||||
| 						// Only log error if the connection was not closed |  | ||||||
| 						// by us. Uncanceled context indicates this is the case. |  | ||||||
| 						l.Errorf("error reading from websocket: %v", err) |  | ||||||
| 					} |  | ||||||
| 					return |  | ||||||
| 				} |  | ||||||
| 				l.Tracef("received message from websocket: %v", msg) |  | ||||||
| 
 | 
 | ||||||
| 				// If the message contains 'stream' and 'type' fields, we can | 		default: | ||||||
| 				// update the set of timelines that are subscribed for events. | 			// Read JSON objects from the client and act on them. | ||||||
| 				updateType, ok := msg["type"] | 			var msg map[string]string | ||||||
| 				if !ok { | 			if err := wsConn.ReadJSON(&msg); err != nil { | ||||||
| 					l.Warn("'type' field not provided") | 				// Only log an error if something weird happened. | ||||||
| 					continue | 				// See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7 | ||||||
|  | 				if websocket.IsUnexpectedCloseError(err, []int{ | ||||||
|  | 					websocket.CloseNormalClosure, | ||||||
|  | 					websocket.CloseGoingAway, | ||||||
|  | 					websocket.CloseNoStatusReceived, | ||||||
|  | 				}...) { | ||||||
|  | 					l.Errorf("error reading from websocket: %v", err) | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				updateStream, ok := msg["stream"] | 				// The connection is gone; no | ||||||
| 				if !ok { | 				// further streaming possible. | ||||||
| 					l.Warn("'stream' field not provided") | 				break readLoop | ||||||
| 					continue |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				// Ignore if the updateStreamType is unknown (or missing), |  | ||||||
| 				// so a bad client can't cause extra memory allocations |  | ||||||
| 				if !slices.Contains(streampkg.AllStatusTimelines, updateStream) { |  | ||||||
| 					l.Warnf("unknown 'stream' field: %v", msg) |  | ||||||
| 					continue |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				updateList, ok := msg["list"] |  | ||||||
| 				if ok { |  | ||||||
| 					updateStream += ":" + updateList |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				switch updateType { |  | ||||||
| 				case "subscribe": |  | ||||||
| 					stream.Lock() |  | ||||||
| 					stream.StreamTypes[updateStream] = true |  | ||||||
| 					stream.Unlock() |  | ||||||
| 				case "unsubscribe": |  | ||||||
| 					stream.Lock() |  | ||||||
| 					delete(stream.StreamTypes, updateStream) |  | ||||||
| 					stream.Unlock() |  | ||||||
| 				default: |  | ||||||
| 					l.Warnf("invalid 'type' field: %v", msg) |  | ||||||
| 				} |  | ||||||
| 			} | 			} | ||||||
| 		}() |  | ||||||
| 
 | 
 | ||||||
| 		for { | 			// Messages *from* the WS connection are infrequent | ||||||
| 			select { | 			// and usually interesting, so log this at info. | ||||||
| 			// Connection closed | 			l.Infof("received message from websocket: %v", msg) | ||||||
| 			case <-ctx.Done(): |  | ||||||
| 				return |  | ||||||
| 
 | 
 | ||||||
| 			// Received next stream message | 			// If the message contains 'stream' and 'type' fields, we can | ||||||
| 			case msg := <-stream.Messages: | 			// update the set of timelines that are subscribed for events. | ||||||
| 				l.Tracef("sending message to websocket: %+v", msg) | 			updateType, ok := msg["type"] | ||||||
| 				if err := wsConn.WriteJSON(msg); err != nil { | 			if !ok { | ||||||
| 					l.Debugf("error writing json to websocket: %v", err) | 				l.Warn("'type' field not provided") | ||||||
| 					return | 				continue | ||||||
| 				} | 			} | ||||||
| 
 | 
 | ||||||
| 				// Reset on each successful send. | 			updateStream, ok := msg["stream"] | ||||||
| 				pinger.Reset(m.dTicker) | 			if !ok { | ||||||
|  | 				l.Warn("'stream' field not provided") | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
| 
 | 
 | ||||||
| 			// Send keep-alive "ping" | 			// Ignore if the updateStreamType is unknown (or missing), | ||||||
| 			case <-pinger.C: | 			// so a bad client can't cause extra memory allocations | ||||||
| 				l.Trace("pinging websocket ...") | 			if !slices.Contains(streampkg.AllStatusTimelines, updateStream) { | ||||||
| 				if err := wsConn.WriteMessage( | 				l.Warnf("unknown 'stream' field: %v", msg) | ||||||
| 					websocket.PingMessage, | 				continue | ||||||
| 					[]byte{}, | 			} | ||||||
| 				); err != nil { | 
 | ||||||
| 					l.Debugf("error writing ping to websocket: %v", err) | 			updateList, ok := msg["list"] | ||||||
| 					return | 			if ok { | ||||||
| 				} | 				updateStream += ":" + updateList | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			switch updateType { | ||||||
|  | 			case "subscribe": | ||||||
|  | 				stream.Lock() | ||||||
|  | 				stream.StreamTypes[updateStream] = true | ||||||
|  | 				stream.Unlock() | ||||||
|  | 			case "unsubscribe": | ||||||
|  | 				stream.Lock() | ||||||
|  | 				delete(stream.StreamTypes, updateStream) | ||||||
|  | 				stream.Unlock() | ||||||
|  | 			default: | ||||||
|  | 				l.Warnf("invalid 'type' field: %v", msg) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	}() | 	} | ||||||
|  | 
 | ||||||
|  | 	l.Debug("finished reading from websocket connection") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // writeToWSConn receives messages coming from the processor via the | ||||||
|  | // given stream, and writes them into the given websockets connection. | ||||||
|  | // This function also handles sending ping messages into the websockets | ||||||
|  | // connection to keep it alive when no other activity occurs. | ||||||
|  | // | ||||||
|  | // This is a blocking function; will return only on write error or | ||||||
|  | // if the given context is canceled. | ||||||
|  | func (m *Module) writeToWSConn( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	username string, | ||||||
|  | 	wsConn *websocket.Conn, | ||||||
|  | 	stream *streampkg.Stream, | ||||||
|  | 	pinger *time.Ticker, | ||||||
|  | ) { | ||||||
|  | 	l := log. | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithFields(kv.Fields{ | ||||||
|  | 			{"username", username}, | ||||||
|  | 			{"streamID", stream.ID}, | ||||||
|  | 		}...) | ||||||
|  | 
 | ||||||
|  | writeLoop: | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case <-ctx.Done(): | ||||||
|  | 			// Connection closed. | ||||||
|  | 			break writeLoop | ||||||
|  | 
 | ||||||
|  | 		case msg := <-stream.Messages: | ||||||
|  | 			// Received a new message from the processor. | ||||||
|  | 			l.Tracef("writing message to websocket: %+v", msg) | ||||||
|  | 			if err := wsConn.WriteJSON(msg); err != nil { | ||||||
|  | 				l.Debugf("error writing json to websocket: %v", err) | ||||||
|  | 				break writeLoop | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Reset pinger on successful send, since | ||||||
|  | 			// we know the connection is still there. | ||||||
|  | 			pinger.Reset(m.dTicker) | ||||||
|  | 
 | ||||||
|  | 		case <-pinger.C: | ||||||
|  | 			// Time to send a keep-alive "ping". | ||||||
|  | 			l.Trace("writing ping control message to websocket") | ||||||
|  | 			if err := wsConn.WriteControl(websocket.PingMessage, nil, time.Time{}); err != nil { | ||||||
|  | 				l.Debugf("error writing ping to websocket: %v", err) | ||||||
|  | 				break writeLoop | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	l.Debug("finished writing to websocket connection") | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -42,15 +42,18 @@ type Module struct { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func New(processor *processing.Processor, dTicker time.Duration, wsBuf int) *Module { | func New(processor *processing.Processor, dTicker time.Duration, wsBuf int) *Module { | ||||||
|  | 	// We expect CORS requests for websockets, | ||||||
|  | 	// (via eg., semaphore.social) so be lenient. | ||||||
|  | 	// TODO: make this customizable? | ||||||
|  | 	checkOrigin := func(r *http.Request) bool { return true } | ||||||
|  | 
 | ||||||
| 	return &Module{ | 	return &Module{ | ||||||
| 		processor: processor, | 		processor: processor, | ||||||
| 		dTicker:   dTicker, | 		dTicker:   dTicker, | ||||||
| 		wsUpgrade: websocket.Upgrader{ | 		wsUpgrade: websocket.Upgrader{ | ||||||
| 			ReadBufferSize:  wsBuf, // we don't expect reads | 			ReadBufferSize:  wsBuf, | ||||||
| 			WriteBufferSize: wsBuf, | 			WriteBufferSize: wsBuf, | ||||||
| 
 | 			CheckOrigin:     checkOrigin, | ||||||
| 			// we expect cors requests (via eg., semaphore.social) so be lenient |  | ||||||
| 			CheckOrigin: func(r *http.Request) bool { return true }, |  | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue