| // Copyright 2018 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| package filter |
| |
| import ( |
| "encoding/binary" |
| "fmt" |
| "log" |
| "sync" |
| "sync/atomic" |
| "time" |
| |
| "github.com/google/netstack/tcpip" |
| "github.com/google/netstack/tcpip/buffer" |
| "github.com/google/netstack/tcpip/header" |
| "github.com/google/netstack/tcpip/ports" |
| ) |
| |
| // Key is a key for the connection state maps. |
| type Key struct { |
| transProto tcpip.TransportProtocolNumber |
| srcAddr tcpip.Address |
| srcPort uint16 |
| dstAddr tcpip.Address |
| dstPort uint16 |
| } |
| |
| type State struct { |
| transProto tcpip.TransportProtocolNumber |
| dir Direction |
| lanAddr tcpip.Address |
| lanPort uint16 |
| gwyAddr tcpip.Address |
| gwyPort uint16 |
| extAddr tcpip.Address |
| extPort uint16 |
| rsvdAddr tcpip.Address |
| rsvdPort uint16 |
| srcEp *Endpoint |
| dstEp *Endpoint |
| createTime time.Time |
| expireTime time.Time |
| packets uint32 |
| bytes uint32 |
| } |
| |
| func (s *State) String() string { |
| return fmt.Sprintf("%s:%d %s:%d %s:%d [Lo=%d Hi=%d Win=%d WS=%d] [Lo=%d Hi=%d Win=%d WS=%d] %s:%s", |
| s.lanAddr, s.lanPort, s.gwyAddr, s.gwyPort, s.extAddr, s.extPort, |
| s.srcEp.seqLo, s.srcEp.seqHi, s.srcEp.maxWin, s.srcEp.wscale, |
| s.dstEp.seqLo, s.dstEp.seqHi, s.dstEp.maxWin, s.dstEp.wscale, |
| s.srcEp.state, s.dstEp.state) |
| } |
| |
| const ( |
| TCPMaxAckWindow = 65535 + 1500 // Defined as slightly larger than the max TCP data length (65535). |
| TCPMaxZWPDataLen = 1 // Max data length we expect for TCP Zero Window Probe. |
| ) |
| |
| const ( |
| ICMPExpireDefault = 10 * time.Second |
| |
| UDPExpireAfterMultiple = 1 * time.Minute |
| UDPExpireDefault = 20 * time.Second |
| |
| TCPExpireAfterFinWait = 5 * time.Second |
| TCPExpireAfterClosing = 5 * time.Minute |
| TCPExpireAfterEstablished = 30 * time.Second |
| TCPExpireDefault = 24 * time.Hour |
| |
| ExpireIntervalMin = 10 * time.Second |
| ) |
| |
| // EndpointState represents the connection state of an Endpoint. |
| type EndpointState int |
| |
| const ( |
| // Note that we currently allow numeric comparison between two |
| // EndpointStates so that the state related logic can be described |
| // compactly. We assume an EndpointState's numeric value will only |
| // increase monotonically during the lifetime of the endpoint |
| // (e.g. TCPFirstPacket => TCPOpening => TCPEstablished => TCPClosing => |
| // TCPFinWait => TCPClosed). |
| |
| // ICMP states. |
| // (TODO: consider more definitions.) |
| ICMPFirstPacket EndpointState = iota |
| |
| // UDP states. |
| UDPFirstPacket |
| UDPSingle |
| UDPMultiple |
| |
| // TCP states. |
| TCPFirstPacket |
| TCPOpening |
| TCPEstablished |
| TCPClosing |
| TCPFinWait |
| TCPClosed |
| ) |
| |
| func (state EndpointState) String() string { |
| switch state { |
| case ICMPFirstPacket: |
| return "ICMPFirstPacket" |
| case UDPFirstPacket: |
| return "UDPFirstPacket" |
| case UDPSingle: |
| return "UDPSingle" |
| case UDPMultiple: |
| return "UDPMultiple" |
| case TCPFirstPacket: |
| return "TCPFirstPacket" |
| case TCPOpening: |
| return "TCPOpening" |
| case TCPEstablished: |
| return "TCPEstablished" |
| case TCPClosing: |
| return "TCPClosing" |
| case TCPFinWait: |
| return "TCPFinWait" |
| case TCPClosed: |
| return "TCPClosed" |
| default: |
| panic("Unknown state") |
| } |
| } |
| |
| // Endpoint maintains the current state and the sequence number information of an endpoint. |
| type Endpoint struct { |
| seqLo seqnum // Max seqnum sent. |
| seqHi seqnum // Max seqnum the peer ACK'ed + win. |
| maxWin uint16 |
| wscale int8 // 0 to 14. -1 means wscale is not supported. |
| state EndpointState |
| } |
| |
| func (s *State) updateStateICMPv4(dir Direction, dataLen uint16) error { |
| s.packets++ |
| s.bytes += uint32(dataLen) |
| s.expireTime = time.Now().Add(ICMPExpireDefault) |
| return nil |
| } |
| |
| func (s *State) updateStateUDP(dir Direction, dataLen uint16) error { |
| var srcEp, dstEp *Endpoint |
| if dir == s.dir { |
| srcEp = s.srcEp |
| dstEp = s.dstEp |
| } else { |
| srcEp = s.dstEp |
| dstEp = s.srcEp |
| } |
| |
| s.packets++ |
| s.bytes += uint32(dataLen) |
| |
| // Update state. |
| if srcEp.state < UDPSingle { |
| srcEp.state = UDPSingle |
| } |
| if dstEp.state == UDPSingle { |
| dstEp.state = UDPMultiple |
| } |
| |
| // Update expire time. |
| if srcEp.state == UDPMultiple && dstEp.state == UDPMultiple { |
| s.expireTime = time.Now().Add(UDPExpireAfterMultiple) |
| } else { |
| s.expireTime = time.Now().Add(UDPExpireDefault) |
| } |
| |
| if debug { |
| log.Printf("packet filter: updated state: %v", s) |
| } |
| |
| return nil |
| } |
| |
| func (s *State) updateStateTCP(dir Direction, dataLen uint16, win uint16, seq, ack seqnum, flags uint8, wscale int) error { |
| end := seq.Add(uint32(dataLen)) |
| |
| if flags&header.TCPFlagSyn != 0 { |
| end++ |
| } |
| if flags&header.TCPFlagFin != 0 { |
| end++ |
| } |
| |
| var srcEp, dstEp *Endpoint |
| if dir == s.dir { |
| srcEp = s.srcEp |
| dstEp = s.dstEp |
| } else { |
| srcEp = s.dstEp |
| dstEp = s.srcEp |
| } |
| |
| // Sequence tracking algorithm from Guido van Rooij's paper. |
| |
| if srcEp.state == TCPFirstPacket { |
| // This is the first packet from this end. Initialize the state. |
| srcEp.seqLo = end |
| srcEp.seqHi = end + TCPMaxZWPDataLen |
| srcEp.maxWin = TCPMaxZWPDataLen |
| srcEp.wscale = int8(wscale) |
| } |
| |
| sws := uint8(0) |
| dws := uint8(0) |
| if srcEp.wscale >= 0 && dstEp.wscale >= 0 && flags&header.TCPFlagSyn == 0 { |
| sws = uint8(srcEp.wscale) |
| dws = uint8(dstEp.wscale) |
| } |
| |
| if flags&header.TCPFlagAck == 0 { |
| // Pretend the ACK flag was set. |
| ack = dstEp.seqLo |
| } else if ack == 0 && |
| flags&(header.TCPFlagAck|header.TCPFlagRst) == (header.TCPFlagAck|header.TCPFlagRst) { |
| // Broken TCP stacks set the ACK flag in RST packets, but leave the ack |
| // field 0. Pretend the ACK is valid. |
| ack = dstEp.seqLo |
| } |
| |
| if seq == end { |
| // If there's no data, assume seq is valid and only look at ack below. |
| seq = srcEp.seqLo |
| end = seq |
| } |
| |
| ackskew := int32(dstEp.seqLo - ack) |
| |
| // Check the boundaries for seq and ack (See Rooij's paper): |
| // I. seq + dataLen <= srcEp.seqHi |
| // II. seq >= srcEp.seqLo - dstEp.maxWin |
| // III. ack <= dstEp.seqLo + TCPMaxAckWindow |
| // IV. ack >= dstEp.seqLo - TCPMaxAckWindow |
| if !end.LessThanEq(srcEp.seqHi) || |
| !seq.GreaterThanEq(srcEp.seqLo.Sub(uint32(dstEp.maxWin)<<dws)) || |
| ackskew < -TCPMaxAckWindow || |
| ackskew > (TCPMaxAckWindow<<dws) { |
| return ErrBadTCPState |
| } |
| |
| if ackskew < 0 { |
| dstEp.seqLo = ack |
| } |
| |
| s.packets++ |
| s.bytes += uint32(dataLen) |
| |
| if srcEp.maxWin < win { |
| srcEp.maxWin = win |
| } |
| if end.GreaterThan(srcEp.seqLo) { |
| srcEp.seqLo = end |
| } |
| if ack.Add(uint32(win) << sws).GreaterThanEq(dstEp.seqHi) { |
| d := uint32(win) << sws |
| if d < TCPMaxZWPDataLen { |
| d = TCPMaxZWPDataLen |
| } |
| dstEp.seqHi = ack.Add(d) |
| } |
| |
| // Update state. |
| if flags&header.TCPFlagSyn != 0 { |
| if srcEp.state < TCPOpening { |
| srcEp.state = TCPOpening |
| } |
| } |
| if flags&header.TCPFlagFin != 0 { |
| if srcEp.state < TCPClosing { |
| srcEp.state = TCPClosing |
| } |
| } |
| if flags&header.TCPFlagAck != 0 { |
| if dstEp.state == TCPOpening { |
| dstEp.state = TCPEstablished |
| } else if dstEp.state == TCPClosing { |
| dstEp.state = TCPFinWait |
| } |
| } |
| if flags&header.TCPFlagRst != 0 { |
| srcEp.state = TCPClosed |
| dstEp.state = TCPClosed |
| } |
| |
| // Update expire time. |
| if srcEp.state >= TCPFinWait && dstEp.state >= TCPFinWait { |
| s.expireTime = time.Now().Add(TCPExpireAfterFinWait) |
| } else if srcEp.state >= TCPClosing || dstEp.state >= TCPClosing { |
| s.expireTime = time.Now().Add(TCPExpireAfterClosing) |
| } else if srcEp.state < TCPEstablished || dstEp.state < TCPEstablished { |
| s.expireTime = time.Now().Add(TCPExpireAfterEstablished) |
| } else { |
| s.expireTime = time.Now().Add(TCPExpireDefault) |
| } |
| return nil |
| } |
| |
| // updateStateTCPinICMP is a subset of updateStateTCP, which just checks seq. |
| func (s *State) updateStateTCPinICMP(dir Direction, seq seqnum) error { |
| var srcEp, dstEp *Endpoint |
| if dir == s.dir { |
| srcEp = s.srcEp |
| dstEp = s.dstEp |
| } else { |
| srcEp = s.dstEp |
| dstEp = s.srcEp |
| } |
| |
| dws := uint8(0) |
| if srcEp.wscale >= 0 && dstEp.wscale >= 0 { |
| dws = uint8(dstEp.wscale) |
| } |
| |
| if !seq.GreaterThanEq(srcEp.seqLo.Sub(uint32(dstEp.maxWin) << dws)) { |
| return ErrBadTCPState |
| } |
| return nil |
| } |
| |
| type StatesLockKey tcpip.FullAddress |
| |
| func makeStatesLockKey(dir Direction, srcAddr tcpip.Address, dstAddr tcpip.Address, transProto tcpip.TransportProtocolNumber, transportHeader []byte) StatesLockKey { |
| var srcPort, dstPort uint16 |
| switch transProto { |
| case header.UDPProtocolNumber: |
| udp := header.UDP(transportHeader) |
| srcPort = udp.SourcePort() |
| dstPort = udp.DestinationPort() |
| case header.TCPProtocolNumber: |
| tcp := header.TCP(transportHeader) |
| srcPort = tcp.SourcePort() |
| dstPort = tcp.DestinationPort() |
| } |
| var extPort uint16 |
| var extAddr tcpip.Address |
| switch dir { |
| case Incoming: |
| extPort = srcPort |
| extAddr = srcAddr |
| case Outgoing: |
| extPort = dstPort |
| extAddr = dstAddr |
| } |
| |
| return StatesLockKey{Addr: extAddr, Port: extPort} |
| } |
| |
| // States is a collection of State we are tracking. |
| type States struct { |
| purgeEnabled uint32 |
| |
| mut sync.RWMutex // Guards access to lockKeyToMut below |
| lockKeyToMut map[StatesLockKey]*sync.Mutex // Guards access to individual maps below |
| |
| extToGwy map[Key]*State |
| lanToExt map[Key]*State |
| } |
| |
| func NewStates() *States { |
| ss := &States{ |
| purgeEnabled: 0, |
| lockKeyToMut: make(map[StatesLockKey]*sync.Mutex), |
| extToGwy: make(map[Key]*State), |
| lanToExt: make(map[Key]*State), |
| } |
| return ss |
| } |
| |
| func (ss *States) lock(lockKey StatesLockKey) { |
| ss.mut.RLock() |
| |
| mu, ok := ss.lockKeyToMut[lockKey] |
| if ok { |
| mu.Lock() |
| return |
| } |
| |
| ss.mut.RUnlock() |
| ss.mut.Lock() |
| |
| mu, ok = ss.lockKeyToMut[lockKey] |
| if !ok { |
| mu = &sync.Mutex{} |
| ss.lockKeyToMut[lockKey] = mu |
| } |
| ss.mut.Unlock() |
| |
| ss.mut.RLock() |
| mu.Lock() |
| } |
| |
| func (ss *States) unlock(lockKey StatesLockKey) { |
| ss.lockKeyToMut[lockKey].Unlock() |
| ss.mut.RUnlock() |
| } |
| |
| func (ss *States) enablePurge() { |
| atomic.StoreUint32(&ss.purgeEnabled, 1) |
| } |
| |
| func (ss *States) purgeExpiredEntries(pm *ports.PortManager) { |
| if atomic.CompareAndSwapUint32(&ss.purgeEnabled, 1, 0) { |
| defer time.AfterFunc(ExpireIntervalMin, ss.enablePurge) |
| |
| ss.mut.Lock() |
| defer ss.mut.Unlock() |
| |
| now := time.Now() |
| for k, s := range ss.extToGwy { |
| if s.expireTime.After(now) { |
| if debug { |
| log.Printf("packet filter: delete state: %v (ExtToGwy)", s) |
| } |
| delete(ss.lockKeyToMut, StatesLockKey{Addr: k.srcAddr, Port: k.srcPort}) |
| delete(ss.extToGwy, k) |
| } |
| if s.rsvdPort != 0 { |
| // Release the reserved port. |
| netProtos := []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber} |
| pm.ReleasePort(netProtos, s.transProto, s.rsvdAddr, s.rsvdPort) |
| } |
| } |
| for k, s := range ss.lanToExt { |
| if s.expireTime.After(now) { |
| if debug { |
| log.Printf("packet filter: delete state: %v (LanToExt)", s) |
| } |
| delete(ss.lanToExt, k) |
| } |
| } |
| } |
| } |
| |
| func (ss *States) getState(dir Direction, transProto tcpip.TransportProtocolNumber, srcAddr tcpip.Address, srcPort uint16, dstAddr tcpip.Address, dstPort uint16) *State { |
| switch dir { |
| case Incoming: |
| return ss.extToGwy[Key{transProto, srcAddr, srcPort, dstAddr, dstPort}] |
| case Outgoing: |
| return ss.lanToExt[Key{transProto, srcAddr, srcPort, dstAddr, dstPort}] |
| default: |
| panic("unknown direction") |
| } |
| } |
| |
| func (ss *States) createState(dir Direction, transProto tcpip.TransportProtocolNumber, srcAddr tcpip.Address, srcPort uint16, dstAddr tcpip.Address, dstPort uint16, origAddr tcpip.Address, origPort uint16, rsvdAddr tcpip.Address, rsvdPort uint16, isNAT bool, isRDR bool, payloadLength uint16, hdr buffer.Prependable, transportHeader []byte, payload buffer.VectorisedView) *State { |
| var srcEp, dstEp *Endpoint |
| var createTime, expireTime time.Time |
| var dataLen uint16 |
| switch transProto { |
| case header.ICMPv4ProtocolNumber: |
| if len(transportHeader) < header.ICMPv4MinimumSize { |
| return nil |
| } |
| icmp := header.ICMPv4(transportHeader) |
| if isICMPv4ErrorMessage(icmp.Type()) { |
| // We don't have to create a state for this. |
| return nil |
| } |
| dataLen = payloadLength - 8 // ICMP header size is 8. |
| var id uint16 |
| if len(transportHeader) > 4 { |
| id = binary.BigEndian.Uint16(transportHeader[4:]) |
| } else if len(transportHeader)+payload.Size() > 4 { |
| id = binary.BigEndian.Uint16(payload.First()[4-len(transportHeader):]) |
| } |
| srcPort = id |
| dstPort = id |
| srcEp = &Endpoint{ |
| seqLo: 0, |
| seqHi: 0, |
| maxWin: 0, |
| state: ICMPFirstPacket, |
| } |
| dstEp = &Endpoint{ |
| seqLo: 0, |
| seqHi: 0, |
| maxWin: 0, |
| state: ICMPFirstPacket, |
| } |
| createTime = time.Now() |
| expireTime = createTime.Add(20 * time.Second) |
| case header.UDPProtocolNumber: |
| dataLen = payloadLength - header.UDPMinimumSize |
| srcEp = &Endpoint{ |
| seqLo: 0, |
| seqHi: 0, |
| maxWin: 0, |
| state: UDPSingle, |
| } |
| dstEp = &Endpoint{ |
| seqLo: 0, |
| seqHi: 0, |
| maxWin: 0, |
| state: UDPFirstPacket, |
| } |
| createTime = time.Now() |
| expireTime = createTime.Add(30 * time.Second) |
| case header.TCPProtocolNumber: |
| tcp := header.TCP(transportHeader) |
| flags := tcp.Flags() |
| dataLen = payloadLength - uint16(tcp.DataOffset()) |
| wscale := -1 |
| if flags&header.TCPFlagSyn != 0 { |
| dataLen++ |
| synOpts := header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0) |
| wscale = synOpts.WS |
| } |
| if flags&header.TCPFlagFin != 0 { |
| dataLen++ |
| } |
| seqLo := seqnum(tcp.SequenceNumber()).Add(uint32(dataLen)) |
| maxWin := tcp.WindowSize() |
| if maxWin < TCPMaxZWPDataLen { |
| maxWin = TCPMaxZWPDataLen |
| } |
| srcEp = &Endpoint{ |
| seqLo: seqLo, |
| seqHi: seqLo + TCPMaxZWPDataLen, |
| maxWin: maxWin, |
| wscale: int8(wscale), |
| state: TCPOpening, |
| } |
| dstEp = &Endpoint{ // Assign temporary values as we haven't seen a packet. |
| seqLo: 0, |
| seqHi: TCPMaxZWPDataLen, |
| maxWin: TCPMaxZWPDataLen, |
| wscale: -1, |
| state: TCPFirstPacket, |
| } |
| createTime = time.Now() |
| expireTime = createTime.Add(60 * time.Second) |
| } |
| |
| var lanAddr, gwyAddr, extAddr tcpip.Address |
| var lanPort, gwyPort, extPort uint16 |
| switch dir { |
| case Incoming: |
| extAddr = srcAddr |
| extPort = srcPort |
| if isRDR { |
| gwyAddr = origAddr |
| gwyPort = origPort |
| } else { |
| gwyAddr = dstAddr |
| gwyPort = dstPort |
| } |
| lanAddr = dstAddr |
| lanPort = dstPort |
| case Outgoing: |
| if isNAT { |
| lanAddr = origAddr |
| lanPort = origPort |
| } else { |
| lanAddr = srcAddr |
| lanPort = srcPort |
| } |
| gwyAddr = srcAddr |
| gwyPort = srcPort |
| extAddr = dstAddr |
| extPort = dstPort |
| } |
| |
| s := &State{ |
| transProto: transProto, |
| dir: dir, |
| lanAddr: lanAddr, |
| lanPort: lanPort, |
| gwyAddr: gwyAddr, |
| gwyPort: gwyPort, |
| extAddr: extAddr, |
| extPort: extPort, |
| rsvdAddr: rsvdAddr, |
| rsvdPort: rsvdPort, |
| srcEp: srcEp, |
| dstEp: dstEp, |
| createTime: createTime, |
| expireTime: expireTime, |
| packets: 1, |
| bytes: uint32(dataLen), |
| } |
| |
| kLanToExt := Key{transProto, lanAddr, lanPort, extAddr, extPort} |
| kExtToGwy := Key{transProto, extAddr, extPort, gwyAddr, gwyPort} |
| ss.lanToExt[kLanToExt] = s |
| ss.extToGwy[kExtToGwy] = s |
| |
| if debug { |
| log.Printf("packet filter: new state: %v", s) |
| } |
| |
| return s |
| } |
| |
| // isICMPv4ErrorMessage returns true if t is ICMPv4 Error Message Type. |
| // It returns false if t is ICMPv4 Informational Message Type. |
| func isICMPv4ErrorMessage(t header.ICMPv4Type) bool { |
| return t == header.ICMPv4DstUnreachable || |
| t == header.ICMPv4SrcQuench || |
| t == header.ICMPv4Redirect || |
| t == header.ICMPv4TimeExceeded || |
| t == header.ICMPv4ParamProblem |
| } |
| |
| func (ss *States) findStateICMPv4(dir Direction, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, srcAddr tcpip.Address, dstAddr tcpip.Address, payloadLength uint16, transportHeader []byte, payload buffer.VectorisedView) (*State, error) { |
| if len(transportHeader) < header.ICMPv4MinimumSize { |
| return nil, ErrPacketTooShort |
| } |
| icmp := header.ICMPv4(transportHeader) |
| if isICMPv4ErrorMessage(icmp.Type()) { |
| // This ICMPv4 packet is reporting an error detected in a transport layer, and |
| // includes the IP header and the first 8 bytes of the transport header of the packet |
| // that had the error. |
| // For NAT and RDR, we have to rewrite the address and port in the packet. |
| // We also test if the values in transport header are consistent with the connection |
| // state we are tracking (but we have only the first 8 bytes and cannot do a full |
| // check). |
| if len(transportHeader) < header.IPv4MinimumSize { |
| return nil, ErrPacketTooShort |
| } |
| // First, look for an IP header at the offset 8. |
| b2 := transportHeader[8:] |
| ipv4 := header.IPv4(b2) |
| transProto2 := ipv4.TransportProtocol() |
| srcAddr = ipv4.SourceAddress() |
| dstAddr = ipv4.DestinationAddress() |
| if len(transportHeader) < int(ipv4.HeaderLength()+8) { // Need the first 8 bytes of transport header. |
| return nil, ErrPacketTooShort |
| } |
| // Here's the transport header. |
| th2 := b2[ipv4.HeaderLength():] |
| |
| switch transProto2 { |
| case header.UDPProtocolNumber: |
| udp := header.UDP(th2) |
| srcPort := udp.SourcePort() |
| dstPort := udp.DestinationPort() |
| s := ss.getState(dir, transProto2, srcAddr, srcPort, dstAddr, dstPort) |
| // There is nothing we can use in the first 8 bytes to update the state. |
| return s, nil |
| case header.TCPProtocolNumber: |
| tcp := header.TCP(th2) |
| srcPort := tcp.SourcePort() |
| dstPort := tcp.DestinationPort() |
| s := ss.getState(dir, transProto2, srcAddr, srcPort, dstAddr, dstPort) |
| if s == nil { |
| return nil, nil |
| } |
| // We have the first 8 bytes of the TCP header only, which means |
| // we just use seq to test the state, |
| seq := seqnum(tcp.SequenceNumber()) |
| err := s.updateStateTCPinICMP(dir, seq) |
| if err != nil { |
| return nil, err |
| } |
| return s, nil |
| default: |
| return nil, ErrUnknownProtocol |
| } |
| } else { |
| dataLen := payloadLength - 8 // ICMP header size is 8. |
| |
| var id uint16 |
| // ICMP query/reply message. |
| if len(transportHeader) > 4 { |
| id = binary.BigEndian.Uint16(transportHeader[4:]) |
| } else if len(transportHeader)+payload.Size() > 4 { |
| id = binary.BigEndian.Uint16(payload.First()[4-len(transportHeader):]) |
| } |
| s := ss.getState(dir, transProto, srcAddr, id, dstAddr, id) |
| if s == nil { |
| return nil, nil |
| } |
| err := s.updateStateICMPv4(dir, dataLen) |
| if err != nil { |
| return nil, err |
| } |
| return s, nil |
| } |
| } |
| |
| func (ss *States) findStateUDP(dir Direction, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, srcAddr tcpip.Address, srcPort uint16, dstAddr tcpip.Address, dstPort uint16, payloadLength uint16, transportHeader []byte) (*State, error) { |
| s := ss.getState(dir, transProto, srcAddr, srcPort, dstAddr, dstPort) |
| if s == nil { |
| return nil, nil |
| } |
| |
| dataLen := payloadLength - header.UDPMinimumSize |
| |
| err := s.updateStateUDP(dir, dataLen) |
| if err != nil { |
| return nil, err |
| } |
| return s, nil |
| } |
| |
| func (ss *States) findStateTCP(dir Direction, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, srcAddr tcpip.Address, srcPort uint16, dstAddr tcpip.Address, dstPort uint16, payloadLength uint16, transportHeader []byte) (*State, error) { |
| s := ss.getState(dir, transProto, srcAddr, srcPort, dstAddr, dstPort) |
| if s == nil { |
| return nil, nil |
| } |
| tcp := header.TCP(transportHeader) |
| dataLen := payloadLength - uint16(tcp.DataOffset()) |
| win := tcp.WindowSize() |
| seq := seqnum(tcp.SequenceNumber()) |
| ack := seqnum(tcp.AckNumber()) |
| flags := tcp.Flags() |
| wscale := -1 |
| if flags&header.TCPFlagSyn != 0 { |
| synOpts := header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0) |
| wscale = synOpts.WS |
| } |
| |
| err := s.updateStateTCP(dir, dataLen, win, seq, ack, flags, wscale) |
| if err != nil { |
| return nil, err |
| } |
| return s, nil |
| } |