| 
									
										
										
										
											2021-08-12 21:03:24 +02:00
										 |  |  | // Copyright 2013 The Go Authors. All rights reserved. | 
					
						
							|  |  |  | // Use of this source code is governed by a BSD-style | 
					
						
							|  |  |  | // license that can be found in the LICENSE file. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | package ssh | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"encoding/binary" | 
					
						
							|  |  |  | 	"fmt" | 
					
						
							|  |  |  | 	"io" | 
					
						
							|  |  |  | 	"log" | 
					
						
							|  |  |  | 	"sync" | 
					
						
							|  |  |  | 	"sync/atomic" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // debugMux, if set, causes messages in the connection protocol to be | 
					
						
							|  |  |  | // logged. | 
					
						
							|  |  |  | const debugMux = false | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // chanList is a thread safe channel list. | 
					
						
							|  |  |  | type chanList struct { | 
					
						
							|  |  |  | 	// protects concurrent access to chans | 
					
						
							|  |  |  | 	sync.Mutex | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// chans are indexed by the local id of the channel, which the | 
					
						
							|  |  |  | 	// other side should send in the PeersId field. | 
					
						
							|  |  |  | 	chans []*channel | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// This is a debugging aid: it offsets all IDs by this | 
					
						
							|  |  |  | 	// amount. This helps distinguish otherwise identical | 
					
						
							|  |  |  | 	// server/client muxes | 
					
						
							|  |  |  | 	offset uint32 | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Assigns a channel ID to the given channel. | 
					
						
							|  |  |  | func (c *chanList) add(ch *channel) uint32 { | 
					
						
							|  |  |  | 	c.Lock() | 
					
						
							|  |  |  | 	defer c.Unlock() | 
					
						
							|  |  |  | 	for i := range c.chans { | 
					
						
							|  |  |  | 		if c.chans[i] == nil { | 
					
						
							|  |  |  | 			c.chans[i] = ch | 
					
						
							|  |  |  | 			return uint32(i) + c.offset | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	c.chans = append(c.chans, ch) | 
					
						
							|  |  |  | 	return uint32(len(c.chans)-1) + c.offset | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // getChan returns the channel for the given ID. | 
					
						
							|  |  |  | func (c *chanList) getChan(id uint32) *channel { | 
					
						
							|  |  |  | 	id -= c.offset | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	c.Lock() | 
					
						
							|  |  |  | 	defer c.Unlock() | 
					
						
							|  |  |  | 	if id < uint32(len(c.chans)) { | 
					
						
							|  |  |  | 		return c.chans[id] | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (c *chanList) remove(id uint32) { | 
					
						
							|  |  |  | 	id -= c.offset | 
					
						
							|  |  |  | 	c.Lock() | 
					
						
							|  |  |  | 	if id < uint32(len(c.chans)) { | 
					
						
							|  |  |  | 		c.chans[id] = nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	c.Unlock() | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // dropAll forgets all channels it knows, returning them in a slice. | 
					
						
							|  |  |  | func (c *chanList) dropAll() []*channel { | 
					
						
							|  |  |  | 	c.Lock() | 
					
						
							|  |  |  | 	defer c.Unlock() | 
					
						
							|  |  |  | 	var r []*channel | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, ch := range c.chans { | 
					
						
							|  |  |  | 		if ch == nil { | 
					
						
							|  |  |  | 			continue | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		r = append(r, ch) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	c.chans = nil | 
					
						
							|  |  |  | 	return r | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // mux represents the state for the SSH connection protocol, which | 
					
						
							|  |  |  | // multiplexes many channels onto a single packet transport. | 
					
						
							|  |  |  | type mux struct { | 
					
						
							|  |  |  | 	conn     packetConn | 
					
						
							|  |  |  | 	chanList chanList | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	incomingChannels chan NewChannel | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	globalSentMu     sync.Mutex | 
					
						
							|  |  |  | 	globalResponses  chan interface{} | 
					
						
							|  |  |  | 	incomingRequests chan *Request | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	errCond *sync.Cond | 
					
						
							|  |  |  | 	err     error | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // When debugging, each new chanList instantiation has a different | 
					
						
							|  |  |  | // offset. | 
					
						
							|  |  |  | var globalOff uint32 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (m *mux) Wait() error { | 
					
						
							|  |  |  | 	m.errCond.L.Lock() | 
					
						
							|  |  |  | 	defer m.errCond.L.Unlock() | 
					
						
							|  |  |  | 	for m.err == nil { | 
					
						
							|  |  |  | 		m.errCond.Wait() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return m.err | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // newMux returns a mux that runs over the given connection. | 
					
						
							|  |  |  | func newMux(p packetConn) *mux { | 
					
						
							|  |  |  | 	m := &mux{ | 
					
						
							|  |  |  | 		conn:             p, | 
					
						
							|  |  |  | 		incomingChannels: make(chan NewChannel, chanSize), | 
					
						
							|  |  |  | 		globalResponses:  make(chan interface{}, 1), | 
					
						
							|  |  |  | 		incomingRequests: make(chan *Request, chanSize), | 
					
						
							|  |  |  | 		errCond:          newCond(), | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if debugMux { | 
					
						
							|  |  |  | 		m.chanList.offset = atomic.AddUint32(&globalOff, 1) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	go m.loop() | 
					
						
							|  |  |  | 	return m | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (m *mux) sendMessage(msg interface{}) error { | 
					
						
							|  |  |  | 	p := Marshal(msg) | 
					
						
							|  |  |  | 	if debugMux { | 
					
						
							|  |  |  | 		log.Printf("send global(%d): %#v", m.chanList.offset, msg) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return m.conn.writePacket(p) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { | 
					
						
							|  |  |  | 	if wantReply { | 
					
						
							|  |  |  | 		m.globalSentMu.Lock() | 
					
						
							|  |  |  | 		defer m.globalSentMu.Unlock() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := m.sendMessage(globalRequestMsg{ | 
					
						
							|  |  |  | 		Type:      name, | 
					
						
							|  |  |  | 		WantReply: wantReply, | 
					
						
							|  |  |  | 		Data:      payload, | 
					
						
							|  |  |  | 	}); err != nil { | 
					
						
							|  |  |  | 		return false, nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if !wantReply { | 
					
						
							|  |  |  | 		return false, nil, nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	msg, ok := <-m.globalResponses | 
					
						
							|  |  |  | 	if !ok { | 
					
						
							|  |  |  | 		return false, nil, io.EOF | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	switch msg := msg.(type) { | 
					
						
							|  |  |  | 	case *globalRequestFailureMsg: | 
					
						
							|  |  |  | 		return false, msg.Data, nil | 
					
						
							|  |  |  | 	case *globalRequestSuccessMsg: | 
					
						
							|  |  |  | 		return true, msg.Data, nil | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // ackRequest must be called after processing a global request that | 
					
						
							|  |  |  | // has WantReply set. | 
					
						
							|  |  |  | func (m *mux) ackRequest(ok bool, data []byte) error { | 
					
						
							|  |  |  | 	if ok { | 
					
						
							|  |  |  | 		return m.sendMessage(globalRequestSuccessMsg{Data: data}) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return m.sendMessage(globalRequestFailureMsg{Data: data}) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (m *mux) Close() error { | 
					
						
							|  |  |  | 	return m.conn.Close() | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // loop runs the connection machine. It will process packets until an | 
					
						
							|  |  |  | // error is encountered. To synchronize on loop exit, use mux.Wait. | 
					
						
							|  |  |  | func (m *mux) loop() { | 
					
						
							|  |  |  | 	var err error | 
					
						
							|  |  |  | 	for err == nil { | 
					
						
							|  |  |  | 		err = m.onePacket() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, ch := range m.chanList.dropAll() { | 
					
						
							|  |  |  | 		ch.close() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	close(m.incomingChannels) | 
					
						
							|  |  |  | 	close(m.incomingRequests) | 
					
						
							|  |  |  | 	close(m.globalResponses) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	m.conn.Close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	m.errCond.L.Lock() | 
					
						
							|  |  |  | 	m.err = err | 
					
						
							|  |  |  | 	m.errCond.Broadcast() | 
					
						
							|  |  |  | 	m.errCond.L.Unlock() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if debugMux { | 
					
						
							|  |  |  | 		log.Println("loop exit", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // onePacket reads and processes one packet. | 
					
						
							|  |  |  | func (m *mux) onePacket() error { | 
					
						
							|  |  |  | 	packet, err := m.conn.readPacket() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if debugMux { | 
					
						
							|  |  |  | 		if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { | 
					
						
							|  |  |  | 			log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) | 
					
						
							|  |  |  | 		} else { | 
					
						
							|  |  |  | 			p, _ := decode(packet) | 
					
						
							|  |  |  | 			log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	switch packet[0] { | 
					
						
							|  |  |  | 	case msgChannelOpen: | 
					
						
							|  |  |  | 		return m.handleChannelOpen(packet) | 
					
						
							|  |  |  | 	case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: | 
					
						
							|  |  |  | 		return m.handleGlobalPacket(packet) | 
					
						
							| 
									
										
										
										
											2023-10-09 10:13:09 +02:00
										 |  |  | 	case msgPing: | 
					
						
							|  |  |  | 		var msg pingMsg | 
					
						
							|  |  |  | 		if err := Unmarshal(packet, &msg); err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("failed to unmarshal ping@openssh.com message: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		return m.sendMessage(pongMsg(msg)) | 
					
						
							| 
									
										
										
										
											2021-08-12 21:03:24 +02:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// assume a channel packet. | 
					
						
							|  |  |  | 	if len(packet) < 5 { | 
					
						
							|  |  |  | 		return parseError(packet[0]) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	id := binary.BigEndian.Uint32(packet[1:]) | 
					
						
							|  |  |  | 	ch := m.chanList.getChan(id) | 
					
						
							|  |  |  | 	if ch == nil { | 
					
						
							|  |  |  | 		return m.handleUnknownChannelPacket(id, packet) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return ch.handlePacket(packet) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (m *mux) handleGlobalPacket(packet []byte) error { | 
					
						
							|  |  |  | 	msg, err := decode(packet) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	switch msg := msg.(type) { | 
					
						
							|  |  |  | 	case *globalRequestMsg: | 
					
						
							|  |  |  | 		m.incomingRequests <- &Request{ | 
					
						
							|  |  |  | 			Type:      msg.Type, | 
					
						
							|  |  |  | 			WantReply: msg.WantReply, | 
					
						
							|  |  |  | 			Payload:   msg.Data, | 
					
						
							|  |  |  | 			mux:       m, | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	case *globalRequestSuccessMsg, *globalRequestFailureMsg: | 
					
						
							|  |  |  | 		m.globalResponses <- msg | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		panic(fmt.Sprintf("not a global message %#v", msg)) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // handleChannelOpen schedules a channel to be Accept()ed. | 
					
						
							|  |  |  | func (m *mux) handleChannelOpen(packet []byte) error { | 
					
						
							|  |  |  | 	var msg channelOpenMsg | 
					
						
							|  |  |  | 	if err := Unmarshal(packet, &msg); err != nil { | 
					
						
							|  |  |  | 		return err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { | 
					
						
							|  |  |  | 		failMsg := channelOpenFailureMsg{ | 
					
						
							|  |  |  | 			PeersID:  msg.PeersID, | 
					
						
							|  |  |  | 			Reason:   ConnectionFailed, | 
					
						
							|  |  |  | 			Message:  "invalid request", | 
					
						
							|  |  |  | 			Language: "en_US.UTF-8", | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		return m.sendMessage(failMsg) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) | 
					
						
							|  |  |  | 	c.remoteId = msg.PeersID | 
					
						
							|  |  |  | 	c.maxRemotePayload = msg.MaxPacketSize | 
					
						
							|  |  |  | 	c.remoteWin.add(msg.PeersWindow) | 
					
						
							|  |  |  | 	m.incomingChannels <- c | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { | 
					
						
							|  |  |  | 	ch, err := m.openChannel(chanType, extra) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return ch, ch.incomingRequests, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { | 
					
						
							|  |  |  | 	ch := m.newChannel(chanType, channelOutbound, extra) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	ch.maxIncomingPayload = channelMaxPacket | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	open := channelOpenMsg{ | 
					
						
							|  |  |  | 		ChanType:         chanType, | 
					
						
							|  |  |  | 		PeersWindow:      ch.myWindow, | 
					
						
							|  |  |  | 		MaxPacketSize:    ch.maxIncomingPayload, | 
					
						
							|  |  |  | 		TypeSpecificData: extra, | 
					
						
							|  |  |  | 		PeersID:          ch.localId, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if err := m.sendMessage(open); err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	switch msg := (<-ch.msg).(type) { | 
					
						
							|  |  |  | 	case *channelOpenConfirmMsg: | 
					
						
							|  |  |  | 		return ch, nil | 
					
						
							|  |  |  | 	case *channelOpenFailureMsg: | 
					
						
							|  |  |  | 		return nil, &OpenChannelError{msg.Reason, msg.Message} | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (m *mux) handleUnknownChannelPacket(id uint32, packet []byte) error { | 
					
						
							|  |  |  | 	msg, err := decode(packet) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	switch msg := msg.(type) { | 
					
						
							|  |  |  | 	// RFC 4254 section 5.4 says unrecognized channel requests should | 
					
						
							|  |  |  | 	// receive a failure response. | 
					
						
							|  |  |  | 	case *channelRequestMsg: | 
					
						
							|  |  |  | 		if msg.WantReply { | 
					
						
							|  |  |  | 			return m.sendMessage(channelRequestFailureMsg{ | 
					
						
							|  |  |  | 				PeersID: msg.PeersID, | 
					
						
							|  |  |  | 			}) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		return nil | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		return fmt.Errorf("ssh: invalid channel %d", id) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } |