mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-30 23:52:26 -05:00 
			
		
		
		
	
		
			
	
	
		
			352 lines
		
	
	
	
		
			7.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			352 lines
		
	
	
	
		
			7.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
|  | // Copyright 2013 The Go Authors. All rights reserved. | ||
|  | // Use of this source code is governed by a BSD-style | ||
|  | // license that can be found in the LICENSE file. | ||
|  | 
 | ||
|  | package ssh | ||
|  | 
 | ||
|  | import ( | ||
|  | 	"encoding/binary" | ||
|  | 	"fmt" | ||
|  | 	"io" | ||
|  | 	"log" | ||
|  | 	"sync" | ||
|  | 	"sync/atomic" | ||
|  | ) | ||
|  | 
 | ||
|  | // debugMux, if set, causes messages in the connection protocol to be | ||
|  | // logged. | ||
|  | const debugMux = false | ||
|  | 
 | ||
|  | // chanList is a thread safe channel list. | ||
|  | type chanList struct { | ||
|  | 	// protects concurrent access to chans | ||
|  | 	sync.Mutex | ||
|  | 
 | ||
|  | 	// chans are indexed by the local id of the channel, which the | ||
|  | 	// other side should send in the PeersId field. | ||
|  | 	chans []*channel | ||
|  | 
 | ||
|  | 	// This is a debugging aid: it offsets all IDs by this | ||
|  | 	// amount. This helps distinguish otherwise identical | ||
|  | 	// server/client muxes | ||
|  | 	offset uint32 | ||
|  | } | ||
|  | 
 | ||
|  | // Assigns a channel ID to the given channel. | ||
|  | func (c *chanList) add(ch *channel) uint32 { | ||
|  | 	c.Lock() | ||
|  | 	defer c.Unlock() | ||
|  | 	for i := range c.chans { | ||
|  | 		if c.chans[i] == nil { | ||
|  | 			c.chans[i] = ch | ||
|  | 			return uint32(i) + c.offset | ||
|  | 		} | ||
|  | 	} | ||
|  | 	c.chans = append(c.chans, ch) | ||
|  | 	return uint32(len(c.chans)-1) + c.offset | ||
|  | } | ||
|  | 
 | ||
|  | // getChan returns the channel for the given ID. | ||
|  | func (c *chanList) getChan(id uint32) *channel { | ||
|  | 	id -= c.offset | ||
|  | 
 | ||
|  | 	c.Lock() | ||
|  | 	defer c.Unlock() | ||
|  | 	if id < uint32(len(c.chans)) { | ||
|  | 		return c.chans[id] | ||
|  | 	} | ||
|  | 	return nil | ||
|  | } | ||
|  | 
 | ||
|  | func (c *chanList) remove(id uint32) { | ||
|  | 	id -= c.offset | ||
|  | 	c.Lock() | ||
|  | 	if id < uint32(len(c.chans)) { | ||
|  | 		c.chans[id] = nil | ||
|  | 	} | ||
|  | 	c.Unlock() | ||
|  | } | ||
|  | 
 | ||
|  | // dropAll forgets all channels it knows, returning them in a slice. | ||
|  | func (c *chanList) dropAll() []*channel { | ||
|  | 	c.Lock() | ||
|  | 	defer c.Unlock() | ||
|  | 	var r []*channel | ||
|  | 
 | ||
|  | 	for _, ch := range c.chans { | ||
|  | 		if ch == nil { | ||
|  | 			continue | ||
|  | 		} | ||
|  | 		r = append(r, ch) | ||
|  | 	} | ||
|  | 	c.chans = nil | ||
|  | 	return r | ||
|  | } | ||
|  | 
 | ||
|  | // mux represents the state for the SSH connection protocol, which | ||
|  | // multiplexes many channels onto a single packet transport. | ||
|  | type mux struct { | ||
|  | 	conn     packetConn | ||
|  | 	chanList chanList | ||
|  | 
 | ||
|  | 	incomingChannels chan NewChannel | ||
|  | 
 | ||
|  | 	globalSentMu     sync.Mutex | ||
|  | 	globalResponses  chan interface{} | ||
|  | 	incomingRequests chan *Request | ||
|  | 
 | ||
|  | 	errCond *sync.Cond | ||
|  | 	err     error | ||
|  | } | ||
|  | 
 | ||
|  | // When debugging, each new chanList instantiation has a different | ||
|  | // offset. | ||
|  | var globalOff uint32 | ||
|  | 
 | ||
|  | func (m *mux) Wait() error { | ||
|  | 	m.errCond.L.Lock() | ||
|  | 	defer m.errCond.L.Unlock() | ||
|  | 	for m.err == nil { | ||
|  | 		m.errCond.Wait() | ||
|  | 	} | ||
|  | 	return m.err | ||
|  | } | ||
|  | 
 | ||
|  | // newMux returns a mux that runs over the given connection. | ||
|  | func newMux(p packetConn) *mux { | ||
|  | 	m := &mux{ | ||
|  | 		conn:             p, | ||
|  | 		incomingChannels: make(chan NewChannel, chanSize), | ||
|  | 		globalResponses:  make(chan interface{}, 1), | ||
|  | 		incomingRequests: make(chan *Request, chanSize), | ||
|  | 		errCond:          newCond(), | ||
|  | 	} | ||
|  | 	if debugMux { | ||
|  | 		m.chanList.offset = atomic.AddUint32(&globalOff, 1) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	go m.loop() | ||
|  | 	return m | ||
|  | } | ||
|  | 
 | ||
|  | func (m *mux) sendMessage(msg interface{}) error { | ||
|  | 	p := Marshal(msg) | ||
|  | 	if debugMux { | ||
|  | 		log.Printf("send global(%d): %#v", m.chanList.offset, msg) | ||
|  | 	} | ||
|  | 	return m.conn.writePacket(p) | ||
|  | } | ||
|  | 
 | ||
|  | func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { | ||
|  | 	if wantReply { | ||
|  | 		m.globalSentMu.Lock() | ||
|  | 		defer m.globalSentMu.Unlock() | ||
|  | 	} | ||
|  | 
 | ||
|  | 	if err := m.sendMessage(globalRequestMsg{ | ||
|  | 		Type:      name, | ||
|  | 		WantReply: wantReply, | ||
|  | 		Data:      payload, | ||
|  | 	}); err != nil { | ||
|  | 		return false, nil, err | ||
|  | 	} | ||
|  | 
 | ||
|  | 	if !wantReply { | ||
|  | 		return false, nil, nil | ||
|  | 	} | ||
|  | 
 | ||
|  | 	msg, ok := <-m.globalResponses | ||
|  | 	if !ok { | ||
|  | 		return false, nil, io.EOF | ||
|  | 	} | ||
|  | 	switch msg := msg.(type) { | ||
|  | 	case *globalRequestFailureMsg: | ||
|  | 		return false, msg.Data, nil | ||
|  | 	case *globalRequestSuccessMsg: | ||
|  | 		return true, msg.Data, nil | ||
|  | 	default: | ||
|  | 		return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | // ackRequest must be called after processing a global request that | ||
|  | // has WantReply set. | ||
|  | func (m *mux) ackRequest(ok bool, data []byte) error { | ||
|  | 	if ok { | ||
|  | 		return m.sendMessage(globalRequestSuccessMsg{Data: data}) | ||
|  | 	} | ||
|  | 	return m.sendMessage(globalRequestFailureMsg{Data: data}) | ||
|  | } | ||
|  | 
 | ||
|  | func (m *mux) Close() error { | ||
|  | 	return m.conn.Close() | ||
|  | } | ||
|  | 
 | ||
|  | // loop runs the connection machine. It will process packets until an | ||
|  | // error is encountered. To synchronize on loop exit, use mux.Wait. | ||
|  | func (m *mux) loop() { | ||
|  | 	var err error | ||
|  | 	for err == nil { | ||
|  | 		err = m.onePacket() | ||
|  | 	} | ||
|  | 
 | ||
|  | 	for _, ch := range m.chanList.dropAll() { | ||
|  | 		ch.close() | ||
|  | 	} | ||
|  | 
 | ||
|  | 	close(m.incomingChannels) | ||
|  | 	close(m.incomingRequests) | ||
|  | 	close(m.globalResponses) | ||
|  | 
 | ||
|  | 	m.conn.Close() | ||
|  | 
 | ||
|  | 	m.errCond.L.Lock() | ||
|  | 	m.err = err | ||
|  | 	m.errCond.Broadcast() | ||
|  | 	m.errCond.L.Unlock() | ||
|  | 
 | ||
|  | 	if debugMux { | ||
|  | 		log.Println("loop exit", err) | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | // onePacket reads and processes one packet. | ||
|  | func (m *mux) onePacket() error { | ||
|  | 	packet, err := m.conn.readPacket() | ||
|  | 	if err != nil { | ||
|  | 		return err | ||
|  | 	} | ||
|  | 
 | ||
|  | 	if debugMux { | ||
|  | 		if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { | ||
|  | 			log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) | ||
|  | 		} else { | ||
|  | 			p, _ := decode(packet) | ||
|  | 			log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) | ||
|  | 		} | ||
|  | 	} | ||
|  | 
 | ||
|  | 	switch packet[0] { | ||
|  | 	case msgChannelOpen: | ||
|  | 		return m.handleChannelOpen(packet) | ||
|  | 	case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: | ||
|  | 		return m.handleGlobalPacket(packet) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	// assume a channel packet. | ||
|  | 	if len(packet) < 5 { | ||
|  | 		return parseError(packet[0]) | ||
|  | 	} | ||
|  | 	id := binary.BigEndian.Uint32(packet[1:]) | ||
|  | 	ch := m.chanList.getChan(id) | ||
|  | 	if ch == nil { | ||
|  | 		return m.handleUnknownChannelPacket(id, packet) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	return ch.handlePacket(packet) | ||
|  | } | ||
|  | 
 | ||
|  | func (m *mux) handleGlobalPacket(packet []byte) error { | ||
|  | 	msg, err := decode(packet) | ||
|  | 	if err != nil { | ||
|  | 		return err | ||
|  | 	} | ||
|  | 
 | ||
|  | 	switch msg := msg.(type) { | ||
|  | 	case *globalRequestMsg: | ||
|  | 		m.incomingRequests <- &Request{ | ||
|  | 			Type:      msg.Type, | ||
|  | 			WantReply: msg.WantReply, | ||
|  | 			Payload:   msg.Data, | ||
|  | 			mux:       m, | ||
|  | 		} | ||
|  | 	case *globalRequestSuccessMsg, *globalRequestFailureMsg: | ||
|  | 		m.globalResponses <- msg | ||
|  | 	default: | ||
|  | 		panic(fmt.Sprintf("not a global message %#v", msg)) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	return nil | ||
|  | } | ||
|  | 
 | ||
|  | // handleChannelOpen schedules a channel to be Accept()ed. | ||
|  | func (m *mux) handleChannelOpen(packet []byte) error { | ||
|  | 	var msg channelOpenMsg | ||
|  | 	if err := Unmarshal(packet, &msg); err != nil { | ||
|  | 		return err | ||
|  | 	} | ||
|  | 
 | ||
|  | 	if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { | ||
|  | 		failMsg := channelOpenFailureMsg{ | ||
|  | 			PeersID:  msg.PeersID, | ||
|  | 			Reason:   ConnectionFailed, | ||
|  | 			Message:  "invalid request", | ||
|  | 			Language: "en_US.UTF-8", | ||
|  | 		} | ||
|  | 		return m.sendMessage(failMsg) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) | ||
|  | 	c.remoteId = msg.PeersID | ||
|  | 	c.maxRemotePayload = msg.MaxPacketSize | ||
|  | 	c.remoteWin.add(msg.PeersWindow) | ||
|  | 	m.incomingChannels <- c | ||
|  | 	return nil | ||
|  | } | ||
|  | 
 | ||
|  | func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { | ||
|  | 	ch, err := m.openChannel(chanType, extra) | ||
|  | 	if err != nil { | ||
|  | 		return nil, nil, err | ||
|  | 	} | ||
|  | 
 | ||
|  | 	return ch, ch.incomingRequests, nil | ||
|  | } | ||
|  | 
 | ||
|  | func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { | ||
|  | 	ch := m.newChannel(chanType, channelOutbound, extra) | ||
|  | 
 | ||
|  | 	ch.maxIncomingPayload = channelMaxPacket | ||
|  | 
 | ||
|  | 	open := channelOpenMsg{ | ||
|  | 		ChanType:         chanType, | ||
|  | 		PeersWindow:      ch.myWindow, | ||
|  | 		MaxPacketSize:    ch.maxIncomingPayload, | ||
|  | 		TypeSpecificData: extra, | ||
|  | 		PeersID:          ch.localId, | ||
|  | 	} | ||
|  | 	if err := m.sendMessage(open); err != nil { | ||
|  | 		return nil, err | ||
|  | 	} | ||
|  | 
 | ||
|  | 	switch msg := (<-ch.msg).(type) { | ||
|  | 	case *channelOpenConfirmMsg: | ||
|  | 		return ch, nil | ||
|  | 	case *channelOpenFailureMsg: | ||
|  | 		return nil, &OpenChannelError{msg.Reason, msg.Message} | ||
|  | 	default: | ||
|  | 		return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | func (m *mux) handleUnknownChannelPacket(id uint32, packet []byte) error { | ||
|  | 	msg, err := decode(packet) | ||
|  | 	if err != nil { | ||
|  | 		return err | ||
|  | 	} | ||
|  | 
 | ||
|  | 	switch msg := msg.(type) { | ||
|  | 	// RFC 4254 section 5.4 says unrecognized channel requests should | ||
|  | 	// receive a failure response. | ||
|  | 	case *channelRequestMsg: | ||
|  | 		if msg.WantReply { | ||
|  | 			return m.sendMessage(channelRequestFailureMsg{ | ||
|  | 				PeersID: msg.PeersID, | ||
|  | 			}) | ||
|  | 		} | ||
|  | 		return nil | ||
|  | 	default: | ||
|  | 		return fmt.Errorf("ssh: invalid channel %d", id) | ||
|  | 	} | ||
|  | } |