mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 05:52:25 -05:00 
			
		
		
		
	[chore/bugfix] Break Websockets logic into smaller read/write functions, don't log expected errors (#1932)
* [chore/bugfix] Break Websockets logic into smaller read/write functions, don't log expected errors * tweak * tidy up, use control message
This commit is contained in:
		
					parent
					
						
							
								ba0bc06b8c
							
						
					
				
			
			
				commit
				
					
						3d16962173
					
				
			
		
					 2 changed files with 237 additions and 133 deletions
				
			
		|  | @ -149,60 +149,78 @@ import ( | |||
| //		'400': | ||||
| //			description: bad request | ||||
| func (m *Module) StreamGETHandler(c *gin.Context) { | ||||
| 	var ( | ||||
| 		account     *gtsmodel.Account | ||||
| 		errWithCode gtserror.WithCode | ||||
| 	) | ||||
| 
 | ||||
| 	// First we check for a query param provided access token | ||||
| 	// Try query param access token. | ||||
| 	token := c.Query(AccessTokenQueryKey) | ||||
| 	if token == "" { | ||||
| 		// Else we check the HTTP header provided token | ||||
| 		// Try fallback HTTP header provided token. | ||||
| 		token = c.GetHeader(AccessTokenHeader) | ||||
| 	} | ||||
| 
 | ||||
| 	var account *gtsmodel.Account | ||||
| 	if token != "" { | ||||
| 		// Check the explicit token | ||||
| 		var errWithCode gtserror.WithCode | ||||
| 		// Token was provided, use it to authorize stream. | ||||
| 		account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token) | ||||
| 		if errWithCode != nil { | ||||
| 			apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||
| 			return | ||||
| 		} | ||||
| 	} else { | ||||
| 		// If no explicit token was provided, try regular oauth | ||||
| 		auth, errStr := oauth.Authed(c, true, true, true, true) | ||||
| 		if errStr != nil { | ||||
| 			err := gtserror.NewErrorUnauthorized(errStr, errStr.Error()) | ||||
| 			apiutil.ErrorHandler(c, err, m.processor.InstanceGetV1) | ||||
| 			return | ||||
| 		} | ||||
| 		account = auth.Account | ||||
| 		// No explicit token was provided: | ||||
| 		// try regular oauth as a last resort. | ||||
| 		account, errWithCode = func() (*gtsmodel.Account, gtserror.WithCode) { | ||||
| 			authed, err := oauth.Authed(c, true, true, true, true) | ||||
| 			if err != nil { | ||||
| 				return nil, gtserror.NewErrorUnauthorized(err, err.Error()) | ||||
| 			} | ||||
| 
 | ||||
| 			return authed.Account, nil | ||||
| 		}() | ||||
| 	} | ||||
| 
 | ||||
| 	// Get the initial stream type, if there is one. | ||||
| 	// By appending other query params to the streamType, | ||||
| 	// we can allow for streaming for specific list IDs | ||||
| 	// or hashtags. | ||||
| 	if errWithCode != nil { | ||||
| 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// Get the initial requested stream type, if there is one. | ||||
| 	streamType := c.Query(StreamQueryKey) | ||||
| 
 | ||||
| 	// By appending other query params to the streamType, we | ||||
| 	// can allow streaming for specific list IDs or hashtags. | ||||
| 	// The streamType in this case will end up looking like | ||||
| 	// `hashtag:example` or `list:01H3YF48G8B7KTPQFS8D2QBVG8`. | ||||
| 	if list := c.Query(StreamListKey); list != "" { | ||||
| 		streamType += ":" + list | ||||
| 	} else if tag := c.Query(StreamTagKey); tag != "" { | ||||
| 		streamType += ":" + tag | ||||
| 	} | ||||
| 
 | ||||
| 	stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType) | ||||
| 	// Open a stream with the processor; this lets processor | ||||
| 	// functions pass messages into a channel, which we can | ||||
| 	// then read from and put into a websockets connection. | ||||
| 	stream, errWithCode := m.processor.Stream().Open( | ||||
| 		c.Request.Context(), | ||||
| 		account, | ||||
| 		streamType, | ||||
| 	) | ||||
| 	if errWithCode != nil { | ||||
| 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	l := log.WithContext(c.Request.Context()). | ||||
| 	l := log. | ||||
| 		WithContext(c.Request.Context()). | ||||
| 		WithFields(kv.Fields{ | ||||
| 			{"account", account.Username}, | ||||
| 			{"username", account.Username}, | ||||
| 			{"streamID", stream.ID}, | ||||
| 			{"streamType", streamType}, | ||||
| 		}...) | ||||
| 
 | ||||
| 	// Upgrade the incoming HTTP request, which hijacks the underlying | ||||
| 	// connection and reuses it for the websocket (non-http) protocol. | ||||
| 	// Upgrade the incoming HTTP request. This hijacks the | ||||
| 	// underlying connection and reuses it for the websocket | ||||
| 	// (non-http) protocol. | ||||
| 	// | ||||
| 	// If the upgrade fails, then Upgrade replies to the client | ||||
| 	// with an HTTP error response. | ||||
| 	wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil) | ||||
| 	if err != nil { | ||||
| 		l.Errorf("error upgrading websocket connection: %v", err) | ||||
|  | @ -210,125 +228,208 @@ func (m *Module) StreamGETHandler(c *gin.Context) { | |||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	l.Info("opened websocket connection") | ||||
| 
 | ||||
| 	// We perform the main websocket rw loops in a separate | ||||
| 	// goroutine in order to let the upgrade handler return. | ||||
| 	// This prevents the upgrade handler from holding open any | ||||
| 	// throttle / rate-limit request tokens which could become | ||||
| 	// problematic on instances with multiple users. | ||||
| 	go m.handleWSConn(account.Username, wsConn, stream) | ||||
| } | ||||
| 
 | ||||
| // handleWSConn handles a two-way websocket streaming connection. | ||||
| // It will both read messages from the connection, and push messages | ||||
| // into the connection. If any errors are encountered while reading | ||||
| // or writing (including expected errors like clients leaving), the | ||||
| // connection will be closed. | ||||
| func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *streampkg.Stream) { | ||||
| 	// Create new context for the lifetime of this connection. | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 
 | ||||
| 	l := log. | ||||
| 		WithContext(ctx). | ||||
| 		WithFields(kv.Fields{ | ||||
| 			{"username", username}, | ||||
| 			{"streamID", stream.ID}, | ||||
| 		}...) | ||||
| 
 | ||||
| 	// Create ticker to send keepalive pings | ||||
| 	pinger := time.NewTicker(m.dTicker) | ||||
| 
 | ||||
| 	// Read messages coming from the Websocket client connection into the server. | ||||
| 	go func() { | ||||
| 		// We perform the main websocket send loop in a separate | ||||
| 		// goroutine in order to let the upgrade handler return. | ||||
| 		// This prevents the upgrade handler from holding open any | ||||
| 		// throttle / rate-limit request tokens which could become | ||||
| 		// problematic on instances with multiple users. | ||||
| 		l.Info("opened websocket connection") | ||||
| 		defer l.Info("closed websocket connection") | ||||
| 		defer cancel() | ||||
| 		m.readFromWSConn(ctx, username, wsConn, stream) | ||||
| 	}() | ||||
| 
 | ||||
| 		// Create new context for lifetime of the connection | ||||
| 		ctx, cncl := context.WithCancel(context.Background()) | ||||
| 	// Write messages coming from the processor into the Websocket client connection. | ||||
| 	go func() { | ||||
| 		defer cancel() | ||||
| 		m.writeToWSConn(ctx, username, wsConn, stream, pinger) | ||||
| 	}() | ||||
| 
 | ||||
| 		// Create ticker to send alive pings | ||||
| 		pinger := time.NewTicker(m.dTicker) | ||||
| 	// Wait for either the read or write functions to close, to indicate | ||||
| 	// that the client has left, or something else has gone wrong. | ||||
| 	<-ctx.Done() | ||||
| 
 | ||||
| 		defer func() { | ||||
| 			// Signal done | ||||
| 			cncl() | ||||
| 	// Tidy up underlying websocket connection. | ||||
| 	if err := wsConn.Close(); err != nil { | ||||
| 		l.Errorf("error closing websocket connection: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 			// Close websocket conn | ||||
| 			_ = wsConn.Close() | ||||
| 	// Close processor channel so the processor knows | ||||
| 	// not to send any more messages to this stream. | ||||
| 	close(stream.Hangup) | ||||
| 
 | ||||
| 			// Close processor stream | ||||
| 			close(stream.Hangup) | ||||
| 	// Stop ping ticker (tiny resource saving). | ||||
| 	pinger.Stop() | ||||
| 
 | ||||
| 			// Stop ping ticker | ||||
| 			pinger.Stop() | ||||
| 		}() | ||||
| 	l.Info("closed websocket connection") | ||||
| } | ||||
| 
 | ||||
| 		go func() { | ||||
| 			// Signal done | ||||
| 			defer cncl() | ||||
| // readFromWSConn reads control messages coming in from the given | ||||
| // websockets connection, and modifies the subscription StreamTypes | ||||
| // of the given stream accordingly after acquiring a lock on it. | ||||
| // | ||||
| // This is a blocking function; will return only on read error or | ||||
| // if the given context is canceled. | ||||
| func (m *Module) readFromWSConn( | ||||
| 	ctx context.Context, | ||||
| 	username string, | ||||
| 	wsConn *websocket.Conn, | ||||
| 	stream *streampkg.Stream, | ||||
| ) { | ||||
| 	l := log. | ||||
| 		WithContext(ctx). | ||||
| 		WithFields(kv.Fields{ | ||||
| 			{"username", username}, | ||||
| 			{"streamID", stream.ID}, | ||||
| 		}...) | ||||
| 
 | ||||
| 			for { | ||||
| 				// We have to listen for received websocket messages in | ||||
| 				// order to trigger the underlying wsConn.PingHandler(). | ||||
| 				// | ||||
| 				// Read JSON objects from the client and act on them | ||||
| 				var msg map[string]string | ||||
| 				err := wsConn.ReadJSON(&msg) | ||||
| 				if err != nil { | ||||
| 					if ctx.Err() == nil { | ||||
| 						// Only log error if the connection was not closed | ||||
| 						// by us. Uncanceled context indicates this is the case. | ||||
| 						l.Errorf("error reading from websocket: %v", err) | ||||
| 					} | ||||
| 					return | ||||
| 				} | ||||
| 				l.Tracef("received message from websocket: %v", msg) | ||||
| readLoop: | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-ctx.Done(): | ||||
| 			// Connection closed. | ||||
| 			break readLoop | ||||
| 
 | ||||
| 				// If the message contains 'stream' and 'type' fields, we can | ||||
| 				// update the set of timelines that are subscribed for events. | ||||
| 				updateType, ok := msg["type"] | ||||
| 				if !ok { | ||||
| 					l.Warn("'type' field not provided") | ||||
| 					continue | ||||
| 		default: | ||||
| 			// Read JSON objects from the client and act on them. | ||||
| 			var msg map[string]string | ||||
| 			if err := wsConn.ReadJSON(&msg); err != nil { | ||||
| 				// Only log an error if something weird happened. | ||||
| 				// See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7 | ||||
| 				if websocket.IsUnexpectedCloseError(err, []int{ | ||||
| 					websocket.CloseNormalClosure, | ||||
| 					websocket.CloseGoingAway, | ||||
| 					websocket.CloseNoStatusReceived, | ||||
| 				}...) { | ||||
| 					l.Errorf("error reading from websocket: %v", err) | ||||
| 				} | ||||
| 
 | ||||
| 				updateStream, ok := msg["stream"] | ||||
| 				if !ok { | ||||
| 					l.Warn("'stream' field not provided") | ||||
| 					continue | ||||
| 				} | ||||
| 
 | ||||
| 				// Ignore if the updateStreamType is unknown (or missing), | ||||
| 				// so a bad client can't cause extra memory allocations | ||||
| 				if !slices.Contains(streampkg.AllStatusTimelines, updateStream) { | ||||
| 					l.Warnf("unknown 'stream' field: %v", msg) | ||||
| 					continue | ||||
| 				} | ||||
| 
 | ||||
| 				updateList, ok := msg["list"] | ||||
| 				if ok { | ||||
| 					updateStream += ":" + updateList | ||||
| 				} | ||||
| 
 | ||||
| 				switch updateType { | ||||
| 				case "subscribe": | ||||
| 					stream.Lock() | ||||
| 					stream.StreamTypes[updateStream] = true | ||||
| 					stream.Unlock() | ||||
| 				case "unsubscribe": | ||||
| 					stream.Lock() | ||||
| 					delete(stream.StreamTypes, updateStream) | ||||
| 					stream.Unlock() | ||||
| 				default: | ||||
| 					l.Warnf("invalid 'type' field: %v", msg) | ||||
| 				} | ||||
| 				// The connection is gone; no | ||||
| 				// further streaming possible. | ||||
| 				break readLoop | ||||
| 			} | ||||
| 		}() | ||||
| 
 | ||||
| 		for { | ||||
| 			select { | ||||
| 			// Connection closed | ||||
| 			case <-ctx.Done(): | ||||
| 				return | ||||
| 			// Messages *from* the WS connection are infrequent | ||||
| 			// and usually interesting, so log this at info. | ||||
| 			l.Infof("received message from websocket: %v", msg) | ||||
| 
 | ||||
| 			// Received next stream message | ||||
| 			case msg := <-stream.Messages: | ||||
| 				l.Tracef("sending message to websocket: %+v", msg) | ||||
| 				if err := wsConn.WriteJSON(msg); err != nil { | ||||
| 					l.Debugf("error writing json to websocket: %v", err) | ||||
| 					return | ||||
| 				} | ||||
| 			// If the message contains 'stream' and 'type' fields, we can | ||||
| 			// update the set of timelines that are subscribed for events. | ||||
| 			updateType, ok := msg["type"] | ||||
| 			if !ok { | ||||
| 				l.Warn("'type' field not provided") | ||||
| 				continue | ||||
| 			} | ||||
| 
 | ||||
| 				// Reset on each successful send. | ||||
| 				pinger.Reset(m.dTicker) | ||||
| 			updateStream, ok := msg["stream"] | ||||
| 			if !ok { | ||||
| 				l.Warn("'stream' field not provided") | ||||
| 				continue | ||||
| 			} | ||||
| 
 | ||||
| 			// Send keep-alive "ping" | ||||
| 			case <-pinger.C: | ||||
| 				l.Trace("pinging websocket ...") | ||||
| 				if err := wsConn.WriteMessage( | ||||
| 					websocket.PingMessage, | ||||
| 					[]byte{}, | ||||
| 				); err != nil { | ||||
| 					l.Debugf("error writing ping to websocket: %v", err) | ||||
| 					return | ||||
| 				} | ||||
| 			// Ignore if the updateStreamType is unknown (or missing), | ||||
| 			// so a bad client can't cause extra memory allocations | ||||
| 			if !slices.Contains(streampkg.AllStatusTimelines, updateStream) { | ||||
| 				l.Warnf("unknown 'stream' field: %v", msg) | ||||
| 				continue | ||||
| 			} | ||||
| 
 | ||||
| 			updateList, ok := msg["list"] | ||||
| 			if ok { | ||||
| 				updateStream += ":" + updateList | ||||
| 			} | ||||
| 
 | ||||
| 			switch updateType { | ||||
| 			case "subscribe": | ||||
| 				stream.Lock() | ||||
| 				stream.StreamTypes[updateStream] = true | ||||
| 				stream.Unlock() | ||||
| 			case "unsubscribe": | ||||
| 				stream.Lock() | ||||
| 				delete(stream.StreamTypes, updateStream) | ||||
| 				stream.Unlock() | ||||
| 			default: | ||||
| 				l.Warnf("invalid 'type' field: %v", msg) | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| 	} | ||||
| 
 | ||||
| 	l.Debug("finished reading from websocket connection") | ||||
| } | ||||
| 
 | ||||
| // writeToWSConn receives messages coming from the processor via the | ||||
| // given stream, and writes them into the given websockets connection. | ||||
| // This function also handles sending ping messages into the websockets | ||||
| // connection to keep it alive when no other activity occurs. | ||||
| // | ||||
| // This is a blocking function; will return only on write error or | ||||
| // if the given context is canceled. | ||||
| func (m *Module) writeToWSConn( | ||||
| 	ctx context.Context, | ||||
| 	username string, | ||||
| 	wsConn *websocket.Conn, | ||||
| 	stream *streampkg.Stream, | ||||
| 	pinger *time.Ticker, | ||||
| ) { | ||||
| 	l := log. | ||||
| 		WithContext(ctx). | ||||
| 		WithFields(kv.Fields{ | ||||
| 			{"username", username}, | ||||
| 			{"streamID", stream.ID}, | ||||
| 		}...) | ||||
| 
 | ||||
| writeLoop: | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-ctx.Done(): | ||||
| 			// Connection closed. | ||||
| 			break writeLoop | ||||
| 
 | ||||
| 		case msg := <-stream.Messages: | ||||
| 			// Received a new message from the processor. | ||||
| 			l.Tracef("writing message to websocket: %+v", msg) | ||||
| 			if err := wsConn.WriteJSON(msg); err != nil { | ||||
| 				l.Debugf("error writing json to websocket: %v", err) | ||||
| 				break writeLoop | ||||
| 			} | ||||
| 
 | ||||
| 			// Reset pinger on successful send, since | ||||
| 			// we know the connection is still there. | ||||
| 			pinger.Reset(m.dTicker) | ||||
| 
 | ||||
| 		case <-pinger.C: | ||||
| 			// Time to send a keep-alive "ping". | ||||
| 			l.Trace("writing ping control message to websocket") | ||||
| 			if err := wsConn.WriteControl(websocket.PingMessage, nil, time.Time{}); err != nil { | ||||
| 				l.Debugf("error writing ping to websocket: %v", err) | ||||
| 				break writeLoop | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	l.Debug("finished writing to websocket connection") | ||||
| } | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue