Merge remote-tracking branch 'upstream/master'

Change-Id: Ib612be3e101fa55248e78eb90fee66dfa2a92e77
diff --git a/tcpip/hash/jenkins/jenkins.go b/tcpip/hash/jenkins/jenkins.go
new file mode 100644
index 0000000..e66d5f1
--- /dev/null
+++ b/tcpip/hash/jenkins/jenkins.go
@@ -0,0 +1,80 @@
+// Copyright 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package jenkins implements Jenkins's one_at_a_time, non-cryptographic hash
+// functions created by by Bob Jenkins.
+//
+// See https://en.wikipedia.org/wiki/Jenkins_hash_function#cite_note-dobbsx-1
+//
+package jenkins
+
+import (
+	"hash"
+)
+
+// Sum32 represents Jenkins's one_at_a_time hash.
+//
+// Use the Sum32 type directly (as opposed to New32 below)
+// to avoid allocations.
+type Sum32 uint32
+
+// New32 returns a new 32-bit Jenkins's one_at_a_time hash.Hash.
+//
+// Its Sum method will lay the value out in big-endian byte order.
+func New32() hash.Hash32 {
+	var s Sum32
+	return &s
+}
+
+// Reset resets the hash to its initial state.
+func (s *Sum32) Reset() { *s = 0 }
+
+// Sum32 returns the hash value
+func (s *Sum32) Sum32() uint32 {
+	hash := *s
+
+	hash += (hash << 3)
+	hash ^= hash >> 11
+	hash += hash << 15
+
+	return uint32(hash)
+}
+
+// Write adds more data to the running hash.
+//
+// It never returns an error.
+func (s *Sum32) Write(data []byte) (int, error) {
+	hash := *s
+	for _, b := range data {
+		hash += Sum32(b)
+		hash += hash << 10
+		hash ^= hash >> 6
+	}
+	*s = hash
+	return len(data), nil
+}
+
+// Size returns the number of bytes Sum will return.
+func (s *Sum32) Size() int { return 4 }
+
+// BlockSize returns the hash's underlying block size.
+func (s *Sum32) BlockSize() int { return 1 }
+
+// Sum appends the current hash to in and returns the resulting slice.
+//
+// It does not change the underlying hash state.
+func (s *Sum32) Sum(in []byte) []byte {
+	v := s.Sum32()
+	return append(in, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
+}
diff --git a/tcpip/hash/jenkins/jenkins_test.go b/tcpip/hash/jenkins/jenkins_test.go
new file mode 100644
index 0000000..9d86174
--- /dev/null
+++ b/tcpip/hash/jenkins/jenkins_test.go
@@ -0,0 +1,176 @@
+// Copyright 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package jenkins
+
+import (
+	"bytes"
+	"encoding/binary"
+	"hash"
+	"hash/fnv"
+	"math"
+	"testing"
+)
+
+func TestGolden32(t *testing.T) {
+	var golden32 = []struct {
+		out []byte
+		in  string
+	}{
+		{[]byte{0x00, 0x00, 0x00, 0x00}, ""},
+		{[]byte{0xca, 0x2e, 0x94, 0x42}, "a"},
+		{[]byte{0x45, 0xe6, 0x1e, 0x58}, "ab"},
+		{[]byte{0xed, 0x13, 0x1f, 0x5b}, "abc"},
+	}
+
+	hash := New32()
+
+	for _, g := range golden32 {
+		hash.Reset()
+		done, error := hash.Write([]byte(g.in))
+		if error != nil {
+			t.Fatalf("write error: %s", error)
+		}
+		if done != len(g.in) {
+			t.Fatalf("wrote only %d out of %d bytes", done, len(g.in))
+		}
+		if actual := hash.Sum(nil); !bytes.Equal(g.out, actual) {
+			t.Errorf("hash(%q) = 0x%x want 0x%x", g.in, actual, g.out)
+		}
+	}
+}
+
+func TestIntegrity32(t *testing.T) {
+	data := []byte{'1', '2', 3, 4, 5}
+
+	h := New32()
+	h.Write(data)
+	sum := h.Sum(nil)
+
+	if size := h.Size(); size != len(sum) {
+		t.Fatalf("Size()=%d but len(Sum())=%d", size, len(sum))
+	}
+
+	if a := h.Sum(nil); !bytes.Equal(sum, a) {
+		t.Fatalf("first Sum()=0x%x, second Sum()=0x%x", sum, a)
+	}
+
+	h.Reset()
+	h.Write(data)
+	if a := h.Sum(nil); !bytes.Equal(sum, a) {
+		t.Fatalf("Sum()=0x%x, but after Reset() Sum()=0x%x", sum, a)
+	}
+
+	h.Reset()
+	h.Write(data[:2])
+	h.Write(data[2:])
+	if a := h.Sum(nil); !bytes.Equal(sum, a) {
+		t.Fatalf("Sum()=0x%x, but with partial writes, Sum()=0x%x", sum, a)
+	}
+
+	sum32 := h.(hash.Hash32).Sum32()
+	if sum32 != binary.BigEndian.Uint32(sum) {
+		t.Fatalf("Sum()=0x%x, but Sum32()=0x%x", sum, sum32)
+	}
+}
+
+func BenchmarkJenkins32KB(b *testing.B) {
+	h := New32()
+
+	b.SetBytes(1024)
+	data := make([]byte, 1024)
+	for i := range data {
+		data[i] = byte(i)
+	}
+	in := make([]byte, 0, h.Size())
+
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		h.Reset()
+		h.Write(data)
+		h.Sum(in)
+	}
+}
+
+func BenchmarkFnv32(b *testing.B) {
+	arr := make([]int64, 1000)
+	for i := 0; i < b.N; i++ {
+		var payload [8]byte
+		binary.BigEndian.PutUint32(payload[:4], uint32(i))
+		binary.BigEndian.PutUint32(payload[4:], uint32(i))
+
+		h := fnv.New32()
+		h.Write(payload[:])
+		idx := int(h.Sum32()) % len(arr)
+		arr[idx]++
+	}
+	b.StopTimer()
+	c := 0
+	if b.N > 1000000 {
+		for i := 0; i < len(arr)-1; i++ {
+			if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) {
+				if c == 0 {
+					b.Logf("i %d val[i] %d val[i+1] %d b.N %b\n", i, arr[i], arr[i+1], b.N)
+				}
+				c++
+			}
+		}
+		if c > 0 {
+			b.Logf("Unbalanced buckets: %d", c)
+		}
+	}
+}
+
+func BenchmarkSum32(b *testing.B) {
+	arr := make([]int64, 1000)
+	for i := 0; i < b.N; i++ {
+		var payload [8]byte
+		binary.BigEndian.PutUint32(payload[:4], uint32(i))
+		binary.BigEndian.PutUint32(payload[4:], uint32(i))
+		h := Sum32(0)
+		h.Write(payload[:])
+		idx := int(h.Sum32()) % len(arr)
+		arr[idx]++
+	}
+	b.StopTimer()
+	if b.N > 1000000 {
+		for i := 0; i < len(arr)-1; i++ {
+			if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) {
+				b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N)
+				break
+			}
+		}
+	}
+}
+
+func BenchmarkNew32(b *testing.B) {
+	arr := make([]int64, 1000)
+	for i := 0; i < b.N; i++ {
+		var payload [8]byte
+		binary.BigEndian.PutUint32(payload[:4], uint32(i))
+		binary.BigEndian.PutUint32(payload[4:], uint32(i))
+		h := New32()
+		h.Write(payload[:])
+		idx := int(h.Sum32()) % len(arr)
+		arr[idx]++
+	}
+	b.StopTimer()
+	if b.N > 1000000 {
+		for i := 0; i < len(arr)-1; i++ {
+			if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) {
+				b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N)
+				break
+			}
+		}
+	}
+}
diff --git a/tcpip/ports/ports.go b/tcpip/ports/ports.go
index ba80a6f..eb0ada8 100644
--- a/tcpip/ports/ports.go
+++ b/tcpip/ports/ports.go
@@ -42,23 +42,47 @@
 	allocatedPorts map[portDescriptor]bindAddresses
 }
 
