| // Copyright 2018 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| package filter |
| |
| import ( |
| "encoding/binary" |
| "fmt" |
| "sync" |
| "sync/atomic" |
| "time" |
| |
| "syslog" |
| |
| "gvisor.dev/gvisor/pkg/tcpip" |
| "gvisor.dev/gvisor/pkg/tcpip/buffer" |
| "gvisor.dev/gvisor/pkg/tcpip/header" |
| "gvisor.dev/gvisor/pkg/tcpip/ports" |
| ) |
| |
| // Key is a key for the connection state maps. |
| type Key struct { |
| transProto tcpip.TransportProtocolNumber |
| srcAddr tcpip.Address |
| srcPort uint16 |
| dstAddr tcpip.Address |
| dstPort uint16 |
| } |
| |
| type State struct { |
| transProto tcpip.TransportProtocolNumber |
| dir Direction |
| nic tcpip.NICID |
| lanAddr tcpip.Address |
| lanPort uint16 |
| gwyAddr tcpip.Address |
| gwyPort uint16 |
| extAddr tcpip.Address |
| extPort uint16 |
| rsvdAddr tcpip.Address |
| rsvdPort uint16 |
| srcEp *endpoint |
| dstEp *endpoint |
| createTime time.Time |
| expireTime time.Time |
| packets uint32 |
| bytes uint32 |
| } |
| |
| func (s *State) String() string { |
| return fmt.Sprintf("%s:%d %s:%d %s:%d [Lo=%d Hi=%d Win=%d WS=%d] [Lo=%d Hi=%d Win=%d WS=%d] %s:%s", |
| s.lanAddr, s.lanPort, s.gwyAddr, s.gwyPort, s.extAddr, s.extPort, |
| s.srcEp.seqLo, s.srcEp.seqHi, s.srcEp.maxWin, s.srcEp.wscale, |
| s.dstEp.seqLo, s.dstEp.seqHi, s.dstEp.maxWin, s.dstEp.wscale, |
| s.srcEp.state, s.dstEp.state) |
| } |
| |
| const ( |
| TCPMaxAckWindow = 65535 + 1500 // Defined as slightly larger than the max TCP data length (65535). |
| TCPMaxZWPDataLen = 1 // Max data length we expect for TCP Zero Window Probe. |
| ) |
| |
| const ( |
| ICMPExpireDefault = 10 * time.Second |
| |
| UDPExpireAfterMultiple = 1 * time.Minute |
| UDPExpireDefault = 20 * time.Second |
| |
| TCPExpireAfterFinWait = 5 * time.Second |
| TCPExpireAfterClosing = 5 * time.Minute |
| TCPExpireAfterEstablished = 30 * time.Second |
| TCPExpireDefault = 24 * time.Hour |
| |
| ExpireIntervalMin = 10 * time.Second |
| ) |
| |
| // endpointState represents the connection state of an endpoint. |
| type endpointState int |
| |
| const ( |
| // Note that we currently allow numeric comparison between two |
| // EndpointStates so that the state related logic can be described |
| // compactly. We assume an endpointState's numeric value will only |
| // increase monotonically during the lifetime of the endpoint |
| // (e.g. TCPFirstPacket => TCPOpening => TCPEstablished => TCPClosing => |
| // TCPFinWait => TCPClosed). |
| |
| // ICMP states. |
| // (TODO: consider more definitions.) |
| ICMPFirstPacket endpointState = iota |
| |
| // UDP states. |
| UDPFirstPacket |
| UDPSingle |
| UDPMultiple |
| |
| // TCP states. |
| TCPFirstPacket |
| TCPOpening |
| TCPEstablished |
| TCPClosing |
| TCPFinWait |
| TCPClosed |
| ) |
| |
| // endpoint maintains the current state and the sequence number information of an endpoint. |
| type endpoint struct { |
| seqLo seqnum // Max seqnum sent. |
| seqHi seqnum // Max seqnum the peer ACK'ed + win. |
| maxWin uint16 |
| wscale int8 // 0 to 14. -1 means wscale is not supported. |
| state endpointState |
| } |
| |
| func (s *State) updateStateICMPv4(dir Direction, dataLen uint16) error { |
| s.packets++ |
| s.bytes += uint32(dataLen) |
| s.expireTime = time.Now().Add(ICMPExpireDefault) |
| return nil |
| } |
| |
| func (s *State) updateStateUDP(dir Direction, dataLen uint16) error { |
| var srcEp, dstEp *endpoint |
| if dir == s.dir { |
| srcEp = s.srcEp |
| dstEp = s.dstEp |
| } else { |
| srcEp = s.dstEp |
| dstEp = s.srcEp |
| } |
| |
| s.packets++ |
| s.bytes += uint32(dataLen) |
| |
| // Update state. |
| if srcEp.state < UDPSingle { |
| srcEp.state = UDPSingle |
| } |
| if dstEp.state == UDPSingle { |
| dstEp.state = UDPMultiple |
| } |
| |
| // Update expire time. |
| if srcEp.state == UDPMultiple && dstEp.state == UDPMultiple { |
| s.expireTime = time.Now().Add(UDPExpireAfterMultiple) |
| } else { |
| s.expireTime = time.Now().Add(UDPExpireDefault) |
| } |
| |
| syslog.VLogTf(syslog.TraceVerbosity, tag, "updated state: %v", s) |
| return nil |
| } |
| |
| func (s *State) updateStateTCP(dir Direction, dataLen uint16, win uint16, seq, ack seqnum, flags uint8, wscale int) error { |
| end := seq.Add(uint32(dataLen)) |
| |
| if flags&header.TCPFlagSyn != 0 { |
| end++ |
| } |
| if flags&header.TCPFlagFin != 0 { |
| end++ |
| } |
| |
| var srcEp, dstEp *endpoint |
| if dir == s.dir { |
| srcEp = s.srcEp |
| dstEp = s.dstEp |
| } else { |
| srcEp = s.dstEp |
| dstEp = s.srcEp |
| } |
| |
| // Sequence tracking algorithm from Guido van Rooij's paper. |
| |
| if srcEp.state == TCPFirstPacket { |
| // This is the first packet from this end. Initialize the state. |
| srcEp.seqLo = end |
| srcEp.seqHi = end + TCPMaxZWPDataLen |
| srcEp.maxWin = TCPMaxZWPDataLen |
| srcEp.wscale = int8(wscale) |
| } |
| |
| sws := uint8(0) |
| dws := uint8(0) |
| if srcEp.wscale >= 0 && dstEp.wscale >= 0 && flags&header.TCPFlagSyn == 0 { |
| sws = uint8(srcEp.wscale) |
| dws = uint8(dstEp.wscale) |
| } |
| |
| if flags&header.TCPFlagAck == 0 { |
| // Pretend the ACK flag was set. |
| ack = dstEp.seqLo |
| } else if ack == 0 && |
| flags&(header.TCPFlagAck|header.TCPFlagRst) == (header.TCPFlagAck|header.TCPFlagRst) { |
| // Broken TCP stacks set the ACK flag in RST packets, but leave the ack |
| // field 0. Pretend the ACK is valid. |
| ack = dstEp.seqLo |
| } |
| |
| if seq == end { |
| // If there's no data, assume seq is valid and only look at ack below. |
| seq = srcEp.seqLo |
| end = seq |
| } |
| |
| ackskew := int32(dstEp.seqLo - ack) |
| |
| // Check the boundaries for seq and ack (See Rooij's paper): |
| // I. seq + dataLen <= srcEp.seqHi |
| // II. seq >= srcEp.seqLo - dstEp.maxWin |
| // III. ack <= dstEp.seqLo + TCPMaxAckWindow |
| // IV. ack >= dstEp.seqLo - TCPMaxAckWindow |
| if !end.LessThanEq(srcEp.seqHi) || |
| !seq.GreaterThanEq(srcEp.seqLo.Sub(uint32(dstEp.maxWin)<<dws)) || |
| ackskew < -(TCPMaxAckWindow<<dws) || |
| ackskew > (TCPMaxAckWindow<<dws) { |
| return ErrBadTCPState |
| } |
| |
| if ackskew < 0 { |
| dstEp.seqLo = ack |
| } |
| |
| s.packets++ |
| s.bytes += uint32(dataLen) |
| |
| if srcEp.maxWin < win { |
| srcEp.maxWin = win |
| } |
| if end.GreaterThan(srcEp.seqLo) { |
| srcEp.seqLo = end |
| } |
| if ack.Add(uint32(win) << sws).GreaterThanEq(dstEp.seqHi) { |
| d := uint32(win) << sws |
| if d < TCPMaxZWPDataLen { |
| d = TCPMaxZWPDataLen |
| } |
| dstEp.seqHi = ack.Add(d) |
| } |
| |
| // Update state. |
| if flags&header.TCPFlagSyn != 0 { |
| if srcEp.state < TCPOpening { |
| srcEp.state = TCPOpening |
| } |
| } |
| if flags&header.TCPFlagFin != 0 { |
| if srcEp.state < TCPClosing { |
| srcEp.state = TCPClosing |
| } |
| } |
| if flags&header.TCPFlagAck != 0 { |
| if dstEp.state == TCPOpening { |
| dstEp.state = TCPEstablished |
| } else if dstEp.state == TCPClosing { |
| dstEp.state = TCPFinWait |
| } |
| } |
| if flags&header.TCPFlagRst != 0 { |
| srcEp.state = TCPClosed |
| dstEp.state = TCPClosed |
| } |
| |
| // Update expire time. |
| if srcEp.state >= TCPFinWait && dstEp.state >= TCPFinWait { |
| s.expireTime = time.Now().Add(TCPExpireAfterFinWait) |
| } else if srcEp.state >= TCPClosing || dstEp.state >= TCPClosing { |
| s.expireTime = time.Now().Add(TCPExpireAfterClosing) |
| } else if srcEp.state < TCPEstablished || dstEp.state < TCPEstablished { |
| s.expireTime = time.Now().Add(TCPExpireAfterEstablished) |
| } else { |
| s.expireTime = time.Now().Add(TCPExpireDefault) |
| } |
| |
| if chatty { |
| syslog.VLogTf(syslog.TraceVerbosity, tag, "updated state: %v", s) |
| } |
| return nil |
| } |
| |
| // updateStateTCPinICMP is a subset of updateStateTCP, which just checks seq. |
| func (s *State) updateStateTCPinICMP(dir Direction, seq seqnum) error { |
| var srcEp, dstEp *endpoint |
| if dir == s.dir { |
| srcEp = s.srcEp |
| dstEp = s.dstEp |
| } else { |
| srcEp = s.dstEp |
| dstEp = s.srcEp |
| } |
| |
| dws := uint8(0) |
| if srcEp.wscale >= 0 && dstEp.wscale >= 0 { |
| dws = uint8(dstEp.wscale) |
| } |
| |
| if !seq.GreaterThanEq(srcEp.seqLo.Sub(uint32(dstEp.maxWin) << dws)) { |
| return ErrBadTCPState |
| } |
| return nil |
| } |
| |
| // States is a collection of State we are tracking. |
| type States struct { |
| purgeEnabled uint32 |
| |
| mut sync.Mutex |
| extToGwy map[Key]*State |
| lanToExt map[Key]*State |
| } |
| |
| func NewStates() *States { |
| ss := &States{ |
| purgeEnabled: 0, |
| extToGwy: make(map[Key]*State), |
| lanToExt: make(map[Key]*State), |
| } |
| return ss |
| } |
| |
| func (ss *States) enablePurge() { |
| atomic.StoreUint32(&ss.purgeEnabled, 1) |
| } |
| |
| func (ss *States) purgeExpiredEntries(pm *ports.PortManager) { |
| if atomic.CompareAndSwapUint32(&ss.purgeEnabled, 1, 0) { |
| defer time.AfterFunc(ExpireIntervalMin, ss.enablePurge) |
| |
| ss.mut.Lock() |
| defer ss.mut.Unlock() |
| |
| now := time.Now() |
| for k, s := range ss.extToGwy { |
| if now.After(s.expireTime) { |
| syslog.VLogTf(syslog.TraceVerbosity, tag, "delete state: %v (ExtToGwy) expire: %v now: %v", s, s.expireTime, now) |
| delete(ss.extToGwy, k) |
| } |
| if s.rsvdPort != 0 { |
| // Release the reserved port. |
| netProtos := []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber} |
| pm.ReleasePort(netProtos, s.transProto, s.rsvdAddr, s.rsvdPort, ports.Flags{}, s.nic) |
| } |
| } |
| for k, s := range ss.lanToExt { |
| if now.After(s.expireTime) { |
| syslog.VLogTf(syslog.TraceVerbosity, tag, "delete state: %v (LanToExt) expire: %v now: %v", s, s.expireTime, now) |
| delete(ss.lanToExt, k) |
| } |
| } |
| } |
| } |
| |
| func (ss *States) getState(dir Direction, transProto tcpip.TransportProtocolNumber, srcAddr tcpip.Address, srcPort uint16, dstAddr tcpip.Address, dstPort uint16) *State { |
| switch dir { |
| case Incoming: |
| return ss.extToGwy[Key{transProto, srcAddr, srcPort, dstAddr, dstPort}] |
| case Outgoing: |
| return ss.lanToExt[Key{transProto, srcAddr, srcPort, dstAddr, dstPort}] |
| default: |
| panic("unknown direction") |
| } |
| } |
| |
| func (ss *States) createState(dir Direction, nic tcpip.NICID, transProto tcpip.TransportProtocolNumber, srcAddr tcpip.Address, srcPort uint16, dstAddr tcpip.Address, dstPort uint16, origAddr tcpip.Address, origPort uint16, rsvdAddr tcpip.Address, rsvdPort uint16, isNAT bool, isRDR bool, payloadLength uint16, hdr buffer.Prependable, transportHeader []byte, payload buffer.VectorisedView) *State { |
| var srcEp, dstEp *endpoint |
| var createTime, expireTime time.Time |
| var dataLen uint16 |
| switch transProto { |
| case header.ICMPv4ProtocolNumber: |
| if len(transportHeader) < header.ICMPv4MinimumSize { |
| return nil |
| } |
| icmp := header.ICMPv4(transportHeader) |
| if isICMPv4ErrorMessage(icmp.Type()) { |
| // We don't have to create a state for this. |
| return nil |
| } |
| dataLen = payloadLength - 8 // ICMP header size is 8. |
| var id uint16 |
| if len(transportHeader) > 4 { |
| id = binary.BigEndian.Uint16(transportHeader[4:]) |
| } else if len(transportHeader)+payload.Size() > 4 { |
| id = binary.BigEndian.Uint16(payload.First()[4-len(transportHeader):]) |
| } |
| srcPort = id |
| dstPort = id |
| srcEp = &endpoint{ |
| seqLo: 0, |
| seqHi: 0, |
| maxWin: 0, |
| state: ICMPFirstPacket, |
| } |
| dstEp = &endpoint{ |
| seqLo: 0, |
| seqHi: 0, |
| maxWin: 0, |
| state: ICMPFirstPacket, |
| } |
| createTime = time.Now() |
| expireTime = createTime.Add(20 * time.Second) |
| case header.UDPProtocolNumber: |
| dataLen = payloadLength - header.UDPMinimumSize |
| srcEp = &endpoint{ |
| seqLo: 0, |
| seqHi: 0, |
| maxWin: 0, |
| state: UDPSingle, |
| } |
| dstEp = &endpoint{ |
| seqLo: 0, |
| seqHi: 0, |
| maxWin: 0, |
| state: UDPFirstPacket, |
| } |
| createTime = time.Now() |
| expireTime = createTime.Add(30 * time.Second) |
| case header.TCPProtocolNumber: |
| tcp := header.TCP(transportHeader) |
| flags := tcp.Flags() |
| dataLen = payloadLength - uint16(tcp.DataOffset()) |
| wscale := -1 |
| if flags&header.TCPFlagSyn != 0 { |
| dataLen++ |
| synOpts := header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0) |
| wscale = synOpts.WS |
| } |
| if flags&header.TCPFlagFin != 0 { |
| dataLen++ |
| } |
| seqLo := seqnum(tcp.SequenceNumber()).Add(uint32(dataLen)) |
| maxWin := tcp.WindowSize() |
| if maxWin < TCPMaxZWPDataLen { |
| maxWin = TCPMaxZWPDataLen |
| } |
| srcEp = &endpoint{ |
| seqLo: seqLo, |
| seqHi: seqLo + TCPMaxZWPDataLen, |
| maxWin: maxWin, |
| wscale: int8(wscale), |
| state: TCPOpening, |
| } |
| dstEp = &endpoint{ // Assign temporary values as we haven't seen a packet. |
| seqLo: 0, |
| seqHi: TCPMaxZWPDataLen, |
| maxWin: TCPMaxZWPDataLen, |
| wscale: -1, |
| state: TCPFirstPacket, |
| } |
| createTime = time.Now() |
| expireTime = createTime.Add(60 * time.Second) |
| } |
| |
| var lanAddr, gwyAddr, extAddr tcpip.Address |
| var lanPort, gwyPort, extPort uint16 |
| switch dir { |
| case Incoming: |
| extAddr = srcAddr |
| extPort = srcPort |
| if isRDR { |
| gwyAddr = origAddr |
| gwyPort = origPort |
| } else { |
| gwyAddr = dstAddr |
| gwyPort = dstPort |
| } |
| lanAddr = dstAddr |
| lanPort = dstPort |
| case Outgoing: |
| if isNAT { |
| lanAddr = origAddr |
| lanPort = origPort |
| } else { |
| lanAddr = srcAddr |
| lanPort = srcPort |
| } |
| gwyAddr = srcAddr |
| gwyPort = srcPort |
| extAddr = dstAddr |
| extPort = dstPort |
| } |
| |
| s := &State{ |
| transProto: transProto, |
| dir: dir, |
| lanAddr: lanAddr, |
| lanPort: lanPort, |
| gwyAddr: gwyAddr, |
| gwyPort: gwyPort, |
| extAddr: extAddr, |
| extPort: extPort, |
| rsvdAddr: rsvdAddr, |
| rsvdPort: rsvdPort, |
| srcEp: srcEp, |
| dstEp: dstEp, |
| createTime: createTime, |
| expireTime: expireTime, |
| packets: 1, |
| bytes: uint32(dataLen), |
| } |
| |
| kLanToExt := Key{transProto, lanAddr, lanPort, extAddr, extPort} |
| kExtToGwy := Key{transProto, extAddr, extPort, gwyAddr, gwyPort} |
| ss.lanToExt[kLanToExt] = s |
| ss.extToGwy[kExtToGwy] = s |
| |
| syslog.VLogTf(syslog.TraceVerbosity, tag, "new state: %v", s) |
| |
| return s |
| } |
| |
| // isICMPv4ErrorMessage returns true if t is ICMPv4 Error Message Type. |
| // It returns false if t is ICMPv4 Informational Message Type. |
| func isICMPv4ErrorMessage(t header.ICMPv4Type) bool { |
| return t == header.ICMPv4DstUnreachable || |
| t == header.ICMPv4SrcQuench || |
| t == header.ICMPv4Redirect || |
| t == header.ICMPv4TimeExceeded || |
| t == header.ICMPv4ParamProblem |
| } |
| |
| func (ss *States) findStateICMPv4(dir Direction, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, srcAddr tcpip.Address, dstAddr tcpip.Address, payloadLength uint16, transportHeader []byte, payload buffer.VectorisedView) (*State, error) { |
| if len(transportHeader) < header.ICMPv4MinimumSize { |
| return nil, ErrPacketTooShort |
| } |
| icmp := header.ICMPv4(transportHeader) |
| if isICMPv4ErrorMessage(icmp.Type()) { |
| // This ICMPv4 packet is reporting an error detected in a transport layer, and |
| // includes the IP header and the first 8 bytes of the transport header of the packet |
| // that had the error. |
| // For NAT and RDR, we have to rewrite the address and port in the packet. |
| // We also test if the values in transport header are consistent with the connection |
| // state we are tracking (but we have only the first 8 bytes and cannot do a full |
| // check). |
| if len(transportHeader) < header.IPv4MinimumSize { |
| return nil, ErrPacketTooShort |
| } |
| // First, look for an IP header at the offset 8. |
| b2 := transportHeader[8:] |
| ipv4 := header.IPv4(b2) |
| transProto2 := ipv4.TransportProtocol() |
| srcAddr = ipv4.SourceAddress() |
| dstAddr = ipv4.DestinationAddress() |
| if len(transportHeader) < int(ipv4.HeaderLength()+8) { // Need the first 8 bytes of transport header. |
| return nil, ErrPacketTooShort |
| } |
| // Here's the transport header. |
| th2 := b2[ipv4.HeaderLength():] |
| |
| switch transProto2 { |
| case header.UDPProtocolNumber: |
| udp := header.UDP(th2) |
| srcPort := udp.SourcePort() |
| dstPort := udp.DestinationPort() |
| s := ss.getState(dir, transProto2, srcAddr, srcPort, dstAddr, dstPort) |
| // There is nothing we can use in the first 8 bytes to update the state. |
| return s, nil |
| case header.TCPProtocolNumber: |
| tcp := header.TCP(th2) |
| srcPort := tcp.SourcePort() |
| dstPort := tcp.DestinationPort() |
| s := ss.getState(dir, transProto2, srcAddr, srcPort, dstAddr, dstPort) |
| if s == nil { |
| return nil, nil |
| } |
| // We have the first 8 bytes of the TCP header only, which means |
| // we just use seq to test the state, |
| seq := seqnum(tcp.SequenceNumber()) |
| err := s.updateStateTCPinICMP(dir, seq) |
| if err != nil { |
| return nil, err |
| } |
| return s, nil |
| default: |
| return nil, ErrUnknownProtocol |
| } |
| } else { |
| dataLen := payloadLength - 8 // ICMP header size is 8. |
| |
| var id uint16 |
| // ICMP query/reply message. |
| if len(transportHeader) > 4 { |
| id = binary.BigEndian.Uint16(transportHeader[4:]) |
| } else if len(transportHeader)+payload.Size() > 4 { |
| id = binary.BigEndian.Uint16(payload.First()[4-len(transportHeader):]) |
| } |
| s := ss.getState(dir, transProto, srcAddr, id, dstAddr, id) |
| if s == nil { |
| return nil, nil |
| } |
| err := s.updateStateICMPv4(dir, dataLen) |
| if err != nil { |
| return nil, err |
| } |
| return s, nil |
| } |
| } |
| |
| func (ss *States) findStateUDP(dir Direction, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, srcAddr tcpip.Address, srcPort uint16, dstAddr tcpip.Address, dstPort uint16, payloadLength uint16, transportHeader []byte) (*State, error) { |
| s := ss.getState(dir, transProto, srcAddr, srcPort, dstAddr, dstPort) |
| if s == nil { |
| return nil, nil |
| } |
| |
| dataLen := payloadLength - header.UDPMinimumSize |
| |
| err := s.updateStateUDP(dir, dataLen) |
| if err != nil { |
| return nil, err |
| } |
| return s, nil |
| } |
| |
| func (ss *States) findStateTCP(dir Direction, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, srcAddr tcpip.Address, srcPort uint16, dstAddr tcpip.Address, dstPort uint16, payloadLength uint16, transportHeader []byte) (*State, error) { |
| s := ss.getState(dir, transProto, srcAddr, srcPort, dstAddr, dstPort) |
| if s == nil { |
| return nil, nil |
| } |
| tcp := header.TCP(transportHeader) |
| dataLen := payloadLength - uint16(tcp.DataOffset()) |
| win := tcp.WindowSize() |
| seq := seqnum(tcp.SequenceNumber()) |
| ack := seqnum(tcp.AckNumber()) |
| flags := tcp.Flags() |
| wscale := -1 |
| if flags&header.TCPFlagSyn != 0 { |
| synOpts := header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0) |
| wscale = synOpts.WS |
| } |
| |
| err := s.updateStateTCP(dir, dataLen, win, seq, ack, flags, wscale) |
| if err != nil { |
| return nil, err |
| } |
| return s, nil |
| } |