[netstack] Update tcpip nic to allow address removal.
- Adds remove methods for addresses and subnets.
- Adds capability to query the "main" address to show the user for
a given NIC when multiple are configured in the stack.
- Fixes string formatting bugs several tcpip objects.
Corresponding changes to netstack in:
Change-Id: Ib0bf3b8739e178bc489424de4fccb2972558b2d0
NET-1160
Test: manual test of adding/removing addresses via ifconfig.
Change-Id: Iccbadebec2fb3d5ad60b354e161be796a39dd306
diff --git a/tcpip/stack/nic.go b/tcpip/stack/nic.go
index 1307848..9be5cc6 100644
--- a/tcpip/stack/nic.go
+++ b/tcpip/stack/nic.go
@@ -72,6 +72,55 @@
n.mu.Unlock()
}
+// Get the primary network endpoint, if there is one; otherwise pick an arbitrary endpoint from the NIC's endpoints.
+func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet) {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ var address tcpip.Address
+ var subnet tcpip.Subnet
+
+ // Check for a primary endpoint.
+ var r *referencedNetworkEndpoint
+ list := n.primary[protocol]
+ if list != nil {
+ for e := list.Front(); e != nil; e = e.Next() {
+ ref := e.(*referencedNetworkEndpoint)
+ if ref.holdsInsertRef && ref.tryIncRef() {
+ r = ref
+ break
+ }
+ }
+
+ }
+
+ // If no primary endpoints then check for other endpoints.
+ if r == nil {
+ for _, ref := range n.endpoints {
+ if ref != nil && ref.holdsInsertRef && ref.tryIncRef() {
+ r = ref
+ break
+ }
+ }
+ }
+
+ if r != nil {
+ address = r.ep.ID().LocalAddress
+ r.decRef()
+ }
+
+ // Find the least-constrained matching subnet for the address, if one exists, and return it
+ if address != "" {
+ for _, s := range n.subnets {
+ if s.Contains(address) && !subnet.Contains(s.ID()) {
+ subnet = s
+ }
+ }
+ }
+
+ return address, subnet
+}
+
// primaryEndpoint returns the primary endpoint of n for the given network
// protocol.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
@@ -182,6 +231,33 @@
n.mu.Unlock()
}
+// RemoveSubnet removes the given subnet from n.
+func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) {
+ n.mu.Lock()
+
+ var filtered []tcpip.Subnet
+ for _, sub := range n.subnets {
+ if sub != subnet {
+ filtered = append(filtered, sub)
+ }
+ }
+
+ n.subnets = filtered
+ n.mu.Unlock()
+ return
+}
+
+func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool {
+ subnets := n.Subnets()
+
+ for _, s := range subnets {
+ if s == subnet {
+ return true
+ }
+ }
+ return false
+}
+
// Subnets returns the Subnets associated with this NIC.
func (n *NIC) Subnets() []tcpip.Subnet {
n.mu.RLock()
diff --git a/tcpip/stack/stack.go b/tcpip/stack/stack.go
index 0fee5c8..dc53ae7 100644
--- a/tcpip/stack/stack.go
+++ b/tcpip/stack/stack.go
@@ -322,6 +322,32 @@
return nil
}
+// RemoveSubnet removes the subnet range from the specified NIC.
+func (s *Stack) RemoveSubnet(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[id]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+ nic.RemoveSubnet(subnet)
+ return nil
+}
+
+// Returns true if the given subnet is present in the specified NIC..
+func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[id]
+ if nic == nil {
+ return false, tcpip.ErrUnknownNICID
+ }
+
+ return nic.ContainsSubnet(subnet), nil
+}
+
// RemoveAddress removes an existing network-layer address from the specified
// NIC.
func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
@@ -336,6 +362,24 @@
return nic.RemoveAddress(addr)
}
+// Returns the first primary address (and subnet that contains it) for the
+// given NIC and protocol.
+func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ var address tcpip.Address
+ var subnet tcpip.Subnet
+
+ nic := s.nics[id]
+ if nic == nil {
+ return address, subnet, tcpip.ErrUnknownNICID
+ }
+
+ address, subnet = nic.getMainNICAddress(protocol)
+ return address, subnet, nil
+}
+
// FindRoute creates a route to the given destination address, leaving through
// the given nic and local address (if provided).
func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (Route, *tcpip.Error) {
diff --git a/tcpip/stack/stack_test.go b/tcpip/stack/stack_test.go
index d493474..e1e2955 100644
--- a/tcpip/stack/stack_test.go
+++ b/tcpip/stack/stack_test.go
@@ -9,6 +9,7 @@
import (
"math"
+ "strings"
"testing"
"github.com/google/netstack/tcpip"
@@ -677,6 +678,111 @@
}
}
+func TestSubnetAddRemove(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, nil)
+ id, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ addr := tcpip.Address("\x01\x01\x01\x01")
+ mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr)))
+ subnet, err1 := tcpip.NewSubnet(addr, mask)
+
+ if err1 != nil {
+ t.Fatalf("NewSubnet failed: %v", err1)
+ }
+
+ if contained, err := s.ContainsSubnet(1, subnet); err != nil || contained {
+ if contained {
+ t.Fatalf("ContainsSubnet spuriously returns true before adding subnet.")
+ }
+ t.Fatalf("ContainsSubnet returned error %v", err)
+ }
+
+ if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
+ t.Fatalf("AddSubnet failed with error: %v", err)
+ }
+
+ if contained, err := s.ContainsSubnet(1, subnet); err != nil || !contained {
+ if !contained {
+ t.Fatalf("ContainsSubnet spuriously returns false after adding subnet.")
+ }
+ t.Fatalf("ContainsSubnet returned error %v", err)
+ }
+
+ if err := s.RemoveSubnet(1, subnet); err != nil {
+ t.Fatalf("RemoveSubnet failed with error: %v", err)
+ }
+
+ if contained, err := s.ContainsSubnet(1, subnet); err != nil || contained {
+ if contained {
+ t.Fatalf("ContainsSubnet spuriously returns true after removing subnet.")
+ }
+ t.Fatalf("ContainsSubnet returned error %v", err)
+ }
+}
+
+func TestGetMainNICAddress(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, nil)
+ id, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ addr := tcpip.Address("\x01\x01\x01\x01")
+ mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr)))
+ subn, _ := tcpip.NewSubnet(addr, mask)
+
+ if err := s.AddAddress(1, fakeNetNumber, addr); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ if err := s.AddSubnet(1, fakeNetNumber, subn); err != nil {
+ t.Fatalf("AddSubnet failed with error: %v", err)
+ }
+
+ // Check that we get the right initial address and subnet
+ address, subnet, err := s.GetMainNICAddress(1, fakeNetNumber)
+
+ if err != nil {
+ t.Fatalf("GetMainNICAddress failed with error: %v", err)
+ }
+
+ if address != addr {
+ t.Fatalf("Expecting address=%s but GetMainNICAddress returned %s", addr, address)
+ }
+
+ if subnet != subn {
+ t.Fatalf("Expecting subnet=%#v but GetMainNICAddress returned %#v", subn, subnet)
+ }
+
+ if err := s.RemoveSubnet(1, subn); err != nil {
+ t.Fatalf("RemoveSubnet failed with error: %v", err)
+ }
+
+ if err := s.RemoveAddress(1, addr); err != nil {
+ t.Fatalf("RemoveAddress failed: %v", err)
+ }
+
+ // Check that we get an empty address and subnet after removal
+ address2, subnet2, err2 := s.GetMainNICAddress(1, fakeNetNumber)
+
+ if err2 != nil {
+ t.Fatalf("GetMainNICAddress failed with error: %v", err2)
+ }
+
+ var emptyAddr tcpip.Address
+ if emptyAddr != address2 {
+ t.Fatalf("Expecting address=%s but GetMainNICAddress returned %s", emptyAddr, address2)
+ }
+
+ var emptySubnet tcpip.Subnet
+ if emptySubnet != subnet2 {
+ t.Fatalf("Expecting subnet=%#v but GetMainNICAddress returned %#v", emptySubnet, subnet2)
+ }
+}
+
func init() {
stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol {
return &fakeNetworkProtocol{}
diff --git a/tcpip/tcpip.go b/tcpip/tcpip.go
index 0ee090a..d014c75 100644
--- a/tcpip/tcpip.go
+++ b/tcpip/tcpip.go
@@ -41,6 +41,9 @@
// String implements fmt.Stringer.String.
func (e *Error) String() string {
+ if e == nil {
+ return "<nil>"
+ }
return e.msg
}
@@ -98,7 +101,7 @@
mask AddressMask
}
-func (a Address) mask(m AddressMask) Address {
+func (a Address) Mask(m AddressMask) Address {
out := []byte(a)
for i, _ := range a {
out[i] = a[i] & m[i]
@@ -171,6 +174,11 @@
return s.mask
}
+// String implements fmt.Stringer.String.
+func (s Subnet) String() string {
+ return fmt.Sprintf("{ address=%s, mask=%s }", s.address, Address(s.mask))
+}
+
// NICID is a number that uniquely identifies a NIC.
type NICID int32
@@ -614,7 +622,7 @@
}
mask := CIDRMask(int(ones), 8*len(addr))
- sn, err := NewSubnet(addr.mask(mask), mask)
+ sn, err := NewSubnet(addr.Mask(mask), mask)
return addr, sn, err
}