+type portNode struct {
+	reuse bool
+	refs  int
+}
+
 // bindAddresses is a set of IP addresses.
-type bindAddresses map[tcpip.Address]struct{}
+type bindAddresses map[tcpip.Address]portNode
 
 // isAvailable checks whether an IP address is available to bind to.
-func (b bindAddresses) isAvailable(addr tcpip.Address) bool {
+func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool) bool {
 	if addr == anyIPAddress {
-		return len(b) == 0
+		if len(b) == 0 {
+			return true
+		}
+		if !reuse {
+			return false
+		}
+		for _, n := range b {
+			if !n.reuse {
+				return false
+			}
+		}
+		return true
 	}
 
 	// If all addresses for this portDescriptor are already bound, no
 	// address is available.
-	if _, ok := b[anyIPAddress]; ok {
-		return false
+	if n, ok := b[anyIPAddress]; ok {
+		if !reuse {
+			return false
+		}
+		if !n.reuse {
+			return false
+		}
 	}
 
-	if _, ok := b[addr]; ok {
-		return false
+	if n, ok := b[addr]; ok {
+		if !reuse {
+			return false
+		}
+		return n.reuse
 	}
 	return true
 }
@@ -92,17 +116,17 @@
 }
 
 // IsPortAvailable tests if the given port is available on all given protocols.
-func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool {
+func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
 	s.mu.Lock()
 	defer s.mu.Unlock()
-	return s.isPortAvailableLocked(networks, transport, addr, port)
+	return s.isPortAvailableLocked(networks, transport, addr, port, reuse)
 }
 
-func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool {
+func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
 	for _, network := range networks {
 		desc := portDescriptor{network, transport, port}
 		if addrs, ok := s.allocatedPorts[desc]; ok {
-			if !addrs.isAvailable(addr) {
+			if !addrs.isAvailable(addr, reuse) {
 				return false
 			}
 		}
@@ -114,14 +138,14 @@
 // reserved by another endpoint. If port is zero, ReservePort will search for
 // an unreserved ephemeral port and reserve it, returning its value in the
 // "port" return value.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) (reservedPort uint16, err *tcpip.Error) {
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) (reservedPort uint16, err *tcpip.Error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
 	// If a port is specified, just try to reserve it for all network
 	// protocols.
 	if port != 0 {
-		if !s.reserveSpecificPort(networks, transport, addr, port) {
+		if !s.reserveSpecificPort(networks, transport, addr, port, reuse) {
 			return 0, tcpip.ErrPortInUse
 		}
 		return port, nil
@@ -129,13 +153,13 @@
 
 	// A port wasn't specified, so try to find one.
 	return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
-		return s.reserveSpecificPort(networks, transport, addr, p), nil
+		return s.reserveSpecificPort(networks, transport, addr, p, reuse), nil
 	})
 }
 
 // reserveSpecificPort tries to reserve the given port on all given protocols.
-func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool {
-	if !s.isPortAvailableLocked(networks, transport, addr, port) {
+func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+	if !s.isPortAvailableLocked(networks, transport, addr, port, reuse) {
 		return false
 	}
 
@@ -147,7 +171,12 @@
 			m = make(bindAddresses)
 			s.allocatedPorts[desc] = m
 		}
-		m[addr] = struct{}{}
+		if n, ok := m[addr]; ok {
+			n.refs++
+			m[addr] = n
+		} else {
+			m[addr] = portNode{reuse: reuse, refs: 1}
+		}
 	}
 
 	return true
@@ -162,7 +191,16 @@
 	for _, network := range networks {
 		desc := portDescriptor{network, transport, port}
 		if m, ok := s.allocatedPorts[desc]; ok {
-			delete(m, addr)
+			n, ok := m[addr]
+			if !ok {
+				continue
+			}
+			n.refs--
+			if n.refs == 0 {
+				delete(m, addr)
+			} else {
+				m[addr] = n
+			}
 			if len(m) == 0 {
 				delete(s.allocatedPorts, desc)
 			}
diff --git a/tcpip/ports/ports_test.go b/tcpip/ports/ports_test.go
index 25cb3dc..bee0e6d 100644
--- a/tcpip/ports/ports_test.go
+++ b/tcpip/ports/ports_test.go
@@ -28,67 +28,99 @@
 	fakeIPAddress1 = tcpip.Address("\x08\x08\x08\x09")
 )
 
-func TestPortReservation(t *testing.T) {
-	pm := NewPortManager()
-	net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber}
+type portReserveTestAction struct {
+	port    uint16
+	ip      tcpip.Address
+	want    *tcpip.Error
+	reuse   bool
+	release bool
+}
 
+func TestPortReservation(t *testing.T) {
 	for _, test := range []struct {
-		port uint16
-		ip   tcpip.Address
-		want *tcpip.Error
+		tname   string
+		actions []portReserveTestAction
 	}{
 		{
-			port: 80,
-			ip:   fakeIPAddress,
-			want: nil,
+			tname: "bind to ip",
+			actions: []portReserveTestAction{
+				{port: 80, ip: fakeIPAddress, want: nil},
+				{port: 80, ip: fakeIPAddress1, want: nil},
+				/* N.B. Order of tests matters! */
+				{port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse},
+				{port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true},
+			},
 		},
 		{
-			port: 80,
-			ip:   fakeIPAddress1,
-			want: nil,
-		},
-		{
-			/* N.B. Order of tests matters! */
-			port: 80,
-			ip:   anyIPAddress,
-			want: tcpip.ErrPortInUse,
-		},
-		{
-			port: 22,
-			ip:   anyIPAddress,
-			want: nil,
-		},
-		{
-			port: 22,
-			ip:   fakeIPAddress,
-			want: tcpip.ErrPortInUse,
-		},
-		{
-			port: 0,
-			ip:   fakeIPAddress,
-			want: nil,
-		},
-		{
-			port: 0,
-			ip:   fakeIPAddress,
-			want: nil,
+			tname: "bind to inaddr any",
+			actions: []portReserveTestAction{
+				{port: 22, ip: anyIPAddress, want: nil},
+				{port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+				/* release fakeIPAddress, but anyIPAddress is still inuse */
+				{port: 22, ip: fakeIPAddress, release: true},
+				{port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+				{port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true},
+				/* Release port 22 from any IP address, then try to reserve fake IP address on 22 */
+				{port: 22, ip: anyIPAddress, want: nil, release: true},
+				{port: 22, ip: fakeIPAddress, want: nil},
+			},
+		}, {
+			tname: "bind to zero port",
+			actions: []portReserveTestAction{
+				{port: 00, ip: fakeIPAddress, want: nil},
+				{port: 00, ip: fakeIPAddress, want: nil},
+				{port: 00, ip: fakeIPAddress, reuse: true, want: nil},
+			},
+		}, {
+			tname: "bind to ip with reuseport",
+			actions: []portReserveTestAction{
+				{port: 25, ip: fakeIPAddress, reuse: true, want: nil},
+				{port: 25, ip: fakeIPAddress, reuse: true, want: nil},
+
+				{port: 25, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+				{port: 25, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+
+				{port: 25, ip: anyIPAddress, reuse: true, want: nil},
+			},
+		}, {
+			tname: "bind to inaddr any with reuseport",
+			actions: []portReserveTestAction{
+				{port: 24, ip: anyIPAddress, reuse: true, want: nil},
+				{port: 24, ip: anyIPAddress, reuse: true, want: nil},
+
+				{port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+				{port: 24, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+
+				{port: 24, ip: fakeIPAddress, reuse: true, want: nil},
+				{port: 24, ip: fakeIPAddress, release: true, want: nil},
+
+				{port: 24, ip: anyIPAddress, release: true},
+				{port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+
+				{port: 24, ip: anyIPAddress, release: true},
+				{port: 24, ip: anyIPAddress, reuse: false, want: nil},
+			},
 		},
 	} {
-		gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port)
-		if err != test.want {
-			t.Fatalf("ReservePort(.., .., %s, %d) = %v, want %v", test.ip, test.port, err, test.want)
-		}
-		if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
-			t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
-		}
-	}
+		t.Run(test.tname, func(t *testing.T) {
+			pm := NewPortManager()
+			net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber}
 
-	// Release port 22 from any IP address, then try to reserve fake IP
-	// address on 22.
-	pm.ReleasePort(net, fakeTransNumber, anyIPAddress, 22)
+			for _, test := range test.actions {
+				if test.release {
+					pm.ReleasePort(net, fakeTransNumber, test.ip, test.port)
+					continue
+				}
+				gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse)
+				if err != test.want {
+					t.Fatalf("ReservePort(.., .., %s, %d, %t) = %v, want %v", test.ip, test.port, test.release, err, test.want)
+				}
+				if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
+					t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
+				}
+			}
+		})
 
-	if port, err := pm.ReservePort(net, fakeTransNumber, fakeIPAddress, 22); port != 22 || err != nil {
-		t.Fatalf("ReservePort(.., .., .., %d) = (port %d, err %v), want (22, nil); failed to reserve port after it should have been released", 22, port, err)
 	}
 }
 
diff --git a/tcpip/stack/stack.go b/tcpip/stack/stack.go
index 89a747a..d1d8dfb 100644
--- a/tcpip/stack/stack.go
+++ b/tcpip/stack/stack.go
@@ -883,9 +883,9 @@
 // transport dispatcher. Received packets that match the provided id will be
 // delivered to the given endpoint; specifying a nic is optional, but
 // nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
 	if nicID == 0 {
-		return s.demux.registerEndpoint(netProtos, protocol, id, ep)
+		return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
 	}
 
 	s.mu.RLock()
@@ -896,14 +896,14 @@
 		return tcpip.ErrUnknownNICID
 	}
 
-	return nic.demux.registerEndpoint(netProtos, protocol, id, ep)
+	return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
 }
 
 // UnregisterTransportEndpoint removes the endpoint with the given id from the
 // stack transport dispatcher.
-func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
 	if nicID == 0 {
-		s.demux.unregisterEndpoint(netProtos, protocol, id)
+		s.demux.unregisterEndpoint(netProtos, protocol, id, ep)
 		return
 	}
 
@@ -912,7 +912,7 @@
 
 	nic := s.nics[nicID]
 	if nic != nil {
-		nic.demux.unregisterEndpoint(netProtos, protocol, id)
+		nic.demux.unregisterEndpoint(netProtos, protocol, id, ep)
 	}
 }
 
diff --git a/tcpip/stack/transport_demuxer.go b/tcpip/stack/transport_demuxer.go
index 5f40727..1d249e0 100644
--- a/tcpip/stack/transport_demuxer.go
+++ b/tcpip/stack/transport_demuxer.go
@@ -15,10 +15,12 @@
 package stack
 
 import (
+	"math/rand"
 	"sync"
 
 	"github.com/google/netstack/tcpip"
 	"github.com/google/netstack/tcpip/buffer"
+	"github.com/google/netstack/tcpip/hash/jenkins"
 	"github.com/google/netstack/tcpip/header"
 )
 
@@ -34,6 +36,23 @@
 	endpoints map[TransportEndpointID]TransportEndpoint
 }
 
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) {
+	eps.mu.Lock()
+	defer eps.mu.Unlock()
+	e, ok := eps.endpoints[id]
+	if !ok {
+		return
+	}
+	if multiPortEp, ok := e.(*multiPortEndpoint); ok {
+		if !multiPortEp.unregisterEndpoint(ep) {
+			return
+		}
+	}
+	delete(eps.endpoints, id)
+}
+
 // transportDemuxer demultiplexes packets targeted at a transport endpoint
 // (i.e., after they've been parsed by the network layer). It does two levels
 // of demultiplexing: first based on the network and transport protocols, then
@@ -57,10 +76,10 @@
 
 // registerEndpoint registers the given endpoint with the dispatcher such that
 // packets that match the endpoint ID are delivered to it.
-func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
 	for i, n := range netProtos {
-		if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil {
-			d.unregisterEndpoint(netProtos[:i], protocol, id)
+		if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil {
+			d.unregisterEndpoint(netProtos[:i], protocol, id, ep)
 			return err
 		}
 	}
@@ -68,7 +87,97 @@
 	return nil
 }
 
-func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
+// multiPortEndpoint is a container for TransportEndpoints which are bound to
+// the same pair of address and port.
+type multiPortEndpoint struct {
+	mu           sync.RWMutex
+	endpointsArr []TransportEndpoint
+	endpointsMap map[TransportEndpoint]int
+	// seed is a random secret for a jenkins hash.
+	seed uint32
+}
+
+// reciprocalScale scales a value into range [0, n).
+//
+// This is similar to val % n, but faster.
+// See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
+func reciprocalScale(val, n uint32) uint32 {
+	return uint32((uint64(val) * uint64(n)) >> 32)
+}
+
+// selectEndpoint calculates a hash of destination and source addresses and
+// ports then uses it to select a socket. In this case, all packets from one
+// address will be sent to same endpoint.
+func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint {
+	ep.mu.RLock()
+	defer ep.mu.RUnlock()
+
+	payload := []byte{
+		byte(id.LocalPort),
+		byte(id.LocalPort >> 8),
+		byte(id.RemotePort),
+		byte(id.RemotePort >> 8),
+	}
+
+	h := jenkins.Sum32(ep.seed)
+	h.Write(payload)
+	h.Write([]byte(id.LocalAddress))
+	h.Write([]byte(id.RemoteAddress))
+	hash := h.Sum32()
+
+	idx := reciprocalScale(hash, uint32(len(ep.endpointsArr)))
+	return ep.endpointsArr[idx]
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+	ep.selectEndpoint(id).HandlePacket(r, id, vv)
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+	ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv)
+}
+
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) {
+	ep.mu.Lock()
+	defer ep.mu.Unlock()
+
+	// A new endpoint is added into endpointsArr and its index there is
+	// saved in endpointsMap. This will allows to remove endpoint from
+	// the array fast.
+	ep.endpointsMap[ep] = len(ep.endpointsArr)
+	ep.endpointsArr = append(ep.endpointsArr, t)
+}
+
+// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
+func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
+	ep.mu.Lock()
+	defer ep.mu.Unlock()
+
+	idx, ok := ep.endpointsMap[t]
+	if !ok {
+		return false
+	}
+	delete(ep.endpointsMap, t)
+	l := len(ep.endpointsArr)
+	if l > 1 {
+		// The last endpoint in endpointsArr is moved instead of the deleted one.
+		lastEp := ep.endpointsArr[l-1]
+		ep.endpointsArr[idx] = lastEp
+		ep.endpointsMap[lastEp] = idx
+		ep.endpointsArr = ep.endpointsArr[0 : l-1]
+		return false
+	}
+	return true
+}
+
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+	if id.RemotePort != 0 {
+		reusePort = false
+	}
+
 	eps, ok := d.protocol[protocolIDs{netProto, protocol}]
 	if !ok {
 		return nil
@@ -77,10 +186,29 @@
 	eps.mu.Lock()
 	defer eps.mu.Unlock()
 
+	var multiPortEp *multiPortEndpoint
 	if _, ok := eps.endpoints[id]; ok {
-		return tcpip.ErrPortInUse
+		if !reusePort {
+			return tcpip.ErrPortInUse
+		}
+		multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint)
+		if !ok {
+			return tcpip.ErrPortInUse
+		}
 	}
 
