blob: 14bbf798161e7dc6fa45bde1d06024c18e28e477 [file] [log] [blame]
package portallocator
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"github.com/containerd/log"
)
type ipMapping map[netip.Addr]protoMap
var (
// errAllPortsAllocated is returned when no more ports are available
errAllPortsAllocated = errors.New("all ports are allocated")
// errUnknownProtocol is returned when an unknown protocol was specified
errUnknownProtocol = errors.New("unknown protocol")
once sync.Once
instance *PortAllocator
)
// alreadyAllocatedErr is the returned error information when a requested port is already being used
type alreadyAllocatedErr struct {
ip string
port int
}
// Error is the implementation of error.Error interface
func (e alreadyAllocatedErr) Error() string {
return fmt.Sprintf("Bind for %s:%d failed: port is already allocated", e.ip, e.port)
}
type (
// PortAllocator manages the transport ports database
PortAllocator struct {
mutex sync.Mutex
defaultIP net.IP
ipMap ipMapping
begin int
end int
}
portRange struct {
begin int
end int
last int
}
portMap struct {
p map[int]struct{}
defaultRange string
portRanges map[string]*portRange
}
protoMap map[string]*portMap
)
// GetPortRange returns the PortAllocator's default port range.
//
// This function is for internal use in tests, and must not be used
// for other purposes.
func GetPortRange() (start, end uint16) {
p := Get()
return uint16(p.begin), uint16(p.end)
}
// Get returns the PortAllocator
func Get() *PortAllocator {
// Port Allocator is a singleton
once.Do(func() {
instance = newInstance()
})
return instance
}
func newInstance() *PortAllocator {
begin, end := dynamicPortRange()
return &PortAllocator{
ipMap: makeIpMapping(begin, end),
defaultIP: net.IPv4zero,
begin: begin,
end: end,
}
}
func dynamicPortRange() (start, end int) {
begin, end, err := getDynamicPortRange()
if err != nil {
log.G(context.TODO()).WithError(err).Infof("falling back to default port range %d-%d", defaultPortRangeStart, defaultPortRangeEnd)
return defaultPortRangeStart, defaultPortRangeEnd
}
return begin, end
}
func makeIpMapping(begin, end int) ipMapping {
return ipMapping{
netip.IPv4Unspecified(): makeProtoMap(begin, end),
netip.IPv6Unspecified(): makeProtoMap(begin, end),
}
}
func makeProtoMap(begin, end int) protoMap {
return protoMap{
"tcp": newPortMap(begin, end),
"udp": newPortMap(begin, end),
"sctp": newPortMap(begin, end),
}
}
// RequestPort requests new port from global ports pool for specified ip and proto.
// If port is 0 it returns first free port. Otherwise it checks port availability
// in proto's pool and returns that port or error if port is already busy.
func (p *PortAllocator) RequestPort(ip net.IP, proto string, port int) (int, error) {
if ip == nil {
ip = p.defaultIP // FIXME(thaJeztah): consider making this a required argument and producing an error instead, or set default when constructing.
}
return p.RequestPortsInRange([]net.IP{ip}, proto, port, port)
}
// RequestPortInRange is equivalent to [PortAllocator.RequestPortsInRange] with
// a single IP address. If ip is nil, a port is instead requested for the
// default IP (0.0.0.0).
func (p *PortAllocator) RequestPortInRange(ip net.IP, proto string, portStart, portEnd int) (int, error) {
if ip == nil {
ip = p.defaultIP // FIXME(thaJeztah): consider making this a required argument and producing an error instead, or set default when constructing.
}
return p.RequestPortsInRange([]net.IP{ip}, proto, portStart, portEnd)
}
// RequestPortsInRange requests new ports from the global ports pool, for proto and each of ips.
// If portStart and portEnd are 0 it returns the first free port in the default ephemeral range.
// If portStart != portEnd it returns the first free port in the requested range.
// Otherwise, (portStart == portEnd) it checks port availability in the requested proto's port-pool
// and returns that port or error if port is already busy.
func (p *PortAllocator) RequestPortsInRange(ips []net.IP, proto string, portStart, portEnd int) (int, error) {
if proto != "tcp" && proto != "udp" && proto != "sctp" {
return 0, errUnknownProtocol
}
if portStart != 0 || portEnd != 0 {
// Validate custom port-range
if portStart == 0 || portEnd == 0 || portEnd < portStart {
return 0, fmt.Errorf("invalid port range: %d-%d", portStart, portEnd)
}
}
if len(ips) == 0 {
return 0, fmt.Errorf("no IP addresses specified")
}
p.mutex.Lock()
defer p.mutex.Unlock()
// Collect the portMap for the required proto and each of the IP addresses.
// If there's a new IP address, create portMap objects for each of the protocols
// and collect the one that's needed for this request.
// Mark these portMap objects as needing port allocations.
type portMapRef struct {
portMap *portMap
allocate bool
}
ipToPortMapRef := map[netip.Addr]*portMapRef{}
var ips4, ips6 bool
for _, ip := range ips {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return 0, fmt.Errorf("invalid IP address: %s", ip)
}
addr = addr.Unmap()
if addr.Is4() {
ips4 = true
} else {
ips6 = true
}
// Make sure addr -> protoMap[proto] -> portMap exists.
if _, ok := p.ipMap[addr]; !ok {
p.ipMap[addr] = makeProtoMap(p.begin, p.end)
}
// Remember the protoMap[proto] portMap, it needs the port allocation.
ipToPortMapRef[addr] = &portMapRef{
portMap: p.ipMap[addr][proto],
allocate: true,
}
}
// If ips includes an unspecified address, the port needs to be free in all ipMaps
// for that address family. Otherwise, the port needs only needs to be free in the
// per-address maps for ips, and the map for 0.0.0.0/::.
//
// Collect the additional portMaps where the port needs to be free, but
// don't mark them as needing port allocation.
for _, unspecAddr := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
if _, ok := ipToPortMapRef[unspecAddr]; ok {
for addr, ipm := range p.ipMap {
if unspecAddr.Is4() == addr.Is4() {
if _, ok := ipToPortMapRef[addr]; !ok {
ipToPortMapRef[addr] = &portMapRef{portMap: ipm[proto]}
}
}
}
} else if (unspecAddr.Is4() && ips4) || (unspecAddr.Is6() && ips6) {
ipToPortMapRef[unspecAddr] = &portMapRef{portMap: p.ipMap[unspecAddr][proto]}
}
}
// Handle a request for a specific port.
if portStart > 0 && portStart == portEnd {
for addr, pMap := range ipToPortMapRef {
if _, allocated := pMap.portMap.p[portStart]; allocated {
return 0, alreadyAllocatedErr{ip: addr.String(), port: portStart}
}
}
for _, pMap := range ipToPortMapRef {
if pMap.allocate {
pMap.portMap.p[portStart] = struct{}{}
}
}
return portStart, nil
}
// Handle a request for a port range.
// Create/fetch ranges for each portMap.
pRanges := map[netip.Addr]*portRange{}
for addr, pMap := range ipToPortMapRef {
pRanges[addr] = pMap.portMap.getPortRange(portStart, portEnd)
}
// Arbitrarily starting after the last port allocated for the first address, search
// for a port that's available in all ranges.
firstAddr, _ := netip.AddrFromSlice(ips[0])
firstRange := pRanges[firstAddr.Unmap()]
port := firstRange.last
for i := firstRange.begin; i <= firstRange.end; i++ {
port++
if port > firstRange.end {
port = firstRange.begin
}
portAlreadyAllocated := func() bool {
for _, pMap := range ipToPortMapRef {
if _, ok := pMap.portMap.p[port]; ok {
return true
}
}
return false
}
if !portAlreadyAllocated() {
for addr, pMap := range ipToPortMapRef {
if pMap.allocate {
pMap.portMap.p[port] = struct{}{}
pRanges[addr].last = port
}
}
return port, nil
}
}
return 0, errAllPortsAllocated
}
// ReleasePort releases port from global ports pool for specified ip and proto.
func (p *PortAllocator) ReleasePort(ip net.IP, proto string, port int) {
p.mutex.Lock()
defer p.mutex.Unlock()
if ip == nil {
ip = p.defaultIP // FIXME(thaJeztah): consider making this a required argument and producing an error instead, or set default when constructing.
}
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return
}
protomap, ok := p.ipMap[addr.Unmap()]
if !ok {
return
}
delete(protomap[proto].p, port)
}
// ReleaseAll releases all ports for all ips.
func (p *PortAllocator) ReleaseAll() {
begin, end := dynamicPortRange()
p.mutex.Lock()
p.ipMap = makeIpMapping(begin, end)
p.mutex.Unlock()
}
func getRangeKey(portStart, portEnd int) string {
return fmt.Sprintf("%d-%d", portStart, portEnd)
}
func newPortRange(portStart, portEnd int) *portRange {
return &portRange{
begin: portStart,
end: portEnd,
last: portEnd,
}
}
func newPortMap(portStart, portEnd int) *portMap {
defaultKey := getRangeKey(portStart, portEnd)
return &portMap{
p: map[int]struct{}{},
defaultRange: defaultKey,
portRanges: map[string]*portRange{
defaultKey: newPortRange(portStart, portEnd),
},
}
}
func (pm *portMap) getPortRange(portStart, portEnd int) *portRange {
var key string
if portStart == 0 && portEnd == 0 {
key = pm.defaultRange
} else {
key = getRangeKey(portStart, portEnd)
}
// Return existing port range, if already known.
if pr, exists := pm.portRanges[key]; exists {
return pr
}
// Otherwise create a new port range.
pr := newPortRange(portStart, portEnd)
pm.portRanges[key] = pr
return pr
}