blob: a461ec70371740d357efecfbc9f3f87dbf8924ab [file] [log] [blame]
// Copyright 2018 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// Package filter provides the implementation of packet filter.
package filter
import (
"log"
"sync/atomic"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/ports"
)
const debug = false
type Filter struct {
enabled atomic.Value // bool
portManager *ports.PortManager
rulesetMain RulesetMain
rulesetNAT RulesetNAT
rulesetRDR RulesetRDR
states *States
}
func New(pm *ports.PortManager) *Filter {
f := &Filter{
portManager: pm,
states: NewStates(),
}
f.enabled.Store(true)
return f
}
// Enable enables or disables the packet filter.
func (f *Filter) Enable(b bool) {
if b {
log.Printf("packet filter: enabled")
} else {
log.Printf("packet filter: disabled")
}
f.enabled.Store(b)
}
// IsEnabled returns true if the packet filter is currently enabled.
func (f *Filter) IsEnabled() bool {
return f.enabled.Load().(bool)
}
// Run is the entry point to the packet filter. It should be called from
// two hook locations in the network stack: one for incoming packets, another
// for outgoing packets.
func (f *Filter) Run(dir Direction, netProto tcpip.NetworkProtocolNumber, hdr buffer.Prependable, payload buffer.VectorisedView) Action {
if f.enabled.Load().(bool) == false {
// The filter is disabled.
return Pass
}
f.states.purgeExpiredEntries(f.portManager)
// Parse the network protocol header.
var transProto tcpip.TransportProtocolNumber
var srcAddr, dstAddr tcpip.Address
var payloadLength uint16
var transportHeader []byte
switch netProto {
case header.IPv4ProtocolNumber:
ipv4 := header.IPv4(hdr.View())
if !ipv4.IsValid(hdr.UsedLength() + payload.Size()) {
if debug {
log.Printf("packet filter: ipv4 packet is not valid")
}
return Drop
}
transProto = ipv4.TransportProtocol()
srcAddr = ipv4.SourceAddress()
dstAddr = ipv4.DestinationAddress()
payloadLength = ipv4.PayloadLength()
transportHeader = ipv4[ipv4.HeaderLength():]
case header.IPv6ProtocolNumber:
ipv6 := header.IPv6(hdr.View())
if !ipv6.IsValid(hdr.UsedLength() + payload.Size()) {
if debug {
log.Printf("packet filter: ipv6 packet is not valid")
}
return Drop
}
transProto = ipv6.TransportProtocol()
srcAddr = ipv6.SourceAddress()
dstAddr = ipv6.DestinationAddress()
payloadLength = ipv6.PayloadLength()
transportHeader = ipv6[header.IPv6MinimumSize:]
case header.ARPProtocolNumber:
// TODO: Anything?
return Pass
default:
if debug {
log.Printf("packet filter: drop unknown network protocol: %v (%s)", netProto, dir)
}
return Drop
}
lockKey := makeStatesLockKey(dir, srcAddr, dstAddr, transProto, transportHeader)
f.states.lock(lockKey)
defer f.states.unlock(lockKey)
switch transProto {
case header.ICMPv4ProtocolNumber:
return f.runForICMPv4(dir, srcAddr, dstAddr, payloadLength, hdr, transportHeader, payload)
case header.ICMPv6ProtocolNumber:
// Do nothing.
return Pass
case header.UDPProtocolNumber:
return f.runForUDP(dir, netProto, srcAddr, dstAddr, payloadLength, hdr, transportHeader, payload)
case header.TCPProtocolNumber:
return f.runForTCP(dir, netProto, srcAddr, dstAddr, payloadLength, hdr, transportHeader, payload)
default:
if debug {
log.Printf("packet filter: %d: drop unknown transport protocol: %d", dir, transProto)
}
return Drop
}
}
func (f *Filter) runForICMPv4(dir Direction, srcAddr, dstAddr tcpip.Address, payloadLength uint16, hdr buffer.Prependable, transportHeader []byte, payload buffer.VectorisedView) Action {
if s, err := f.states.findStateICMPv4(dir, header.IPv4ProtocolNumber, header.ICMPv4ProtocolNumber, srcAddr, dstAddr, payloadLength, transportHeader, payload); s != nil {
if debug {
log.Printf("packet filter: icmp state found: %v", s)
}
// If NAT or RDR is in effect, rewrite address and port.
// Note that findStateICMPv4 may return a state for a different transport protocol.
switch s.transProto {
case header.ICMPv4ProtocolNumber:
if s.lanAddr != s.gwyAddr {
switch dir {
case Incoming:
rewritePacketICMPv4(s.lanAddr, false, hdr, transportHeader)
case Outgoing:
rewritePacketICMPv4(s.gwyAddr, true, hdr, transportHeader)
}
}
case header.UDPProtocolNumber:
if s.lanAddr != s.gwyAddr || s.lanPort != s.gwyPort {
switch dir {
case Incoming:
rewritePacketUDPv4(s.lanAddr, s.lanPort, false, hdr, transportHeader)
case Outgoing:
rewritePacketUDPv4(s.gwyAddr, s.gwyPort, true, hdr, transportHeader)
}
}
case header.TCPProtocolNumber:
if s.lanAddr != s.gwyAddr || s.lanPort != s.gwyPort {
switch dir {
case Incoming:
rewritePacketTCPv4(s.lanAddr, s.lanPort, false, hdr, transportHeader)
case Outgoing:
rewritePacketTCPv4(s.gwyAddr, s.gwyPort, true, hdr, transportHeader)
}
}
default:
panic("Unsupported transport protocol")
}
return Pass
} else if err != nil {
if debug {
log.Printf("packet filter: %v", err)
}
return Drop
}
var nat *NAT
var origAddr tcpip.Address
if dir == Outgoing {
if nat = f.matchNAT(header.ICMPv4ProtocolNumber, srcAddr); nat != nil {
// Rewrite srcAddr in the packet.
// The original values are saved in origAddr.
origAddr = srcAddr
srcAddr = nat.newSrcAddr
rewritePacketICMPv4(srcAddr, true, hdr, transportHeader)
}
}
// TODO: Add interface parameter.
rm := f.matchMain(dir, header.ICMPv4ProtocolNumber, srcAddr, 0, dstAddr, 0)
if rm != nil {
if rm.log {
// TODO: Improve the log format.
log.Printf("packet filter: Rule matched: %+v", rm)
}
if rm.action == DropReset {
if nat != nil {
// Revert the packet modified for NAT.
rewritePacketICMPv4(origAddr, true, hdr, transportHeader)
}
// TODO: Send a Reset packet.
return Drop
} else if rm.action == Drop {
return Drop
}
}
if (rm != nil && rm.keepState) || nat != nil {
f.states.createState(dir, header.ICMPv4ProtocolNumber, srcAddr, 0, dstAddr, 0, origAddr, 0, "", 0, nat != nil, false, payloadLength, hdr, transportHeader, payload)
}
return Pass
}
func (f *Filter) runForUDP(dir Direction, netProto tcpip.NetworkProtocolNumber, srcAddr, dstAddr tcpip.Address, payloadLength uint16, hdr buffer.Prependable, transportHeader []byte, payload buffer.VectorisedView) Action {
if len(transportHeader) < header.UDPMinimumSize {
if debug {
log.Printf("packet filter: udp packet too short")
}
return Drop
}
udp := header.UDP(transportHeader)
srcPort := udp.SourcePort()
dstPort := udp.DestinationPort()
if s, err := f.states.findStateUDP(dir, netProto, header.UDPProtocolNumber, srcAddr, srcPort, dstAddr, dstPort, payloadLength, transportHeader); s != nil {
if debug {
log.Printf("packet filter: udp state found: %v", s)
}
// If NAT or RDR is in effect, rewrite address and port.
if s.lanAddr != s.gwyAddr || s.lanPort != s.gwyPort {
switch netProto {
case header.IPv4ProtocolNumber:
switch dir {
case Incoming:
rewritePacketUDPv4(s.lanAddr, s.lanPort, false, hdr, transportHeader)
case Outgoing:
rewritePacketUDPv4(s.gwyAddr, s.gwyPort, true, hdr, transportHeader)
}
case header.IPv6ProtocolNumber:
switch dir {
case Incoming:
rewritePacketUDPv6(s.lanAddr, s.lanPort, false, hdr, transportHeader)
case Outgoing:
rewritePacketUDPv6(s.gwyAddr, s.gwyPort, true, hdr, transportHeader)
}
}
}
return Pass
} else if err != nil {
if debug {
log.Printf("packet filter: %v", err)
}
return Drop
}
var nat *NAT
var rdr *RDR
var origAddr tcpip.Address
var origPort uint16
var newAddr tcpip.Address
var newPort uint16
switch dir {
case Incoming:
if rdr = f.matchRDR(header.UDPProtocolNumber, dstAddr, dstPort); rdr != nil {
if debug {
log.Printf("packet filter: RDR rule matched: proto: %d, dstAddr: %s, dstPort: %d, newDstAddr: %s, newDstPort: %d, nic: %d", rdr.transProto, rdr.dstAddr, rdr.dstPort, rdr.newDstAddr, rdr.newDstPort, rdr.nic)
}
// Rewrite dstAddr and dstPort in the packet.
// The original values are saved in origAddr and origPort.
origAddr = dstAddr
dstAddr = rdr.newDstAddr
origPort = dstPort
dstPort = rdr.newDstPort
if debug {
log.Printf("packet filter: RDR: rewrite orig(%s:%d) with new(%s:%d)", origAddr, origPort, dstAddr, dstPort)
}
switch netProto {
case header.IPv4ProtocolNumber:
rewritePacketUDPv4(dstAddr, dstPort, false, hdr, transportHeader)
case header.IPv6ProtocolNumber:
rewritePacketUDPv6(dstAddr, dstPort, false, hdr, transportHeader)
}
}
case Outgoing:
if nat = f.matchNAT(header.UDPProtocolNumber, srcAddr); nat != nil {
if debug {
log.Printf("packet filter: NAT rule matched: proto: %d, srcNet: %s(%s), srcAddr: %s, nic: %d", nat.transProto, nat.srcSubnet.ID(), tcpip.Address(nat.srcSubnet.Mask()), nat.newSrcAddr, nat.nic)
}
newAddr = nat.newSrcAddr
// Reserve a new port.
netProtos := []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber}
var e *tcpip.Error
newPort, e = f.portManager.ReservePort(netProtos, header.UDPProtocolNumber, newAddr, 0, false)
if e != nil {
if debug {
log.Printf("packet filter: ReservePort: %v", e)
}
return Drop
}
// Rewrite srcAddr and srcPort in the packet.
// The original values are saved in origAddr and origPort.
origAddr = srcAddr
srcAddr = newAddr
origPort = srcPort
srcPort = newPort
if debug {
log.Printf("packet filter: NAT: rewrite orig(%s:%d) with new(%s:%d)", origAddr, origPort, srcAddr, srcPort)
}
switch netProto {
case header.IPv4ProtocolNumber:
rewritePacketUDPv4(srcAddr, srcPort, true, hdr, transportHeader)
case header.IPv6ProtocolNumber:
rewritePacketUDPv6(srcAddr, srcPort, true, hdr, transportHeader)
}
}
}
// TODO: Add interface parameter.
rm := f.matchMain(dir, header.UDPProtocolNumber, srcAddr, srcPort, dstAddr, dstPort)
if rm != nil {
if rm.log {
// TODO: Improve the log format.
log.Printf("packet filter: Rule matched: %+v", rm)
}
if rm.action == DropReset {
if nat != nil {
// Revert the packet modified for NAT.
switch netProto {
case header.IPv4ProtocolNumber:
rewritePacketUDPv4(origAddr, origPort, true, hdr, transportHeader)
case header.IPv6ProtocolNumber:
rewritePacketUDPv6(origAddr, origPort, true, hdr, transportHeader)
}
}
if rdr != nil {
// Revert the packet modified for RDR.
switch netProto {
case header.IPv4ProtocolNumber:
rewritePacketUDPv4(origAddr, origPort, false, hdr, transportHeader)
case header.IPv6ProtocolNumber:
rewritePacketUDPv6(origAddr, origPort, false, hdr, transportHeader)
}
}
// TODO: Send a Reset packet.
return Drop
} else if rm.action == Drop {
return Drop
}
}
if (rm != nil && rm.keepState) || nat != nil || rdr != nil {
f.states.createState(dir, header.UDPProtocolNumber, srcAddr, srcPort, dstAddr, dstPort, origAddr, origPort, newAddr, newPort, nat != nil, rdr != nil, payloadLength, hdr, transportHeader, payload)
}
return Pass
}
func (f *Filter) runForTCP(dir Direction, netProto tcpip.NetworkProtocolNumber, srcAddr, dstAddr tcpip.Address, payloadLength uint16, hdr buffer.Prependable, transportHeader []byte, payload buffer.VectorisedView) Action {
if len(transportHeader) < header.TCPMinimumSize {
if debug {
log.Printf("packet filter: tcp packet too short")
}
return Drop
}
tcp := header.TCP(transportHeader)
srcPort := tcp.SourcePort()
dstPort := tcp.DestinationPort()
if s, err := f.states.findStateTCP(dir, netProto, header.TCPProtocolNumber, srcAddr, srcPort, dstAddr, dstPort, payloadLength, transportHeader); s != nil {
if debug {
log.Printf("packet filter: tcp state found: %v", s)
}
// If NAT or RDR is in effect, rewrite address and port.
if s.lanAddr != s.gwyAddr || s.lanPort != s.gwyPort {
switch netProto {
case header.IPv4ProtocolNumber:
switch dir {
case Incoming:
rewritePacketTCPv4(s.lanAddr, s.lanPort, false, hdr, transportHeader)
case Outgoing:
rewritePacketTCPv4(s.gwyAddr, s.gwyPort, true, hdr, transportHeader)
}
case header.IPv6ProtocolNumber:
switch dir {
case Incoming:
rewritePacketTCPv6(s.lanAddr, s.lanPort, false, hdr, transportHeader)
case Outgoing:
rewritePacketTCPv6(s.gwyAddr, s.gwyPort, true, hdr, transportHeader)
}
}
}
return Pass
} else if err != nil {
if debug {
log.Printf("packet filter: %v", err)
}
return Drop
}
var nat *NAT
var rdr *RDR
var origAddr tcpip.Address
var origPort uint16
var newAddr tcpip.Address
var newPort uint16
switch dir {
case Incoming:
if rdr = f.matchRDR(header.TCPProtocolNumber, dstAddr, dstPort); rdr != nil {
// Rewrite dstAddr and dstPort in the packet.
// The original values are saved in origAddr and origPort.
origAddr = dstAddr
dstAddr = rdr.newDstAddr
origPort = dstPort
dstPort = rdr.newDstPort
switch netProto {
case header.IPv4ProtocolNumber:
rewritePacketTCPv4(dstAddr, dstPort, false, hdr, transportHeader)
case header.IPv6ProtocolNumber:
rewritePacketTCPv6(dstAddr, dstPort, false, hdr, transportHeader)
}
}
case Outgoing:
if nat = f.matchNAT(header.TCPProtocolNumber, srcAddr); nat != nil {
newAddr = nat.newSrcAddr
// Reserve a new port.
netProtos := []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber}
var e *tcpip.Error
newPort, e = f.portManager.ReservePort(netProtos, header.TCPProtocolNumber, newAddr, 0, false)
if e != nil {
if debug {
log.Printf("packet filter: ReservePort: %v", e)
}
return Drop
}
// Rewrite srcAddr and srcPort in the packet.
// The original values are saved in origAddr and origPort.
origAddr = srcAddr
srcAddr = newAddr
origPort = srcPort
srcPort = newPort
switch netProto {
case header.IPv4ProtocolNumber:
rewritePacketTCPv4(srcAddr, srcPort, true, hdr, transportHeader)
case header.IPv6ProtocolNumber:
rewritePacketTCPv6(srcAddr, srcPort, true, hdr, transportHeader)
}
}
}
// TODO: Add interface parameter.
rm := f.matchMain(dir, header.TCPProtocolNumber, srcAddr, srcPort, dstAddr, dstPort)
if rm != nil {
if rm.log {
// TODO: Improve the log format.
log.Printf("packet filter: Rule matched: %+v", rm)
}
if rm.action == DropReset {
if nat != nil {
switch netProto {
case header.IPv4ProtocolNumber:
rewritePacketTCPv4(origAddr, origPort, true, hdr, transportHeader)
case header.IPv6ProtocolNumber:
rewritePacketTCPv6(origAddr, origPort, true, hdr, transportHeader)
}
}
if rdr != nil {
// Revert the packet modified for RDR.
switch netProto {
case header.IPv4ProtocolNumber:
rewritePacketTCPv4(origAddr, origPort, false, hdr, transportHeader)
case header.IPv6ProtocolNumber:
rewritePacketTCPv6(origAddr, origPort, false, hdr, transportHeader)
}
}
// TODO: Send a Reset packet.
return Drop
} else if rm.action == Drop {
return Drop
}
}
if (rm != nil && rm.keepState) || nat != nil || rdr != nil {
f.states.createState(dir, header.TCPProtocolNumber, srcAddr, srcPort, dstAddr, dstPort, origAddr, origPort, newAddr, newPort, nat != nil, rdr != nil, payloadLength, hdr, transportHeader, payload)
}
return Pass
}
func (f *Filter) matchMain(dir Direction, transProto tcpip.TransportProtocolNumber, srcAddr tcpip.Address, srcPort uint16, dstAddr tcpip.Address, dstPort uint16) *Rule {
f.rulesetMain.RLock()
defer f.rulesetMain.RUnlock()
var rm *Rule
for i := range f.rulesetMain.v {
r := &f.rulesetMain.v[i]
if r.direction == dir &&
r.transProto == transProto &&
(r.srcSubnet == nil || r.srcSubnet.Contains(srcAddr) != r.srcSubnetInvertMatch) &&
(r.srcPort == 0 || r.srcPort == srcPort) &&
(r.dstSubnet == nil || r.dstSubnet.Contains(dstAddr) != r.dstSubnetInvertMatch) &&
(r.dstPort == 0 || r.dstPort == dstPort) {
rm = r
if r.quick {
break
}
}
}
return rm
}
func (f *Filter) matchNAT(transProto tcpip.TransportProtocolNumber, srcAddr tcpip.Address) *NAT {
f.rulesetNAT.RLock()
defer f.rulesetNAT.RUnlock()
for i := range f.rulesetNAT.v {
r := &f.rulesetNAT.v[i]
if r.transProto == transProto &&
r.srcSubnet.Contains(srcAddr) {
return r
}
}
return nil
}
func (f *Filter) matchRDR(transProto tcpip.TransportProtocolNumber, dstAddr tcpip.Address, dstPort uint16) *RDR {
f.rulesetRDR.RLock()
defer f.rulesetRDR.RUnlock()
for i := range f.rulesetRDR.v {
r := &f.rulesetRDR.v[i]
if r.transProto == transProto &&
r.dstAddr == dstAddr &&
r.dstPort == dstPort {
return r
}
}
return nil
}