+	if reusePort {
+		if multiPortEp == nil {
+			multiPortEp = &multiPortEndpoint{}
+			multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
+			multiPortEp.seed = rand.Uint32()
+			eps.endpoints[id] = multiPortEp
+		}
+
+		multiPortEp.singleRegisterEndpoint(ep)
+
+		return nil
+	}
 	eps.endpoints[id] = ep
 
 	return nil
@@ -88,12 +216,10 @@
 
 // unregisterEndpoint unregisters the endpoint with the given id such that it
 // won't receive any more packets.
-func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
 	for _, n := range netProtos {
 		if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
-			eps.mu.Lock()
-			delete(eps.endpoints, id)
-			eps.mu.Unlock()
+			eps.unregisterEndpoint(id, ep)
 		}
 	}
 }
diff --git a/tcpip/stack/transport_test.go b/tcpip/stack/transport_test.go
index 3e86578..63c62e1 100644
--- a/tcpip/stack/transport_test.go
+++ b/tcpip/stack/transport_test.go
@@ -107,7 +107,7 @@
 
 	// Try to register so that we can start receiving packets.
 	f.id.RemoteAddress = addr.Addr
-	err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f)
+	err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false)
 	if err != nil {
 		return err
 	}
diff --git a/tcpip/tcpip.go b/tcpip/tcpip.go
index aeba76c..164d1b6 100644
--- a/tcpip/tcpip.go
+++ b/tcpip/tcpip.go
@@ -436,6 +436,10 @@
 // should allow reuse of local address.
 type ReuseAddressOption int
 
+// ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets
+// to be bound to an identical socket address.
+type ReusePortOption int
+
 // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
 type QuickAckOption int
 
diff --git a/tcpip/transport/ping/endpoint.go b/tcpip/transport/ping/endpoint.go
index efdab37..06f5e73 100644
--- a/tcpip/transport/ping/endpoint.go
+++ b/tcpip/transport/ping/endpoint.go
@@ -100,7 +100,7 @@
 	e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
 	switch e.state {
 	case stateBound, stateConnected:
-		e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id)
+		e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
 	}
 
 	// Close the receive list and drain it.
@@ -541,14 +541,14 @@
 	if id.LocalPort != 0 {
 		// The endpoint already has a local port, just attempt to
 		// register it.
-		err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e)
+		err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
 		return id, err
 	}
 
 	// We need to find a port for the endpoint.
 	_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
 		id.LocalPort = p
-		err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e)
+		err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
 		switch err {
 		case nil:
 			return true, nil
@@ -597,7 +597,7 @@
 	if commit != nil {
 		if err := commit(); err != nil {
 			// Unregister, the commit failed.
-			e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, e.transProto, id)
+			e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, e.transProto, id, e)
 			return err
 		}
 	}
diff --git a/tcpip/transport/tcp/accept.go b/tcpip/transport/tcp/accept.go
index 05ebd3f..ecaa132 100644
--- a/tcpip/transport/tcp/accept.go
+++ b/tcpip/transport/tcp/accept.go
@@ -215,7 +215,7 @@
 	n.maybeEnableSACKPermitted(rcvdSynOpts)
 
 	// Register new endpoint so that packets are routed to it.
-	if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n); err != nil {
+	if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil {
 		n.Close()
 		return nil, err
 	}
diff --git a/tcpip/transport/tcp/endpoint.go b/tcpip/transport/tcp/endpoint.go
index 503b86a..6a661f0 100644
--- a/tcpip/transport/tcp/endpoint.go
+++ b/tcpip/transport/tcp/endpoint.go
@@ -162,6 +162,9 @@
 	// sack holds TCP SACK related information for this endpoint.
 	sack SACKInfo
 
+	// reusePort is set to true if SO_REUSEPORT is enabled.
+	reusePort bool
+
 	// delay enables Nagle's algorithm.
 	//
 	// delay is a boolean (0 is false) and must be accessed atomically.
@@ -416,7 +419,7 @@
 		e.isPortReserved = false
 
 		if e.isRegistered {
-			e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+			e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
 			e.isRegistered = false
 		}
 	}
@@ -453,7 +456,7 @@
 	e.workerCleanup = false
 
 	if e.isRegistered {
-		e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+		e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
 	}
 
 	e.route.Release()
@@ -681,6 +684,12 @@
 		e.mu.Unlock()
 		return nil
 
+	case tcpip.ReusePortOption:
+		e.mu.Lock()
+		e.reusePort = v != 0
+		e.mu.Unlock()
+		return nil
+
 	case tcpip.QuickAckOption:
 		if v == 0 {
 			atomic.StoreUint32(&e.slowAck, 1)
@@ -875,6 +884,17 @@
 		}
 		return nil
 
+	case *tcpip.ReusePortOption:
+		e.mu.RLock()
+		v := e.reusePort
+		e.mu.RUnlock()
+
+		*o = 0
+		if v {
+			*o = 1
+		}
+		return nil
+
 	case *tcpip.QuickAckOption:
 		*o = 1
 		if v := atomic.LoadUint32(&e.slowAck); v != 0 {
@@ -1057,7 +1077,7 @@
 
 	if e.id.LocalPort != 0 {
 		// The endpoint is bound to a port, attempt to register it.
-		err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e)
+		err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort)
 		if err != nil {
 			return err
 		}
@@ -1071,13 +1091,13 @@
 			if sameAddr && p == e.id.RemotePort {
 				return false, nil
 			}
-			if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p) {
+			if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) {
 				return false, nil
 			}
 
 			id := e.id
 			id.LocalPort = p
-			switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) {
+			switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) {
 			case nil:
 				e.id = id
 				return true, nil
@@ -1234,7 +1254,7 @@
 	}
 
 	// Register the endpoint.
-	if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e); err != nil {
+	if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil {
 		return err
 	}
 
@@ -1315,7 +1335,7 @@
 		}
 	}
 
-	port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port)
+	port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort)
 	if err != nil {
 		return err
 	}
diff --git a/tcpip/transport/udp/endpoint.go b/tcpip/transport/udp/endpoint.go
index ea1f6f3..14b3e86 100644
--- a/tcpip/transport/udp/endpoint.go
+++ b/tcpip/transport/udp/endpoint.go
@@ -81,6 +81,7 @@
 	dstPort      uint16
 	v6only       bool
 	multicastTTL uint8
+	reusePort    bool
 
 	// shutdownFlags represent the current shutdown state of the endpoint.
 	shutdownFlags tcpip.ShutdownFlags
@@ -132,7 +133,7 @@
 	ep := newEndpoint(stack, r.NetProto, waiterQueue)
 
 	// Register new endpoint so that packets are routed to it.
-	if err := stack.RegisterTransportEndpoint(r.NICID(), []tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep); err != nil {
+	if err := stack.RegisterTransportEndpoint(r.NICID(), []tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep, ep.reusePort); err != nil {
 		ep.Close()
 		return nil, err
 	}
@@ -155,7 +156,7 @@
 
 	switch e.state {
 	case stateBound, stateConnected:
-		e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+		e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
 		e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
 	}
 
@@ -449,6 +450,12 @@
 				break
 			}
 		}
+
+	case tcpip.ReusePortOption:
+		e.mu.Lock()
+		e.reusePort = v != 0
+		e.mu.Unlock()
+		return nil
 	}
 	return nil
 }
@@ -513,6 +520,17 @@
 		e.mu.Unlock()
 		return nil
 
+	case *tcpip.ReusePortOption:
+		e.mu.RLock()
+		v := e.reusePort
+		e.mu.RUnlock()
+
+		*o = 0
+		if v {
+			*o = 1
+		}
+		return nil
+
 	case *tcpip.KeepaliveEnabledOption:
 		*o = 0
 		return nil
@@ -648,7 +666,7 @@
 
 	// Remove the old registration.
 	if e.id.LocalPort != 0 {
-		e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+		e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
 	}
 
 	e.id = id
@@ -711,14 +729,14 @@
 
 func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
 	if e.id.LocalPort == 0 {
-		port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
+		port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort)
 		if err != nil {
 			return id, err
 		}
 		id.LocalPort = port
 	}
 
-	err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e)
+	err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort)
 	if err != nil {
 		e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
 	}
@@ -766,7 +784,7 @@
 	if commit != nil {
 		if err := commit(); err != nil {
 			// Unregister, the commit failed.
-			e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id)
+			e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id, e)
 			e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
 			return err
 		}
diff --git a/tcpip/transport/udp/udp_test.go b/tcpip/transport/udp/udp_test.go
index 2dc27ae..46f268c 100644
--- a/tcpip/transport/udp/udp_test.go
+++ b/tcpip/transport/udp/udp_test.go
@@ -16,6 +16,7 @@
 
 import (
 	"bytes"
+	"math"
 	"math/rand"
 	"testing"
 	"time"
@@ -254,6 +255,90 @@
 	return b
 }
 
+func TestBindPortReuse(t *testing.T) {
+	c := newDualTestContext(t, defaultMTU)
+	defer c.cleanup()
+
+	c.createV6Endpoint(false)
+
+	var eps [5]tcpip.Endpoint
+	reusePortOpt := tcpip.ReusePortOption(1)
+
+	pollChannel := make(chan tcpip.Endpoint)
+	for i := 0; i < len(eps); i++ {
+		// Try to receive the data.
+		wq := waiter.Queue{}
+		we, ch := waiter.NewChannelEntry(nil)
+		wq.EventRegister(&we, waiter.EventIn)
+		defer wq.EventUnregister(&we)
+		defer close(ch)
+
+		var err *tcpip.Error
+		eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
+		if err != nil {
+			c.t.Fatalf("NewEndpoint failed: %v", err)
+		}
+
+		go func(ep tcpip.Endpoint) {
+			for range ch {
+				pollChannel <- ep
+			}
+		}(eps[i])
+
+		defer eps[i].Close()
+		if err := eps[i].SetSockOpt(reusePortOpt); err != nil {
+			c.t.Fatalf("SetSockOpt failed failed: %v", err)
+		}
+		if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, nil); err != nil {
+			t.Fatalf("ep.Bind(...) failed: %v", err)
+		}
+	}
+
+	npackets := 100000
+	nports := 10000
+	ports := make(map[uint16]tcpip.Endpoint)
+	stats := make(map[tcpip.Endpoint]int)
+	for i := 0; i < npackets; i++ {
+		// Send a packet.
+		port := uint16(i % nports)
+		payload := newPayload()
+		c.sendV6Packet(payload, &headers{
+			srcPort: testPort + port,
+			dstPort: stackPort,
+		})
+
+		var addr tcpip.FullAddress
+		ep := <-pollChannel
+		_, _, err := ep.Read(&addr)
+		if err != nil {
+			c.t.Fatalf("Read failed: %v", err)
+		}
+		stats[ep]++
+		if i < nports {
+			ports[uint16(i)] = ep
+		} else {
+			// Check that all packets from one client are handled
+			// by the same socket.
+			if ports[port] != ep {
+				t.Fatalf("Port mismatch")
+			}
+		}
+	}
+
+	if len(stats) != len(eps) {
+		t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps))
+	}
+
+	// Check that a packet distribution is fair between sockets.
+	for _, c := range stats {
+		n := float64(npackets) / float64(len(eps))
+		// The deviation is less than 10%.
+		if math.Abs(float64(c)-n) > n/10 {
+			t.Fatal(c, n)
+		}
+	}
+}
+
 func testV4Read(c *testContext) {
 	// Send a packet.
 	payload := newPayload()