Merge remote-tracking branch 'upstream/master'
Change-Id: If00b462700e5c00cee1b79404117929d246b3fd4
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 169c356..befdf3c 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,31 +1,5 @@
-Want to contribute? Great! First, read this page.
+This repository is a synthetic mirror.
-### Before you contribute
-First, contact us! Tell us what you are working on, we will figure out how
-to set up a contribution system.
-
-Before we can use your code, you must sign the
-[Google Individual Contributor License Agreement]
-(https://cla.developers.google.com/about/google-individual)
-(CLA), which you can do online. The CLA is necessary mainly because you own the
-copyright to your changes, even after your contribution becomes part of our
-codebase, so we need your permission to use and distribute your code. We also
-need to be sure of various other things—for instance that you'll tell us if you
-know that your code infringes on other people's patents. You don't have to sign
-the CLA until after you've submitted your code for review and a member has
-approved it, but you must do it before we can put your code into our codebase.
-Before you start working on a larger contribution, you should get in touch with
-us first through the issue tracker with your idea so that we can help out and
-possibly guide you. Coordinating up front makes it much easier to avoid
-frustration later on.
-
-### Code reviews
-All submissions, including submissions by project members, require review.
-We have not worked out how to handle this yet, please contact the authors
-before sending any contributions.
-
-### The small print
-Contributions made by corporations are covered by a different agreement than
-the one above, the
-[Software Grant and Corporate Contributor License Agreement]
-(https://cla.developers.google.com/about/google-corporate).
+Contributions to netstack must be submitted via
+[gVisor](https://gvisor-review.googlesource.com), where there are complete
+[instructions](https://gvisor.googlesource.com/gvisor/+/master/CONTRIBUTING.md).
diff --git a/LICENSE b/LICENSE
index c8768f0..d645695 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,27 +1,202 @@
-Copyright (c) 2016 The Netstack Authors. All rights reserved.
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are
-met:
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
- * Redistributions of source code must retain the above copyright
-notice, this list of conditions and the following disclaimer.
- * Redistributions in binary form must reproduce the above
-copyright notice, this list of conditions and the following disclaimer
-in the documentation and/or other materials provided with the
-distribution.
- * Neither the name of Google Inc. nor the names of its
-contributors may be used to endorse or promote products derived from
-this software without specific prior written permission.
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index 10ca511..bafb56d 100644
--- a/README.md
+++ b/README.md
@@ -26,13 +26,9 @@
## Contributions
-We would love to accept contributions, but we have not yet worked
-out how to handle them. Please contact us before sending any pull requests.
-
-Whatever we do decide on will require signing the Google Contributor License.
Please see [CONTRIBUTING.md](CONTRIBUTING.md) for more details.
### Disclaimer
-This is not an official Google product (experimental or otherwise), it
-is just code that happens to be owned by Google.
+This is not an official Google product (experimental or otherwise), it is just
+code that happens to be owned by Google.
diff --git a/dhcp/client.go b/dhcp/client.go
index bb9b09a..dc23ad7 100644
--- a/dhcp/client.go
+++ b/dhcp/client.go
@@ -1,19 +1,28 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package dhcp
import (
"bytes"
"context"
- "crypto/rand"
"fmt"
"sync"
"time"
+ "github.com/google/netstack/rand"
"github.com/google/netstack/tcpip"
- "github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/network/ipv4"
"github.com/google/netstack/tcpip/stack"
"github.com/google/netstack/tcpip/transport/udp"
@@ -36,7 +45,7 @@
// NewClient creates a DHCP client.
//
-// TODO(crawshaw): add s.LinkAddr(nicid) to *stack.Stack.
+// TODO: add s.LinkAddr(nicid) to *stack.Stack.
func NewClient(s *stack.Stack, nicid tcpip.NICID, linkAddr tcpip.LinkAddress, acquiredFunc func(old, new tcpip.Address, cfg Config)) *Client {
return &Client{
stack: s,
@@ -124,13 +133,12 @@
if err != nil {
return Config{}, fmt.Errorf("dhcp: outbound endpoint: %v", err)
}
- err = ep.Bind(tcpip.FullAddress{
+ defer ep.Close()
+ if err := ep.Bind(tcpip.FullAddress{
Addr: "\x00\x00\x00\x00",
Port: ClientPort,
NIC: c.nicid,
- }, nil)
- defer ep.Close()
- if err != nil {
+ }, nil); err != nil {
return Config{}, fmt.Errorf("dhcp: connect failed: %v", err)
}
@@ -138,13 +146,12 @@
if err != nil {
return Config{}, fmt.Errorf("dhcp: inbound endpoint: %v", err)
}
- err = epin.Bind(tcpip.FullAddress{
+ defer epin.Close()
+ if err := epin.Bind(tcpip.FullAddress{
Addr: "\xff\xff\xff\xff",
Port: ClientPort,
NIC: c.nicid,
- }, nil)
- defer epin.Close()
- if err != nil {
+ }, nil); err != nil {
return Config{}, fmt.Errorf("dhcp: connect failed: %v", err)
}
@@ -166,9 +173,10 @@
}
var clientID []byte
if len(c.linkAddr) == 6 {
- clientID = make([]byte, 7)
- clientID[0] = 1 // htype: ARP Ethernet from RFC 1700
- copy(clientID[1:], c.linkAddr)
+ clientID = append(
+ []byte{1}, // RFC 1700: Hardware Type [Ethernet = 1]
+ c.linkAddr...,
+ )
discOpts = append(discOpts, option{optClientID, clientID})
}
h := make(header, headerBaseSize+discOpts.len()+1)
@@ -184,7 +192,10 @@
Port: ServerPort,
NIC: c.nicid,
}
- if _, err := ep.Write(buffer.View(h), serverAddr); err != nil {
+ wopts := tcpip.WriteOptions{
+ To: serverAddr,
+ }
+ if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
}
@@ -196,8 +207,8 @@
var opts options
for {
var addr tcpip.FullAddress
- v, e := epin.Read(&addr)
- if e == tcpip.ErrWouldBlock {
+ v, _, err := epin.Read(&addr)
+ if err == tcpip.ErrWouldBlock {
select {
case <-ch:
continue
@@ -207,10 +218,11 @@
}
h = header(v)
var valid bool
- var err error
- opts, valid, err = loadDHCPReply(h, dhcpOFFER, xid[:])
+ var e error
+ opts, valid, e = loadDHCPReply(h, dhcpOFFER, xid[:])
if !valid {
- if err != nil {
+ if e != nil {
+ // TODO: handle all the errors?
// TODO: report malformed server responses
}
continue
@@ -277,15 +289,15 @@
reqOpts = append(reqOpts, option{optClientID, clientID})
}
h.setOptions(reqOpts)
- if _, err := ep.Write([]byte(h), serverAddr); err != nil {
+ if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
}
// DHCPACK
for {
var addr tcpip.FullAddress
- v, e := epin.Read(&addr)
- if e == tcpip.ErrWouldBlock {
+ v, _, err := epin.Read(&addr)
+ if err == tcpip.ErrWouldBlock {
select {
case <-ch:
continue
@@ -295,10 +307,11 @@
}
h = header(v)
var valid bool
- var err error
- opts, valid, err = loadDHCPReply(h, dhcpACK, xid[:])
+ var e error
+ opts, valid, e = loadDHCPReply(h, dhcpACK, xid[:])
if !valid {
- if err != nil {
+ if e != nil {
+ // TODO: handle all the errors?
// TODO: report malformed server responses
}
if opts, valid, _ = loadDHCPReply(h, dhcpNAK, xid[:]); valid {
@@ -319,13 +332,13 @@
if !h.isValid() || h.op() != opReply || !bytes.Equal(h.xidbytes(), xid[:]) {
return nil, false, nil
}
- opts, e := h.options()
- if e != nil {
- return nil, false, fmt.Errorf("dhcp ack: %v", e)
+ opts, err = h.options()
+ if err != nil {
+ return nil, false, err
}
- msgtype, e := opts.dhcpMsgType()
- if e != nil {
- return nil, false, fmt.Errorf("dhcp ack: %v", e)
+ msgtype, err := opts.dhcpMsgType()
+ if err != nil {
+ return nil, false, err
}
if msgtype != typ {
return nil, false, nil
diff --git a/dhcp/dhcp.go b/dhcp/dhcp.go
index 6269166..b36149a 100644
--- a/dhcp/dhcp.go
+++ b/dhcp/dhcp.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package dhcp implements a DHCP client and server as described in RFC 2131.
package dhcp
@@ -20,7 +30,7 @@
ServerAddress tcpip.Address // address of the server
SubnetMask tcpip.AddressMask // client address subnet mask
Gateway tcpip.Address // client default gateway
- DNS []tcpip.Address // client domain name servers
+ DNS []tcpip.Address // client DNS server addresses
LeaseLength time.Duration // length of the address lease
}
@@ -29,7 +39,8 @@
for _, opt := range opts {
b := opt.body
if !opt.code.lenValid(len(b)) {
- return fmt.Errorf("%s bad length: %d", opt.code, len(b))
+ // TODO: s/%v/%s/ when `go vet` is smarter.
+ return fmt.Errorf("%v: bad length: %d", opt.code, len(b))
}
switch opt.code {
case optLeaseTime:
@@ -82,9 +93,9 @@
}
const (
- // ServerPort well-known UDP port number for a DHCP server
+ // ServerPort is the well-known UDP port number for a DHCP server.
ServerPort = 67
- // ClientPort well-known UDP port number for a DHCP client
+ // ClientPort is the well-known UDP port number for a DHCP client.
ClientPort = 68
)
@@ -220,11 +231,12 @@
for _, opt := range opts {
if opt.code == optDHCPMsgType {
if len(opt.body) != 1 {
- return 0, fmt.Errorf("%s: wrong length: %d", optDHCPMsgType, len(opt.body))
+ // TODO: s/%v/%s/ when `go vet` is smarter.
+ return 0, fmt.Errorf("%v: bad length: %d", opt.code, len(opt.body))
}
v := opt.body[0]
if v <= 0 || v >= 8 {
- return 0, fmt.Errorf("%s: unknown value: %d", optDHCPMsgType, v)
+ return 0, fmt.Errorf("DHCP bad length: %d", len(opt.body))
}
return dhcpMsgType(v), nil
}
diff --git a/dhcp/dhcp_string.go b/dhcp/dhcp_string.go
index 68b23a4..504de4d 100644
--- a/dhcp/dhcp_string.go
+++ b/dhcp/dhcp_string.go
@@ -1,3 +1,17 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package dhcp
import (
diff --git a/dhcp/dhcp_test.go b/dhcp/dhcp_test.go
index 2fa584d..082837a 100644
--- a/dhcp/dhcp_test.go
+++ b/dhcp/dhcp_test.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package dhcp
@@ -32,15 +42,11 @@
go func() {
for pkt := range linkEP.C {
- v := make(buffer.View, len(pkt.Header)+len(pkt.Payload))
- copy(v, pkt.Header)
- copy(v[len(pkt.Header):], pkt.Payload)
- vv := v.ToVectorisedView([1]buffer.View{})
- linkEP.Inject(pkt.Proto, &vv)
+ linkEP.Inject(pkt.Proto, buffer.NewVectorisedView(len(pkt.Header)+len(pkt.Payload), []buffer.View{pkt.Header, pkt.Payload}))
}
}()
- s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName})
+ s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
if err := s.CreateNIC(nicid, id); err != nil {
t.Fatal(err)
@@ -220,8 +226,8 @@
t.Error("failed to decode header")
}
- if op := h.op(); op != opReply {
- t.Errorf("bad opcode: %v expected: %v", op, opReply)
+ if got, want := h.op(), opReply; got != want {
+ t.Errorf("h.op()=%v, want=%v", got, want)
}
if _, err := h.options(); err != nil {
@@ -286,25 +292,23 @@
defer cancel()
c1, c2 := teeConn(newEPConn(serverCtx, wq, ep))
- _, tcpErr := NewServer(serverCtx, c1, []tcpip.Address{"\xc0\xa8\x03\x02"}, Config{
+ if _, err := NewServer(serverCtx, c1, []tcpip.Address{"\xc0\xa8\x03\x02"}, Config{
ServerAddress: "\xc0\xa8\x03\x01",
SubnetMask: "\xff\xff\xff\x00",
Gateway: "\xc0\xa8\x03\xF0",
DNS: []tcpip.Address{"\x08\x08\x08\x08"},
LeaseLength: 30 * time.Minute,
- })
- if tcpErr != nil {
- t.Fatal(tcpErr)
+ }); err != nil {
+ t.Fatal(err)
}
- _, tcpErr = NewServer(serverCtx, c2, []tcpip.Address{"\xc0\xa8\x04\x02"}, Config{
+ if _, err := NewServer(serverCtx, c2, []tcpip.Address{"\xc0\xa8\x04\x02"}, Config{
ServerAddress: "\xc0\xa8\x04\x01",
SubnetMask: "\xff\xff\xff\x00",
Gateway: "\xc0\xa8\x03\xF0",
DNS: []tcpip.Address{"\x08\x08\x08\x08"},
LeaseLength: 30 * time.Minute,
- })
- if tcpErr != nil {
- t.Fatal(tcpErr)
+ }); err != nil {
+ t.Fatal(err)
}
const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52")
diff --git a/dhcp/server.go b/dhcp/server.go
index ad3b938..838a652 100644
--- a/dhcp/server.go
+++ b/dhcp/server.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package dhcp
@@ -68,7 +78,7 @@
func (c *epConn) Read() (buffer.View, tcpip.FullAddress, error) {
for {
var addr tcpip.FullAddress
- v, err := c.ep.Read(&addr)
+ v, _, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
select {
case <-c.inCh:
@@ -76,17 +86,19 @@
case <-c.ctx.Done():
return nil, tcpip.FullAddress{}, io.EOF
}
- } else if err != nil {
- return v, addr, fmt.Errorf("dhcp: %v", err)
- } else {
- return v, addr, nil
}
+ if err != nil {
+ return v, addr, fmt.Errorf("read: %v", err)
+ }
+ return v, addr, nil
}
}
func (c *epConn) Write(b []byte, addr *tcpip.FullAddress) error {
- _, err := c.ep.Write(b, addr)
- return fmt.Errorf("dhcp: %v", err)
+ if _, err := c.ep.Write(tcpip.SlicePayload(b), tcpip.WriteOptions{To: addr}); err != nil {
+ return fmt.Errorf("write: %v", err)
+ }
+ return nil
}
func newEPConnServer(ctx context.Context, stack *stack.Stack, addrs []tcpip.Address, cfg Config) (*Server, error) {
@@ -195,7 +207,7 @@
if err != nil {
continue
}
- // TODO(crawshaw): Handle DHCPRELEASE and DHCPDECLINE.
+ // TODO: Handle DHCPRELEASE and DHCPDECLINE.
msgtype, err := opts.dhcpMsgType()
if err != nil {
continue
@@ -222,7 +234,7 @@
case leaseNew:
if len(s.leases) < len(s.addrs) {
// Find an unused address.
- // TODO(crawshaw): avoid building this state on each request.
+ // TODO: avoid building this state on each request.
alloced := make(map[tcpip.Address]bool)
for _, lease := range s.leases {
alloced[lease.addr] = true
@@ -331,7 +343,7 @@
s.mu.Unlock()
if lease.state == leaseNew {
- // TODO(crawshaw): NACK or accept request
+ // TODO: NACK or accept request
return
}
diff --git a/gate/gate.go b/gate/gate.go
new file mode 100644
index 0000000..93808c9
--- /dev/null
+++ b/gate/gate.go
@@ -0,0 +1,134 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gate provides a usage Gate synchronization primitive.
+package gate
+
+import (
+ "sync/atomic"
+)
+
+const (
+ // gateClosed is the bit set in the gate's user count to indicate that
+ // it has been closed. It is the MSB of the 32-bit field; the other 31
+ // bits carry the actual count.
+ gateClosed = 0x80000000
+)
+
+// Gate is a synchronization primitive that allows concurrent goroutines to
+// "enter" it as long as it hasn't been closed yet. Once it's been closed,
+// goroutines cannot enter it anymore, but are allowed to leave, and the closer
+// will be informed when all goroutines have left.
+//
+// Many goroutines are allowed to enter the gate concurrently, but only one is
+// allowed to close it.
+//
+// This is similar to a r/w critical section, except that goroutines "entering"
+// never block: they either enter immediately or fail to enter. The closer will
+// block waiting for all goroutines currently inside the gate to leave.
+//
+// This function is implemented efficiently. On x86, only one interlocked
+// operation is performed on enter, and one on leave.
+//
+// This is useful, for example, in cases when a goroutine is trying to clean up
+// an object for which multiple goroutines have pointers. In such a case, users
+// would be required to enter and leave the gates, and the cleaner would wait
+// until all users are gone (and no new ones are allowed) before proceeding.
+//
+// Users:
+//
+// if !g.Enter() {
+// // Gate is closed, we can't use the object.
+// return
+// }
+//
+// // Do something with object.
+// [...]
+//
+// g.Leave()
+//
+// Closer:
+//
+// // Prevent new users from using the object, and wait for the existing
+// // ones to complete.
+// g.Close()
+//
+// // Clean up the object.
+// [...]
+//
+type Gate struct {
+ userCount uint32
+ done chan struct{}
+}
+
+// Enter tries to enter the gate. It will succeed if it hasn't been closed yet,
+// in which case the caller must eventually call Leave().
+//
+// This function is thread-safe.
+func (g *Gate) Enter() bool {
+ if g == nil {
+ return false
+ }
+
+ for {
+ v := atomic.LoadUint32(&g.userCount)
+ if v&gateClosed != 0 {
+ return false
+ }
+
+ if atomic.CompareAndSwapUint32(&g.userCount, v, v+1) {
+ return true
+ }
+ }
+}
+
+// Leave leaves the gate. This must only be called after a successful call to
+// Enter(). If the gate has been closed and this is the last one inside the
+// gate, it will notify the closer that the gate is done.
+//
+// This function is thread-safe.
+func (g *Gate) Leave() {
+ for {
+ v := atomic.LoadUint32(&g.userCount)
+ if v&^gateClosed == 0 {
+ panic("leaving a gate with zero usage count")
+ }
+
+ if atomic.CompareAndSwapUint32(&g.userCount, v, v-1) {
+ if v == gateClosed+1 {
+ close(g.done)
+ }
+ return
+ }
+ }
+}
+
+// Close closes the gate for entering, and waits until all goroutines [that are
+// currently inside the gate] leave before returning.
+//
+// Only one goroutine can call this function.
+func (g *Gate) Close() {
+ for {
+ v := atomic.LoadUint32(&g.userCount)
+ if v&^gateClosed != 0 && g.done == nil {
+ g.done = make(chan struct{})
+ }
+ if atomic.CompareAndSwapUint32(&g.userCount, v, v|gateClosed) {
+ if v&^gateClosed != 0 {
+ <-g.done
+ }
+ return
+ }
+ }
+}
diff --git a/gate/gate_test.go b/gate/gate_test.go
new file mode 100644
index 0000000..b194415
--- /dev/null
+++ b/gate/gate_test.go
@@ -0,0 +1,189 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gate_test
+
+import (
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/google/netstack/gate"
+)
+
+func TestBasicEnter(t *testing.T) {
+ var g gate.Gate
+
+ if !g.Enter() {
+ t.Fatalf("Failed to enter when it should be allowed")
+ }
+
+ g.Leave()
+
+ g.Close()
+
+ if g.Enter() {
+ t.Fatalf("Allowed to enter when it should fail")
+ }
+}
+
+func enterFunc(t *testing.T, g *gate.Gate, enter, leave, reenter chan struct{}, done1, done2, done3 *sync.WaitGroup) {
+ // Wait until instructed to enter.
+ <-enter
+ if !g.Enter() {
+ t.Errorf("Failed to enter when it should be allowed")
+ }
+
+ done1.Done()
+
+ // Wait until instructed to leave.
+ <-leave
+ g.Leave()
+
+ done2.Done()
+
+ // Wait until instructed to reenter.
+ <-reenter
+ if g.Enter() {
+ t.Errorf("Allowed to enter when it should fail")
+ }
+ done3.Done()
+}
+
+func TestConcurrentEnter(t *testing.T) {
+ var g gate.Gate
+ var done1, done2, done3 sync.WaitGroup
+
+ // Create 1000 worker goroutines.
+ enter := make(chan struct{})
+ leave := make(chan struct{})
+ reenter := make(chan struct{})
+ done1.Add(1000)
+ done2.Add(1000)
+ done3.Add(1000)
+ for i := 0; i < 1000; i++ {
+ go enterFunc(t, &g, enter, leave, reenter, &done1, &done2, &done3)
+ }
+
+ // Tell them all to enter, then leave.
+ close(enter)
+ done1.Wait()
+
+ close(leave)
+ done2.Wait()
+
+ // Close the gate, then have the workers try to enter again.
+ g.Close()
+ close(reenter)
+ done3.Wait()
+}
+
+func closeFunc(g *gate.Gate, done chan struct{}) {
+ g.Close()
+ close(done)
+}
+
+func TestCloseWaits(t *testing.T) {
+ var g gate.Gate
+
+ // Enter 10 times.
+ for i := 0; i < 10; i++ {
+ if !g.Enter() {
+ t.Fatalf("Failed to enter when it should be allowed")
+ }
+ }
+
+ // Launch closer. Check that it doesn't complete.
+ done := make(chan struct{})
+ go closeFunc(&g, done)
+
+ for i := 0; i < 10; i++ {
+ select {
+ case <-done:
+ t.Fatalf("Close function completed too soon")
+ case <-time.After(100 * time.Millisecond):
+ }
+
+ g.Leave()
+ }
+
+ // Now the closer must complete.
+ <-done
+}
+
+func TestMultipleSerialCloses(t *testing.T) {
+ var g gate.Gate
+
+ // Enter 10 times.
+ for i := 0; i < 10; i++ {
+ if !g.Enter() {
+ t.Fatalf("Failed to enter when it should be allowed")
+ }
+ }
+
+ // Launch closer. Check that it doesn't complete.
+ done := make(chan struct{})
+ go closeFunc(&g, done)
+
+ for i := 0; i < 10; i++ {
+ select {
+ case <-done:
+ t.Fatalf("Close function completed too soon")
+ case <-time.After(100 * time.Millisecond):
+ }
+
+ g.Leave()
+ }
+
+ // Now the closer must complete.
+ <-done
+
+ // Close again should not block.
+ done = make(chan struct{})
+ go closeFunc(&g, done)
+
+ select {
+ case <-done:
+ case <-time.After(2 * time.Second):
+ t.Fatalf("Second Close is blocking")
+ }
+}
+
+func worker(g *gate.Gate, done *sync.WaitGroup) {
+ for {
+ if !g.Enter() {
+ break
+ }
+ g.Leave()
+ }
+ done.Done()
+}
+
+func TestConcurrentAll(t *testing.T) {
+ var g gate.Gate
+ var done sync.WaitGroup
+
+ // Launch 1000 goroutines to concurrently enter/leave.
+ done.Add(1000)
+ for i := 0; i < 1000; i++ {
+ go worker(&g, &done)
+ }
+
+ // Wait for the goroutines to do some work, then close the gate.
+ time.Sleep(2 * time.Second)
+ g.Close()
+
+ // Wait for all of them to complete.
+ done.Wait()
+}
diff --git a/ilist/list.go b/ilist/list.go
index 739575f..4ae02ee 100644
--- a/ilist/list.go
+++ b/ilist/list.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package ilist provides the implementation of intrusive linked lists.
package ilist
@@ -11,12 +21,34 @@
// N.B. When substituted in a template instantiation, Linker doesn't need to
// be an interface, and in most cases won't be.
type Linker interface {
- Next() Linker
- Prev() Linker
- SetNext(Linker)
- SetPrev(Linker)
+ Next() Element
+ Prev() Element
+ SetNext(Element)
+ SetPrev(Element)
}
+// Element the item that is used at the API level.
+//
+// N.B. Like Linker, this is unlikely to be an interface in most cases.
+type Element interface {
+ Linker
+}
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type ElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (ElementMapper) linkerFor(elem Element) Linker { return elem }
+
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
//
@@ -26,9 +58,11 @@
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
+//
+// +stateify savable
type List struct {
- head Linker
- tail Linker
+ head Element
+ tail Element
}
// Reset resets list l to the empty state.
@@ -43,22 +77,22 @@
}
// Front returns the first element of list l or nil.
-func (l *List) Front() Linker {
+func (l *List) Front() Element {
return l.head
}
// Back returns the last element of list l or nil.
-func (l *List) Back() Linker {
+func (l *List) Back() Element {
return l.tail
}
// PushFront inserts the element e at the front of list l.
-func (l *List) PushFront(e Linker) {
- e.SetNext(l.head)
- e.SetPrev(nil)
+func (l *List) PushFront(e Element) {
+ ElementMapper{}.linkerFor(e).SetNext(l.head)
+ ElementMapper{}.linkerFor(e).SetPrev(nil)
if l.head != nil {
- l.head.SetPrev(e)
+ ElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
l.tail = e
}
@@ -67,12 +101,12 @@
}
// PushBack inserts the element e at the back of list l.
-func (l *List) PushBack(e Linker) {
- e.SetNext(nil)
- e.SetPrev(l.tail)
+func (l *List) PushBack(e Element) {
+ ElementMapper{}.linkerFor(e).SetNext(nil)
+ ElementMapper{}.linkerFor(e).SetPrev(l.tail)
if l.tail != nil {
- l.tail.SetNext(e)
+ ElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
l.head = e
}
@@ -86,8 +120,8 @@
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
- l.tail.SetNext(m.head)
- m.head.SetPrev(l.tail)
+ ElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ ElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail
}
@@ -97,46 +131,46 @@
}
// InsertAfter inserts e after b.
-func (l *List) InsertAfter(b, e Linker) {
- a := b.Next()
- e.SetNext(a)
- e.SetPrev(b)
- b.SetNext(e)
+func (l *List) InsertAfter(b, e Element) {
+ a := ElementMapper{}.linkerFor(b).Next()
+ ElementMapper{}.linkerFor(e).SetNext(a)
+ ElementMapper{}.linkerFor(e).SetPrev(b)
+ ElementMapper{}.linkerFor(b).SetNext(e)
if a != nil {
- a.SetPrev(e)
+ ElementMapper{}.linkerFor(a).SetPrev(e)
} else {
l.tail = e
}
}
// InsertBefore inserts e before a.
-func (l *List) InsertBefore(a, e Linker) {
- b := a.Prev()
- e.SetNext(a)
- e.SetPrev(b)
- a.SetPrev(e)
+func (l *List) InsertBefore(a, e Element) {
+ b := ElementMapper{}.linkerFor(a).Prev()
+ ElementMapper{}.linkerFor(e).SetNext(a)
+ ElementMapper{}.linkerFor(e).SetPrev(b)
+ ElementMapper{}.linkerFor(a).SetPrev(e)
if b != nil {
- b.SetNext(e)
+ ElementMapper{}.linkerFor(b).SetNext(e)
} else {
l.head = e
}
}
// Remove removes e from l.
-func (l *List) Remove(e Linker) {
- prev := e.Prev()
- next := e.Next()
+func (l *List) Remove(e Element) {
+ prev := ElementMapper{}.linkerFor(e).Prev()
+ next := ElementMapper{}.linkerFor(e).Next()
if prev != nil {
- prev.SetNext(next)
+ ElementMapper{}.linkerFor(prev).SetNext(next)
} else {
l.head = next
}
if next != nil {
- next.SetPrev(prev)
+ ElementMapper{}.linkerFor(next).SetPrev(prev)
} else {
l.tail = prev
}
@@ -145,27 +179,29 @@
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
+//
+// +stateify savable
type Entry struct {
- next Linker
- prev Linker
+ next Element
+ prev Element
}
// Next returns the entry that follows e in the list.
-func (e *Entry) Next() Linker {
+func (e *Entry) Next() Element {
return e.next
}
// Prev returns the entry that precedes e in the list.
-func (e *Entry) Prev() Linker {
+func (e *Entry) Prev() Element {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
-func (e *Entry) SetNext(entry Linker) {
- e.next = entry
+func (e *Entry) SetNext(elem Element) {
+ e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
-func (e *Entry) SetPrev(entry Linker) {
- e.prev = entry
+func (e *Entry) SetPrev(elem Element) {
+ e.prev = elem
}
diff --git a/rand/rand.go b/rand/rand.go
new file mode 100644
index 0000000..e81f0f5
--- /dev/null
+++ b/rand/rand.go
@@ -0,0 +1,29 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !linux
+
+// Package rand implements a cryptographically secure pseudorandom number
+// generator.
+package rand
+
+import "crypto/rand"
+
+// Reader is the default reader.
+var Reader = rand.Reader
+
+// Read implements io.Reader.Read.
+func Read(b []byte) (int, error) {
+ return rand.Read(b)
+}
diff --git a/rand/rand_linux.go b/rand/rand_linux.go
new file mode 100644
index 0000000..37ac076
--- /dev/null
+++ b/rand/rand_linux.go
@@ -0,0 +1,39 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package rand implements a cryptographically secure pseudorandom number
+// generator.
+package rand
+
+import (
+ "io"
+
+ "golang.org/x/sys/unix"
+)
+
+// reader implements an io.Reader that returns pseudorandom bytes.
+type reader struct{}
+
+// Read implements io.Reader.Read.
+func (reader) Read(p []byte) (int, error) {
+ return unix.Getrandom(p, 0)
+}
+
+// Reader is the default reader.
+var Reader io.Reader = reader{}
+
+// Read reads from the default reader.
+func Read(b []byte) (int, error) {
+ return io.ReadFull(Reader, b)
+}
diff --git a/sleep/commit_amd64.s b/sleep/commit_amd64.s
index a5b620b..d525e5b 100644
--- a/sleep/commit_amd64.s
+++ b/sleep/commit_amd64.s
@@ -1,3 +1,17 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
#include "textflag.h"
#define preparingG 1
diff --git a/sleep/commit_arm64.s b/sleep/commit_arm64.s
index 9d351d0..8aca31b 100644
--- a/sleep/commit_arm64.s
+++ b/sleep/commit_arm64.s
@@ -1,5 +1,15 @@
-// Copyright 2017 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Empty assembly file so empty func definitions work.
diff --git a/sleep/commit_asm.go b/sleep/commit_asm.go
index b7589df..39a55df 100644
--- a/sleep/commit_asm.go
+++ b/sleep/commit_asm.go
@@ -1,6 +1,16 @@
-// Copyright 2017 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// +build amd64
diff --git a/sleep/commit_noasm.go b/sleep/commit_noasm.go
index 22c734e..584866c 100644
--- a/sleep/commit_noasm.go
+++ b/sleep/commit_noasm.go
@@ -1,6 +1,16 @@
-// Copyright 2017 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// +build !race
// +build !amd64
diff --git a/sleep/sleep_test.go b/sleep/sleep_test.go
index eba1801..bc17383 100644
--- a/sleep/sleep_test.go
+++ b/sleep/sleep_test.go
@@ -1,3 +1,17 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package sleep
import (
diff --git a/sleep/sleep_unsafe.go b/sleep/sleep_unsafe.go
index 5ecb7a3..b12cce6 100644
--- a/sleep/sleep_unsafe.go
+++ b/sleep/sleep_unsafe.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package sleep allows goroutines to efficiently sleep on multiple sources of
// notifications (wakers). It offers O(1) complexity, which is different from
diff --git a/tcpip/adapters/gonet/gonet.go b/tcpip/adapters/gonet/gonet.go
index 2094bda..7c40b04 100644
--- a/tcpip/adapters/gonet/gonet.go
+++ b/tcpip/adapters/gonet/gonet.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package gonet provides a Go net package compatible wrapper for a tcpip stack.
package gonet
@@ -136,24 +146,35 @@
*cancelCh = make(chan struct{})
}
+ // Create a new channel if we already closed it due to setting an already
+ // expired time. We won't race with the timer because we already handled
+ // that above.
+ select {
+ case <-*cancelCh:
+ *cancelCh = make(chan struct{})
+ default:
+ }
+
// "A zero value for t means I/O operations will not time out."
// - net.Conn.SetDeadline
- if !t.IsZero() {
- timeout := t.Sub(time.Now())
- if timeout <= 0 {
- close(*cancelCh)
- return
- }
-
- // Timer.Stop returns whether or not the AfterFunc has started, but
- // does not indicate whether or not it has completed. Make a copy of
- // the cancel channel to prevent this code from racing with the next
- // call of setDeadline replacing *cancelCh.
- ch := *cancelCh
- *timer = time.AfterFunc(timeout, func() {
- close(ch)
- })
+ if t.IsZero() {
+ return
}
+
+ timeout := t.Sub(time.Now())
+ if timeout <= 0 {
+ close(*cancelCh)
+ return
+ }
+
+ // Timer.Stop returns whether or not the AfterFunc has started, but
+ // does not indicate whether or not it has completed. Make a copy of
+ // the cancel channel to prevent this code from racing with the next
+ // call of setDeadline replacing *cancelCh.
+ ch := *cancelCh
+ *timer = time.AfterFunc(timeout, func() {
+ close(ch)
+ })
}
// SetReadDeadline implements net.Conn.SetReadDeadline and
@@ -195,7 +216,7 @@
//
// Lock ordering:
// If both readMu and deadlineTimer.mu are to be used in a single
- // request, readMu must be aquired before deadlineTimer.mu.
+ // request, readMu must be acquired before deadlineTimer.mu.
readMu sync.Mutex
// read contains bytes that have been read from the endpoint,
@@ -257,7 +278,7 @@
// commonRead implements the common logic between net.Conn.Read and
// net.PacketConn.ReadFrom.
func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) ([]byte, error) {
- read, err := ep.Read(addr)
+ read, _, err := ep.Read(addr)
if err == tcpip.ErrWouldBlock {
// Create wait queue entry that notifies a channel.
@@ -265,7 +286,7 @@
wq.EventRegister(&waitEntry, waiter.EventIn)
defer wq.EventUnregister(&waitEntry)
for {
- read, err = ep.Read(addr)
+ read, _, err = ep.Read(addr)
if err != tcpip.ErrWouldBlock {
break
}
@@ -329,8 +350,7 @@
default:
}
- v := buffer.NewView(len(b))
- copy(v, b)
+ v := buffer.NewViewFromBytes(b)
// We must handle two soft failure conditions simultaneously:
// 1. Write may write nothing and return tcpip.ErrWouldBlock.
@@ -366,14 +386,14 @@
// the notification.
select {
case <-deadline:
- return 0, c.newOpError("write", &timeoutError{})
+ return nbytes, c.newOpError("write", &timeoutError{})
case <-notifyCh:
}
}
}
var n uintptr
- n, err = c.ep.Write(v, nil)
+ n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
nbytes += int(n)
v.TrimFront(int(n))
}
@@ -382,7 +402,7 @@
return nbytes, nil
}
- return 0, c.newOpError("write", errors.New(err.String()))
+ return nbytes, c.newOpError("write", errors.New(err.String()))
}
// Close implements net.Conn.Close.
@@ -444,20 +464,19 @@
defer wq.EventUnregister(&waitEntry)
err = ep.Connect(addr)
- for err != nil {
- if err != tcpip.ErrConnectStarted {
- ep.Close()
- return nil, &net.OpError{
- Op: "connect",
- Net: "tcp",
- Addr: fullToTCPAddr(addr),
- Err: errors.New(err.String()),
- }
- }
-
+ if err == tcpip.ErrConnectStarted {
<-notifyCh
err = ep.GetSockOpt(tcpip.ErrorOption{})
}
+ if err != nil {
+ ep.Close()
+ return nil, &net.OpError{
+ Op: "connect",
+ Net: "tcp",
+ Addr: fullToTCPAddr(addr),
+ Err: errors.New(err.String()),
+ }
+ }
return NewConn(&wq, ep), nil
}
@@ -551,7 +570,8 @@
v := buffer.NewView(len(b))
copy(v, b)
- n, err := c.ep.Write(v, &fullAddr)
+ wopts := tcpip.WriteOptions{To: &fullAddr}
+ n, err := c.ep.Write(tcpip.SlicePayload(v), wopts)
if err == tcpip.ErrWouldBlock {
// Create wait queue entry that notifies a channel.
@@ -559,13 +579,13 @@
c.wq.EventRegister(&waitEntry, waiter.EventOut)
defer c.wq.EventUnregister(&waitEntry)
for {
- n, err = c.ep.Write(v, &fullAddr)
+ n, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
if err != tcpip.ErrWouldBlock {
break
}
select {
case <-deadline:
- return 0, c.newRemoteOpError("write", addr, &timeoutError{})
+ return int(n), c.newRemoteOpError("write", addr, &timeoutError{})
case <-notifyCh:
}
}
@@ -575,7 +595,7 @@
return int(n), nil
}
- return 0, c.newRemoteOpError("write", addr, errors.New(err.String()))
+ return int(n), c.newRemoteOpError("write", addr, errors.New(err.String()))
}
// Close implements net.PacketConn.Close.
diff --git a/tcpip/adapters/gonet/gonet_test.go b/tcpip/adapters/gonet/gonet_test.go
index 87a519a..438d3d7 100644
--- a/tcpip/adapters/gonet/gonet_test.go
+++ b/tcpip/adapters/gonet/gonet_test.go
@@ -1,10 +1,21 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package gonet
import (
+ "fmt"
"net"
"reflect"
"strings"
@@ -19,6 +30,7 @@
"github.com/google/netstack/tcpip/transport/tcp"
"github.com/google/netstack/tcpip/transport/udp"
"github.com/google/netstack/waiter"
+ "golang.org/x/net/nettest"
)
const (
@@ -45,7 +57,7 @@
func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
// Create the stack and add a NIC.
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName, udp.ProtocolName})
+ s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{})
if err := s.CreateNIC(NICID, loopback.New()); err != nil {
return nil, err
@@ -132,20 +144,16 @@
// Give c.Read() a chance to block before closing the connection.
time.AfterFunc(time.Millisecond*50, func() {
- t.Log("c.Close()")
c.Close()
- t.Log("c.Close() ok")
})
buf := make([]byte, 256)
- t.Log("c.Read()")
n, err := c.Read(buf)
got, ok := err.(*net.OpError)
want := tcpip.ErrConnectionAborted
if n != 0 || !ok || got.Err.Error() != want.String() {
t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want)
}
- t.Logf("c.Read() = %d, %v", n, err)
}()
sender, err := connect(s, addr)
if err != nil {
@@ -188,20 +196,16 @@
// Give c.Read() a chance to block before closing the connection.
time.AfterFunc(time.Millisecond*50, func() {
- t.Log("c.Close()")
c.Close()
- t.Log("c.Close() ok")
})
buf := make([]byte, 256)
- t.Log("c.Read()")
n, e := c.Read(buf)
got, ok := e.(*net.OpError)
want := tcpip.ErrConnectionAborted
if n != 0 || !ok || got.Err.Error() != want.String() {
t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, e, want)
}
- t.Logf("c.Read() = %d, %v", n, e)
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
@@ -244,20 +248,16 @@
c.SetDeadline(time.Now().Add(time.Minute))
// Give c.Read() a chance to block before closing the connection.
time.AfterFunc(time.Millisecond*50, func() {
- t.Log("c.SetDeadline()")
c.SetDeadline(time.Now().Add(time.Millisecond * 10))
- t.Log("c.SetDeadline() ok")
})
buf := make([]byte, 256)
- t.Log("c.Read()")
n, err := c.Read(buf)
got, ok := err.(*net.OpError)
want := "i/o timeout"
if n != 0 || !ok || got.Err == nil || got.Err.Error() != want {
t.Errorf("c.Read() = (%d, %v), want (0, OpError(%s))", n, err, want)
}
- t.Logf("c.Read() = %d, %v", n, err)
}()
sender, err := connect(s, addr)
if err != nil {
@@ -324,10 +324,10 @@
}
}
-func TestTCPConnTransfer(t *testing.T) {
+func makePipe() (c1, c2 net.Conn, stop func(), err error) {
s, e := newLoopbackStack()
if e != nil {
- t.Fatalf("newLoopbackStack() = %v", e)
+ return nil, nil, nil, fmt.Errorf("newLoopbackStack() = %v", e)
}
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
@@ -336,29 +336,44 @@
l, err := NewListener(s, addr, ipv4.ProtocolNumber)
if err != nil {
- t.Fatal("NewListener:", err)
+ return nil, nil, nil, fmt.Errorf("NewListener: %v", err)
}
- defer func() {
- if err := l.Close(); err != nil {
- t.Error("l.Close():", err)
- }
- }()
- c1, err := DialTCP(s, addr, ipv4.ProtocolNumber)
+ c1, err = DialTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
- t.Fatal("DialTCP:", err)
+ l.Close()
+ return nil, nil, nil, fmt.Errorf("DialTCP: %v", err)
+ }
+
+ c2, err = l.Accept()
+ if err != nil {
+ l.Close()
+ c1.Close()
+ return nil, nil, nil, fmt.Errorf("l.Accept: %v", err)
+ }
+
+ stop = func() {
+ c1.Close()
+ c2.Close()
+ }
+
+ if err := l.Close(); err != nil {
+ stop()
+ return nil, nil, nil, fmt.Errorf("l.Close(): %v", err)
+ }
+
+ return c1, c2, stop, nil
+}
+
+func TestTCPConnTransfer(t *testing.T) {
+ c1, c2, _, err := makePipe()
+ if err != nil {
+ t.Fatal(err)
}
defer func() {
if err := c1.Close(); err != nil {
t.Error("c1.Close():", err)
}
- }()
-
- c2, err := l.Accept()
- if err != nil {
- t.Fatal("l.Accept:", err)
- }
- defer func() {
if err := c2.Close(); err != nil {
t.Error("c2.Close():", err)
}
@@ -413,3 +428,7 @@
t.Errorf("Got DialTCP() = %v, want = %v", err, tcpip.ErrNoRoute)
}
}
+
+func TestNetTest(t *testing.T) {
+ nettest.TestConn(t, makePipe)
+}
diff --git a/tcpip/buffer/prependable.go b/tcpip/buffer/prependable.go
index fd84585..c5dd281 100644
--- a/tcpip/buffer/prependable.go
+++ b/tcpip/buffer/prependable.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package buffer
@@ -22,6 +32,26 @@
return Prependable{buf: NewView(size), usedIdx: size}
}
+// NewPrependableFromView creates an entirely-used Prependable from a View.
+//
+// NewPrependableFromView takes ownership of v. Note that since the entire
+// prependable is used, further attempts to call Prepend will note that size >
+// p.usedIdx and return nil.
+func NewPrependableFromView(v View) Prependable {
+ return Prependable{buf: v, usedIdx: 0}
+}
+
+// View returns a View of the backing buffer that contains all prepended
+// data so far.
+func (p Prependable) View() View {
+ return p.buf[p.usedIdx:]
+}
+
+// UsedLength returns the number of bytes used so far.
+func (p Prependable) UsedLength() int {
+ return len(p.buf) - p.usedIdx
+}
+
// Prepend reserves the requested space in front of the buffer, returning a
// slice that represents the reserved space.
func (p *Prependable) Prepend(size int) []byte {
@@ -30,24 +60,5 @@
}
p.usedIdx -= size
- return p.buf[p.usedIdx:][:size:size]
-}
-
-// View returns a View of the backing buffer that contains all prepended
-// data so far.
-func (p *Prependable) View() View {
- v := p.buf
- v.TrimFront(p.usedIdx)
- return v
-}
-
-// UsedBytes returns a slice of the backing buffer that contains all prepended
-// data so far.
-func (p *Prependable) UsedBytes() []byte {
- return p.buf[p.usedIdx:]
-}
-
-// UsedLength returns the number of bytes used so far.
-func (p *Prependable) UsedLength() int {
- return len(p.buf) - p.usedIdx
+ return p.View()[:size:size]
}
diff --git a/tcpip/buffer/view.go b/tcpip/buffer/view.go
index 46440cc..cea4e36 100644
--- a/tcpip/buffer/view.go
+++ b/tcpip/buffer/view.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package buffer provides the implementation of a buffer view.
package buffer
@@ -14,6 +24,11 @@
return make(View, size)
}
+// NewViewFromBytes allocates a new buffer and copies in the given bytes.
+func NewViewFromBytes(b []byte) View {
+ return append(View(nil), b...)
+}
+
// TrimFront removes the first "count" bytes from the visible section of the
// buffer.
func (v *View) TrimFront(count int) {
@@ -30,15 +45,15 @@
*v = (*v)[:length:length]
}
-// ToVectorisedView transforms a View in a VectorisedView from an
-// already-allocated slice of View.
-func (v *View) ToVectorisedView(views [1]View) VectorisedView {
- views[0] = *v
- return NewVectorisedView(len(*v), views[:])
+// ToVectorisedView returns a VectorisedView containing the receiver.
+func (v View) ToVectorisedView() VectorisedView {
+ return NewVectorisedView(len(v), []View{v})
}
// VectorisedView is a vectorised version of View using non contigous memory.
// It supports all the convenience methods supported by View.
+//
+// +stateify savable
type VectorisedView struct {
views []View
size int
@@ -90,24 +105,14 @@
// Clone returns a clone of this VectorisedView.
// If the buffer argument is large enough to contain all the Views of this VectorisedView,
// the method will avoid allocations and use the buffer to store the Views of the clone.
-func (vv *VectorisedView) Clone(buffer []View) VectorisedView {
- var views []View
- if len(buffer) >= len(vv.views) {
- views = buffer[:len(vv.views)]
- } else {
- views = make([]View, len(vv.views))
- }
- for i, v := range vv.views {
- views[i] = v
- }
- return VectorisedView{views: views, size: vv.size}
+func (vv VectorisedView) Clone(buffer []View) VectorisedView {
+ return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size}
}
// First returns the first view of the vectorised view.
-// It panics if the vectorised view is empty.
-func (vv *VectorisedView) First() View {
+func (vv VectorisedView) First() View {
if len(vv.views) == 0 {
- panic("vview is empty")
+ return nil
}
return vv.views[0]
}
@@ -121,56 +126,21 @@
vv.views = vv.views[1:]
}
-// SetSize unsafely sets the size of the VectorisedView.
-func (vv *VectorisedView) SetSize(size int) {
- vv.size = size
-}
-
-// SetViews unsafely sets the views of the VectorisedView.
-func (vv *VectorisedView) SetViews(views []View) {
- vv.views = views
-}
-
// Size returns the size in bytes of the entire content stored in the vectorised view.
-func (vv *VectorisedView) Size() int {
+func (vv VectorisedView) Size() int {
return vv.size
}
-// ToView returns the a single view containing the content of the vectorised view.
-func (vv *VectorisedView) ToView() View {
- v := make([]byte, vv.size)
- u := v
- for i := range vv.views {
- n := copy(u, vv.views[i])
- u = u[n:]
+// ToView returns a single view containing the content of the vectorised view.
+func (vv VectorisedView) ToView() View {
+ u := make([]byte, 0, vv.size)
+ for _, v := range vv.views {
+ u = append(u, v...)
}
- return v
+ return u
}
// Views returns the slice containing the all views.
-func (vv *VectorisedView) Views() []View {
+func (vv VectorisedView) Views() []View {
return vv.views
}
-
-// ByteSlice returns a slice containing the all views as a []byte.
-func (vv *VectorisedView) ByteSlice() [][]byte {
- s := make([][]byte, len(vv.views))
- for i := range vv.views {
- s[i] = []byte(vv.views[i])
- }
- return s
-}
-
-// copy returns a deep-copy of the vectorised view.
-// It is an expensive method that should be used only in tests.
-func (vv *VectorisedView) copy() *VectorisedView {
- uu := &VectorisedView{
- views: make([]View, len(vv.views)),
- size: vv.size,
- }
- for i, v := range vv.views {
- uu.views[i] = make(View, len(v))
- copy(uu.views[i], v)
- }
- return uu
-}
diff --git a/tcpip/buffer/view_test.go b/tcpip/buffer/view_test.go
index ff8535b..02c2645 100644
--- a/tcpip/buffer/view_test.go
+++ b/tcpip/buffer/view_test.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package buffer_test contains tests for the VectorisedView type.
package buffer
@@ -10,22 +20,33 @@
"testing"
)
+// copy returns a deep-copy of the vectorised view.
+func (vv VectorisedView) copy() VectorisedView {
+ uu := VectorisedView{
+ views: make([]View, 0, len(vv.views)),
+ size: vv.size,
+ }
+ for _, v := range vv.views {
+ uu.views = append(uu.views, append(View(nil), v...))
+ }
+ return uu
+}
+
// vv is an helper to build VectorisedView from different strings.
-func vv(size int, pieces ...string) *VectorisedView {
+func vv(size int, pieces ...string) VectorisedView {
views := make([]View, len(pieces))
for i, p := range pieces {
views[i] = []byte(p)
}
- vv := NewVectorisedView(size, views)
- return &vv
+ return NewVectorisedView(size, views)
}
var capLengthTestCases = []struct {
comment string
- in *VectorisedView
+ in VectorisedView
length int
- want *VectorisedView
+ want VectorisedView
}{
{
comment: "Simple case",
@@ -78,9 +99,9 @@
var trimFrontTestCases = []struct {
comment string
- in *VectorisedView
+ in VectorisedView
count int
- want *VectorisedView
+ want VectorisedView
}{
{
comment: "Simple case",
@@ -139,7 +160,7 @@
var toViewCases = []struct {
comment string
- in *VectorisedView
+ in VectorisedView
want View
}{
{
@@ -171,7 +192,7 @@
var toCloneCases = []struct {
comment string
- inView *VectorisedView
+ inView VectorisedView
inBuffer []View
}{
{
@@ -203,10 +224,12 @@
func TestToClone(t *testing.T) {
for _, c := range toCloneCases {
- got := c.inView.Clone(c.inBuffer)
- if !reflect.DeepEqual(&got, c.inView) {
- t.Errorf("Test \"%s\" failed when calling Clone(%v) on %v. Got %v. Want %v",
- c.comment, c.inBuffer, c.inView, got, c.inView)
- }
+ t.Run(c.comment, func(t *testing.T) {
+ got := c.inView.Clone(c.inBuffer)
+ if !reflect.DeepEqual(got, c.inView) {
+ t.Fatalf("got (%+v).Clone(%+v) = %+v, want = %+v",
+ c.inView, c.inBuffer, got, c.inView)
+ }
+ })
}
}
diff --git a/tcpip/checker/checker.go b/tcpip/checker/checker.go
index 921cbfe..164dc29 100644
--- a/tcpip/checker/checker.go
+++ b/tcpip/checker/checker.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package checker provides helper functions to check networking packets for
// validity.
@@ -13,6 +23,7 @@
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/header"
+ "github.com/google/netstack/tcpip/seqnum"
)
// NetworkChecker is a function to check a property of a network packet.
@@ -28,40 +39,52 @@
//
// checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y))
func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) {
+ t.Helper()
+
ipv4 := header.IPv4(b)
if !ipv4.IsValid(len(b)) {
- t.Fatalf("Not a valid IPv4 packet")
+ t.Error("Not a valid IPv4 packet")
}
xsum := ipv4.CalculateChecksum()
if xsum != 0 && xsum != 0xffff {
- t.Fatalf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum())
+ t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum())
}
for _, f := range checkers {
f(t, []header.Network{ipv4})
}
+ if t.Failed() {
+ t.FailNow()
+ }
}
// IPv6 checks the validity and properties of the given IPv6 packet. The usage
// is similar to IPv4.
func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) {
+ t.Helper()
+
ipv6 := header.IPv6(b)
if !ipv6.IsValid(len(b)) {
- t.Fatalf("Not a valid IPv6 packet")
+ t.Error("Not a valid IPv6 packet")
}
for _, f := range checkers {
f(t, []header.Network{ipv6})
}
+ if t.Failed() {
+ t.FailNow()
+ }
}
// SrcAddr creates a checker that checks the source address.
func SrcAddr(addr tcpip.Address) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
if a := h[0].SourceAddress(); a != addr {
- t.Fatalf("Bad source address, got %v, want %v", a, addr)
+ t.Errorf("Bad source address, got %v, want %v", a, addr)
}
}
}
@@ -69,8 +92,10 @@
// DstAddr creates a checker that checks the destination address.
func DstAddr(addr tcpip.Address) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
if a := h[0].DestinationAddress(); a != addr {
- t.Fatalf("Bad destination address, got %v, want %v", a, addr)
+ t.Errorf("Bad destination address, got %v, want %v", a, addr)
}
}
}
@@ -94,8 +119,10 @@
// PayloadLen creates a checker that checks the payload length.
func PayloadLen(plen int) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
if l := len(h[0].Payload()); l != plen {
- t.Fatalf("Bad payload length, got %v, want %v", l, plen)
+ t.Errorf("Bad payload length, got %v, want %v", l, plen)
}
}
}
@@ -103,11 +130,13 @@
// FragmentOffset creates a checker that checks the FragmentOffset field.
func FragmentOffset(offset uint16) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
// We only do this of IPv4 for now.
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.FragmentOffset(); v != offset {
- t.Fatalf("Bad fragment offset, got %v, want %v", v, offset)
+ t.Errorf("Bad fragment offset, got %v, want %v", v, offset)
}
}
}
@@ -116,11 +145,13 @@
// FragmentFlags creates a checker that checks the fragment flags field.
func FragmentFlags(flags uint8) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
// We only do this of IPv4 for now.
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.Flags(); v != flags {
- t.Fatalf("Bad fragment offset, got %v, want %v", v, flags)
+ t.Errorf("Bad fragment offset, got %v, want %v", v, flags)
}
}
}
@@ -129,8 +160,10 @@
// TOS creates a checker that checks the TOS field.
func TOS(tos uint8, label uint32) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
if v, l := h[0].TOS(); v != tos || l != label {
- t.Fatalf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label)
+ t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label)
}
}
}
@@ -142,8 +175,10 @@
// the bytes added by the IPv6 fragmentation.
func Raw(want []byte) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) {
- t.Fatalf("Wrong payload, got %v, want %v", got, want)
+ t.Errorf("Wrong payload, got %v, want %v", got, want)
}
}
}
@@ -151,18 +186,23 @@
// IPv6Fragment creates a checker that validates an IPv6 fragment.
func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
- t.Fatalf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
}
ipv6Frag := header.IPv6Fragment(h[0].Payload())
if !ipv6Frag.IsValid() {
- t.Fatalf("Not a valid IPv6 fragment")
+ t.Error("Not a valid IPv6 fragment")
}
for _, f := range checkers {
f(t, []header.Network{h[0], ipv6Frag})
}
+ if t.Failed() {
+ t.FailNow()
+ }
}
}
@@ -170,11 +210,13 @@
// potentially additional transport header fields.
func TCP(checkers ...TransportChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
first := h[0]
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
- t.Fatalf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber)
+ t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber)
}
// Verify the checksum.
@@ -188,13 +230,16 @@
xsum = header.Checksum(tcp, xsum)
if xsum != 0 && xsum != 0xffff {
- t.Fatalf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum())
+ t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum())
}
// Run the transport checkers.
for _, f := range checkers {
f(t, tcp)
}
+ if t.Failed() {
+ t.FailNow()
+ }
}
}
@@ -202,24 +247,31 @@
// potentially additional transport header fields.
func UDP(checkers ...TransportChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
- t.Fatalf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
}
udp := header.UDP(last.Payload())
for _, f := range checkers {
f(t, udp)
}
+ if t.Failed() {
+ t.FailNow()
+ }
}
}
// SrcPort creates a checker that checks the source port.
func SrcPort(port uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
if p := h.SourcePort(); p != port {
- t.Fatalf("Bad source port, got %v, want %v", p, port)
+ t.Errorf("Bad source port, got %v, want %v", p, port)
}
}
}
@@ -228,7 +280,7 @@
func DstPort(port uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
if p := h.DestinationPort(); p != port {
- t.Fatalf("Bad destination port, got %v, want %v", p, port)
+ t.Errorf("Bad destination port, got %v, want %v", p, port)
}
}
}
@@ -236,13 +288,15 @@
// SeqNum creates a checker that checks the sequence number.
func SeqNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
}
if s := tcp.SequenceNumber(); s != seq {
- t.Fatalf("Bad sequence number, got %v, want %v", s, seq)
+ t.Errorf("Bad sequence number, got %v, want %v", s, seq)
}
}
}
@@ -250,13 +304,14 @@
// AckNum creates a checker that checks the ack number.
func AckNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
return
}
if s := tcp.AckNumber(); s != seq {
- t.Fatalf("Bad ack number, got %v, want %v", s, seq)
+ t.Errorf("Bad ack number, got %v, want %v", s, seq)
}
}
}
@@ -270,7 +325,7 @@
}
if w := tcp.WindowSize(); w != window {
- t.Fatalf("Bad window, got 0x%x, want 0x%x", w, window)
+ t.Errorf("Bad window, got 0x%x, want 0x%x", w, window)
}
}
}
@@ -278,13 +333,15 @@
// TCPFlags creates a checker that checks the tcp flags.
func TCPFlags(flags uint8) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
}
if f := tcp.Flags(); f != flags {
- t.Fatalf("Bad flags, got 0x%x, want 0x%x", f, flags)
+ t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags)
}
}
}
@@ -299,7 +356,7 @@
}
if f := tcp.Flags(); (f & mask) != (flags & mask) {
- t.Fatalf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask)
+ t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask)
}
}
}
@@ -319,6 +376,7 @@
foundMSS := false
foundWS := false
foundTS := false
+ foundSACKPermitted := false
tsVal := uint32(0)
tsEcr := uint32(0)
for i := 0; i < limit; {
@@ -330,23 +388,26 @@
case header.TCPOptionMSS:
v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
if wantOpts.MSS != v {
- t.Fatalf("Bad MSS: got %v, want %v", v, wantOpts.MSS)
+ t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS)
}
foundMSS = true
i += 4
case header.TCPOptionWS:
if wantOpts.WS < 0 {
- t.Fatalf("WS present when it shouldn't be")
+ t.Error("WS present when it shouldn't be")
}
v := int(opts[i+2])
if v != wantOpts.WS {
- t.Fatalf("Bad WS: got %v, want %v", v, wantOpts.WS)
+ t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS)
}
foundWS = true
i += 3
case header.TCPOptionTS:
- if i+10 > limit || opts[i+1] != 10 {
- t.Fatalf("bad length %d for TS option, limit: %d", opts[i+1], limit)
+ if i+9 >= limit {
+ t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i)
+ }
+ if opts[i+1] != 10 {
+ t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit)
}
tsVal = binary.BigEndian.Uint32(opts[i+2:])
tsEcr = uint32(0)
@@ -357,26 +418,39 @@
}
foundTS = true
i += 10
+ case header.TCPOptionSACKPermitted:
+ if i+1 >= limit {
+ t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i)
+ }
+ if opts[i+1] != 2 {
+ t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit)
+ }
+ foundSACKPermitted = true
+ i += 2
+
default:
i += int(opts[i+1])
}
}
if !foundMSS {
- t.Fatalf("MSS option not found. Options: %x", opts)
+ t.Errorf("MSS option not found. Options: %x", opts)
}
if !foundWS && wantOpts.WS >= 0 {
- t.Fatalf("WS option not found. Options: %x", opts)
+ t.Errorf("WS option not found. Options: %x", opts)
}
if wantOpts.TS && !foundTS {
- t.Fatalf("TS option not found. Options: %x", opts)
+ t.Errorf("TS option not found. Options: %x", opts)
}
if foundTS && tsVal == 0 {
- t.Fatalf("TS option specified but the timestamp value is zero")
+ t.Error("TS option specified but the timestamp value is zero")
}
if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
- t.Fatalf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr)
+ t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr)
+ }
+ if wantOpts.SACKPermitted && !foundSACKPermitted {
+ t.Errorf("SACKPermitted option not found. Options: %x", opts)
}
}
}
@@ -405,11 +479,11 @@
case header.TCPOptionNOP:
i++
case header.TCPOptionTS:
- if i+10 > limit {
- t.Fatalf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
+ if i+9 >= limit {
+ t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
}
if opts[i+1] != 10 {
- t.Fatalf("TS option found, but bad length specified: %d, want: 10", opts[i+1])
+ t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1])
}
tsVal = binary.BigEndian.Uint32(opts[i+2:])
tsEcr = binary.BigEndian.Uint32(opts[i+6:])
@@ -429,13 +503,77 @@
}
if wantTS != foundTS {
- t.Fatalf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS)
+ t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS)
}
if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
- t.Fatalf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal)
+ t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal)
}
if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
- t.Fatalf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr)
+ t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr)
+ }
+ }
+}
+
+// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not
+// contain any SACK blocks in the TCP options.
+func TCPNoSACKBlockChecker() TransportChecker {
+ return TCPSACKBlockChecker(nil)
+}
+
+// TCPSACKBlockChecker creates a checker that verifies that the segment does
+// contain the specified SACK blocks in the TCP options.
+func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+ var gotSACKBlocks []header.SACKBlock
+
+ opts := []byte(tcp.Options())
+ limit := len(opts)
+ for i := 0; i < limit; {
+ switch opts[i] {
+ case header.TCPOptionEOL:
+ i = limit
+ case header.TCPOptionNOP:
+ i++
+ case header.TCPOptionSACK:
+ if i+2 > limit {
+ // Malformed SACK block.
+ t.Errorf("malformed SACK option in options: %v", opts)
+ }
+ sackOptionLen := int(opts[i+1])
+ if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
+ // Malformed SACK block.
+ t.Errorf("malformed SACK option length in options: %v", opts)
+ }
+ numBlocks := sackOptionLen / 8
+ for j := 0; j < numBlocks; j++ {
+ start := binary.BigEndian.Uint32(opts[i+2+j*8:])
+ end := binary.BigEndian.Uint32(opts[i+2+j*8+4:])
+ gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{
+ Start: seqnum.Value(start),
+ End: seqnum.Value(end),
+ })
+ }
+ i += sackOptionLen
+ default:
+ // We don't recognize this option, just skip over it.
+ if i+2 > limit {
+ break
+ }
+ l := int(opts[i+1])
+ if l < 2 || i+l > limit {
+ break
+ }
+ i += l
+ }
+ }
+
+ if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
+ t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks)
}
}
}
@@ -444,7 +582,7 @@
func Payload(want []byte) TransportChecker {
return func(t *testing.T, h header.Transport) {
if got := h.Payload(); !reflect.DeepEqual(got, want) {
- t.Fatalf("Wrong payload, got %v, want %v", got, want)
+ t.Errorf("Wrong payload, got %v, want %v", got, want)
}
}
}
diff --git a/tcpip/header/arp.go b/tcpip/header/arp.go
index d057ee2..87f5f24 100644
--- a/tcpip/header/arp.go
+++ b/tcpip/header/arp.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package header
diff --git a/tcpip/header/checksum.go b/tcpip/header/checksum.go
index f0e0c18..6397e3d 100644
--- a/tcpip/header/checksum.go
+++ b/tcpip/header/checksum.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package header provides the implementation of the encoding and decoding of
// network protocol headers.
diff --git a/tcpip/header/eth.go b/tcpip/header/eth.go
index eae2d22..49ddfd6 100644
--- a/tcpip/header/eth.go
+++ b/tcpip/header/eth.go
@@ -1,6 +1,16 @@
-// Copyright 2017 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package header
diff --git a/tcpip/header/gue.go b/tcpip/header/gue.go
index a069fb6..aac4593 100644
--- a/tcpip/header/gue.go
+++ b/tcpip/header/gue.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package header
diff --git a/tcpip/header/icmpv4.go b/tcpip/header/icmpv4.go
index 4787097..5b25e42 100644
--- a/tcpip/header/icmpv4.go
+++ b/tcpip/header/icmpv4.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package header
@@ -20,6 +30,10 @@
// ICMPv4EchoMinimumSize is the minimum size of a valid ICMP echo packet.
ICMPv4EchoMinimumSize = 6
+ // ICMPv4DstUnreachableMinimumSize is the minimum size of a valid ICMP
+ // destination unreachable packet.
+ ICMPv4DstUnreachableMinimumSize = ICMPv4MinimumSize + 4
+
// ICMPv4ProtocolNumber is the ICMP transport protocol number.
ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1
)
@@ -42,6 +56,12 @@
ICMPv4InfoReply ICMPv4Type = 16
)
+// Values for ICMP code as defined in RFC 792.
+const (
+ ICMPv4PortUnreachable = 3
+ ICMPv4FragmentationNeeded = 4
+)
+
// Type is the ICMP type field.
func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) }
@@ -54,12 +74,35 @@
// SetCode sets the ICMP code field.
func (b ICMPv4) SetCode(c byte) { b[1] = c }
+// Checksum is the ICMP checksum field.
+func (b ICMPv4) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[2:])
+}
+
// SetChecksum sets the ICMP checksum field.
func (b ICMPv4) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[2:], checksum)
}
-// CalculateChecksum calculates the checksum of the ipv4 header.
-func (b ICMPv4) CalculateChecksum(prev uint16) uint16 {
- return Checksum(b[:], prev)
+// SourcePort implements Transport.SourcePort.
+func (ICMPv4) SourcePort() uint16 {
+ return 0
+}
+
+// DestinationPort implements Transport.DestinationPort.
+func (ICMPv4) DestinationPort() uint16 {
+ return 0
+}
+
+// SetSourcePort implements Transport.SetSourcePort.
+func (ICMPv4) SetSourcePort(uint16) {
+}
+
+// SetDestinationPort implements Transport.SetDestinationPort.
+func (ICMPv4) SetDestinationPort(uint16) {
+}
+
+// Payload implements Transport.Payload.
+func (b ICMPv4) Payload() []byte {
+ return b[ICMPv4MinimumSize:]
}
diff --git a/tcpip/header/icmpv6.go b/tcpip/header/icmpv6.go
index 18d8d4f..e452392 100644
--- a/tcpip/header/icmpv6.go
+++ b/tcpip/header/icmpv6.go
@@ -1,6 +1,16 @@
-// Copyright 2017 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package header
@@ -24,17 +34,25 @@
// neighbor solicitation packet.
ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 4 + 16
- // ICMPv6NeighborSolicitMinimumSize is size of a neighbor advertisement.
+ // ICMPv6NeighborAdvertSize is size of a neighbor advertisement.
ICMPv6NeighborAdvertSize = 32
// ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet.
ICMPv6EchoMinimumSize = 8
+
+ // ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP
+ // destination unreachable packet.
+ ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize + 4
+
+ // ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP
+ // packet-too-big packet.
+ ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + 4
)
// ICMPv6Type is the ICMP type field described in RFC 4443 and friends.
type ICMPv6Type byte
-// Typical values of ICMPv6Type defined in RFC 792.
+// Typical values of ICMPv6Type defined in RFC 4443.
const (
ICMPv6DstUnreachable ICMPv6Type = 1
ICMPv6PacketTooBig ICMPv6Type = 2
@@ -52,6 +70,11 @@
ICMPv6RedirectMsg ICMPv6Type = 137
)
+// Values for ICMP code as defined in RFC 4443.
+const (
+ ICMPv6PortUnreachable = 4
+)
+
// Type is the ICMP type field.
func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) }
@@ -61,7 +84,38 @@
// Code is the ICMP code field. Its meaning depends on the value of Type.
func (b ICMPv6) Code() byte { return b[1] }
+// SetCode sets the ICMP code field.
+func (b ICMPv6) SetCode(c byte) { b[1] = c }
+
+// Checksum is the ICMP checksum field.
+func (b ICMPv6) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[2:])
+}
+
// SetChecksum calculates and sets the ICMP checksum field.
func (b ICMPv6) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[2:], checksum)
}
+
+// SourcePort implements Transport.SourcePort.
+func (ICMPv6) SourcePort() uint16 {
+ return 0
+}
+
+// DestinationPort implements Transport.DestinationPort.
+func (ICMPv6) DestinationPort() uint16 {
+ return 0
+}
+
+// SetSourcePort implements Transport.SetSourcePort.
+func (ICMPv6) SetSourcePort(uint16) {
+}
+
+// SetDestinationPort implements Transport.SetDestinationPort.
+func (ICMPv6) SetDestinationPort(uint16) {
+}
+
+// Payload implements Transport.Payload.
+func (b ICMPv6) Payload() []byte {
+ return b[ICMPv6MinimumSize:]
+}
diff --git a/tcpip/header/interfaces.go b/tcpip/header/interfaces.go
index 1ff5c98..02dba78 100644
--- a/tcpip/header/interfaces.go
+++ b/tcpip/header/interfaces.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package header
diff --git a/tcpip/header/ipv4.go b/tcpip/header/ipv4.go
index 63e0a42..15e2bba 100644
--- a/tcpip/header/ipv4.go
+++ b/tcpip/header/ipv4.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package header
@@ -81,15 +91,9 @@
// IPv4ProtocolNumber is IPv4's network protocol number.
IPv4ProtocolNumber tcpip.NetworkProtocolNumber = 0x0800
- // IPv4Version is the version of the ipv4 procotol.
+ // IPv4Version is the version of the ipv4 protocol.
IPv4Version = 4
- // IPv4DefaultTTL is the default time-to-live value for sent packets.
- IPv4DefaultTTL = 65
-
- // IPv4Loopback is the loopback address of the IPv4 procotol.
- IPv4Loopback tcpip.Address = "\x7f\x00\x00\x01"
-
// IPv4Broadcast is the broadcast address of the IPv4 procotol.
IPv4Broadcast tcpip.Address = "\xff\xff\xff\xff"
@@ -119,12 +123,12 @@
return (b[versIHL] & 0xf) * 4
}
-// ID returns the value of the identifier field of the the ipv4 header.
+// ID returns the value of the identifier field of the ipv4 header.
func (b IPv4) ID() uint16 {
return binary.BigEndian.Uint16(b[id:])
}
-// Protocol returns the value of the protocol field of the the ipv4 header.
+// Protocol returns the value of the protocol field of the ipv4 header.
func (b IPv4) Protocol() uint8 {
return b[protocol]
}
@@ -262,8 +266,9 @@
return true
}
-// IsV4MulticastAddress determines if the provided address is an IPv4
-// multicast address (range 224.0.0.0 to 239.255.255.255).
+// IsV4MulticastAddress determines if the provided address is an IPv4 multicast
+// address (range 224.0.0.0 to 239.255.255.255). The four most significant bits
+// will be 1110 = 0xe0.
func IsV4MulticastAddress(addr tcpip.Address) bool {
if len(addr) != IPv4AddressSize {
return false
diff --git a/tcpip/header/ipv6.go b/tcpip/header/ipv6.go
index 22a45e9..26a96e9 100644
--- a/tcpip/header/ipv6.go
+++ b/tcpip/header/ipv6.go
@@ -1,11 +1,22 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package header
import (
"encoding/binary"
+ "strings"
"github.com/google/netstack/tcpip"
)
@@ -60,19 +71,12 @@
// IPv6ProtocolNumber is IPv6's network protocol number.
IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd
- // IPv6Version is the version of the ipv6 procotol.
+ // IPv6Version is the version of the ipv6 protocol.
IPv6Version = 6
- // IPv6DefaultHopLimit is the default hop limit (or TTL) value for
- // sent packets.
- IPv6DefaultHopLimit = 255
-
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
// section 5.
IPv6MinimumMTU = 1280
-
- // IPv6Loopback is the loopback address of the IPv6 procotol.
- IPv6Loopback tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
)
// PayloadLength returns the value of the "payload length" field of the ipv6
@@ -146,6 +150,11 @@
copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr)
}
+// SetNextHeader sets the value of the "next header" field of the ipv6 header.
+func (b IPv6) SetNextHeader(v uint8) {
+ b[nextHdr] = v
+}
+
// SetChecksum implements Network.SetChecksum. Given that IPv6 doesn't have a
// checksum, it is empty.
func (IPv6) SetChecksum(uint16) {
@@ -182,14 +191,7 @@
return false
}
- const prefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
- for i := 0; i < len(prefix); i++ {
- if prefix[i] != addr[i] {
- return false
- }
- }
-
- return true
+ return strings.HasPrefix(string(addr), "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff")
}
// IsV6MulticastAddress determines if the provided address is an IPv6
diff --git a/tcpip/header/ipv6_fragment.go b/tcpip/header/ipv6_fragment.go
index a22e5ad..226024e 100644
--- a/tcpip/header/ipv6_fragment.go
+++ b/tcpip/header/ipv6_fragment.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package header
diff --git a/tcpip/header/ipversion_test.go b/tcpip/header/ipversion_test.go
index 9957850..a9ebe0a 100644
--- a/tcpip/header/ipversion_test.go
+++ b/tcpip/header/ipversion_test.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package header_test
diff --git a/tcpip/header/tcp.go b/tcpip/header/tcp.go
index 2c558f7..6dfa5ed 100644
--- a/tcpip/header/tcp.go
+++ b/tcpip/header/tcp.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package header
@@ -8,6 +18,7 @@
"encoding/binary"
"github.com/google/netstack/tcpip"
+ "github.com/google/netstack/tcpip/seqnum"
)
const (
@@ -26,6 +37,10 @@
// MaxWndScale is maximum allowed window scaling, as described in
// RFC 1323, section 2.3, page 11.
MaxWndScale = 14
+
+ // TCPMaxSACKBlocks is the maximum number of SACK blocks that can
+ // be encoded in a TCP option field.
+ TCPMaxSACKBlocks = 4
)
// Flags that may be set in a TCP segment.
@@ -40,11 +55,13 @@
// Options that may be present in a TCP segment.
const (
- TCPOptionEOL = 0
- TCPOptionNOP = 1
- TCPOptionMSS = 2
- TCPOptionWS = 3
- TCPOptionTS = 8
+ TCPOptionEOL = 0
+ TCPOptionNOP = 1
+ TCPOptionMSS = 2
+ TCPOptionWS = 3
+ TCPOptionTS = 8
+ TCPOptionSACKPermitted = 4
+ TCPOptionSACK = 5
)
// TCPFields contains the fields of a TCP packet. It is used to describe the
@@ -97,10 +114,27 @@
// TSEcr is the value of the TSEcr field in the timestamp option.
TSEcr uint32
+
+ // SACKPermitted is true if the SACK option was provided in the SYN/SYN-ACK.
+ SACKPermitted bool
+}
+
+// SACKBlock represents a single contiguous SACK block.
+//
+// +stateify savable
+type SACKBlock struct {
+ // Start indicates the lowest sequence number in the block.
+ Start seqnum.Value
+
+ // End indicates the sequence number immediately following the last
+ // sequence number of this block.
+ End seqnum.Value
}
// TCPOptions are used to parse and cache the TCP segment options for a non
// syn/syn-ack segment.
+//
+// +stateify savable
type TCPOptions struct {
// TS is true if the TimeStamp option is enabled.
TS bool
@@ -110,6 +144,9 @@
// TSEcr is the value in the TSEcr field of the segment.
TSEcr uint32
+
+ // SACKBlocks are the SACK blocks specified in the segment.
+ SACKBlocks []SACKBlock
}
// TCP represents a TCP header stored in a byte array.
@@ -121,14 +158,6 @@
// TCPProtocolNumber is TCP's transport protocol number.
TCPProtocolNumber tcpip.TransportProtocolNumber = 6
-
- // TCPTimeStampOptionSize is the size of an encoded TCP timestamp
- // option.
- //
- // NOTE: The actual option is 10 bytes but we always include 2
- // NOP options to quad align the timestamp option and hence it's
- // 12 and not 10.
- TCPTimeStampOptionSize = 12
)
// SourcePort returns the "source port" field of the tcp header.
@@ -314,6 +343,12 @@
}
synOpts.TS = true
i += 10
+ case TCPOptionSACKPermitted:
+ if i+2 > limit || opts[i+1] != 2 {
+ return synOpts
+ }
+ synOpts.SACKPermitted = true
+ i += 2
default:
// We don't recognize this option, just skip over it.
@@ -345,13 +380,34 @@
case TCPOptionNOP:
i++
case TCPOptionTS:
- if i+10 > limit || b[i+1] != 10 {
+ if i+10 > limit || (b[i+1] != 10) {
return opts
}
opts.TS = true
opts.TSVal = binary.BigEndian.Uint32(b[i+2:])
opts.TSEcr = binary.BigEndian.Uint32(b[i+6:])
i += 10
+ case TCPOptionSACK:
+ if i+2 > limit {
+ // Malformed SACK block, just return and stop parsing.
+ return opts
+ }
+ sackOptionLen := int(b[i+1])
+ if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
+ // Malformed SACK block, just return and stop parsing.
+ return opts
+ }
+ numBlocks := (sackOptionLen - 2) / 8
+ opts.SACKBlocks = []SACKBlock{}
+ for j := 0; j < numBlocks; j++ {
+ start := binary.BigEndian.Uint32(b[i+2+j*8:])
+ end := binary.BigEndian.Uint32(b[i+2+j*8+4:])
+ opts.SACKBlocks = append(opts.SACKBlocks, SACKBlock{
+ Start: seqnum.Value(start),
+ End: seqnum.Value(end),
+ })
+ }
+ i += sackOptionLen
default:
// We don't recognize this option, just skip over it.
if i+2 > limit {
@@ -369,16 +425,108 @@
return opts
}
-// EncodeTSOption builds and returns an array containing a TCP
-// timestamp option with the TSVal/TSEcr fields set to the value of
-// tsVal/tsEcr. This function also pads the option with two
-// TCPOptionNOP to make sure it is correctly quad aligned.
-func EncodeTSOption(tsVal, tsEcr uint32) (b [12]byte) {
- b[0] = TCPOptionTS
- b[1] = 10
+// EncodeMSSOption encodes the MSS TCP option with the provided MSS values in
+// the supplied buffer. If the provided buffer is not large enough then it just
+// returns without encoding anything. It returns the number of bytes written to
+// the provided buffer.
+func EncodeMSSOption(mss uint32, b []byte) int {
+ // mssOptionSize is the number of bytes in a valid MSS option.
+ const mssOptionSize = 4
+
+ if len(b) < mssOptionSize {
+ return 0
+ }
+ b[0], b[1], b[2], b[3] = TCPOptionMSS, mssOptionSize, byte(mss>>8), byte(mss)
+ return mssOptionSize
+}
+
+// EncodeWSOption encodes the WS TCP option with the WS value in the
+// provided buffer. If the provided buffer is not large enough then it just
+// returns without encoding anything. It returns the number of bytes written to
+// the provided buffer.
+func EncodeWSOption(ws int, b []byte) int {
+ if len(b) < 3 {
+ return 0
+ }
+ b[0], b[1], b[2] = TCPOptionWS, 3, uint8(ws)
+ return int(b[1])
+}
+
+// EncodeTSOption encodes the provided tsVal and tsEcr values as a TCP timestamp
+// option into the provided buffer. If the buffer is smaller than expected it
+// just returns without encoding anything. It returns the number of bytes
+// written to the provided buffer.
+func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int {
+ if len(b) < 10 {
+ return 0
+ }
+ b[0], b[1] = TCPOptionTS, 10
binary.BigEndian.PutUint32(b[2:], tsVal)
binary.BigEndian.PutUint32(b[6:], tsEcr)
- b[10] = TCPOptionNOP
- b[11] = TCPOptionNOP
- return b
+ return int(b[1])
+}
+
+// EncodeSACKPermittedOption encodes a SACKPermitted option into the provided
+// buffer. If the buffer is smaller than required it just returns without
+// encoding anything. It returns the number of bytes written to the provided
+// buffer.
+func EncodeSACKPermittedOption(b []byte) int {
+ if len(b) < 2 {
+ return 0
+ }
+
+ b[0], b[1] = TCPOptionSACKPermitted, 2
+ return int(b[1])
+}
+
+// EncodeSACKBlocks encodes the provided SACK blocks as a TCP SACK option block
+// in the provided slice. It tries to fit in as many blocks as possible based on
+// number of bytes available in the provided buffer. It returns the number of
+// bytes written to the provided buffer.
+func EncodeSACKBlocks(sackBlocks []SACKBlock, b []byte) int {
+ if len(sackBlocks) == 0 {
+ return 0
+ }
+ l := len(sackBlocks)
+ if l > TCPMaxSACKBlocks {
+ l = TCPMaxSACKBlocks
+ }
+ if ll := (len(b) - 2) / 8; ll < l {
+ l = ll
+ }
+ if l == 0 {
+ // There is not enough space in the provided buffer to add
+ // any SACK blocks.
+ return 0
+ }
+ b[0] = TCPOptionSACK
+ b[1] = byte(l*8 + 2)
+ for i := 0; i < l; i++ {
+ binary.BigEndian.PutUint32(b[i*8+2:], uint32(sackBlocks[i].Start))
+ binary.BigEndian.PutUint32(b[i*8+6:], uint32(sackBlocks[i].End))
+ }
+ return int(b[1])
+}
+
+// EncodeNOP adds an explicit NOP to the option list.
+func EncodeNOP(b []byte) int {
+ if len(b) == 0 {
+ return 0
+ }
+ b[0] = TCPOptionNOP
+ return 1
+}
+
+// AddTCPOptionPadding adds the required number of TCPOptionNOP to quad align
+// the option buffer. It adds padding bytes after the offset specified and
+// returns the number of padding bytes added. The passed in options slice
+// must have space for the padding bytes.
+func AddTCPOptionPadding(options []byte, offset int) int {
+ paddingToAdd := -offset & 3
+ // Now add any padding bytes that might be required to quad align the
+ // options.
+ for i := offset; i < offset+paddingToAdd; i++ {
+ options[i] = TCPOptionNOP
+ }
+ return paddingToAdd
}
diff --git a/tcpip/header/tcp_test.go b/tcpip/header/tcp_test.go
new file mode 100644
index 0000000..262246a
--- /dev/null
+++ b/tcpip/header/tcp_test.go
@@ -0,0 +1,148 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header_test
+
+import (
+ "reflect"
+ "testing"
+
+ "github.com/google/netstack/tcpip/header"
+)
+
+func TestEncodeSACKBlocks(t *testing.T) {
+ testCases := []struct {
+ sackBlocks []header.SACKBlock
+ want []header.SACKBlock
+ bufSize int
+ }{
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}},
+ 40,
+ },
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}},
+ 30,
+ },
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ []header.SACKBlock{{10, 20}, {22, 30}},
+ 20,
+ },
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ []header.SACKBlock{{10, 20}},
+ 10,
+ },
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ nil,
+ 8,
+ },
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}},
+ 60,
+ },
+ }
+ for _, tc := range testCases {
+ b := make([]byte, tc.bufSize)
+ t.Logf("testing: %v", tc)
+ header.EncodeSACKBlocks(tc.sackBlocks, b)
+ opts := header.ParseTCPOptions(b)
+ if got, want := opts.SACKBlocks, tc.want; !reflect.DeepEqual(got, want) {
+ t.Errorf("header.EncodeSACKBlocks(%v, %v), encoded blocks got: %v, want: %v", tc.sackBlocks, b, got, want)
+ }
+ }
+}
+
+func TestTCPParseOptions(t *testing.T) {
+ type tsOption struct {
+ tsVal uint32
+ tsEcr uint32
+ }
+
+ generateOptions := func(tsOpt *tsOption, sackBlocks []header.SACKBlock) []byte {
+ l := 0
+ if tsOpt != nil {
+ l += 10
+ }
+ if len(sackBlocks) != 0 {
+ l += len(sackBlocks)*8 + 2
+ }
+ b := make([]byte, l)
+ offset := 0
+ if tsOpt != nil {
+ offset = header.EncodeTSOption(tsOpt.tsVal, tsOpt.tsEcr, b)
+ }
+ header.EncodeSACKBlocks(sackBlocks, b[offset:])
+ return b
+ }
+
+ testCases := []struct {
+ b []byte
+ want header.TCPOptions
+ }{
+ // Trivial cases.
+ {nil, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionNOP, header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionEOL}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionNOP, header.TCPOptionEOL, header.TCPOptionTS, 10, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
+
+ // Test timestamp parsing.
+ {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
+ {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
+
+ // Test malformed timestamp option.
+ {[]byte{header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
+
+ // Test SACKBlock parsing.
+ {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}}}},
+ {[]byte{header.TCPOptionSACK, 18, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}, {11, 12}}}},
+
+ // Test malformed SACK option.
+ {[]byte{header.TCPOptionSACK, 0}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 8, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 17, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 10}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0}, header.TCPOptions{false, 0, 0, nil}},
+
+ // Test Timestamp + SACK block parsing.
+ {generateOptions(&tsOption{1, 1}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 1, []header.SACKBlock{{1, 10}, {11, 12}}}},
+ {generateOptions(&tsOption{1, 2}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 2, []header.SACKBlock{{1, 10}, {11, 12}}}},
+ {generateOptions(&tsOption{1, 3}, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}, {15, 16}}), header.TCPOptions{true, 1, 3, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}}}},
+
+ // Test valid timestamp + malformed SACK block parsing.
+ {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK}, header.TCPOptions{true, 1, 1, nil}},
+ {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10}, header.TCPOptions{true, 1, 1, nil}},
+ {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10, 0, 0, 0}, header.TCPOptions{true, 1, 1, nil}},
+ {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
+ {[]byte{header.TCPOptionSACK, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 10, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{134873088, 65536}}}},
+ {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{8, 167772160}}}},
+ {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
+ }
+ for _, tc := range testCases {
+ if got, want := header.ParseTCPOptions(tc.b), tc.want; !reflect.DeepEqual(got, want) {
+ t.Errorf("ParseTCPOptions(%v) = %v, want: %v", tc.b, got, tc.want)
+ }
+ }
+}
diff --git a/tcpip/header/udp.go b/tcpip/header/udp.go
index b1427c2..954b8e3 100644
--- a/tcpip/header/udp.go
+++ b/tcpip/header/udp.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package header
diff --git a/tcpip/link/bufwritingchannel/bufwritingchannel.go b/tcpip/link/bufwritingchannel/bufwritingchannel.go
deleted file mode 100644
index 7a51d2d..0000000
--- a/tcpip/link/bufwritingchannel/bufwritingchannel.go
+++ /dev/null
@@ -1,49 +0,0 @@
-// Copyright 2018 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// Package channel provides the implemention of channel-based data-link layer
-// endpoints. Such endpoints allow injection of inbound packets and store
-// outbound packets in a channel.
-package bufwritingchannel
-
-import (
- "github.com/google/netstack/tcpip"
- "github.com/google/netstack/tcpip/buffer"
- "github.com/google/netstack/tcpip/link/channel"
- "github.com/google/netstack/tcpip/stack"
-)
-
-// Endpoint is link layer endpoint that stores outbound packets in a channel
-// and allows injection of inbound packets.
-type Endpoint struct {
- channel.Endpoint
-}
-
-// New creates a new channel endpoint.
-func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) (tcpip.LinkEndpointID, *Endpoint) {
- _, ce := channel.New(size, mtu, linkAddr)
-
- e := &Endpoint{
- Endpoint: *ce,
- }
-
- return stack.RegisterLinkEndpoint(e), e
-}
-
-func (e *Endpoint) WriteBuffer(_ *stack.Route, payload *buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
- p := channel.PacketInfo{
- Proto: protocol,
- }
-
- if payload != nil {
- p.Payload = payload.ToView()
- }
-
- select {
- case e.C <- p:
- default:
- }
-
- return nil
-}
diff --git a/tcpip/link/channel/channel.go b/tcpip/link/channel/channel.go
index 82f77f1..5606867 100644
--- a/tcpip/link/channel/channel.go
+++ b/tcpip/link/channel/channel.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package channel provides the implemention of channel-based data-link layer
// endpoints. Such endpoints allow injection of inbound packets and store
@@ -56,15 +66,13 @@
}
// Inject injects an inbound packet.
-func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
- uu := vv.Clone(nil)
- e.dispatcher.DeliverNetworkPacket(e, e.linkAddr, "", protocol, &uu)
+func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+ e.InjectLinkAddr(protocol, "", vv)
}
// InjectLinkAddr injects an inbound packet with a remote link address.
-func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, vv *buffer.VectorisedView) {
- uu := vv.Clone(nil)
- e.dispatcher.DeliverNetworkPacket(e, e.linkAddr, remoteLinkAddr, protocol, &uu)
+func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, vv buffer.VectorisedView) {
+ e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, "" /* localLinkAddr */, protocol, vv.Clone(nil))
}
// Attach saves the stack network-layer dispatcher for use later when packets
@@ -73,12 +81,22 @@
e.dispatcher = dispatcher
}
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *Endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
// during construction.
func (e *Endpoint) MTU() uint32 {
return e.mtu
}
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (*Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return 0
+}
+
// MaxHeaderLength returns the maximum size of the link layer header. Given it
// doesn't have a header, it just returns 0.
func (*Endpoint) MaxHeaderLength() uint16 {
@@ -91,15 +109,11 @@
}
// WritePacket stores outbound packets into the channel.
-func (e *Endpoint) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (e *Endpoint) WritePacket(_ *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
p := PacketInfo{
- Header: hdr.View(),
- Proto: protocol,
- }
-
- if payload != nil {
- p.Payload = make(buffer.View, len(payload))
- copy(p.Payload, payload)
+ Header: hdr.View(),
+ Proto: protocol,
+ Payload: payload.ToView(),
}
select {
diff --git a/tcpip/link/fdbased/endpoint.go b/tcpip/link/fdbased/endpoint.go
index 0da690a..a9938af 100644
--- a/tcpip/link/fdbased/endpoint.go
+++ b/tcpip/link/fdbased/endpoint.go
@@ -1,6 +1,18 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
// Package fdbased provides the implemention of data-link layer endpoints
// backed by boundary-preserving file descriptors (e.g., TUN devices,
@@ -31,35 +43,96 @@
// mtu (maximum transmission unit) is the maximum size of a packet.
mtu uint32
+ // hdrSize specifies the link-layer header size. If set to 0, no header
+ // is added/removed; otherwise an ethernet header is used.
+ hdrSize int
+
+ // addr is the address of the endpoint.
+ addr tcpip.LinkAddress
+
+ // caps holds the endpoint capabilities.
+ caps stack.LinkEndpointCapabilities
+
// closed is a function to be called when the FD's peer (if any) closes
// its end of the communication pipe.
closed func(*tcpip.Error)
- vv *buffer.VectorisedView
- iovecs []syscall.Iovec
- views []buffer.View
+ iovecs []syscall.Iovec
+ views []buffer.View
+ dispatcher stack.NetworkDispatcher
+
+ // handleLocal indicates whether packets destined to itself should be
+ // handled by the netstack internally (true) or be forwarded to the FD
+ // endpoint (false).
+ handleLocal bool
+}
+
+// Options specify the details about the fd-based endpoint to be created.
+type Options struct {
+ FD int
+ MTU uint32
+ EthernetHeader bool
+ ChecksumOffload bool
+ ClosedFunc func(*tcpip.Error)
+ Address tcpip.LinkAddress
+ SaveRestore bool
+ DisconnectOk bool
+ HandleLocal bool
}
// New creates a new fd-based endpoint.
-func New(fd int, mtu uint32, closed func(*tcpip.Error)) tcpip.LinkEndpointID {
- syscall.SetNonblock(fd, true)
+//
+// Makes fd non-blocking, but does not take ownership of fd, which must remain
+// open for the lifetime of the returned endpoint.
+func New(opts *Options) tcpip.LinkEndpointID {
+ syscall.SetNonblock(opts.FD, true)
+
+ caps := stack.LinkEndpointCapabilities(0)
+ if opts.ChecksumOffload {
+ caps |= stack.CapabilityChecksumOffload
+ }
+
+ hdrSize := 0
+ if opts.EthernetHeader {
+ hdrSize = header.EthernetMinimumSize
+ caps |= stack.CapabilityResolutionRequired
+ }
+
+ if opts.SaveRestore {
+ caps |= stack.CapabilitySaveRestore
+ }
+
+ if opts.DisconnectOk {
+ caps |= stack.CapabilityDisconnectOk
+ }
e := &endpoint{
- fd: fd,
- mtu: mtu,
- closed: closed,
- views: make([]buffer.View, len(BufConfig)),
- iovecs: make([]syscall.Iovec, len(BufConfig)),
+ fd: opts.FD,
+ mtu: opts.MTU,
+ caps: caps,
+ closed: opts.ClosedFunc,
+ addr: opts.Address,
+ hdrSize: hdrSize,
+ views: make([]buffer.View, len(BufConfig)),
+ iovecs: make([]syscall.Iovec, len(BufConfig)),
+ handleLocal: opts.HandleLocal,
}
- vv := buffer.NewVectorisedView(0, e.views)
- e.vv = &vv
return stack.RegisterLinkEndpoint(e)
}
// Attach launches the goroutine that reads packets from the file descriptor and
// dispatches them via the provided dispatcher.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
- go e.dispatchLoop(dispatcher)
+ e.dispatcher = dispatcher
+ // Link endpoints are not savable. When transportation endpoints are
+ // saved, they stop sending outgoing packets and all incoming packets
+ // are rejected.
+ go e.dispatchLoop()
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *endpoint) IsAttached() bool {
+ return e.dispatcher != nil
}
// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
@@ -68,26 +141,54 @@
return e.mtu
}
-// MaxHeaderLength returns the maximum size of the header. Given that it
-// doesn't have a header, it just returns 0.
-func (*endpoint) MaxHeaderLength() uint16 {
- return 0
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.caps
+}
+
+// MaxHeaderLength returns the maximum size of the link-layer header.
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return uint16(e.hdrSize)
}
// LinkAddress returns the link address of this endpoint.
-func (*endpoint) LinkAddress() tcpip.LinkAddress {
- return ""
+func (e *endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.addr
}
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
- if payload == nil {
- return rawfile.NonBlockingWrite(e.fd, hdr.UsedBytes())
+func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ if e.handleLocal && r.LocalAddress != "" && r.LocalAddress == r.RemoteAddress {
+ views := make([]buffer.View, 1, 1+len(payload.Views()))
+ views[0] = hdr.View()
+ views = append(views, payload.Views()...)
+ vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+ e.dispatcher.DeliverNetworkPacket(e, r.RemoteLinkAddress, r.LocalLinkAddress, protocol, vv)
+ return nil
+ }
+ if e.hdrSize > 0 {
+ // Add ethernet header if needed.
+ eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize))
+ ethHdr := &header.EthernetFields{
+ DstAddr: r.RemoteLinkAddress,
+ Type: protocol,
+ }
+ // Preserve the src address if it's set in the route.
+ if r.LocalLinkAddress != "" {
+ ethHdr.SrcAddr = r.LocalLinkAddress
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
}
- return rawfile.NonBlockingWrite2(e.fd, hdr.UsedBytes(), payload)
+ if payload.Size() == 0 {
+ return rawfile.NonBlockingWrite(e.fd, hdr.View())
+ }
+
+ return rawfile.NonBlockingWrite2(e.fd, hdr.View(), payload.ToView())
}
func (e *endpoint) capViews(n int, buffers []int) int {
@@ -117,7 +218,7 @@
}
// dispatch reads one packet from the file descriptor and dispatches it.
-func (e *endpoint) dispatch(d stack.NetworkDispatcher, largeV buffer.View) (bool, *tcpip.Error) {
+func (e *endpoint) dispatch(largeV buffer.View) (bool, *tcpip.Error) {
e.allocateViews(BufConfig)
n, err := rawfile.BlockingReadv(e.fd, e.iovecs)
@@ -125,27 +226,37 @@
return false, err
}
- if n <= 0 {
+ if n <= e.hdrSize {
return false, nil
}
- used := e.capViews(n, BufConfig)
- e.vv.SetViews(e.views[:used])
- e.vv.SetSize(n)
-
- // We don't get any indication of what the packet is, so try to guess
- // if it's an IPv4 or IPv6 packet.
- var p tcpip.NetworkProtocolNumber
- switch header.IPVersion(e.views[0]) {
- case header.IPv4Version:
- p = header.IPv4ProtocolNumber
- case header.IPv6Version:
- p = header.IPv6ProtocolNumber
- default:
- return true, nil
+ var (
+ p tcpip.NetworkProtocolNumber
+ remoteLinkAddr, localLinkAddr tcpip.LinkAddress
+ )
+ if e.hdrSize > 0 {
+ eth := header.Ethernet(e.views[0])
+ p = eth.Type()
+ remoteLinkAddr = eth.SourceAddress()
+ localLinkAddr = eth.DestinationAddress()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ switch header.IPVersion(e.views[0]) {
+ case header.IPv4Version:
+ p = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ p = header.IPv6ProtocolNumber
+ default:
+ return true, nil
+ }
}
- d.DeliverNetworkPacket(e, e.LinkAddress(), "", p, e.vv)
+ used := e.capViews(n, BufConfig)
+ vv := buffer.NewVectorisedView(n, e.views[:used])
+ vv.TrimFront(e.hdrSize)
+
+ e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddr, p, vv)
// Prepare e.views for another packet: release used views.
for i := 0; i < used; i++ {
@@ -157,10 +268,10 @@
// dispatchLoop reads packets from the file descriptor in a loop and dispatches
// them to the network stack.
-func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) *tcpip.Error {
+func (e *endpoint) dispatchLoop() *tcpip.Error {
v := buffer.NewView(header.MaxIPPacketSize)
for {
- cont, err := e.dispatch(d, v)
+ cont, err := e.dispatch(v)
if err != nil || !cont {
if e.closed != nil {
e.closed(err)
@@ -185,9 +296,8 @@
}
// Inject injects an inbound packet.
-func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
- uu := vv.Clone(nil)
- e.dispatcher.DeliverNetworkPacket(e, e.LinkAddress(), "", protocol, &uu)
+func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remoteLinkAddr */, "" /* localLinkAddr */, protocol, vv)
}
// NewInjectable creates a new fd-based InjectableEndpoint.
diff --git a/tcpip/link/fdbased/endpoint_test.go b/tcpip/link/fdbased/endpoint_test.go
index 9ecbc00..1b6eed1 100644
--- a/tcpip/link/fdbased/endpoint_test.go
+++ b/tcpip/link/fdbased/endpoint_test.go
@@ -1,18 +1,290 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
package fdbased
import (
+ "bytes"
+ "fmt"
+ "math/rand"
"reflect"
"syscall"
"testing"
+ "time"
+ "github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
+ "github.com/google/netstack/tcpip/stack"
)
+const (
+ mtu = 1500
+ laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66")
+ raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc")
+ proto = 10
+)
+
+type packetInfo struct {
+ raddr tcpip.LinkAddress
+ proto tcpip.NetworkProtocolNumber
+ contents buffer.View
+}
+
+type context struct {
+ t *testing.T
+ fds [2]int
+ ep stack.LinkEndpoint
+ ch chan packetInfo
+ done chan struct{}
+}
+
+func newContext(t *testing.T, opt *Options) *context {
+ fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
+ if err != nil {
+ t.Fatalf("Socketpair failed: %v", err)
+ }
+
+ done := make(chan struct{}, 1)
+ opt.ClosedFunc = func(*tcpip.Error) {
+ done <- struct{}{}
+ }
+
+ opt.FD = fds[1]
+ ep := stack.FindLinkEndpoint(New(opt)).(*endpoint)
+
+ c := &context{
+ t: t,
+ fds: fds,
+ ep: ep,
+ ch: make(chan packetInfo, 100),
+ done: done,
+ }
+
+ ep.Attach(c)
+
+ return c
+}
+
+func (c *context) cleanup() {
+ syscall.Close(c.fds[0])
+ <-c.done
+ syscall.Close(c.fds[1])
+}
+
+func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, localLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+ c.ch <- packetInfo{remoteLinkAddr, protocol, vv.ToView()}
+}
+
+func TestNoEthernetProperties(t *testing.T) {
+ c := newContext(t, &Options{MTU: mtu})
+ defer c.cleanup()
+
+ if want, v := uint16(0), c.ep.MaxHeaderLength(); want != v {
+ t.Fatalf("MaxHeaderLength() = %v, want %v", v, want)
+ }
+
+ if want, v := uint32(mtu), c.ep.MTU(); want != v {
+ t.Fatalf("MTU() = %v, want %v", v, want)
+ }
+}
+
+func TestEthernetProperties(t *testing.T) {
+ c := newContext(t, &Options{EthernetHeader: true, MTU: mtu})
+ defer c.cleanup()
+
+ if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v {
+ t.Fatalf("MaxHeaderLength() = %v, want %v", v, want)
+ }
+
+ if want, v := uint32(mtu), c.ep.MTU(); want != v {
+ t.Fatalf("MTU() = %v, want %v", v, want)
+ }
+}
+
+func TestAddress(t *testing.T) {
+ addrs := []tcpip.LinkAddress{"", "abc", "def"}
+ for _, a := range addrs {
+ t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) {
+ c := newContext(t, &Options{Address: a, MTU: mtu})
+ defer c.cleanup()
+
+ if want, v := a, c.ep.LinkAddress(); want != v {
+ t.Fatalf("LinkAddress() = %v, want %v", v, want)
+ }
+ })
+ }
+}
+
+func TestWritePacket(t *testing.T) {
+ lengths := []int{0, 100, 1000}
+ eths := []bool{true, false}
+
+ for _, eth := range eths {
+ for _, plen := range lengths {
+ t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) {
+ c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth})
+ defer c.cleanup()
+
+ r := &stack.Route{
+ RemoteLinkAddress: raddr,
+ }
+
+ // Build header.
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100)
+ b := hdr.Prepend(100)
+ for i := range b {
+ b[i] = uint8(rand.Intn(256))
+ }
+
+ // Build payload and write.
+ payload := make(buffer.View, plen)
+ for i := range payload {
+ payload[i] = uint8(rand.Intn(256))
+ }
+ want := append(hdr.View(), payload...)
+ if err := c.ep.WritePacket(r, hdr, payload.ToVectorisedView(), proto); err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
+
+ // Read from fd, then compare with what we wrote.
+ b = make([]byte, mtu)
+ n, err := syscall.Read(c.fds[0], b)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+ b = b[:n]
+ if eth {
+ h := header.Ethernet(b)
+ b = b[header.EthernetMinimumSize:]
+
+ if a := h.SourceAddress(); a != laddr {
+ t.Fatalf("SourceAddress() = %v, want %v", a, laddr)
+ }
+
+ if a := h.DestinationAddress(); a != raddr {
+ t.Fatalf("DestinationAddress() = %v, want %v", a, raddr)
+ }
+
+ if et := h.Type(); et != proto {
+ t.Fatalf("Type() = %v, want %v", et, proto)
+ }
+ }
+ if len(b) != len(want) {
+ t.Fatalf("Read returned %v bytes, want %v", len(b), len(want))
+ }
+ if !bytes.Equal(b, want) {
+ t.Fatalf("Read returned %x, want %x", b, want)
+ }
+ })
+ }
+ }
+}
+
+func TestPreserveSrcAddress(t *testing.T) {
+ baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99")
+
+ c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: true})
+ defer c.cleanup()
+
+ // Set LocalLinkAddress in route to the value of the bridged address.
+ r := &stack.Route{
+ RemoteLinkAddress: raddr,
+ LocalLinkAddress: baddr,
+ }
+
+ // WritePacket panics given a prependable with anything less than
+ // the minimum size of the ethernet header.
+ hdr := buffer.NewPrependable(header.EthernetMinimumSize)
+ if err := c.ep.WritePacket(r, hdr, buffer.VectorisedView{}, proto); err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
+
+ // Read from the FD, then compare with what we wrote.
+ b := make([]byte, mtu)
+ n, err := syscall.Read(c.fds[0], b)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+ b = b[:n]
+ h := header.Ethernet(b)
+
+ if a := h.SourceAddress(); a != baddr {
+ t.Fatalf("SourceAddress() = %v, want %v", a, baddr)
+ }
+}
+
+func TestDeliverPacket(t *testing.T) {
+ lengths := []int{100, 1000}
+ eths := []bool{true, false}
+
+ for _, eth := range eths {
+ for _, plen := range lengths {
+ t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) {
+ c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth})
+ defer c.cleanup()
+
+ // Build packet.
+ b := make([]byte, plen)
+ all := b
+ for i := range b {
+ b[i] = uint8(rand.Intn(256))
+ }
+
+ if !eth {
+ // So that it looks like an IPv4 packet.
+ b[0] = 0x40
+ } else {
+ hdr := make(header.Ethernet, header.EthernetMinimumSize)
+ hdr.Encode(&header.EthernetFields{
+ SrcAddr: raddr,
+ DstAddr: laddr,
+ Type: proto,
+ })
+ all = append(hdr, b...)
+ }
+
+ // Write packet via the file descriptor.
+ if _, err := syscall.Write(c.fds[0], all); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ // Receive packet through the endpoint.
+ select {
+ case pi := <-c.ch:
+ want := packetInfo{
+ raddr: raddr,
+ proto: proto,
+ contents: b,
+ }
+ if !eth {
+ want.proto = header.IPv4ProtocolNumber
+ want.raddr = ""
+ }
+ if !reflect.DeepEqual(want, pi) {
+ t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want)
+ }
+ case <-time.After(10 * time.Second):
+ t.Fatalf("Timed out waiting for packet")
+ }
+ })
+ }
+ }
+}
+
func TestBufConfigMaxLength(t *testing.T) {
got := 0
for _, i := range BufConfig {
diff --git a/tcpip/link/loopback/loopback.go b/tcpip/link/loopback/loopback.go
index 20e6f9a..a4f3ff8 100644
--- a/tcpip/link/loopback/loopback.go
+++ b/tcpip/link/loopback/loopback.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package loopback provides the implemention of loopback data-link layer
// endpoints. Such endpoints just turn outbound packets into inbound ones.
@@ -32,12 +42,23 @@
e.dispatcher = dispatcher
}
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
// MTU implements stack.LinkEndpoint.MTU. It returns a constant that matches the
// linux loopback interface.
func (*endpoint) MTU() uint32 {
return 65536
}
+// Capabilities implements stack.LinkEndpoint.Capabilities. Loopback advertises
+// itself as supporting checksum offload, but in reality it's just omitted.
+func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return stack.CapabilityChecksumOffload | stack.CapabilitySaveRestore | stack.CapabilityLoopback
+}
+
// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. Given that the
// loopback interface doesn't have a header, it just returns 0.
func (*endpoint) MaxHeaderLength() uint16 {
@@ -51,18 +72,16 @@
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
-func (e *endpoint) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
- if len(payload) == 0 {
- // We don't have a payload, so just use the buffer from the
- // header as the full packet.
- v := hdr.View()
- vv := v.ToVectorisedView([1]buffer.View{})
- e.dispatcher.DeliverNetworkPacket(e, "", "", protocol, &vv)
- } else {
- views := []buffer.View{hdr.View(), payload}
- vv := buffer.NewVectorisedView(len(views[0])+len(views[1]), views)
- e.dispatcher.DeliverNetworkPacket(e, "", "", protocol, &vv)
- }
+func (e *endpoint) WritePacket(_ *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ views := make([]buffer.View, 1, 1+len(payload.Views()))
+ views[0] = hdr.View()
+ views = append(views, payload.Views()...)
+ vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+
+ // Because we're immediately turning around and writing the packet back to the
+ // rx path, we intentionally don't preserve the remote and local link
+ // addresses from the stack.Route we're passed.
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remoteLinkAddr */, "" /* localLinkAddr */, protocol, vv)
return nil
}
diff --git a/tcpip/link/rawfile/blockingpoll_amd64.s b/tcpip/link/rawfile/blockingpoll_amd64.s
index 88206fc..fc52318 100644
--- a/tcpip/link/rawfile/blockingpoll_amd64.s
+++ b/tcpip/link/rawfile/blockingpoll_amd64.s
@@ -1,10 +1,24 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
#include "textflag.h"
// blockingPoll makes the poll() syscall while calling the version of
// entersyscall that relinquishes the P so that other Gs can run. This is meant
// to be called in cases when the syscall is expected to block.
//
-// func blockingPoll(fds unsafe.Pointer, nfds int, timeout int64) (n int, err syscall.Errno)
+// func blockingPoll(fds *pollEvent, nfds int, timeout int64) (n int, err syscall.Errno)
TEXT ·blockingPoll(SB),NOSPLIT,$0-40
CALL runtime·entersyscallblock(SB)
MOVQ fds+0(FP), DI
diff --git a/tcpip/link/rawfile/blockingpoll_unsafe.go b/tcpip/link/rawfile/blockingpoll_unsafe.go
new file mode 100644
index 0000000..a0a9d4a
--- /dev/null
+++ b/tcpip/link/rawfile/blockingpoll_unsafe.go
@@ -0,0 +1,27 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,!amd64
+
+package rawfile
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+func blockingPoll(fds *pollEvent, nfds int, timeout int64) (int, syscall.Errno) {
+ n, _, e := syscall.Syscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(fds)), uintptr(nfds), uintptr(timeout))
+ return int(n), e
+}
diff --git a/tcpip/link/rawfile/blockingpoll_unsafe_amd64.go b/tcpip/link/rawfile/blockingpoll_unsafe_amd64.go
new file mode 100644
index 0000000..1f143c0
--- /dev/null
+++ b/tcpip/link/rawfile/blockingpoll_unsafe_amd64.go
@@ -0,0 +1,24 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,amd64
+
+package rawfile
+
+import (
+ "syscall"
+)
+
+//go:noescape
+func blockingPoll(fds *pollEvent, nfds int, timeout int64) (int, syscall.Errno)
diff --git a/tcpip/link/rawfile/errors.go b/tcpip/link/rawfile/errors.go
index d470676..e8ca48e 100644
--- a/tcpip/link/rawfile/errors.go
+++ b/tcpip/link/rawfile/errors.go
@@ -1,3 +1,19 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
package rawfile
import (
diff --git a/tcpip/link/rawfile/rawfile_unsafe.go b/tcpip/link/rawfile/rawfile_unsafe.go
index aea4569..968c93d 100644
--- a/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/tcpip/link/rawfile/rawfile_unsafe.go
@@ -1,6 +1,18 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
// Package rawfile contains utilities for using the netstack with raw host
// files on Linux hosts.
@@ -13,12 +25,6 @@
"github.com/google/netstack/tcpip"
)
-// TODO: Placed here to avoid breakage caused by coverage
-// instrumentation. Any, even unrelated, changes to this file should ensure
-// that coverage still work. See bug for details.
-//go:noescape
-func blockingPoll(fds unsafe.Pointer, nfds int, timeout int64) (n int, err syscall.Errno)
-
// GetMTU determines the MTU of a network interface device.
func GetMTU(name string) (uint32, error) {
fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
@@ -71,16 +77,16 @@
// a writev syscall.
iovec := [...]syscall.Iovec{
{
- Base: (*byte)(unsafe.Pointer(&b1[0])),
+ Base: &b1[0],
Len: uint64(len(b1)),
},
{
- Base: (*byte)(unsafe.Pointer(&b2[0])),
+ Base: &b2[0],
Len: uint64(len(b2)),
},
}
- _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), 2)
+ _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec)))
if e != 0 {
return TranslateErrno(e)
}
@@ -88,6 +94,12 @@
return nil
}
+type pollEvent struct {
+ fd int32
+ events int16
+ revents int16
+}
+
// BlockingRead reads from a file descriptor that is set up as non-blocking. If
// no data is available, it will block in a poll() syscall until the file
// descirptor becomes readable.
@@ -98,16 +110,12 @@
return int(n), nil
}
- event := struct {
- fd int32
- events int16
- revents int16
- }{
+ event := pollEvent{
fd: int32(fd),
events: 1, // POLLIN
}
- _, e = blockingPoll(unsafe.Pointer(&event), 1, -1)
+ _, e = blockingPoll(&event, 1, -1)
if e != 0 && e != syscall.EINTR {
return 0, TranslateErrno(e)
}
@@ -124,16 +132,12 @@
return int(n), nil
}
- event := struct {
- fd int32
- events int16
- revents int16
- }{
+ event := pollEvent{
fd: int32(fd),
events: 1, // POLLIN
}
- _, e = blockingPoll(unsafe.Pointer(&event), 1, -1)
+ _, e = blockingPoll(&event, 1, -1)
if e != 0 && e != syscall.EINTR {
return 0, TranslateErrno(e)
}
diff --git a/tcpip/link/sharedmem/pipe/pipe.go b/tcpip/link/sharedmem/pipe/pipe.go
index 1173a60..1a0edba 100644
--- a/tcpip/link/sharedmem/pipe/pipe.go
+++ b/tcpip/link/sharedmem/pipe/pipe.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package pipe implements a shared memory ring buffer on which a single reader
// and a single writer can operate (read/write) concurrently. The ring buffer
diff --git a/tcpip/link/sharedmem/pipe/pipe_test.go b/tcpip/link/sharedmem/pipe/pipe_test.go
index d35e7c9..db0737c 100644
--- a/tcpip/link/sharedmem/pipe/pipe_test.go
+++ b/tcpip/link/sharedmem/pipe/pipe_test.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package pipe
@@ -463,10 +473,8 @@
go func() {
defer wg.Done()
runtime.Gosched()
- total := 0
for i := 0; i < count; i++ {
n := 1 + tr.Intn(80)
- total += n
wb := tx.Push(uint64(n))
for wb == nil {
wb = tx.Push(uint64(n))
diff --git a/tcpip/link/sharedmem/pipe/pipe_unsafe.go b/tcpip/link/sharedmem/pipe/pipe_unsafe.go
index d536abe..480dc4a 100644
--- a/tcpip/link/sharedmem/pipe/pipe_unsafe.go
+++ b/tcpip/link/sharedmem/pipe/pipe_unsafe.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package pipe
diff --git a/tcpip/link/sharedmem/pipe/rx.go b/tcpip/link/sharedmem/pipe/rx.go
index 261e21f..ff778ce 100644
--- a/tcpip/link/sharedmem/pipe/rx.go
+++ b/tcpip/link/sharedmem/pipe/rx.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package pipe
diff --git a/tcpip/link/sharedmem/pipe/tx.go b/tcpip/link/sharedmem/pipe/tx.go
index 374f515..717f5a4 100644
--- a/tcpip/link/sharedmem/pipe/tx.go
+++ b/tcpip/link/sharedmem/pipe/tx.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package pipe
diff --git a/tcpip/link/sharedmem/queue/queue_test.go b/tcpip/link/sharedmem/queue/queue_test.go
index 28f527b..3c14c3d 100644
--- a/tcpip/link/sharedmem/queue/queue_test.go
+++ b/tcpip/link/sharedmem/queue/queue_test.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package queue
diff --git a/tcpip/link/sharedmem/queue/rx.go b/tcpip/link/sharedmem/queue/rx.go
index 2827602..2c2cc7e 100644
--- a/tcpip/link/sharedmem/queue/rx.go
+++ b/tcpip/link/sharedmem/queue/rx.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package queue provides the implementation of transmit and receive queues
// based on shared memory ring buffers.
diff --git a/tcpip/link/sharedmem/queue/tx.go b/tcpip/link/sharedmem/queue/tx.go
index aa26515..daf1469 100644
--- a/tcpip/link/sharedmem/queue/tx.go
+++ b/tcpip/link/sharedmem/queue/tx.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package queue
diff --git a/tcpip/link/sharedmem/rx.go b/tcpip/link/sharedmem/rx.go
index 29e40e1..465af69 100644
--- a/tcpip/link/sharedmem/rx.go
+++ b/tcpip/link/sharedmem/rx.go
@@ -1,6 +1,18 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
package sharedmem
diff --git a/tcpip/link/sharedmem/sharedmem.go b/tcpip/link/sharedmem/sharedmem.go
index 962a449..5a2a0d2 100644
--- a/tcpip/link/sharedmem/sharedmem.go
+++ b/tcpip/link/sharedmem/sharedmem.go
@@ -1,6 +1,18 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
// Package sharedmem provides the implemention of data-link layer endpoints
// backed by shared memory.
@@ -132,17 +144,32 @@
if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 {
e.workerStarted = true
e.completed.Add(1)
+ // Link endpoints are not savable. When transportation endpoints
+ // are saved, they stop sending outgoing packets and all
+ // incoming packets are rejected.
go e.dispatchLoop(dispatcher)
}
e.mu.Unlock()
}
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *endpoint) IsAttached() bool {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.workerStarted
+}
+
// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
// during construction.
func (e *endpoint) MTU() uint32 {
return e.mtu - header.EthernetMinimumSize
}
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return 0
+}
+
// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the
// ethernet frame header size.
func (*endpoint) MaxHeaderLength() uint16 {
@@ -157,18 +184,24 @@
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
// Add the ethernet header here.
eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize))
- eth.Encode(&header.EthernetFields{
+ ethHdr := &header.EthernetFields{
DstAddr: r.RemoteLinkAddress,
- SrcAddr: e.addr,
Type: protocol,
- })
+ }
+ if r.LocalLinkAddress != "" {
+ ethHdr.SrcAddr = r.LocalLinkAddress
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
+ v := payload.ToView()
// Transmit the packet.
e.mu.Lock()
- ok := e.tx.transmit(hdr.UsedBytes(), payload)
+ ok := e.tx.transmit(hdr.View(), v)
e.mu.Unlock()
if !ok {
@@ -199,8 +232,6 @@
// Read in a loop until a stop is requested.
var rxb []queue.RxBuffer
- views := []buffer.View{nil}
- vv := buffer.NewVectorisedView(0, views)
for atomic.LoadUint32(&e.stopRequested) == 0 {
var n uint32
rxb, n = e.rx.postAndReceive(rxb, &e.stopRequested)
@@ -222,9 +253,7 @@
// Send packet up the stack.
eth := header.Ethernet(b)
- views[0] = b[header.EthernetMinimumSize:]
- vv.SetSize(int(n) - header.EthernetMinimumSize)
- d.DeliverNetworkPacket(e, e.addr, eth.SourceAddress(), eth.Type(), &vv)
+ d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView())
}
// Clean state.
diff --git a/tcpip/link/sharedmem/sharedmem_test.go b/tcpip/link/sharedmem/sharedmem_test.go
index b9650e7..4efdad7 100644
--- a/tcpip/link/sharedmem/sharedmem_test.go
+++ b/tcpip/link/sharedmem/sharedmem_test.go
@@ -1,14 +1,27 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
package sharedmem
import (
+ "bytes"
"io/ioutil"
"math/rand"
"os"
- "reflect"
+ "strings"
"sync"
"syscall"
"testing"
@@ -117,10 +130,10 @@
return c
}
-func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, _, remoteAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
+func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
c.mu.Lock()
c.packets = append(c.packets, packetInfo{
- addr: remoteAddr,
+ addr: remoteLinkAddr,
proto: proto,
vv: vv.Clone(nil),
})
@@ -247,63 +260,120 @@
}
for iters := 1000; iters > 0; iters-- {
- // Prepare and send packet.
- n := rand.Intn(10000)
- hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength()))
- hdrBuf := hdr.Prepend(n)
- randomFill(hdrBuf)
+ func() {
+ // Prepare and send packet.
+ n := rand.Intn(10000)
+ hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength()))
+ hdrBuf := hdr.Prepend(n)
+ randomFill(hdrBuf)
- n = rand.Intn(10000)
- buf := buffer.NewView(n)
- randomFill(buf)
+ n = rand.Intn(10000)
+ buf := buffer.NewView(n)
+ randomFill(buf)
- proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- err := c.ep.WritePacket(&r, &hdr, buf, proto)
- if err != nil {
- t.Fatalf("WritePacket failed: %v", err)
- }
+ proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
+ if err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), proto); err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
- // Receive packet.
- desc := c.txq.tx.Pull()
- pi := queue.DecodeTxPacketHeader(desc)
- contents := make([]byte, 0, pi.Size)
- for i := 0; i < pi.BufferCount; i++ {
- bi := queue.DecodeTxBufferHeader(desc, i)
- contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...)
- }
- c.txq.tx.Flush()
+ // Receive packet.
+ desc := c.txq.tx.Pull()
+ pi := queue.DecodeTxPacketHeader(desc)
+ if pi.Reserved != 0 {
+ t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved)
+ }
+ contents := make([]byte, 0, pi.Size)
+ for i := 0; i < pi.BufferCount; i++ {
+ bi := queue.DecodeTxBufferHeader(desc, i)
+ contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...)
+ }
+ c.txq.tx.Flush()
- if pi.Reserved != 0 {
- t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved)
- }
+ defer func() {
+ // Tell the endpoint about the completion of the write.
+ b := c.txq.rx.Push(8)
+ queue.EncodeTxCompletion(b, pi.ID)
+ c.txq.rx.Flush()
+ }()
- // Check the thernet header.
- ethTemplate := make(header.Ethernet, header.EthernetMinimumSize)
- ethTemplate.Encode(&header.EthernetFields{
- SrcAddr: localLinkAddr,
- DstAddr: remoteLinkAddr,
- Type: proto,
- })
- if got := contents[:header.EthernetMinimumSize]; !reflect.DeepEqual(got, []byte(ethTemplate)) {
- t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate)
- }
+ // Check the ethernet header.
+ ethTemplate := make(header.Ethernet, header.EthernetMinimumSize)
+ ethTemplate.Encode(&header.EthernetFields{
+ SrcAddr: localLinkAddr,
+ DstAddr: remoteLinkAddr,
+ Type: proto,
+ })
+ if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) {
+ t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate)
+ }
- // Compare contents skipping the ethernet header added by the
- // endpoint.
- merged := append(hdrBuf, buf...)
- if uint32(len(contents)) < pi.Size {
- t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size)
- }
- contents = contents[:pi.Size][header.EthernetMinimumSize:]
+ // Compare contents skipping the ethernet header added by the
+ // endpoint.
+ merged := append(hdrBuf, buf...)
+ if uint32(len(contents)) < pi.Size {
+ t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size)
+ }
+ contents = contents[:pi.Size][header.EthernetMinimumSize:]
- if !reflect.DeepEqual(contents, merged) {
- t.Fatalf("Buffers are different: got %x (%v bytes), want %x (%v bytes)", contents, len(contents), merged, len(merged))
- }
+ if !bytes.Equal(contents, merged) {
+ t.Fatalf("Buffers are different: got %x (%v bytes), want %x (%v bytes)", contents, len(contents), merged, len(merged))
+ }
+ }()
+ }
+}
+// TestPreserveSrcAddressInSend calls WritePacket once with LocalLinkAddress
+// set in Route (using much of the same code as TestSimpleSend), then checks
+// that the encoded ethernet header received includes the correct SrcAddr.
+func TestPreserveSrcAddressInSend(t *testing.T) {
+ c := newTestContext(t, 20000, 1500, localLinkAddr)
+ defer c.cleanup()
+
+ newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6))
+ // Set both remote and local link address in route.
+ r := stack.Route{
+ RemoteLinkAddress: remoteLinkAddr,
+ LocalLinkAddress: newLocalLinkAddress,
+ }
+
+ // WritePacket panics given a prependable with anything less than
+ // the minimum size of the ethernet header.
+ hdr := buffer.NewPrependable(header.EthernetMinimumSize)
+
+ proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
+ if err := c.ep.WritePacket(&r, hdr, buffer.VectorisedView{}, proto); err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
+
+ // Receive packet.
+ desc := c.txq.tx.Pull()
+ pi := queue.DecodeTxPacketHeader(desc)
+ if pi.Reserved != 0 {
+ t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved)
+ }
+ contents := make([]byte, 0, pi.Size)
+ for i := 0; i < pi.BufferCount; i++ {
+ bi := queue.DecodeTxBufferHeader(desc, i)
+ contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...)
+ }
+ c.txq.tx.Flush()
+
+ defer func() {
// Tell the endpoint about the completion of the write.
b := c.txq.rx.Push(8)
queue.EncodeTxCompletion(b, pi.ID)
c.txq.rx.Flush()
+ }()
+
+ // Check that the ethernet header contains the expected SrcAddr.
+ ethTemplate := make(header.Ethernet, header.EthernetMinimumSize)
+ ethTemplate.Encode(&header.EthernetFields{
+ SrcAddr: newLocalLinkAddress,
+ DstAddr: remoteLinkAddr,
+ Type: proto,
+ })
+ if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) {
+ t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate)
}
}
@@ -324,7 +394,8 @@
ids := make(map[uint64]struct{})
for i := queuePipeSize / 40; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil {
+
+ if err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -339,8 +410,7 @@
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber)
- if want := tcpip.ErrWouldBlock; err != want {
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want {
t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
}
}
@@ -365,7 +435,7 @@
// Send two packets so that the id slice has at least two slots.
for i := 2; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil {
+ if err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
}
@@ -385,7 +455,7 @@
ids := make(map[uint64]struct{})
for i := queuePipeSize / 40; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil {
+ if err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -400,8 +470,7 @@
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber)
- if want := tcpip.ErrWouldBlock; err != want {
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want {
t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
}
}
@@ -424,7 +493,7 @@
ids := make(map[uint64]struct{})
for i := queueDataSize / bufferSize; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil {
+ if err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -440,7 +509,7 @@
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber)
+ err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber)
if want := tcpip.ErrWouldBlock; err != want {
t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
}
@@ -465,7 +534,7 @@
// until there is only one buffer left.
for i := queueDataSize/bufferSize - 1; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil {
+ if err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -475,20 +544,26 @@
}
// Attempt to write a two-buffer packet. It must fail.
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- err := c.ep.WritePacket(&r, &hdr, buffer.NewView(bufferSize), header.IPv4ProtocolNumber)
- if want := tcpip.ErrWouldBlock; err != want {
- t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ {
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ uu := buffer.NewView(bufferSize).ToVectorisedView()
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, hdr, uu, header.IPv4ProtocolNumber); err != want {
+ t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ }
}
- // Attempt to write a one-buffer packet. It must succeed.
- hdr = buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil {
- t.Fatalf("WritePacket failed unexpectedly: %v", err)
+ // Attempt to write the one-buffer packet again. It must succeed.
+ {
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ if err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
+ t.Fatalf("WritePacket failed unexpectedly: %v", err)
+ }
}
}
func pollPull(t *testing.T, p *pipe.Rx, to <-chan time.Time, errStr string) []byte {
+ t.Helper()
+
for {
b := p.Pull()
if b != nil {
@@ -498,7 +573,7 @@
select {
case <-time.After(10 * time.Millisecond):
case <-to:
- t.Fatalf(errStr)
+ t.Fatal(errStr)
}
}
}
@@ -513,8 +588,8 @@
// Check that buffers have been posted.
limit := c.ep.rx.q.PostedBuffersLimit()
- timeout := time.After(2 * time.Second)
for i := uint64(0); i < limit; i++ {
+ timeout := time.After(2 * time.Second)
bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers to be posted"))
if want := i * bufferSize; want != bi.Offset {
@@ -539,6 +614,7 @@
// Complete random packets 1000 times.
for iters := 1000; iters > 0; iters-- {
+ timeout := time.After(2 * time.Second)
// Prepare a random packet.
shuffle(idx)
n := 1 + rand.Intn(10)
@@ -566,15 +642,14 @@
c.packets = c.packets[:0]
c.mu.Unlock()
- contents = contents[header.EthernetMinimumSize:]
- if !reflect.DeepEqual(contents, rcvd) {
+ if contents := contents[header.EthernetMinimumSize:]; !bytes.Equal(contents, rcvd) {
t.Fatalf("Unexpected buffer contents: got %x, want %x", rcvd, contents)
}
// Check that buffers have been reposted.
for i := range bufs {
bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffers to be reposted"))
- if !reflect.DeepEqual(bi, bufs[i]) {
+ if bi != bufs[i] {
t.Fatalf("Unexpected buffer reposted: got %x, want %x", bi, bufs[i])
}
}
@@ -592,15 +667,15 @@
// Receive all posted buffers.
limit := c.ep.rx.q.PostedBuffersLimit()
buffers := make([]queue.RxBuffer, 0, limit)
- timeout := time.After(2 * time.Second)
for i := limit; i > 0; i-- {
+ timeout := time.After(2 * time.Second)
buffers = append(buffers, queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers")))
}
c.rxq.tx.Flush()
// Check that all buffers are reposted when individually completed.
- timeout = time.After(2 * time.Second)
for i := range buffers {
+ timeout := time.After(2 * time.Second)
// Complete the buffer.
c.pushRxCompletion(buffers[i].Size, buffers[i:][:1])
c.rxq.rx.Flush()
@@ -608,28 +683,26 @@
// Wait for it to be reposted.
bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
- if !reflect.DeepEqual(bi, buffers[i]) {
+ if bi != buffers[i] {
t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[i])
}
}
c.rxq.tx.Flush()
// Check that all buffers are reposted when completed in pairs.
- timeout = time.After(2 * time.Second)
for i := 0; i < len(buffers)/2; i++ {
+ timeout := time.After(2 * time.Second)
// Complete with two buffers.
c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2])
c.rxq.rx.Flush()
syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
// Wait for them to be reposted.
- bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
- if !reflect.DeepEqual(bi, buffers[2*i]) {
- t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[2*i])
- }
- bi = queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
- if !reflect.DeepEqual(bi, buffers[2*i+1]) {
- t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[2*i+1])
+ for j := 0; j < 2; j++ {
+ bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
+ if bi != buffers[2*i+j] {
+ t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[2*i+j])
+ }
}
}
c.rxq.tx.Flush()
diff --git a/tcpip/link/sharedmem/sharedmem_unsafe.go b/tcpip/link/sharedmem/sharedmem_unsafe.go
index 52f93f4..f0be2dc 100644
--- a/tcpip/link/sharedmem/sharedmem_unsafe.go
+++ b/tcpip/link/sharedmem/sharedmem_unsafe.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package sharedmem
diff --git a/tcpip/link/sharedmem/tx.go b/tcpip/link/sharedmem/tx.go
index a2a5576..472899b 100644
--- a/tcpip/link/sharedmem/tx.go
+++ b/tcpip/link/sharedmem/tx.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package sharedmem
diff --git a/tcpip/link/sniffer/pcap.go b/tcpip/link/sniffer/pcap.go
new file mode 100644
index 0000000..04f3d49
--- /dev/null
+++ b/tcpip/link/sniffer/pcap.go
@@ -0,0 +1,66 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sniffer
+
+import "time"
+
+type pcapHeader struct {
+ // MagicNumber is the file magic number.
+ MagicNumber uint32
+
+ // VersionMajor is the major version number.
+ VersionMajor uint16
+
+ // VersionMinor is the minor version number.
+ VersionMinor uint16
+
+ // Thiszone is the GMT to local correction.
+ Thiszone int32
+
+ // Sigfigs is the accuracy of timestamps.
+ Sigfigs uint32
+
+ // Snaplen is the max length of captured packets, in octets.
+ Snaplen uint32
+
+ // Network is the data link type.
+ Network uint32
+}
+
+const pcapPacketHeaderLen = 16
+
+type pcapPacketHeader struct {
+ // Seconds is the timestamp seconds.
+ Seconds uint32
+
+ // Microseconds is the timestamp microseconds.
+ Microseconds uint32
+
+ // IncludedLength is the number of octets of packet saved in file.
+ IncludedLength uint32
+
+ // OriginalLength is the actual length of packet.
+ OriginalLength uint32
+}
+
+func newPCAPPacketHeader(incLen, orgLen uint32) pcapPacketHeader {
+ now := time.Now()
+ return pcapPacketHeader{
+ Seconds: uint32(now.Unix()),
+ Microseconds: uint32(now.Nanosecond() / 1000),
+ IncludedLength: incLen,
+ OriginalLength: orgLen,
+ }
+}
diff --git a/tcpip/link/sniffer/sniffer.go b/tcpip/link/sniffer/sniffer.go
index e710582..12f445c 100644
--- a/tcpip/link/sniffer/sniffer.go
+++ b/tcpip/link/sniffer/sniffer.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package sniffer provides the implementation of data-link layer endpoints that
// wrap another endpoint and logs inbound and outbound packets.
@@ -11,24 +21,39 @@
package sniffer
import (
+ "bytes"
+ "encoding/binary"
"fmt"
+ "io"
+ "os"
"sync/atomic"
-
- "log"
+ "time"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/stack"
+ "log"
)
-// LogPackets is a flag used to enable or disable packet valid values
-// are 0 or 1.
+// LogPackets is a flag used to enable or disable packet logging via the log
+// package. Valid values are 0 or 1.
+//
+// LogPackets must be accessed atomically.
var LogPackets uint32 = 1
+// LogPacketsToFile is a flag used to enable or disable logging packets to a
+// pcap file. Valid values are 0 or 1. A file must have been specified when the
+// sniffer was created for this flag to have effect.
+//
+// LogPacketsToFile must be accessed atomically.
+var LogPacketsToFile uint32 = 1
+
type endpoint struct {
dispatcher stack.NetworkDispatcher
lower stack.LinkEndpoint
+ file *os.File
+ maxPCAPLen uint32
}
// New creates a new sniffer link-layer endpoint. It wraps around another
@@ -39,14 +64,90 @@
})
}
+func zoneOffset() (int32, error) {
+ loc, err := time.LoadLocation("Local")
+ if err != nil {
+ return 0, err
+ }
+ date := time.Date(0, 0, 0, 0, 0, 0, 0, loc)
+ _, offset := date.Zone()
+ return int32(offset), nil
+}
+
+func writePCAPHeader(w io.Writer, maxLen uint32) error {
+ offset, err := zoneOffset()
+ if err != nil {
+ return err
+ }
+ return binary.Write(w, binary.BigEndian, pcapHeader{
+ // From https://wiki.wireshark.org/Development/LibpcapFileFormat
+ MagicNumber: 0xa1b2c3d4,
+
+ VersionMajor: 2,
+ VersionMinor: 4,
+ Thiszone: offset,
+ Sigfigs: 0,
+ Snaplen: maxLen,
+ Network: 101, // LINKTYPE_RAW
+ })
+}
+
+// NewWithFile creates a new sniffer link-layer endpoint. It wraps around
+// another endpoint and logs packets and they traverse the endpoint.
+//
+// Packets can be logged to file in the pcap format. A sniffer created
+// with this function will not emit packets using the standard log
+// package.
+//
+// snapLen is the maximum amount of a packet to be saved. Packets with a length
+// less than or equal too snapLen will be saved in their entirety. Longer
+// packets will be truncated to snapLen.
+func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcpip.LinkEndpointID, error) {
+ if err := writePCAPHeader(file, snapLen); err != nil {
+ return 0, err
+ }
+ return stack.RegisterLinkEndpoint(&endpoint{
+ lower: stack.FindLinkEndpoint(lower),
+ file: file,
+ maxPCAPLen: snapLen,
+ }), nil
+}
+
// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
-func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, dstLinkAddr, srcLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
- if atomic.LoadUint32(&LogPackets) == 1 {
- LogPacket("recv", protocol, vv.First(), nil)
+func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+ if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
+ logPacket("recv", protocol, vv.First())
}
- e.dispatcher.DeliverNetworkPacket(e, dstLinkAddr, srcLinkAddr, protocol, vv)
+ if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
+ vs := vv.Views()
+ length := vv.Size()
+ if length > int(e.maxPCAPLen) {
+ length = int(e.maxPCAPLen)
+ }
+
+ buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
+ if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil {
+ panic(err)
+ }
+ for _, v := range vs {
+ if length == 0 {
+ break
+ }
+ if len(v) > length {
+ v = v[:length]
+ }
+ if _, err := buf.Write([]byte(v)); err != nil {
+ panic(err)
+ }
+ length -= len(v)
+ }
+ if _, err := e.file.Write(buf.Bytes()); err != nil {
+ panic(err)
+ }
+ }
+ e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddr, protocol, vv)
}
// Attach implements the stack.LinkEndpoint interface. It saves the dispatcher
@@ -57,12 +158,23 @@
e.lower.Attach(e)
}
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the
// lower endpoint.
func (e *endpoint) MTU() uint32 {
return e.lower.MTU()
}
+// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the
+// request to the lower endpoint.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.lower.Capabilities()
+}
+
// MaxHeaderLength implements the stack.LinkEndpoint interface. It just forwards
// the request to the lower endpoint.
func (e *endpoint) MaxHeaderLength() uint16 {
@@ -76,15 +188,51 @@
// WritePacket implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and forwards
// the request to the lower endpoint.
-func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
- if atomic.LoadUint32(&LogPackets) == 1 {
- LogPacket("send", protocol, hdr.UsedBytes(), payload)
+func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
+ logPacket("send", protocol, hdr.View())
+ }
+ if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
+ hdrBuf := hdr.View()
+ length := len(hdrBuf) + payload.Size()
+ if length > int(e.maxPCAPLen) {
+ length = int(e.maxPCAPLen)
+ }
+
+ buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
+ if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+payload.Size()))); err != nil {
+ panic(err)
+ }
+ if len(hdrBuf) > length {
+ hdrBuf = hdrBuf[:length]
+ }
+ if _, err := buf.Write(hdrBuf); err != nil {
+ panic(err)
+ }
+ length -= len(hdrBuf)
+ if length > 0 {
+ for _, v := range payload.Views() {
+ if len(v) > length {
+ v = v[:length]
+ }
+ n, err := buf.Write(v)
+ if err != nil {
+ panic(err)
+ }
+ length -= n
+ if length == 0 {
+ break
+ }
+ }
+ }
+ if _, err := e.file.Write(buf.Bytes()); err != nil {
+ panic(err)
+ }
}
return e.lower.WritePacket(r, hdr, payload, protocol)
}
-// LogPacket logs the given packet.
-func LogPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b, plb []byte) {
+func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View) {
// Figure out the network layer info.
var transProto uint8
src := tcpip.Address("unknown")
@@ -161,6 +309,37 @@
log.Printf("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
+ case header.ICMPv6ProtocolNumber:
+ transName = "icmp"
+ icmp := header.ICMPv6(b)
+ icmpType := "unknown"
+ switch icmp.Type() {
+ case header.ICMPv6DstUnreachable:
+ icmpType = "destination unreachable"
+ case header.ICMPv6PacketTooBig:
+ icmpType = "packet too big"
+ case header.ICMPv6TimeExceeded:
+ icmpType = "time exceeded"
+ case header.ICMPv6ParamProblem:
+ icmpType = "param problem"
+ case header.ICMPv6EchoRequest:
+ icmpType = "echo request"
+ case header.ICMPv6EchoReply:
+ icmpType = "echo reply"
+ case header.ICMPv6RouterSolicit:
+ icmpType = "router solicit"
+ case header.ICMPv6RouterAdvert:
+ icmpType = "router advert"
+ case header.ICMPv6NeighborSolicit:
+ icmpType = "neighbor solicit"
+ case header.ICMPv6NeighborAdvert:
+ icmpType = "neighbor advert"
+ case header.ICMPv6RedirectMsg:
+ icmpType = "redirect message"
+ }
+ log.Printf("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ return
+
case header.UDPProtocolNumber:
transName = "udp"
udp := header.UDP(b)
@@ -173,9 +352,19 @@
case header.TCPProtocolNumber:
transName = "tcp"
tcp := header.TCP(b)
+ offset := int(tcp.DataOffset())
+ if offset < header.TCPMinimumSize {
+ details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset)
+ break
+ }
+ if offset > len(tcp) {
+ details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, len(tcp))
+ break
+ }
+
srcPort = tcp.SourcePort()
dstPort = tcp.DestinationPort()
- size -= uint16(tcp.DataOffset())
+ size -= uint16(offset)
// Initialize the TCP flags.
flags := tcp.Flags()
@@ -191,6 +380,7 @@
} else {
details += fmt.Sprintf(" options: %+v", tcp.ParsedOptions())
}
+
default:
log.Printf("%s %v -> %v unknown transport protocol: %d", prefix, src, dst, transProto)
return
diff --git a/tcpip/link/tun/tun_unsafe.go b/tcpip/link/tun/tun_unsafe.go
index 7dfd324..1dec419 100644
--- a/tcpip/link/tun/tun_unsafe.go
+++ b/tcpip/link/tun/tun_unsafe.go
@@ -1,7 +1,20 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// +build linux
+
+// Package tun contains methods to open TAP and TUN devices.
package tun
import (
@@ -12,6 +25,16 @@
// Open opens the specified TUN device, sets it to non-blocking mode, and
// returns its file descriptor.
func Open(name string) (int, error) {
+ return open(name, syscall.IFF_TUN|syscall.IFF_NO_PI)
+}
+
+// OpenTAP opens the specified TAP device, sets it to non-blocking mode, and
+// returns its file descriptor.
+func OpenTAP(name string) (int, error) {
+ return open(name, syscall.IFF_TAP|syscall.IFF_NO_PI)
+}
+
+func open(name string, flags uint16) (int, error) {
fd, err := syscall.Open("/dev/net/tun", syscall.O_RDWR, 0)
if err != nil {
return -1, err
@@ -24,7 +47,7 @@
}
copy(ifr.name[:], name)
- ifr.flags = syscall.IFF_TUN | syscall.IFF_NO_PI
+ ifr.flags = flags
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.TUNSETIFF, uintptr(unsafe.Pointer(&ifr)))
if errno != 0 {
syscall.Close(fd)
diff --git a/tcpip/link/waitable/waitable.go b/tcpip/link/waitable/waitable.go
new file mode 100644
index 0000000..f8ea43f
--- /dev/null
+++ b/tcpip/link/waitable/waitable.go
@@ -0,0 +1,123 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package waitable provides the implementation of data-link layer endpoints
+// that wrap other endpoints, and can wait for inflight calls to WritePacket or
+// DeliverNetworkPacket to finish (and new ones to be prevented).
+//
+// Waitable endpoints can be used in the networking stack by calling New(eID) to
+// create a new endpoint, where eID is the ID of the endpoint being wrapped,
+// and then passing it as an argument to Stack.CreateNIC().
+package waitable
+
+import (
+ "github.com/google/netstack/gate"
+ "github.com/google/netstack/tcpip"
+ "github.com/google/netstack/tcpip/buffer"
+ "github.com/google/netstack/tcpip/stack"
+)
+
+// Endpoint is a waitable link-layer endpoint.
+type Endpoint struct {
+ dispatchGate gate.Gate
+ dispatcher stack.NetworkDispatcher
+
+ writeGate gate.Gate
+ lower stack.LinkEndpoint
+}
+
+// New creates a new waitable link-layer endpoint. It wraps around another
+// endpoint and allows the caller to block new write/dispatch calls and wait for
+// the inflight ones to finish before returning.
+func New(lower tcpip.LinkEndpointID) (tcpip.LinkEndpointID, *Endpoint) {
+ e := &Endpoint{
+ lower: stack.FindLinkEndpoint(lower),
+ }
+ return stack.RegisterLinkEndpoint(e), e
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket.
+// It is called by the link-layer endpoint being wrapped when a packet arrives,
+// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't
+// been called.
+func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr, localLinkAddress tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+ if !e.dispatchGate.Enter() {
+ return
+ }
+
+ e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddress, protocol, vv)
+ e.dispatchGate.Leave()
+}
+
+// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and
+// registers with the lower endpoint as its dispatcher so that "e" is called
+// for inbound packets.
+func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+ e.lower.Attach(e)
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *Endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the
+// lower endpoint.
+func (e *Endpoint) MTU() uint32 {
+ return e.lower.MTU()
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the
+// request to the lower endpoint.
+func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.lower.Capabilities()
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It just
+// forwards the request to the lower endpoint.
+func (e *Endpoint) MaxHeaderLength() uint16 {
+ return e.lower.MaxHeaderLength()
+}
+
+// LinkAddress implements stack.LinkEndpoint.LinkAddress. It just forwards the
+// request to the lower endpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.lower.LinkAddress()
+}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by
+// higher-level protocols to write packets. It only forwards packets to the
+// lower endpoint if Wait or WaitWrite haven't been called.
+func (e *Endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ if !e.writeGate.Enter() {
+ return nil
+ }
+
+ err := e.lower.WritePacket(r, hdr, payload, protocol)
+ e.writeGate.Leave()
+ return err
+}
+
+// WaitWrite prevents new calls to WritePacket from reaching the lower endpoint,
+// and waits for inflight ones to finish before returning.
+func (e *Endpoint) WaitWrite() {
+ e.writeGate.Close()
+}
+
+// WaitDispatch prevents new calls to DeliverNetworkPacket from reaching the
+// actual dispatcher, and waits for inflight ones to finish before returning.
+func (e *Endpoint) WaitDispatch() {
+ e.dispatchGate.Close()
+}
diff --git a/tcpip/link/waitable/waitable_test.go b/tcpip/link/waitable/waitable_test.go
new file mode 100644
index 0000000..e68484f
--- /dev/null
+++ b/tcpip/link/waitable/waitable_test.go
@@ -0,0 +1,159 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package waitable
+
+import (
+ "testing"
+
+ "github.com/google/netstack/tcpip"
+ "github.com/google/netstack/tcpip/buffer"
+ "github.com/google/netstack/tcpip/stack"
+)
+
+type countedEndpoint struct {
+ dispatchCount int
+ writeCount int
+ attachCount int
+
+ mtu uint32
+ capabilities stack.LinkEndpointCapabilities
+ hdrLen uint16
+ linkAddr tcpip.LinkAddress
+
+ dispatcher stack.NetworkDispatcher
+}
+
+func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+ e.dispatchCount++
+}
+
+func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.attachCount++
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *countedEndpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+func (e *countedEndpoint) MTU() uint32 {
+ return e.mtu
+}
+
+func (e *countedEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.capabilities
+}
+
+func (e *countedEndpoint) MaxHeaderLength() uint16 {
+ return e.hdrLen
+}
+
+func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress {
+ return e.linkAddr
+}
+
+func (e *countedEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ e.writeCount++
+ return nil
+}
+
+func TestWaitWrite(t *testing.T) {
+ ep := &countedEndpoint{}
+ _, wep := New(stack.RegisterLinkEndpoint(ep))
+
+ // Write and check that it goes through.
+ wep.WritePacket(nil, buffer.Prependable{}, buffer.VectorisedView{}, 0)
+ if want := 1; ep.writeCount != want {
+ t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
+ }
+
+ // Wait on dispatches, then try to write. It must go through.
+ wep.WaitDispatch()
+ wep.WritePacket(nil, buffer.Prependable{}, buffer.VectorisedView{}, 0)
+ if want := 2; ep.writeCount != want {
+ t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
+ }
+
+ // Wait on writes, then try to write. It must not go through.
+ wep.WaitWrite()
+ wep.WritePacket(nil, buffer.Prependable{}, buffer.VectorisedView{}, 0)
+ if want := 2; ep.writeCount != want {
+ t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
+ }
+}
+
+func TestWaitDispatch(t *testing.T) {
+ ep := &countedEndpoint{}
+ _, wep := New(stack.RegisterLinkEndpoint(ep))
+
+ // Check that attach happens.
+ wep.Attach(ep)
+ if want := 1; ep.attachCount != want {
+ t.Fatalf("Unexpected attachCount: got=%v, want=%v", ep.attachCount, want)
+ }
+
+ // Dispatch and check that it goes through.
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{})
+ if want := 1; ep.dispatchCount != want {
+ t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
+ }
+
+ // Wait on writes, then try to dispatch. It must go through.
+ wep.WaitWrite()
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{})
+ if want := 2; ep.dispatchCount != want {
+ t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
+ }
+
+ // Wait on dispatches, then try to dispatch. It must not go through.
+ wep.WaitDispatch()
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{})
+ if want := 2; ep.dispatchCount != want {
+ t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
+ }
+}
+
+func TestOtherMethods(t *testing.T) {
+ const (
+ mtu = 0xdead
+ capabilities = 0xbeef
+ hdrLen = 0x1234
+ linkAddr = "test address"
+ )
+ ep := &countedEndpoint{
+ mtu: mtu,
+ capabilities: capabilities,
+ hdrLen: hdrLen,
+ linkAddr: linkAddr,
+ }
+ _, wep := New(stack.RegisterLinkEndpoint(ep))
+
+ if v := wep.MTU(); v != mtu {
+ t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu)
+ }
+
+ if v := wep.Capabilities(); v != capabilities {
+ t.Fatalf("Unexpected capabilities: got=%v, want=%v", v, capabilities)
+ }
+
+ if v := wep.MaxHeaderLength(); v != hdrLen {
+ t.Fatalf("Unexpected MaxHeaderLength: got=%v, want=%v", v, hdrLen)
+ }
+
+ if v := wep.LinkAddress(); v != linkAddr {
+ t.Fatalf("Unexpected LinkAddress: got=%q, want=%q", v, linkAddr)
+ }
+}
diff --git a/tcpip/network/arp/arp.go b/tcpip/network/arp/arp.go
index a41e475..852d3de 100644
--- a/tcpip/network/arp/arp.go
+++ b/tcpip/network/arp/arp.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package arp implements the ARP network protocol. It is used to resolve
// IPv4 addresses into link-local MAC addresses, and advertises IPv4
@@ -41,8 +51,9 @@
linkAddrCache stack.LinkAddressCache
}
+// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
func (e *endpoint) DefaultTTL() uint8 {
- return 0 // unused for ARP
+ return 0
}
func (e *endpoint) MTU() uint32 {
@@ -54,6 +65,10 @@
return e.nicid
}
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
func (e *endpoint) ID() *stack.NetworkEndpointID {
return &stack.NetworkEndpointID{ProtocolAddress}
}
@@ -64,11 +79,11 @@
func (e *endpoint) Close() {}
-func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
return tcpip.ErrNotSupported
}
-func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
v := vv.First()
h := header.ARP(v)
if !h.IsValid() {
@@ -78,7 +93,7 @@
switch h.Op() {
case header.ARPRequest:
localAddr := tcpip.Address(h.ProtocolAddressTarget())
- if e.linkAddrCache.CheckLocalAddress(e.nicid, localAddr) == 0 {
+ if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize)
@@ -88,7 +103,7 @@
copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:])
copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget())
copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
- e.linkEP.WritePacket(r, &hdr, nil, ProtocolNumber)
+ e.linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber)
fallthrough // also fill the cache from requests
case header.ARPReply:
addr := tcpip.Address(h.ProtocolAddressSender())
@@ -121,10 +136,12 @@
}, nil
}
+// LinkAddressProtocol implements stack.LinkAddressResolver.
func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
+// LinkAddressRequest implements stack.LinkAddressResolver.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
RemoteLinkAddress: broadcastMAC,
@@ -138,14 +155,27 @@
copy(h.ProtocolAddressSender(), localAddr)
copy(h.ProtocolAddressTarget(), addr)
- return linkEP.WritePacket(r, &hdr, nil, ProtocolNumber)
+ return linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber)
}
-// SetOption implements NetworkProtocol.SetOption.
+// ResolveStaticAddress implements stack.LinkAddressResolver.
+func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == "\xff\xff\xff\xff" {
+ return broadcastMAC, true
+ }
+ return "", false
+}
+
+// SetOption implements NetworkProtocol.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Option implements NetworkProtocol.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
func init() {
diff --git a/tcpip/network/arp/arp_test.go b/tcpip/network/arp/arp_test.go
index b660c53..7063c8d 100644
--- a/tcpip/network/arp/arp_test.go
+++ b/tcpip/network/arp/arp_test.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package arp_test
@@ -16,6 +26,7 @@
"github.com/google/netstack/tcpip/network/arp"
"github.com/google/netstack/tcpip/network/ipv4"
"github.com/google/netstack/tcpip/stack"
+ "github.com/google/netstack/tcpip/transport/ping"
)
const (
@@ -32,7 +43,7 @@
}
func newTestContext(t *testing.T) *testContext {
- s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{ipv4.PingProtocolName})
+ s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{ping.ProtocolName4}, stack.Options{})
const defaultMTU = 65536
id, linkEP := channel.New(256, defaultMTU, stackLinkAddr)
@@ -85,54 +96,54 @@
copy(h.HardwareAddressSender(), senderMAC)
copy(h.ProtocolAddressSender(), senderIPv4)
- // stackAddr1
- copy(h.ProtocolAddressTarget(), stackAddr1)
- vv := v.ToVectorisedView([1]buffer.View{})
- c.linkEP.Inject(arp.ProtocolNumber, &vv)
- pkt := <-c.linkEP.C
- if pkt.Proto != arp.ProtocolNumber {
- t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto)
- }
- rep := header.ARP(pkt.Header)
- if !rep.IsValid() {
- t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
- }
- if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 {
- t.Errorf("stackAddr1: expected sender to be set")
- }
- if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
- t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got)
+ inject := func(addr tcpip.Address) {
+ copy(h.ProtocolAddressTarget(), addr)
+ c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView())
}
- // stackAddr2
- copy(h.ProtocolAddressTarget(), stackAddr2)
- vv = v.ToVectorisedView([1]buffer.View{})
- c.linkEP.Inject(arp.ProtocolNumber, &vv)
- pkt = <-c.linkEP.C
- if pkt.Proto != arp.ProtocolNumber {
- t.Fatalf("stackAddr2: expected ARP response, got network protocol number %v", pkt.Proto)
- }
- rep = header.ARP(pkt.Header)
- if !rep.IsValid() {
- t.Fatalf("stackAddr2: invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
- }
- if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr2 {
- t.Errorf("stackAddr2: expected sender to be set")
- }
- if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
- t.Errorf("stackAddr2: expected sender to be stackLinkAddr, got %q", got)
+ inject(stackAddr1)
+ {
+ pkt := <-c.linkEP.C
+ if pkt.Proto != arp.ProtocolNumber {
+ t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto)
+ }
+ rep := header.ARP(pkt.Header)
+ if !rep.IsValid() {
+ t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
+ }
+ if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 {
+ t.Errorf("stackAddr1: expected sender to be set")
+ }
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
+ t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got)
+ }
}
- // stackAddrBad
- copy(h.ProtocolAddressTarget(), stackAddrBad)
- vv = v.ToVectorisedView([1]buffer.View{})
- c.linkEP.Inject(arp.ProtocolNumber, &vv)
+ inject(stackAddr2)
+ {
+ pkt := <-c.linkEP.C
+ if pkt.Proto != arp.ProtocolNumber {
+ t.Fatalf("stackAddr2: expected ARP response, got network protocol number %v", pkt.Proto)
+ }
+ rep := header.ARP(pkt.Header)
+ if !rep.IsValid() {
+ t.Fatalf("stackAddr2: invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
+ }
+ if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr2 {
+ t.Errorf("stackAddr2: expected sender to be set")
+ }
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
+ t.Errorf("stackAddr2: expected sender to be stackLinkAddr, got %q", got)
+ }
+ }
+
+ inject(stackAddrBad)
select {
case pkt := <-c.linkEP.C:
t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
case <-time.After(100 * time.Millisecond):
- // Sleep tests are gross, but this will only
- // potentially fail flakily if there's a bugj
- // If there is no bug this will reliably succeed.
+ // Sleep tests are gross, but this will only potentially flake
+ // if there's a bug. If there is no bug this will reliably
+ // succeed.
}
}
diff --git a/tcpip/network/fragmentation/frag_heap.go b/tcpip/network/fragmentation/frag_heap.go
index 096e598..0d405d9 100644
--- a/tcpip/network/fragmentation/frag_heap.go
+++ b/tcpip/network/fragmentation/frag_heap.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package fragmentation
@@ -13,7 +23,7 @@
type fragment struct {
offset uint16
- vv *buffer.VectorisedView
+ vv buffer.VectorisedView
}
type fragHeap []fragment
@@ -50,7 +60,7 @@
size := curr.vv.Size()
if curr.offset != 0 {
- return buffer.NewVectorisedView(0, nil), fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset)
+ return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset)
}
for h.Len() > 0 {
@@ -58,7 +68,7 @@
if int(curr.offset) < size {
curr.vv.TrimFront(size - int(curr.offset))
} else if int(curr.offset) > size {
- return buffer.NewVectorisedView(0, nil), fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset)
+ return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset)
}
size += curr.vv.Size()
views = append(views, curr.vv.Views()...)
diff --git a/tcpip/network/fragmentation/frag_heap_test.go b/tcpip/network/fragmentation/frag_heap_test.go
index ba074ac..bdae67f 100644
--- a/tcpip/network/fragmentation/frag_heap_test.go
+++ b/tcpip/network/fragmentation/frag_heap_test.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package fragmentation
@@ -15,7 +25,7 @@
var reassambleTestCases = []struct {
comment string
in []fragment
- want *buffer.VectorisedView
+ want buffer.VectorisedView
}{
{
comment: "Non-overlapping in-order",
@@ -77,21 +87,25 @@
func TestReassamble(t *testing.T) {
for _, c := range reassambleTestCases {
- h := (fragHeap)(make([]fragment, 0, 8))
- heap.Init(&h)
- for _, f := range c.in {
- heap.Push(&h, f)
- }
- got, _ := h.reassemble()
-
- if !reflect.DeepEqual(got, *c.want) {
- t.Errorf("Test \"%s\" reassembling failed. Got %v. Want %v", c.comment, got, *c.want)
- }
+ t.Run(c.comment, func(t *testing.T) {
+ h := make(fragHeap, 0, 8)
+ heap.Init(&h)
+ for _, f := range c.in {
+ heap.Push(&h, f)
+ }
+ got, err := h.reassemble()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(got, c.want) {
+ t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want)
+ }
+ })
}
}
func TestReassambleFailsForNonZeroOffset(t *testing.T) {
- h := (fragHeap)(make([]fragment, 0, 8))
+ h := make(fragHeap, 0, 8)
heap.Init(&h)
heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")})
_, err := h.reassemble()
@@ -101,7 +115,7 @@
}
func TestReassambleFailsForHoles(t *testing.T) {
- h := (fragHeap)(make([]fragment, 0, 8))
+ h := make(fragHeap, 0, 8)
heap.Init(&h)
heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")})
heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")})
diff --git a/tcpip/network/fragmentation/fragmentation.go b/tcpip/network/fragmentation/fragmentation.go
index 3c41636..32a85b8 100644
--- a/tcpip/network/fragmentation/fragmentation.go
+++ b/tcpip/network/fragmentation/fragmentation.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package fragmentation contains the implementation of IP fragmentation.
// It is based on RFC 791 and RFC 815.
@@ -14,17 +24,27 @@
"github.com/google/netstack/tcpip/buffer"
)
-// DefaultReassembleTimeout is based on the reassembling timeout suggest in RFC 791 (4.25 minutes).
-const DefaultReassembleTimeout = 5 * time.Minute
+// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
+const DefaultReassembleTimeout = 30 * time.Second
-// MemoryLimit is a suggested value for the limit on the memory used to reassemble packets.
-const MemoryLimit = 8 * 1024 * 1024 // 8MB
+// HighFragThreshold is the threshold at which we start trimming old
+// fragmented packets. Linux uses a default value of 4 MB. See
+// net.ipv4.ipfrag_high_thresh for more information.
+const HighFragThreshold = 4 << 20 // 4MB
+
+// LowFragThreshold is the threshold we reach to when we start dropping
+// older fragmented packets. It's important that we keep enough room for newer
+// packets to be re-assembled. Hence, this needs to be lower than
+// HighFragThreshold enough. Linux uses a default value of 3 MB. See
+// net.ipv4.ipfrag_low_thresh for more information.
+const LowFragThreshold = 3 << 20 // 3MB
// Fragmentation is the main structure that other modules
// of the stack should use to implement IP Fragmentation.
type Fragmentation struct {
mu sync.Mutex
- limit int
+ highLimit int
+ lowLimit int
reassemblers map[uint32]*reassembler
rList reassemblerList
size int
@@ -33,24 +53,36 @@
// NewFragmentation creates a new Fragmentation.
//
-// memoryLimit specifies the limit on the memory consumed
+// highMemoryLimit specifies the limit on the memory consumed
// by the fragments stored by Fragmentation (overhead of internal data-structures
// is not accounted). Fragments are dropped when the limit is reached.
//
+// lowMemoryLimit specifies the limit on which we will reach by dropping
+// fragments after reaching highMemoryLimit.
+//
// reassemblingTimeout specifes the maximum time allowed to reassemble a packet.
// Fragments are lazily evicted only when a new a packet with an
// already existing fragmentation-id arrives after the timeout.
-func NewFragmentation(memoryLimit int, reassemblingTimeout time.Duration) Fragmentation {
- return Fragmentation{
+func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
+ if lowMemoryLimit >= highMemoryLimit {
+ lowMemoryLimit = highMemoryLimit
+ }
+
+ if lowMemoryLimit < 0 {
+ lowMemoryLimit = 0
+ }
+
+ return &Fragmentation{
reassemblers: make(map[uint32]*reassembler),
- limit: memoryLimit,
+ highLimit: highMemoryLimit,
+ lowLimit: lowMemoryLimit,
timeout: reassemblingTimeout,
}
}
// Process processes an incoming fragment beloning to an ID
// and returns a complete packet when all the packets belonging to that ID have been received.
-func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv *buffer.VectorisedView) (buffer.VectorisedView, bool) {
+func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool) {
f.mu.Lock()
r, ok := f.reassemblers[id]
if ok && r.tooOld(f.timeout) {
@@ -72,9 +104,14 @@
if done {
f.release(r)
}
- // Evict reassemblers if we are consuming more memory than the limit.
- for f.size > f.limit {
- f.release(f.rList.Back())
+ // Evict reassemblers if we are consuming more memory than highLimit until
+ // we reach lowLimit.
+ if f.size > f.highLimit {
+ tail := f.rList.Back()
+ for f.size > f.lowLimit && tail != nil {
+ f.release(tail)
+ tail = tail.Prev()
+ }
}
f.mu.Unlock()
return res, done
diff --git a/tcpip/network/fragmentation/fragmentation_test.go b/tcpip/network/fragmentation/fragmentation_test.go
index 8ef89fd..627ac70 100644
--- a/tcpip/network/fragmentation/fragmentation_test.go
+++ b/tcpip/network/fragmentation/fragmentation_test.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package fragmentation
@@ -13,19 +23,13 @@
)
// vv is a helper to build VectorisedView from different strings.
-func vv(size int, pieces ...string) *buffer.VectorisedView {
+func vv(size int, pieces ...string) buffer.VectorisedView {
views := make([]buffer.View, len(pieces))
for i, p := range pieces {
views[i] = []byte(p)
}
- vv := buffer.NewVectorisedView(size, views)
- return &vv
-}
-
-func emptyVv() *buffer.VectorisedView {
- vv := buffer.NewVectorisedView(0, nil)
- return &vv
+ return buffer.NewVectorisedView(size, views)
}
type processInput struct {
@@ -33,11 +37,11 @@
first uint16
last uint16
more bool
- vv *buffer.VectorisedView
+ vv buffer.VectorisedView
}
type processOutput struct {
- vv *buffer.VectorisedView
+ vv buffer.VectorisedView
done bool
}
@@ -53,7 +57,7 @@
{id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
},
out: []processOutput{
- {vv: emptyVv(), done: false},
+ {vv: buffer.VectorisedView{}, done: false},
{vv: vv(4, "01", "23"), done: true},
},
},
@@ -66,8 +70,8 @@
{id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
},
out: []processOutput{
- {vv: emptyVv(), done: false},
- {vv: emptyVv(), done: false},
+ {vv: buffer.VectorisedView{}, done: false},
+ {vv: buffer.VectorisedView{}, done: false},
{vv: vv(4, "ab", "cd"), done: true},
{vv: vv(4, "01", "23"), done: true},
},
@@ -76,22 +80,34 @@
func TestFragmentationProcess(t *testing.T) {
for _, c := range processTestCases {
- f := NewFragmentation(1024, DefaultReassembleTimeout)
- for i, in := range c.in {
- vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv)
- if !reflect.DeepEqual(vv, *(c.out[i].vv)) {
- t.Errorf("Test \"%s\" Process() returned a wrong vv. Got %v. Want %v", c.comment, vv, *(c.out[i].vv))
+ t.Run(c.comment, func(t *testing.T) {
+ f := NewFragmentation(1024, 512, DefaultReassembleTimeout)
+ for i, in := range c.in {
+ vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv)
+ if !reflect.DeepEqual(vv, c.out[i].vv) {
+ t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv)
+ }
+ if done != c.out[i].done {
+ t.Errorf("got Process(%d) = %+v, want = %+v", i, done, c.out[i].done)
+ }
+ if c.out[i].done {
+ if _, ok := f.reassemblers[in.id]; ok {
+ t.Errorf("Process(%d) did not remove buffer from reassemblers", i)
+ }
+ for n := f.rList.Front(); n != nil; n = n.Next() {
+ if n.id == in.id {
+ t.Errorf("Process(%d) did not remove buffer from rList", i)
+ }
+ }
+ }
}
- if done != c.out[i].done {
- t.Errorf("Test \"%s\" Process() returned a wrong done. Got %t. Want %t", c.comment, done, c.out[i].done)
- }
- }
+ })
}
}
func TestReassemblingTimeout(t *testing.T) {
timeout := time.Millisecond
- f := NewFragmentation(1024, timeout)
+ f := NewFragmentation(1024, 512, timeout)
// Send first fragment with id = 0, first = 0, last = 0, and more = true.
f.Process(0, 0, 0, true, vv(1, "0"))
// Sleep more than the timeout.
@@ -105,22 +121,31 @@
}
func TestMemoryLimits(t *testing.T) {
- f := NewFragmentation(1, DefaultReassembleTimeout)
+ f := NewFragmentation(3, 1, DefaultReassembleTimeout)
// Send first fragment with id = 0.
f.Process(0, 0, 0, true, vv(1, "0"))
- // Send first fragment with id = 1. This should caused id = 0 to be evicted.
- f.Process(1, 0, 0, true, vv(1, "0"))
+ // Send first fragment with id = 1.
+ f.Process(1, 0, 0, true, vv(1, "1"))
+ // Send first fragment with id = 2.
+ f.Process(2, 0, 0, true, vv(1, "2"))
+
+ // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
+ // evicted.
+ f.Process(3, 0, 0, true, vv(1, "3"))
if _, ok := f.reassemblers[0]; ok {
t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
}
- if _, ok := f.reassemblers[1]; !ok {
- t.Errorf("Implementation of memory limits is wrong: id=1 is not present.")
+ if _, ok := f.reassemblers[1]; ok {
+ t.Errorf("Memory limits are not respected: id=1 has not been evicted.")
+ }
+ if _, ok := f.reassemblers[3]; !ok {
+ t.Errorf("Implementation of memory limits is wrong: id=3 is not present.")
}
}
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- f := NewFragmentation(1, DefaultReassembleTimeout)
+ f := NewFragmentation(1, 0, DefaultReassembleTimeout)
// Send first fragment with id = 0.
f.Process(0, 0, 0, true, vv(1, "0"))
// Send the same packet again.
@@ -132,16 +157,3 @@
t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
}
}
-
-func TestFragmentationViewsDoNotEscape(t *testing.T) {
- f := NewFragmentation(1024, DefaultReassembleTimeout)
- in := vv(2, "0", "1")
- f.Process(0, 0, 1, true, in)
- // Modify input view.
- in.RemoveFirst()
- got, _ := f.Process(0, 2, 2, false, vv(1, "2"))
- want := vv(3, "0", "1", "2")
- if !reflect.DeepEqual(got, *want) {
- t.Errorf("Process() returned a wrong vv. Got %v. Want %v", got, *want)
- }
-}
diff --git a/tcpip/network/fragmentation/reassembler.go b/tcpip/network/fragmentation/reassembler.go
index 46dd26b..29b1809 100644
--- a/tcpip/network/fragmentation/reassembler.go
+++ b/tcpip/network/fragmentation/reassembler.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package fragmentation
@@ -68,7 +78,7 @@
return used
}
-func (r *reassembler) process(first, last uint16, more bool, vv *buffer.VectorisedView) (buffer.VectorisedView, bool, int) {
+func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int) {
r.mu.Lock()
defer r.mu.Unlock()
consumed := 0
@@ -76,24 +86,22 @@
// A concurrent goroutine might have already reassembled
// the packet and emptied the heap while this goroutine
// was waiting on the mutex. We don't have to do anything in this case.
- return buffer.NewVectorisedView(0, nil), false, consumed
+ return buffer.VectorisedView{}, false, consumed
}
if r.updateHoles(first, last, more) {
// We store the incoming packet only if it filled some holes.
- uu := vv.Clone(nil)
- heap.Push(&r.heap, fragment{offset: first, vv: &uu})
+ heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)})
consumed = vv.Size()
r.size += consumed
}
// Check if all the holes have been deleted and we are ready to reassamble.
if r.deleted < len(r.holes) {
- return buffer.NewVectorisedView(0, nil), false, consumed
+ return buffer.VectorisedView{}, false, consumed
}
res, err := r.heap.reassemble()
if err != nil {
panic(fmt.Sprintf("reassemble failed with: %v. There is probably a bug in the code handling the holes.", err))
}
- r.done = true
return res, true, consumed
}
diff --git a/tcpip/network/fragmentation/reassembler_list.go b/tcpip/network/fragmentation/reassembler_list.go
index e92a0f2..1222949 100644
--- a/tcpip/network/fragmentation/reassembler_list.go
+++ b/tcpip/network/fragmentation/reassembler_list.go
@@ -1,8 +1,3 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// Package ilist provides the implementation of intrusive linked lists.
package fragmentation
// List is an intrusive list. Entries can be added to or removed from the list
diff --git a/tcpip/network/fragmentation/reassembler_test.go b/tcpip/network/fragmentation/reassembler_test.go
index b646043..4c13782 100644
--- a/tcpip/network/fragmentation/reassembler_test.go
+++ b/tcpip/network/fragmentation/reassembler_test.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package fragmentation
diff --git a/tcpip/network/hash/hash.go b/tcpip/network/hash/hash.go
index a1f33d6..183d893 100644
--- a/tcpip/network/hash/hash.go
+++ b/tcpip/network/hash/hash.go
@@ -1,15 +1,24 @@
-// Copyright 2017 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package hash contains utility functions for hashing.
package hash
import (
- "crypto/rand"
"encoding/binary"
- "fmt"
+ "github.com/google/netstack/rand"
"github.com/google/netstack/tcpip/header"
)
@@ -19,7 +28,7 @@
func RandN32(n int) []uint32 {
b := make([]byte, 4*n)
if _, err := rand.Read(b); err != nil {
- panic(fmt.Sprintf("unable to get random numbers: %v", err))
+ panic("unable to get random numbers: " + err.Error())
}
r := make([]uint32, n)
for i := range r {
diff --git a/tcpip/network/ip_test.go b/tcpip/network/ip_test.go
index 5ac1efe..f62da52 100644
--- a/tcpip/network/ip_test.go
+++ b/tcpip/network/ip_test.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package ip_test
@@ -10,9 +20,25 @@
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
+ "github.com/google/netstack/tcpip/link/loopback"
"github.com/google/netstack/tcpip/network/ipv4"
"github.com/google/netstack/tcpip/network/ipv6"
"github.com/google/netstack/tcpip/stack"
+ "github.com/google/netstack/tcpip/transport/tcp"
+ "github.com/google/netstack/tcpip/transport/udp"
+)
+
+const (
+ localIpv4Addr = "\x0a\x00\x00\x01"
+ remoteIpv4Addr = "\x0a\x00\x00\x02"
+ ipv4SubnetAddr = "\x0a\x00\x00\x00"
+ ipv4SubnetMask = "\xff\xff\xff\x00"
+ ipv4Gateway = "\x0a\x00\x00\x03"
+ localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
+ ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
)
// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
@@ -30,12 +56,17 @@
srcAddr tcpip.Address
dstAddr tcpip.Address
v4 bool
+ typ stack.ControlType
+ extra uint32
+
+ dataCalls int
+ controlCalls int
}
// checkValues verifies that the transport protocol, data contents, src & dst
// addresses of a packet match what's expected. If any field doesn't match, the
// test fails.
-func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) {
+func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) {
v := vv.ToView()
if protocol != t.protocol {
t.t.Errorf("protocol = %v, want %v", protocol, t.protocol)
@@ -63,19 +94,44 @@
// DeliverTransportPacket is called by network endpoints after parsing incoming
// packets. This is used by the test object to verify that the results of the
// parsing are expected.
-func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView) {
+func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) {
t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress)
+ t.dataCalls++
+}
+
+// DeliverTransportControlPacket is called by network endpoints after parsing
+// incoming control (ICMP) packets. This is used by the test object to verify
+// that the results of the parsing are expected.
+func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+ t.checkValues(trans, vv, remote, local)
+ if typ != t.typ {
+ t.t.Errorf("typ = %v, want %v", typ, t.typ)
+ }
+ if extra != t.extra {
+ t.t.Errorf("extra = %v, want %v", extra, t.extra)
+ }
+ t.controlCalls++
}
// Attach is only implemented to satisfy the LinkEndpoint interface.
func (*testObject) Attach(stack.NetworkDispatcher) {}
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (*testObject) IsAttached() bool {
+ return true
+}
+
// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that
// matches the linux loopback MTU.
func (*testObject) MTU() uint32 {
return 65536
}
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (*testObject) Capabilities() stack.LinkEndpointCapabilities {
+ return 0
+}
+
// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface.
func (*testObject) MaxHeaderLength() uint16 {
return 0
@@ -89,33 +145,59 @@
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
-func (t *testObject) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (t *testObject) WritePacket(_ *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
var prot tcpip.TransportProtocolNumber
var srcAddr tcpip.Address
var dstAddr tcpip.Address
if t.v4 {
- h := header.IPv4(hdr.UsedBytes())
+ h := header.IPv4(hdr.View())
prot = tcpip.TransportProtocolNumber(h.Protocol())
srcAddr = h.SourceAddress()
dstAddr = h.DestinationAddress()
} else {
- h := header.IPv6(hdr.UsedBytes())
+ h := header.IPv6(hdr.View())
prot = tcpip.TransportProtocolNumber(h.NextHeader())
srcAddr = h.SourceAddress()
dstAddr = h.DestinationAddress()
}
- var views [1]buffer.View
- vv := payload.ToVectorisedView(views)
- t.checkValues(prot, &vv, srcAddr, dstAddr)
+ t.checkValues(prot, payload, srcAddr, dstAddr)
return nil
}
+func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
+ s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
+ s.CreateNIC(1, loopback.New())
+ s.AddAddress(1, ipv4.ProtocolNumber, local)
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: ipv4SubnetAddr,
+ Mask: ipv4SubnetMask,
+ Gateway: ipv4Gateway,
+ NIC: 1,
+ }})
+
+ return s.FindRoute(1, local, remote, ipv4.ProtocolNumber)
+}
+
+func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
+ s := stack.New([]string{ipv6.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
+ s.CreateNIC(1, loopback.New())
+ s.AddAddress(1, ipv6.ProtocolNumber, local)
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: ipv6SubnetAddr,
+ Mask: ipv6SubnetMask,
+ Gateway: ipv6Gateway,
+ NIC: 1,
+ }})
+
+ return s.FindRoute(1, local, remote, ipv6.ProtocolNumber)
+}
+
func TestIPv4Send(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, nil, &o)
+ ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, nil, &o)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
@@ -131,15 +213,15 @@
// Issue the write.
o.protocol = 123
- o.srcAddr = "\x0a\x00\x00\x01"
- o.dstAddr = "\x0a\x00\x00\x02"
+ o.srcAddr = localIpv4Addr
+ o.dstAddr = remoteIpv4Addr
o.contents = payload
- r := stack.Route{
- RemoteAddress: o.dstAddr,
- LocalAddress: o.srcAddr,
+ r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ if err != nil {
+ t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, &hdr, payload, 123); err != nil {
+ if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
@@ -147,7 +229,7 @@
func TestIPv4Receive(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, &o, nil)
+ ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
@@ -160,8 +242,8 @@
TotalLength: uint16(totalLen),
TTL: 20,
Protocol: 10,
- SrcAddr: "\x0a\x00\x00\x02",
- DstAddr: "\x0a\x00\x00\x01",
+ SrcAddr: remoteIpv4Addr,
+ DstAddr: localIpv4Addr,
})
// Make payload be non-zero.
@@ -171,23 +253,113 @@
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
o.protocol = 10
- o.srcAddr = "\x0a\x00\x00\x02"
- o.dstAddr = "\x0a\x00\x00\x01"
+ o.srcAddr = remoteIpv4Addr
+ o.dstAddr = localIpv4Addr
o.contents = view[header.IPv4MinimumSize:totalLen]
- r := stack.Route{
- LocalAddress: o.dstAddr,
- RemoteAddress: o.srcAddr,
+ r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ if err != nil {
+ t.Fatalf("could not find route: %v", err)
}
- var views [1]buffer.View
- vv := view.ToVectorisedView(views)
- ep.HandlePacket(&r, &vv)
+ ep.HandlePacket(&r, view.ToVectorisedView())
+ if o.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ }
+}
+
+func TestIPv4ReceiveControl(t *testing.T) {
+ const mtu = 0xbeef - header.IPv4MinimumSize
+ cases := []struct {
+ name string
+ expectedCount int
+ fragmentOffset uint16
+ code uint8
+ expectedTyp stack.ControlType
+ expectedExtra uint32
+ trunc int
+ }{
+ {"FragmentationNeeded", 1, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 0},
+ {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10},
+ {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8},
+ {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8},
+ {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4DstUnreachableMinimumSize + header.IPv4MinimumSize + 8},
+ {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4DstUnreachableMinimumSize + 8},
+ }
+ r := stack.Route{
+ LocalAddress: localIpv4Addr,
+ RemoteAddress: "\x0a\x00\x00\xbb",
+ }
+ for _, c := range cases {
+ t.Run(c.name, func(t *testing.T) {
+ o := testObject{t: t}
+ proto := ipv4.NewProtocol()
+ ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize + 4
+ view := buffer.NewView(dataOffset + 8)
+
+ // Create the outer IPv4 header.
+ ip := header.IPv4(view)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(len(view) - c.trunc),
+ TTL: 20,
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ SrcAddr: "\x0a\x00\x00\xbb",
+ DstAddr: localIpv4Addr,
+ })
+
+ // Create the ICMP header.
+ icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
+ icmp.SetType(header.ICMPv4DstUnreachable)
+ icmp.SetCode(c.code)
+ copy(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef})
+
+ // Create the inner IPv4 header.
+ ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize+4:])
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: 100,
+ TTL: 20,
+ Protocol: 10,
+ FragmentOffset: c.fragmentOffset,
+ SrcAddr: localIpv4Addr,
+ DstAddr: remoteIpv4Addr,
+ })
+
+ // Make payload be non-zero.
+ for i := dataOffset; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to IPv4 endpoint, dispatcher will validate that
+ // it's ok.
+ o.protocol = 10
+ o.srcAddr = remoteIpv4Addr
+ o.dstAddr = localIpv4Addr
+ o.contents = view[dataOffset:]
+ o.typ = c.expectedTyp
+ o.extra = c.expectedExtra
+
+ vv := view[:len(view)-c.trunc].ToVectorisedView()
+ ep.HandlePacket(&r, vv)
+ if want := c.expectedCount; o.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ }
+ })
+ }
}
func TestIPv4FragmentationReceive(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, &o, nil)
+ ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
@@ -203,8 +375,8 @@
Protocol: 10,
FragmentOffset: 0,
Flags: header.IPv4FlagMoreFragments,
- SrcAddr: "\x0a\x00\x00\x02",
- DstAddr: "\x0a\x00\x00\x01",
+ SrcAddr: remoteIpv4Addr,
+ DstAddr: localIpv4Addr,
})
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
@@ -219,8 +391,8 @@
TTL: 20,
Protocol: 10,
FragmentOffset: 24,
- SrcAddr: "\x0a\x00\x00\x02",
- DstAddr: "\x0a\x00\x00\x01",
+ SrcAddr: remoteIpv4Addr,
+ DstAddr: localIpv4Addr,
})
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
@@ -229,30 +401,32 @@
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
o.protocol = 10
- o.srcAddr = "\x0a\x00\x00\x02"
- o.dstAddr = "\x0a\x00\x00\x01"
+ o.srcAddr = remoteIpv4Addr
+ o.dstAddr = localIpv4Addr
o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
- r := stack.Route{
- LocalAddress: o.dstAddr,
- RemoteAddress: o.srcAddr,
+ r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ if err != nil {
+ t.Fatalf("could not find route: %v", err)
}
// Send first segment.
- var views1 [1]buffer.View
- vv1 := frag1.ToVectorisedView(views1)
- ep.HandlePacket(&r, &vv1)
+ ep.HandlePacket(&r, frag1.ToVectorisedView())
+ if o.dataCalls != 0 {
+ t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
+ }
// Send second segment.
- var views2 [1]buffer.View
- vv2 := frag2.ToVectorisedView(views2)
- ep.HandlePacket(&r, &vv2)
+ ep.HandlePacket(&r, frag2.ToVectorisedView())
+ if o.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ }
}
func TestIPv6Send(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", nil, nil, &o)
+ ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, nil, &o)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
@@ -268,15 +442,15 @@
// Issue the write.
o.protocol = 123
- o.srcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- o.dstAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ o.srcAddr = localIpv6Addr
+ o.dstAddr = remoteIpv6Addr
o.contents = payload
- r := stack.Route{
- RemoteAddress: o.dstAddr,
- LocalAddress: o.srcAddr,
+ r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
+ if err != nil {
+ t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, &hdr, payload, 123); err != nil {
+ if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
@@ -284,7 +458,7 @@
func TestIPv6Receive(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", nil, &o, nil)
+ ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, &o, nil)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
@@ -296,8 +470,8 @@
PayloadLength: uint16(totalLen - header.IPv6MinimumSize),
NextHeader: 10,
HopLimit: 20,
- SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02",
- DstAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ SrcAddr: remoteIpv6Addr,
+ DstAddr: localIpv6Addr,
})
// Make payload be non-zero.
@@ -307,15 +481,124 @@
// Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
o.protocol = 10
- o.srcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- o.dstAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ o.srcAddr = remoteIpv6Addr
+ o.dstAddr = localIpv6Addr
o.contents = view[header.IPv6MinimumSize:totalLen]
- r := stack.Route{
- LocalAddress: o.dstAddr,
- RemoteAddress: o.srcAddr,
+ r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
+ if err != nil {
+ t.Fatalf("could not find route: %v", err)
}
- var views [1]buffer.View
- vv := view.ToVectorisedView(views)
- ep.HandlePacket(&r, &vv)
+
+ ep.HandlePacket(&r, view.ToVectorisedView())
+ if o.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ }
+}
+
+func TestIPv6ReceiveControl(t *testing.T) {
+ newUint16 := func(v uint16) *uint16 { return &v }
+
+ const mtu = 0xffff
+ cases := []struct {
+ name string
+ expectedCount int
+ fragmentOffset *uint16
+ typ header.ICMPv6Type
+ code uint8
+ expectedTyp stack.ControlType
+ expectedExtra uint32
+ trunc int
+ }{
+ {"PacketTooBig", 1, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 0},
+ {"Truncated (10 bytes missing)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 10},
+ {"Truncated (missing IPv6 header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.IPv6MinimumSize + 8},
+ {"Truncated PacketTooBig (missing 'extra info')", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 4 + header.IPv6MinimumSize + 8},
+ {"Truncated (missing ICMP header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + 8},
+ {"Port unreachable", 1, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Truncated DstUnreachable (missing 'extra info')", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 4 + header.IPv6MinimumSize + 8},
+ {"Fragmented, zero offset", 1, newUint16(0), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8},
+ }
+ r := stack.Route{
+ LocalAddress: localIpv6Addr,
+ RemoteAddress: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
+ }
+ for _, c := range cases {
+ t.Run(c.name, func(t *testing.T) {
+ o := testObject{t: t}
+ proto := ipv6.NewProtocol()
+ ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, &o, nil)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ defer ep.Close()
+
+ dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize + 4
+ if c.fragmentOffset != nil {
+ dataOffset += header.IPv6FragmentHeaderSize
+ }
+ view := buffer.NewView(dataOffset + 8)
+
+ // Create the outer IPv6 header.
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 20,
+ SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
+ DstAddr: localIpv6Addr,
+ })
+
+ // Create the ICMP header.
+ icmp := header.ICMPv6(view[header.IPv6MinimumSize:])
+ icmp.SetType(c.typ)
+ icmp.SetCode(c.code)
+ copy(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef})
+
+ // Create the inner IPv6 header.
+ ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:])
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: 100,
+ NextHeader: 10,
+ HopLimit: 20,
+ SrcAddr: localIpv6Addr,
+ DstAddr: remoteIpv6Addr,
+ })
+
+ // Build the fragmentation header if needed.
+ if c.fragmentOffset != nil {
+ ip.SetNextHeader(header.IPv6FragmentHeader)
+ frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:])
+ frag.Encode(&header.IPv6FragmentFields{
+ NextHeader: 10,
+ FragmentOffset: *c.fragmentOffset,
+ M: true,
+ Identification: 0x12345678,
+ })
+ }
+
+ // Make payload be non-zero.
+ for i := dataOffset; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to IPv6 endpoint, dispatcher will validate that
+ // it's ok.
+ o.protocol = 10
+ o.srcAddr = remoteIpv6Addr
+ o.dstAddr = localIpv6Addr
+ o.contents = view[dataOffset:]
+ o.typ = c.expectedTyp
+ o.extra = c.expectedExtra
+
+ vv := view[:len(view)-c.trunc].ToVectorisedView()
+ ep.HandlePacket(&r, vv)
+ if want := c.expectedCount; o.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ }
+ })
+ }
}
diff --git a/tcpip/network/ipv4/icmp.go b/tcpip/network/ipv4/icmp.go
index 29e7d89..16947b9 100644
--- a/tcpip/network/ipv4/icmp.go
+++ b/tcpip/network/ipv4/icmp.go
@@ -1,32 +1,61 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package ipv4
import (
- "context"
"encoding/binary"
- "sync"
- "time"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/stack"
- "github.com/google/netstack/waiter"
)
-// PingProtocolName is a pseudo transport protocol used to handle ping replies.
-// Use it when constructing a stack that intends to use ipv4.Ping.
-const PingProtocolName = "icmpv4ping"
+// handleControl handles the case when an ICMP packet contains the headers of
+// the original packet that caused the ICMP one to be sent. This information is
+// used to find out which transport endpoint must be notified about the ICMP
+// packet.
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+ h := header.IPv4(vv.First())
-// PingProtocolNumber is a transport protocol used to
-// transmit and deliver ICMP messages. The ICMP identifier
-// number is used as a port number for multiplexing.
-const PingProtocolNumber tcpip.TransportProtocolNumber = 1
+ // We don't use IsValid() here because ICMP only requires that the IP
+ // header plus 8 bytes of the transport header be included. So it's
+ // likely that it is truncated, which would cause IsValid to return
+ // false.
+ //
+ // Drop packet if it doesn't have the basic IPv4 header or if the
+ // original source address doesn't match the endpoint's address.
+ if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress {
+ return
+ }
-func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) {
+ hlen := int(h.HeaderLength())
+ if vv.Size() < hlen || h.FragmentOffset() != 0 {
+ // We won't be able to handle this if it doesn't contain the
+ // full IPv4 header, or if it's a fragment not at offset 0
+ // (because it won't have the transport header).
+ return
+ }
+
+ // Skip the ip header, then deliver control message.
+ vv.TrimFront(hlen)
+ p := h.TransportProtocol()
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
+}
+
+func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) {
v := vv.First()
if len(v) < header.ICMPv4MinimumSize {
return
@@ -45,10 +74,28 @@
default:
req.r.Release()
}
- case header.ICMPv4EchoReply, header.ICMPv4InfoReply, header.ICMPv4TimestampReply:
- e.dispatcher.DeliverTransportPacket(r, PingProtocolNumber, vv)
+
+ case header.ICMPv4EchoReply:
+ if len(v) < header.ICMPv4EchoMinimumSize {
+ return
+ }
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, vv)
+
+ case header.ICMPv4DstUnreachable:
+ if len(v) < header.ICMPv4DstUnreachableMinimumSize {
+ return
+ }
+ vv.TrimFront(header.ICMPv4DstUnreachableMinimumSize)
+ switch h.Code() {
+ case header.ICMPv4PortUnreachable:
+ e.handleControl(stack.ControlPortUnreachable, 0, vv)
+
+ case header.ICMPv4FragmentationNeeded:
+ mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4DstUnreachableMinimumSize-2:]))
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
+ }
}
- // TODO(crawshaw): Handle other ICMP types.
+ // TODO: Handle other ICMP types.
}
type echoRequest struct {
@@ -58,366 +105,20 @@
func (e *endpoint) echoReplier() {
for req := range e.echoRequests {
- sendICMPv4(&req.r, header.ICMPv4EchoReply, 0, req.v)
+ sendPing4(&req.r, 0, req.v)
req.r.Release()
}
}
-func sendICMPv4(r *stack.Route, typ header.ICMPv4Type, code byte, data buffer.View) *tcpip.Error {
- hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength()))
+func sendPing4(r *stack.Route, code byte, data buffer.View) *tcpip.Error {
+ hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength()))
- icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- icmpv4.SetType(typ)
+ icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize))
+ icmpv4.SetType(header.ICMPv4EchoReply)
icmpv4.SetCode(code)
+ copy(icmpv4[header.ICMPv4MinimumSize:], data)
+ data = data[header.ICMPv4EchoMinimumSize-header.ICMPv4MinimumSize:]
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
- return r.WritePacket(&hdr, data, header.ICMPv4ProtocolNumber, r.DefaultTTL())
-}
-
-// A Pinger can send echo requests to an address.
-type Pinger struct {
- Stack *stack.Stack
- NICID tcpip.NICID
- Addr tcpip.Address
- LocalAddr tcpip.Address // optional
- Wait time.Duration // if zero, defaults to 1 second
- Count uint16 // if zero, defaults to MaxUint16
-}
-
-type pingerEndpoint struct {
- stack *stack.Stack
- pktCh chan buffer.View
-}
-
-func (e *pingerEndpoint) Close() {
- close(e.pktCh)
-}
-
-func (e *pingerEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) {
- select {
- case e.pktCh <- vv.ToView():
- default:
- }
-}
-
-// Ping sends echo requests to an ICMPv4 endpoint.
-// Responses are streamed to the channel ch.
-func (p *Pinger) Ping(ctx context.Context, ch chan<- PingReply) *tcpip.Error {
- count := p.Count
- if count == 0 {
- count = 1<<16 - 1
- }
- wait := p.Wait
- if wait == 0 {
- wait = 1 * time.Second
- }
-
- r, err := p.Stack.FindRoute(p.NICID, p.LocalAddr, p.Addr, ProtocolNumber)
- if err != nil {
- return err
- }
-
- netProtos := []tcpip.NetworkProtocolNumber{ProtocolNumber}
- ep := &pingerEndpoint{
- stack: p.Stack,
- pktCh: make(chan buffer.View, 1),
- }
- id := stack.TransportEndpointID{
- LocalAddress: r.LocalAddress,
- RemoteAddress: p.Addr,
- }
-
- _, err = p.Stack.PickEphemeralPort(func(port uint16) (bool, *tcpip.Error) {
- id.LocalPort = port
- err := p.Stack.RegisterTransportEndpoint(p.NICID, netProtos, PingProtocolNumber, id, ep)
- switch err {
- case nil:
- return true, nil
- case tcpip.ErrPortInUse:
- return false, nil
- default:
- return false, err
- }
- })
- if err != nil {
- return err
- }
- defer p.Stack.UnregisterTransportEndpoint(p.NICID, netProtos, PingProtocolNumber, id)
-
- v := buffer.NewView(4)
- binary.BigEndian.PutUint16(v[0:], id.LocalPort)
-
- start := time.Now()
-
- done := make(chan struct{})
- go func(count int) {
- loop:
- for ; count > 0; count-- {
- select {
- case v := <-ep.pktCh:
- seq := binary.BigEndian.Uint16(v[header.ICMPv4MinimumSize+2:])
- ch <- PingReply{
- Duration: time.Since(start) - time.Duration(seq)*wait,
- SeqNumber: seq,
- }
- case <-ctx.Done():
- break loop
- }
- }
- close(done)
- }(int(count))
- defer func() { <-done }()
-
- t := time.NewTicker(wait)
- defer t.Stop()
- for seq := uint16(0); seq < count; seq++ {
- select {
- case <-t.C:
- case <-ctx.Done():
- return nil
- }
- binary.BigEndian.PutUint16(v[2:], seq)
- sent := time.Now()
- if err := sendICMPv4(&r, header.ICMPv4Echo, 0, v); err != nil {
- ch <- PingReply{
- Error: err,
- Duration: time.Since(sent),
- SeqNumber: seq,
- }
- }
- }
- return nil
-}
-
-// PingReply summarizes an ICMP echo reply.
-type PingReply struct {
- Error *tcpip.Error // reports any errors sending a ping request
- Duration time.Duration
- SeqNumber uint16
-}
-
-type endpointState int
-
-const (
- stateInitial endpointState = iota
- stateConnected
- stateClosed
-)
-
-type pingEndpoint struct {
- stack *stack.Stack
- netProto tcpip.NetworkProtocolNumber
- waiterQueue *waiter.Queue
-
- mu sync.RWMutex
- pktCh chan buffer.View
- state endpointState
- route stack.Route
- nic tcpip.NICID
- id stack.TransportEndpointID
-}
-
-func (e *pingEndpoint) Close() {
- e.mu.Lock()
- defer e.mu.Unlock()
- if e.state == stateClosed {
- return
- }
- if e.state == stateConnected {
- netProtos := []tcpip.NetworkProtocolNumber{ProtocolNumber}
- e.stack.UnregisterTransportEndpoint(e.nic, netProtos, PingProtocolNumber, e.id)
- e.route.Release()
- }
- close(e.pktCh)
- e.state = stateClosed
-}
-
-func (e *pingEndpoint) Read(a *tcpip.FullAddress) (buffer.View, *tcpip.Error) {
- select {
- case v := <-e.pktCh:
- return v, nil
- default:
- return buffer.View{}, tcpip.ErrWouldBlock
- }
-}
-
-func (e *pingEndpoint) Write(v buffer.View, to *tcpip.FullAddress) (uintptr, *tcpip.Error) {
- e.mu.Lock()
- defer e.mu.Unlock()
- switch state := e.state; state {
- case stateInitial:
- if to == nil {
- return 0, tcpip.ErrNotSupported
- } else if err := e.bindLocked(*to, nil); err != nil {
- return 0, err
- }
- case stateConnected:
- if to != nil {
- prev := tcpip.FullAddress{
- NIC: e.nic,
- Addr: e.id.RemoteAddress,
- Port: e.id.RemotePort,
- }
-
- if prev != *to {
- return 0, tcpip.ErrAlreadyConnected
- }
- }
- default:
- return 0, tcpip.ErrClosedForSend
- }
-
- if len(v) < header.ICMPv4MinimumSize {
- return 0, tcpip.ErrNotSupported
- }
-
- hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(e.route.MaxHeaderLength()))
- icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- copy(icmpv4, v[:header.ICMPv4MinimumSize])
- icmpv4.SetCode(0)
- data := v[header.ICMPv4MinimumSize:]
- // Overwrite the ID with the port number
- binary.BigEndian.PutUint16(data[0:], e.id.LocalPort)
- // Overwrite the checksum of the packet
- icmpv4.SetChecksum(0)
- chksum := header.ICMPv4(data).CalculateChecksum(icmpv4.CalculateChecksum(0))
- icmpv4.SetChecksum(^chksum)
-
- if err := e.route.WritePacket(&hdr, data, header.ICMPv4ProtocolNumber, e.route.DefaultTTL()); err != nil {
- return 0, err
- }
- return uintptr(len(v)), nil
-}
-
-func (e *pingEndpoint) Peek(data [][]byte) (uintptr, *tcpip.Error) {
- return 0, tcpip.ErrNotSupported
-}
-
-// SetOption implements TransportProtocol.SetOption.
-func (p *pingProtocol) SetOption(option interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
-}
-
-func init() {
- stack.RegisterTransportProtocolFactory(PingProtocolName, func() stack.TransportProtocol {
- return &pingProtocol{}
- })
-}
-
-func (e *pingEndpoint) Connect(address tcpip.FullAddress) *tcpip.Error {
- return tcpip.ErrNotSupported
-}
-
-func (e *pingEndpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
- e.Close()
- return nil
-}
-
-func (e *pingEndpoint) Listen(backlog int) *tcpip.Error {
- return tcpip.ErrNotSupported
-}
-
-func (e *pingEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
- return nil, nil, tcpip.ErrNotSupported
-}
-
-func (e *pingEndpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
- return e.bindLocked(addr, commit)
-}
-
-func (e *pingEndpoint) bindLocked(to tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
- if e.state != stateInitial {
- return tcpip.ErrAlreadyConnected
- }
- r, err := e.stack.FindRoute(to.NIC, "", to.Addr, e.netProto)
- if err != nil {
- return err
- }
-
- netProtos := []tcpip.NetworkProtocolNumber{e.netProto}
- id := stack.TransportEndpointID{
- LocalAddress: r.LocalAddress,
- RemoteAddress: to.Addr,
- }
-
- _, err = e.stack.PickEphemeralPort(func(port uint16) (bool, *tcpip.Error) {
- id.LocalPort = port
- err := e.stack.RegisterTransportEndpoint(to.NIC, netProtos, PingProtocolNumber, id, e)
- switch err {
- case nil:
- return true, nil
- case tcpip.ErrPortInUse:
- return false, nil
- default:
- return false, err
- }
- })
-
- if commit != nil {
- if err := commit(); err != nil {
- e.stack.UnregisterTransportEndpoint(to.NIC, netProtos, PingProtocolNumber, id)
- r.Release()
- return err
- }
- }
-
- e.state = stateConnected
- e.route = r
- e.nic = to.NIC
- e.id = id
- return nil
-}
-
-func (e *pingEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
- return tcpip.FullAddress{}, tcpip.ErrNotSupported
-}
-
-func (e *pingEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
- return tcpip.FullAddress{}, tcpip.ErrNotSupported
-}
-
-func (e *pingEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
- return 0
-}
-
-func (e *pingEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrNotSupported
-}
-
-func (e *pingEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrNotSupported
-}
-
-func (e *pingEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) {
- select {
- case e.pktCh <- vv.ToView():
- e.waiterQueue.Notify(waiter.EventIn)
- default:
- }
-}
-
-type pingProtocol struct{}
-
-func (p *pingProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return &pingEndpoint{
- stack: stack,
- netProto: netProto,
- waiterQueue: waiterQueue,
- pktCh: make(chan buffer.View, 10),
- }, nil
-}
-
-func (*pingProtocol) Number() tcpip.TransportProtocolNumber { return PingProtocolNumber }
-
-func (*pingProtocol) MinimumPacketSize() int { return header.ICMPv4EchoMinimumSize }
-
-func (*pingProtocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
- ident := binary.BigEndian.Uint16(v[4:])
- return 0, ident, nil
-}
-
-func (*pingProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool {
- return true
+ return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL())
}
diff --git a/tcpip/network/ipv4/icmp_test.go b/tcpip/network/ipv4/icmp_test.go
deleted file mode 100644
index 3a04f60..0000000
--- a/tcpip/network/ipv4/icmp_test.go
+++ /dev/null
@@ -1,336 +0,0 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package ipv4_test
-
-import (
- "context"
- "encoding/binary"
- "testing"
- "time"
-
- "github.com/google/netstack/tcpip"
- "github.com/google/netstack/tcpip/buffer"
- "github.com/google/netstack/tcpip/header"
- "github.com/google/netstack/tcpip/link/channel"
- "github.com/google/netstack/tcpip/link/sniffer"
- "github.com/google/netstack/tcpip/network/ipv4"
- "github.com/google/netstack/tcpip/stack"
- "github.com/google/netstack/waiter"
-)
-
-const stackAddr = "\x0a\x00\x00\x01"
-
-type testContext struct {
- t *testing.T
- linkEP *channel.Endpoint
- s *stack.Stack
-}
-
-func newTestContext(t *testing.T) *testContext {
- s := stack.New([]string{ipv4.ProtocolName}, []string{ipv4.PingProtocolName})
-
- const defaultMTU = 65536
- id, linkEP := channel.New(256, defaultMTU, "")
- if testing.Verbose() {
- id = sniffer.New(id)
- }
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
- NIC: 1,
- }})
-
- return &testContext{
- t: t,
- s: s,
- linkEP: linkEP,
- }
-}
-
-func (c *testContext) cleanup() {
- close(c.linkEP.C)
-}
-
-func (c *testContext) loopback() {
- go func() {
- for pkt := range c.linkEP.C {
- v := make(buffer.View, len(pkt.Header)+len(pkt.Payload))
- copy(v, pkt.Header)
- copy(v[len(pkt.Header):], pkt.Payload)
- vv := v.ToVectorisedView([1]buffer.View{})
- c.linkEP.Inject(pkt.Proto, &vv)
- }
- }()
-}
-
-func TestEcho(t *testing.T) {
- c := newTestContext(t)
- defer c.cleanup()
- c.loopback()
-
- ch := make(chan ipv4.PingReply, 1)
- p := ipv4.Pinger{
- Stack: c.s,
- NICID: 1,
- Addr: stackAddr,
- Wait: 10 * time.Millisecond,
- Count: 1, // one ping only
- }
- if err := p.Ping(context.Background(), ch); err != nil {
- t.Fatalf("icmp.Ping failed: %v", err)
- }
-
- ping := <-ch
- if ping.Error != nil {
- t.Errorf("bad ping response: %v", ping.Error)
- }
-}
-
-func TestEchoSequence(t *testing.T) {
- c := newTestContext(t)
- defer c.cleanup()
- c.loopback()
-
- const numPings = 3
- ch := make(chan ipv4.PingReply, numPings)
- p := ipv4.Pinger{
- Stack: c.s,
- NICID: 1,
- Addr: stackAddr,
- Wait: 10 * time.Millisecond,
- Count: numPings,
- }
- if err := p.Ping(context.Background(), ch); err != nil {
- t.Fatalf("icmp.Ping failed: %v", err)
- }
-
- for i := uint16(0); i < numPings; i++ {
- ping := <-ch
- if ping.Error != nil {
- t.Errorf("i=%d bad ping response: %v", i, ping.Error)
- }
- if ping.SeqNumber != i {
- t.Errorf("SeqNumber=%d, want %d", ping.SeqNumber, i)
- }
- }
-}
-
-const (
- stackAddr0 = "\x0a\x00\x00\x02"
- stackAddr1 = "\x0a\x00\x00\x03"
- linkAddr0 = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
- linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
-)
-
-type testEndpointContext struct {
- t *testing.T
- s *stack.Stack
-
- linkEP0 *channel.Endpoint
- linkEP1 *channel.Endpoint
-
- icmpCh chan header.ICMPv4
-}
-
-func (c *testEndpointContext) cleanup() {
- close(c.linkEP0.C)
- close(c.linkEP1.C)
-}
-
-func newTestEndpointContext(t *testing.T) *testEndpointContext {
- c := &testEndpointContext{
- t: t,
- s: stack.New([]string{ipv4.ProtocolName}, []string{ipv4.PingProtocolName}),
- icmpCh: make(chan header.ICMPv4, 10),
- }
-
- const defaultMTU = 65536
- id0, linkEP := channel.New(256, defaultMTU, linkAddr0)
- c.linkEP0 = linkEP
- if testing.Verbose() {
- id0 = sniffer.New(id0)
- }
- if err := c.s.CreateNIC(1, id0); err != nil {
- t.Fatalf("CreateNIC s: %v", err)
- }
- id1, linkEP := channel.New(256, defaultMTU, linkAddr1)
- c.linkEP1 = linkEP
- if testing.Verbose() {
- id1 = sniffer.New(id1)
- }
- if err := c.s.CreateNIC(2, id1); err != nil {
- t.Fatalf("CreateNIC s: %v", err)
- }
- if err := c.s.AddAddress(2, ipv4.ProtocolNumber, stackAddr0); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
- if err := c.s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
- c.s.SetRouteTable([]tcpip.Route{
- {
- Destination: stackAddr0,
- Mask: "\xFF\xFF\xFF\xFF",
- Gateway: "",
- NIC: 1,
- },
- {
- Destination: stackAddr1,
- Mask: "\xFF\xFF\xFF\xFF",
- Gateway: "",
- NIC: 2,
- },
- })
-
- go c.routePackets(c.linkEP0.C, c.linkEP1)
- go c.routePackets(c.linkEP1.C, c.linkEP0)
- return c
-}
-
-func (c *testEndpointContext) countPacket(pkt channel.PacketInfo) {
- if pkt.Proto != header.IPv4ProtocolNumber {
- c.t.Fatalf("Received non IPV4 packet: 0x%x", pkt.Proto)
- }
- ipv4 := header.IPv4(pkt.Header)
- c.icmpCh <- header.ICMPv4(append(pkt.Header[ipv4.HeaderLength():], pkt.Payload...))
-}
-
-func (c *testEndpointContext) routePackets(ch <-chan channel.PacketInfo, ep *channel.Endpoint) {
- for pkt := range ch {
- c.countPacket(pkt)
- v := buffer.View(append(pkt.Header, pkt.Payload...))
- vs := []buffer.View{v}
- vv := buffer.NewVectorisedView(len(v), vs)
- ep.InjectLinkAddr(pkt.Proto, ep.LinkAddress(), &vv)
- }
-}
-
-type callbackStub struct {
- f func(e *waiter.Entry)
-}
-
-func (c *callbackStub) Callback(e *waiter.Entry) {
- c.f(e)
-}
-
-func TestEndpoints(t *testing.T) {
- c := newTestEndpointContext(t)
- defer c.cleanup()
-
- wq0 := &waiter.Queue{}
- ep0, err := c.s.NewEndpoint(ipv4.PingProtocolNumber, ipv4.ProtocolNumber, wq0)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
- defer ep0.Close()
- wq1 := &waiter.Queue{}
- ep1, err := c.s.NewEndpoint(ipv4.PingProtocolNumber, ipv4.ProtocolNumber, wq1)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
- defer ep1.Close()
-
- if err := ep0.Bind(tcpip.FullAddress{NIC: 1, Addr: stackAddr0}, nil); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
- }
- if err := ep1.Bind(tcpip.FullAddress{NIC: 2, Addr: stackAddr1}, nil); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
- }
-
- echos := 64
-
- ping := func(wq *waiter.Queue, ep tcpip.Endpoint, data []byte) {
- outPkt := make([]byte, header.ICMPv4MinimumSize+4+len(data))
- icmpv4 := header.ICMPv4(outPkt[:header.ICMPv4MinimumSize])
- icmpv4.SetType(header.ICMPv4Echo)
- copy(outPkt[header.ICMPv4MinimumSize+4:], data)
- binary.BigEndian.PutUint16(outPkt[header.ICMPv4MinimumSize:], 0)
-
- for seqno := uint16(1); seqno <= uint16(echos); seqno++ {
- binary.BigEndian.PutUint16(outPkt[header.ICMPv4MinimumSize+2:], seqno)
-
- // We need to register with the waiter queue before we try writing, since
- // the notification that the endpoint received a response may arrive immediately.
- ready := make(chan struct{})
- e := waiter.Entry{Callback: &callbackStub{func(*waiter.Entry) { close(ready) }}}
- wq.EventRegister(&e, waiter.EventIn)
- n, err := ep.Write(buffer.View(outPkt), nil)
- if err != nil {
- c.t.Fatalf("Write failed: %v\n", err)
- } else if n != uintptr(len(outPkt)) {
- c.t.Fatalf("Write was short: %v\n", n)
- }
-
- // Avoid reading until we have something to read
- select {
- case <-time.After(1 * time.Second):
- c.t.Fatalf("Timed out waiting for socket to be readable")
- case <-ready:
- }
- wq.EventUnregister(&e)
- inPkt, err := ep.Read(nil)
- if err != nil {
- c.t.Fatalf("Read failed: %v\n", err)
- }
-
- // Verify the contents of the packet we just read.
- var icmp header.ICMPv4 = []byte(inPkt)
- if icmp.Type() != header.ICMPv4EchoReply {
- c.t.Fatalf("Unexpected packet type: %d", icmp.Type())
- }
- inSeqno := binary.BigEndian.Uint16(inPkt[header.ICMPv4MinimumSize+2 : header.ICMPv4MinimumSize+4])
- if inSeqno != seqno {
- c.t.Fatalf("Unexpected sequence number: %d", inSeqno)
- }
- outData := outPkt[header.ICMPv4EchoMinimumSize:]
- inData := inPkt[header.ICMPv4EchoMinimumSize:]
- if len(outData) != len(inData) {
- c.t.Fatalf("Read packet of unexpected length: %d\n", len(inData))
- }
- for i := range inData {
- if inData[i] != outData[i] {
- c.t.Fatalf("Data mismatch")
- }
- }
- }
- }
-
- data := []byte{0xaa, 0xab, 0xac}
- go ping(wq0, ep0, data)
- data = []byte{0xad, 0xae, 0xaf}
- go ping(wq1, ep1, data)
-
- ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
- defer cancel()
-
- stats := make(map[header.ICMPv4Type]int)
- for {
- select {
- case <-ctx.Done():
- t.Errorf("Timeout waiting for ICMP, got: %#+v", stats)
- return
- case icmp := <-c.icmpCh:
- if icmp.Type() != header.ICMPv4Echo && icmp.Type() != header.ICMPv4EchoReply {
- c.t.Fatalf("Unexpected type: %d", icmp.Type())
- }
- stats[icmp.Type()]++
- if stats[icmp.Type()] > echos*2 {
- c.t.Fatalf("Too many (%d) packets of type %d", stats[icmp.Type()], icmp.Type())
- }
- if len(stats) == 2 && stats[header.ICMPv4Echo] == echos*2 && stats[header.ICMPv4EchoReply] == echos*2 {
- return
- }
- }
- }
-}
diff --git a/tcpip/network/ipv4/ipv4.go b/tcpip/network/ipv4/ipv4.go
index 96673e5..1c329fa 100644
--- a/tcpip/network/ipv4/ipv4.go
+++ b/tcpip/network/ipv4/ipv4.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package ipv4 contains the implementation of the ipv4 network protocol. To use
// it in the networking stack, this package must be added to the project, and
@@ -45,7 +55,7 @@
linkEP stack.LinkEndpoint
dispatcher stack.TransportDispatcher
echoRequests chan echoRequest
- fragmentation fragmentation.Fragmentation
+ fragmentation *fragmentation.Fragmentation
}
func newEndpoint(nicid tcpip.NICID, addr tcpip.Address, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) *endpoint {
@@ -54,28 +64,30 @@
linkEP: linkEP,
dispatcher: dispatcher,
echoRequests: make(chan echoRequest, 10),
- fragmentation: fragmentation.NewFragmentation(fragmentation.MemoryLimit, fragmentation.DefaultReassembleTimeout),
+ fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
}
copy(e.address[:], addr)
- e.id = stack.NetworkEndpointID{tcpip.Address(e.address[:])}
+ e.id = stack.NetworkEndpointID{LocalAddress: tcpip.Address(e.address[:])}
go e.echoReplier()
return e
}
+// DefaultTTL is the default time-to-live value for this endpoint.
func (e *endpoint) DefaultTTL() uint8 {
- return header.IPv4DefaultTTL
+ return 255
}
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- lmtu := e.linkEP.MTU()
- if lmtu > maxTotalSize {
- lmtu = maxTotalSize
- }
- return lmtu - uint32(e.MaxHeaderLength())
+ return calculateMTU(e.linkEP.MTU())
+}
+
+// Capabilities implements stack.NetworkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
}
// NICID returns the ID of the NIC this endpoint belongs to.
@@ -95,9 +107,9 @@
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- length := uint16(hdr.UsedLength() + len(payload))
+ length := uint16(hdr.UsedLength() + payload.Size())
id := uint32(0)
if length > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
@@ -114,14 +126,14 @@
DstAddr: r.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
- atomic.AddUint64(&r.MutableStats().IP.PacketsSent, 1)
+ r.Stats().IP.PacketsSent.Increment()
return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber)
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
h := header.IPv4(vv.First())
if !h.IsValid(vv.Size()) {
return
@@ -136,18 +148,18 @@
if more || h.FragmentOffset() != 0 {
// The packet is a fragment, let's try to reassemble it.
last := h.FragmentOffset() + uint16(vv.Size()) - 1
- tt, ready := e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv)
+ var ready bool
+ vv, ready = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv)
if !ready {
return
}
- vv = &tt
}
- p := tcpip.TransportProtocolNumber(h.Protocol())
+ p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
e.handleICMP(r, vv)
return
}
- atomic.AddUint64(&r.MutableStats().IP.PacketsDelivered, 1)
+ r.Stats().IP.PacketsDelivered.Increment()
e.dispatcher.DeliverTransportPacket(r, p, vv)
}
@@ -192,6 +204,20 @@
return tcpip.ErrUnknownProtocolOption
}
+// Option implements NetworkProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// calculateMTU calculates the network-layer payload MTU based on the link-layer
+// payload mtu.
+func calculateMTU(mtu uint32) uint32 {
+ if mtu > maxTotalSize {
+ mtu = maxTotalSize
+ }
+ return mtu - header.IPv4MinimumSize
+}
+
// hashRoute calculates a hash value for the given route. It uses the source &
// destination address, the transport protocol number, and a random initial
// value (generated once on initialization) to generate the hash.
diff --git a/tcpip/network/ipv4/ipv4_test.go b/tcpip/network/ipv4/ipv4_test.go
index c827ecf..33ea8a8 100644
--- a/tcpip/network/ipv4/ipv4_test.go
+++ b/tcpip/network/ipv4/ipv4_test.go
@@ -1,6 +1,16 @@
-// Copyright 2017 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package ipv4_test
@@ -18,7 +28,7 @@
)
func TestExcludeBroadcast(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName})
+ s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
const defaultMTU = 65536
id, _ := channel.New(256, defaultMTU, "")
@@ -43,40 +53,40 @@
NIC: 1,
}})
- var wq waiter.Queue
- ep, e := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if e != nil {
- t.Fatal(e)
- }
-
randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53}
- // Cannot connect using a broadcast address as the source.
- e = ep.Connect(randomAddr)
- if e == nil {
- t.Error(`connect succeeded, expected "no route"`)
- } else if e != tcpip.ErrNoRoute {
- t.Error(`connect failed with %v, expected "no route"`, e)
- }
+ var wq waiter.Queue
+ t.Run("WithoutPrimaryAddress", func(t *testing.T) {
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ep.Close()
- // However, we can bind to a broadcast address to listen.
- e = ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}, nil)
- if e != nil {
- t.Errorf("bind failed: %v", e)
- }
+ // Cannot connect using a broadcast address as the source.
+ if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute {
+ t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute)
+ }
- ep.Close()
- ep, e = s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if e != nil {
- t.Fatal(e)
- }
+ // However, we can bind to a broadcast address to listen.
+ if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}, nil); err != nil {
+ t.Errorf("Bind failed: %v", err)
+ }
+ })
- // Add a valid primary endpoint address, now we can connect.
- if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
- e = ep.Connect(randomAddr)
- if e != nil {
- t.Errorf("connect failed: %v", e)
- }
+ t.Run("WithPrimaryAddress", func(t *testing.T) {
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ep.Close()
+
+ // Add a valid primary endpoint address, now we can connect.
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+ if err := ep.Connect(randomAddr); err != nil {
+ t.Errorf("Connect failed: %v", err)
+ }
+ })
}
diff --git a/tcpip/network/ipv6/icmp.go b/tcpip/network/ipv6/icmp.go
index dd3a348..3679df5 100644
--- a/tcpip/network/ipv6/icmp.go
+++ b/tcpip/network/ipv6/icmp.go
@@ -1,12 +1,21 @@
-// Copyright 2017 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package ipv6
import (
"encoding/binary"
- "log"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
@@ -14,15 +23,46 @@
"github.com/google/netstack/tcpip/stack"
)
-const (
- ndpSolicitedFlag = 1 << 6
- ndpOverrideFlag = 1 << 5
+// handleControl handles the case when an ICMP packet contains the headers of
+// the original packet that caused the ICMP one to be sent. This information is
+// used to find out which transport endpoint must be notified about the ICMP
+// packet.
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+ h := header.IPv6(vv.First())
- ndpOptSrcLinkAddr = 1
- ndpOptDstLinkAddr = 2
-)
+ // We don't use IsValid() here because ICMP only requires that up to
+ // 1280 bytes of the original packet be included. So it's likely that it
+ // is truncated, which would cause IsValid to return false.
+ //
+ // Drop packet if it doesn't have the basic IPv6 header or if the
+ // original source address doesn't match the endpoint's address.
+ if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress {
+ return
+ }
-func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) {
+ // Skip the IP header, then handle the fragmentation header if there
+ // is one.
+ vv.TrimFront(header.IPv6MinimumSize)
+ p := h.TransportProtocol()
+ if p == header.IPv6FragmentHeader {
+ f := header.IPv6Fragment(vv.First())
+ if !f.IsValid() || f.FragmentOffset() != 0 {
+ // We can't handle fragments that aren't at offset 0
+ // because they don't have the transport headers.
+ return
+ }
+
+ // Skip fragmentation header and find out the actual protocol
+ // number.
+ vv.TrimFront(header.IPv6FragmentHeaderSize)
+ p = f.TransportProtocol()
+ }
+
+ // Deliver the control packet to the transport endpoint.
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
+}
+
+func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) {
v := vv.First()
if len(v) < header.ICMPv6MinimumSize {
return
@@ -30,25 +70,43 @@
h := header.ICMPv6(v)
switch h.Type() {
+ case header.ICMPv6PacketTooBig:
+ if len(v) < header.ICMPv6PacketTooBigMinimumSize {
+ return
+ }
+ vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
+ mtu := binary.BigEndian.Uint32(v[header.ICMPv6MinimumSize:])
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
+
+ case header.ICMPv6DstUnreachable:
+ if len(v) < header.ICMPv6DstUnreachableMinimumSize {
+ return
+ }
+ vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
+ switch h.Code() {
+ case header.ICMPv6PortUnreachable:
+ e.handleControl(stack.ControlPortUnreachable, 0, vv)
+ }
+
case header.ICMPv6NeighborSolicit:
if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
return
}
targetAddr := tcpip.Address(v[8 : 8+16])
- if e.linkAddrCache.CheckLocalAddress(e.nicid, targetAddr) == 0 {
- return // we have no useful answer, ignore the request
+ if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 {
+ // We don't have a useful answer; the best we can do is ignore the request.
+ return
}
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- pkt[4] = ndpSolicitedFlag | ndpOverrideFlag
- copy(pkt[8:24], v[8:])
- pkt[24] = ndpOptDstLinkAddr
- pkt[25] = 1 // address length
- copy(pkt[26:], r.LocalLinkAddress[:])
- r.LocalAddress = targetAddr
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, nil))
- r.WritePacket(&hdr, nil, header.ICMPv6ProtocolNumber, r.DefaultTTL())
+ pkt[icmpV6FlagOffset] = ndpSolicitedFlag | ndpOverrideFlag
+ copy(pkt[icmpV6OptOffset-len(targetAddr):], targetAddr)
+ pkt[icmpV6OptOffset] = ndpOptDstLinkAddr
+ pkt[icmpV6LengthOffset] = 1
+ copy(pkt[icmpV6LengthOffset+1:], r.LocalLinkAddress[:])
+ pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ r.WritePacket(hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL())
e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
@@ -56,34 +114,64 @@
if len(v) < header.ICMPv6NeighborAdvertSize {
return
}
- e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
+ targetAddr := tcpip.Address(v[8 : 8+16])
+ e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress)
+ if targetAddr != r.RemoteAddress {
+ e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
+ }
case header.ICMPv6EchoRequest:
if len(v) < header.ICMPv6EchoMinimumSize {
return
}
- data := v[header.ICMPv6EchoMinimumSize:]
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ vv.TrimFront(header.ICMPv6EchoMinimumSize)
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
copy(pkt, h)
pkt.SetType(header.ICMPv6EchoReply)
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, data))
- r.WritePacket(&hdr, data, header.ICMPv6ProtocolNumber, r.DefaultTTL())
- default:
- log.Printf("got ICMPv6: type=%v, code=%v, len(v)=%d", h.Type(), h.Code(), len(v))
+ pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, vv))
+ r.WritePacket(hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL())
+
+ case header.ICMPv6EchoReply:
+ if len(v) < header.ICMPv6EchoMinimumSize {
+ return
+ }
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, vv)
+
}
- // TODO case header.ICMPv6EchoReply
+}
+
+const (
+ ndpSolicitedFlag = 1 << 6
+ ndpOverrideFlag = 1 << 5
+
+ ndpOptSrcLinkAddr = 1
+ ndpOptDstLinkAddr = 2
+
+ icmpV6FlagOffset = 4
+ icmpV6OptOffset = 24
+ icmpV6LengthOffset = 25
+)
+
+// solicitedNodeAddr computes the solicited-node multicast address.
+// This is used for NDP. Described in RFC 4291.
+func solicitedNodeAddr(addr tcpip.Address) tcpip.Address {
+ const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff"
+ return solicitedNodeMulticastPrefix + addr[len(addr)-3:]
}
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+var _ stack.LinkAddressResolver = (*protocol)(nil)
+
+// LinkAddressProtocol implements stack.LinkAddressResolver.
func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv6ProtocolNumber
}
+// LinkAddressRequest implements stack.LinkAddressResolver.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
- // Solicited-Node multicast address, used for NDP. Described in RFC 4291.
- snaddr := "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff" + addr[len(addr)-3:]
+ snaddr := solicitedNodeAddr(addr)
r := &stack.Route{
LocalAddress: localAddr,
RemoteAddress: snaddr,
@@ -92,35 +180,43 @@
hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- copy(pkt[8:24], addr)
- pkt[24] = ndpOptSrcLinkAddr
- pkt[25] = 1 // address length
- copy(pkt[26:], linkEP.LinkAddress())
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, nil))
+ copy(pkt[icmpV6OptOffset-len(addr):], addr)
+ pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr
+ pkt[icmpV6LengthOffset] = 1
+ copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress())
+ pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
length := uint16(hdr.UsedLength())
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
+ HopLimit: defaultIPv6HopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- return linkEP.WritePacket(r, &hdr, nil, ProtocolNumber)
+ return linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber)
}
-func icmpChecksum(h header.ICMPv6, src, dst tcpip.Address, data []byte) uint16 {
+// ResolveStaticAddress implements stack.LinkAddressResolver.
+func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ return "", false
+}
+
+func icmpChecksum(h header.ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
// Calculate the IPv6 pseudo-header upper-layer checksum.
xsum := header.Checksum([]byte(src), 0)
xsum = header.Checksum([]byte(dst), xsum)
var upperLayerLength [4]byte
- binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+len(data)))
+ binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size()))
xsum = header.Checksum(upperLayerLength[:], xsum)
xsum = header.Checksum([]byte{0, 0, 0, uint8(header.ICMPv6ProtocolNumber)}, xsum)
- xsum = header.Checksum(data, xsum)
+ for _, v := range vv.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
h2, h3 := h[2], h[3]
h[2], h[3] = 0, 0
xsum = ^header.Checksum(h, xsum)
diff --git a/tcpip/network/ipv6/icmp_test.go b/tcpip/network/ipv6/icmp_test.go
index b59a20f..3a5d030 100644
--- a/tcpip/network/ipv6/icmp_test.go
+++ b/tcpip/network/ipv6/icmp_test.go
@@ -1,11 +1,22 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package ipv6
import (
"context"
+ "runtime"
"strings"
"testing"
"time"
@@ -16,6 +27,8 @@
"github.com/google/netstack/tcpip/link/channel"
"github.com/google/netstack/tcpip/link/sniffer"
"github.com/google/netstack/tcpip/stack"
+ "github.com/google/netstack/tcpip/transport/ping"
+ "github.com/google/netstack/waiter"
)
const (
@@ -23,9 +36,33 @@
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
)
+// linkLocalAddr computes the default IPv6 link-local address from
+// a link-layer (MAC) address.
+func linkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address {
+ // Convert a 48-bit MAC to an EUI-64 and then prepend the
+ // link-local header, FE80::.
+ //
+ // The conversion is very nearly:
+ // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff
+ // Note the capital A. The conversion aa->Aa involves a bit flip.
+ lladdrb := [16]byte{
+ 0: 0xFE,
+ 1: 0x80,
+ 8: linkAddr[0] ^ 2,
+ 9: linkAddr[1],
+ 10: linkAddr[2],
+ 11: 0xFF,
+ 12: 0xFE,
+ 13: linkAddr[3],
+ 14: linkAddr[4],
+ 15: linkAddr[5],
+ }
+ return tcpip.Address(lladdrb[:])
+}
+
var (
- lladdr0 = LinkLocalAddr(linkAddr0)
- lladdr1 = LinkLocalAddr(linkAddr1)
+ lladdr0 = linkLocalAddr(linkAddr0)
+ lladdr1 = linkLocalAddr(linkAddr1)
)
type testContext struct {
@@ -39,17 +76,27 @@
icmpCh chan header.ICMPv6Type
}
+type endpointWithResolutionCapability struct {
+ stack.LinkEndpoint
+}
+
+func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapabilities {
+ return e.LinkEndpoint.Capabilities() | stack.CapabilityResolutionRequired
+}
+
func newTestContext(t *testing.T) *testContext {
c := &testContext{
t: t,
- s0: stack.New([]string{ProtocolName}, nil),
- s1: stack.New([]string{ProtocolName}, nil),
+ s0: stack.New([]string{ProtocolName}, []string{ping.ProtocolName6}, stack.Options{}),
+ s1: stack.New([]string{ProtocolName}, []string{ping.ProtocolName6}, stack.Options{}),
icmpCh: make(chan header.ICMPv6Type, 10),
}
const defaultMTU = 65536
- id0, linkEP0 := channel.New(256, defaultMTU, linkAddr0)
+ _, linkEP0 := channel.New(256, defaultMTU, linkAddr0)
c.linkEP0 = linkEP0
+ wrappedEP0 := endpointWithResolutionCapability{LinkEndpoint: linkEP0}
+ id0 := stack.RegisterLinkEndpoint(wrappedEP0)
if testing.Verbose() {
id0 = sniffer.New(id0)
}
@@ -59,29 +106,38 @@
if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress lladdr0: %v", err)
}
- if err := c.s0.AddAddress(1, ProtocolNumber, SolicitedNodeAddr(lladdr0)); err != nil {
+ if err := c.s0.AddAddress(1, ProtocolNumber, solicitedNodeAddr(lladdr0)); err != nil {
t.Fatalf("AddAddress sn lladdr0: %v", err)
}
- id1, linkEP1 := channel.New(256, defaultMTU, linkAddr1)
+ _, linkEP1 := channel.New(256, defaultMTU, linkAddr1)
c.linkEP1 = linkEP1
+ wrappedEP1 := endpointWithResolutionCapability{LinkEndpoint: linkEP1}
+ id1 := stack.RegisterLinkEndpoint(wrappedEP1)
if err := c.s1.CreateNIC(1, id1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil {
t.Fatalf("AddAddress lladdr1: %v", err)
}
- if err := c.s1.AddAddress(1, ProtocolNumber, SolicitedNodeAddr(lladdr1)); err != nil {
+ if err := c.s1.AddAddress(1, ProtocolNumber, solicitedNodeAddr(lladdr1)); err != nil {
t.Fatalf("AddAddress sn lladdr1: %v", err)
}
- routeTable := []tcpip.Route{{
- Destination: tcpip.Address(strings.Repeat("\x00", 16)),
- Mask: tcpip.Address(strings.Repeat("\x00", 16)),
- NIC: 1,
- }}
- c.s0.SetRouteTable(routeTable)
- c.s1.SetRouteTable(routeTable)
+ c.s0.SetRouteTable(
+ []tcpip.Route{{
+ Destination: lladdr1,
+ Mask: tcpip.Address(strings.Repeat("\xff", 16)),
+ NIC: 1,
+ }},
+ )
+ c.s1.SetRouteTable(
+ []tcpip.Route{{
+ Destination: lladdr0,
+ Mask: tcpip.Address(strings.Repeat("\xff", 16)),
+ NIC: 1,
+ }},
+ )
go c.routePackets(linkEP0.C, linkEP1)
go c.routePackets(linkEP1.C, linkEP0)
@@ -90,7 +146,7 @@
}
func (c *testContext) countPacket(pkt channel.PacketInfo) {
- if pkt.Proto != header.IPv6ProtocolNumber {
+ if pkt.Proto != ProtocolNumber {
return
}
ipv6 := header.IPv6(pkt.Header)
@@ -109,7 +165,7 @@
views := []buffer.View{pkt.Header, pkt.Payload}
size := len(pkt.Header) + len(pkt.Payload)
vv := buffer.NewVectorisedView(size, views)
- ep.InjectLinkAddr(pkt.Proto, ep.LinkAddress(), &vv)
+ ep.InjectLinkAddr(pkt.Proto, ep.LinkAddress(), vv)
}
}
@@ -121,7 +177,7 @@
func TestLinkResolution(t *testing.T) {
c := newTestContext(t)
defer c.cleanup()
- r, err := c.s0.FindRoute(1, lladdr0, lladdr1, header.IPv6ProtocolNumber)
+ r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber)
if err != nil {
t.Fatal(err)
}
@@ -130,8 +186,14 @@
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
pkt.SetType(header.ICMPv6EchoRequest)
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, nil))
- if err := r.WritePacket(&hdr, nil, header.ICMPv6ProtocolNumber, header.IPv6DefaultHopLimit); err != nil {
+ pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ payload := tcpip.SlicePayload(hdr.View())
+
+ // We can't send our payload directly over the route because that
+ // doesn't provoke NDP discovery.
+ var wq waiter.Queue
+ ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
t.Fatal(err)
}
@@ -140,6 +202,20 @@
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
+ for {
+ if ctx.Err() != nil {
+ break
+ }
+ if _, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}}); err == tcpip.ErrNoLinkAddress {
+ // There's something asynchronous going on; yield to let it do its thing.
+ runtime.Gosched()
+ } else if err == nil {
+ break
+ } else {
+ t.Fatal(err)
+ }
+ }
+
stats := make(map[header.ICMPv6Type]int)
for {
select {
diff --git a/tcpip/network/ipv6/ipv6.go b/tcpip/network/ipv6/ipv6.go
index 314f5f0..2458efa 100644
--- a/tcpip/network/ipv6/ipv6.go
+++ b/tcpip/network/ipv6/ipv6.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package ipv6 contains the implementation of the ipv6 network protocol. To use
// it in the networking stack, this package must be added to the project, and
@@ -11,8 +21,6 @@
package ipv6
import (
- "sync/atomic"
-
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
@@ -29,31 +37,29 @@
// maxTotalSize is maximum size that can be encoded in the 16-bit
// PayloadLength field of the ipv6 header.
maxPayloadSize = 0xffff
-)
-type address [header.IPv6AddressSize]byte
+ // defaultIPv6HopLimit is the default hop limit for IPv6 Packets
+ // egressed by Netstack.
+ defaultIPv6HopLimit = 255
+)
type endpoint struct {
nicid tcpip.NICID
id stack.NetworkEndpointID
- address address
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
dispatcher stack.TransportDispatcher
}
+// DefaultTTL is the default hop limit for this endpoint.
func (e *endpoint) DefaultTTL() uint8 {
- return header.IPv6DefaultHopLimit
+ return 255
}
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- mtu := e.linkEP.MTU() - uint32(e.MaxHeaderLength())
- if mtu <= maxPayloadSize {
- return mtu
- }
- return maxPayloadSize
+ return calculateMTU(e.linkEP.MTU())
}
// NICID returns the ID of the NIC this endpoint belongs to.
@@ -66,6 +72,11 @@
return &e.id
}
+// Capabilities implements stack.NetworkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
@@ -73,27 +84,24 @@
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
- length := uint16(hdr.UsedLength())
- if payload != nil {
- length += uint16(len(payload))
- }
+func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
+ length := uint16(hdr.UsedLength() + payload.Size())
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
NextHeader: uint8(protocol),
HopLimit: ttl,
- SrcAddr: r.LocalAddress,
+ SrcAddr: e.id.LocalAddress,
DstAddr: r.RemoteAddress,
})
- atomic.AddUint64(&r.MutableStats().IP.PacketsSent, 1)
+ r.Stats().IP.PacketsSent.Increment()
return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber)
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
h := header.IPv6(vv.First())
if !h.IsValid(vv.Size()) {
return
@@ -101,12 +109,14 @@
vv.TrimFront(header.IPv6MinimumSize)
vv.CapLength(int(h.PayloadLength()))
- p := tcpip.TransportProtocolNumber(h.NextHeader())
+
+ p := h.TransportProtocol()
if p == header.ICMPv6ProtocolNumber {
e.handleICMP(r, vv)
return
}
- atomic.AddUint64(&r.MutableStats().IP.PacketsDelivered, 1)
+
+ r.Stats().IP.PacketsDelivered.Increment()
e.dispatcher.DeliverTransportPacket(r, p, vv)
}
@@ -141,15 +151,13 @@
// NewEndpoint creates a new ipv6 endpoint.
func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
- e := &endpoint{
+ return &endpoint{
nicid: nicid,
+ id: stack.NetworkEndpointID{LocalAddress: addr},
linkEP: linkEP,
linkAddrCache: linkAddrCache,
dispatcher: dispatcher,
- }
- copy(e.address[:], addr)
- e.id = stack.NetworkEndpointID{tcpip.Address(e.address[:])}
- return e, nil
+ }, nil
}
// SetOption implements NetworkProtocol.SetOption.
@@ -157,38 +165,23 @@
return tcpip.ErrUnknownProtocolOption
}
+// Option implements NetworkProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// calculateMTU calculates the network-layer payload MTU based on the link-layer
+// payload mtu.
+func calculateMTU(mtu uint32) uint32 {
+ mtu -= header.IPv6MinimumSize
+ if mtu <= maxPayloadSize {
+ return mtu
+ }
+ return maxPayloadSize
+}
+
func init() {
stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
return &protocol{}
})
}
-
-// LinkLocalAddr computes the default IPv6 link-local address from
-// a link-layer (MAC) address.
-func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address {
- // Convert a 48-bit MAC to an EUI-64 and then prepend the
- // link-local header, FE80::.
- //
- // The conversion is very nearly:
- // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff
- // Note the capital A. The conversion aa->Aa involves a bit flip.
- lladdrb := [16]byte{
- 0: 0xFE,
- 1: 0x80,
- 8: linkAddr[0] ^ 2,
- 9: linkAddr[1],
- 10: linkAddr[2],
- 11: 0xFF,
- 12: 0xFE,
- 13: linkAddr[3],
- 14: linkAddr[4],
- 15: linkAddr[5],
- }
- return tcpip.Address(lladdrb[:])
-}
-
-// SolicitedNodeAddr computes the solicited-node multicast address.
-// This is used for NDP. Described in RFC 4291.
-func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address {
- return "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff" + addr[len(addr)-3:]
-}
diff --git a/tcpip/ports/ports.go b/tcpip/ports/ports.go
index 1f20f92..9c5bdab 100644
--- a/tcpip/ports/ports.go
+++ b/tcpip/ports/ports.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package ports provides PortManager that manages allocating, reserving and releasing ports.
package ports
@@ -14,10 +24,10 @@
)
const (
- // firstEphemeral is the first ephemeral port.
- firstEphemeral uint16 = 16000
+ // FirstEphemeral is the first ephemeral port.
+ FirstEphemeral = 16000
- anyIPAddress = tcpip.Address("")
+ anyIPAddress tcpip.Address = ""
)
type portDescriptor struct {
@@ -63,11 +73,11 @@
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
- count := uint16(math.MaxUint16 - firstEphemeral + 1)
+ count := uint16(math.MaxUint16 - FirstEphemeral + 1)
offset := uint16(rand.Int31n(int32(count)))
for i := uint16(0); i < count; i++ {
- port = firstEphemeral + (offset+i)%count
+ port = FirstEphemeral + (offset+i)%count
ok, err := testPort(port)
if err != nil {
return 0, err
@@ -81,18 +91,37 @@
return 0, tcpip.ErrNoPortAvailable
}
+// IsPortAvailable tests if the given port is available on all given protocols.
+func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.isPortAvailableLocked(networks, transport, addr, port)
+}
+
+func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool {
+ for _, network := range networks {
+ desc := portDescriptor{network, transport, port}
+ if addrs, ok := s.allocatedPorts[desc]; ok {
+ if !addrs.isAvailable(addr) {
+ return false
+ }
+ }
+ }
+ return true
+}
+
// ReservePort marks a port/IP combination as reserved so that it cannot be
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
-func (s *PortManager) ReservePort(network []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) (reservedPort uint16, err *tcpip.Error) {
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) (reservedPort uint16, err *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
// If a port is specified, just try to reserve it for all network
// protocols.
if port != 0 {
- if !s.reserveSpecificPort(network, transport, addr, port) {
+ if !s.reserveSpecificPort(networks, transport, addr, port) {
return 0, tcpip.ErrPortInUse
}
return port, nil
@@ -100,26 +129,19 @@
// A port wasn't specified, so try to find one.
return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- return s.reserveSpecificPort(network, transport, addr, p), nil
+ return s.reserveSpecificPort(networks, transport, addr, p), nil
})
}
// reserveSpecificPort tries to reserve the given port on all given protocols.
-func (s *PortManager) reserveSpecificPort(network []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool {
- // Check that the port is available on all network protocols.
- desc := portDescriptor{0, transport, port}
- for _, n := range network {
- desc.network = n
- if addrs, ok := s.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(addr) {
- return false
- }
- }
+func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool {
+ if !s.isPortAvailableLocked(networks, transport, addr, port) {
+ return false
}
// Reserve port on all network protocols.
- for _, n := range network {
- desc.network = n
+ for _, network := range networks {
+ desc := portDescriptor{network, transport, port}
m, ok := s.allocatedPorts[desc]
if !ok {
m = make(bindAddresses)
@@ -133,31 +155,17 @@
// ReleasePort releases the reservation on a port/IP combination so that it can
// be reserved by other endpoints.
-func (s *PortManager) ReleasePort(network []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) {
+func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) {
s.mu.Lock()
defer s.mu.Unlock()
- for _, n := range network {
- desc := portDescriptor{n, transport, port}
- m := s.allocatedPorts[desc]
- delete(m, addr)
- if len(m) == 0 {
- delete(s.allocatedPorts, desc)
+ for _, network := range networks {
+ desc := portDescriptor{network, transport, port}
+ if m, ok := s.allocatedPorts[desc]; ok {
+ delete(m, addr)
+ if len(m) == 0 {
+ delete(s.allocatedPorts, desc)
+ }
}
}
}
-
-// IsPortReserved tests if the given port is reserved on any of given protocols.
-func (s *PortManager) IsPortReserved(network []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- for _, n := range network {
- desc := portDescriptor{n, transport, port}
- if _, ok := s.allocatedPorts[desc]; ok {
- return true
- }
- }
-
- return false
-}
diff --git a/tcpip/ports/ports_test.go b/tcpip/ports/ports_test.go
index 8ce7a3b..6e75c2e 100644
--- a/tcpip/ports/ports_test.go
+++ b/tcpip/ports/ports_test.go
@@ -1,6 +1,17 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package ports
import (
@@ -67,8 +78,8 @@
if err != test.want {
t.Fatalf("ReservePort(.., .., %s, %d) = %v, want %v", test.ip, test.port, err, test.want)
}
- if test.port == 0 && (gotPort == 0 || gotPort < firstEphemeral) {
- t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, firstEphemeral)
+ if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
+ t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
}
}
@@ -107,17 +118,17 @@
{
name: "only-port-16042-available",
f: func(port uint16) (bool, *tcpip.Error) {
- if port == firstEphemeral+42 {
+ if port == FirstEphemeral+42 {
return true, nil
}
return false, nil
},
- wantPort: firstEphemeral + 42,
+ wantPort: FirstEphemeral + 42,
},
{
name: "only-port-under-16000-available",
f: func(port uint16) (bool, *tcpip.Error) {
- if port < firstEphemeral {
+ if port < FirstEphemeral {
return true, nil
}
return false, nil
diff --git a/tcpip/sample/tun_tcp_connect/main.go b/tcpip/sample/tun_tcp_connect/main.go
index 1391d90..3f2c78e 100644
--- a/tcpip/sample/tun_tcp_connect/main.go
+++ b/tcpip/sample/tun_tcp_connect/main.go
@@ -1,6 +1,18 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
// This sample creates a stack with TCP and IPv4 protocols on top of a TUN
// device, and connects to a peer. Similar to "nc <address> <port>". While the
@@ -68,7 +80,7 @@
v.CapLength(n)
for len(v) > 0 {
- n, err := ep.Write(v, nil)
+ n, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
if err != nil {
fmt.Println("Write failed:", err)
return
@@ -113,7 +125,7 @@
// Create the stack with ipv4 and tcp protocols, then add a tun-based
// NIC and ipv4 address.
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName})
+ s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
mtu, err := rawfile.GetMTU(tunName)
if err != nil {
@@ -125,7 +137,7 @@
log.Fatal(err)
}
- linkID := fdbased.New(fd, mtu, nil)
+ linkID := fdbased.New(&fdbased.Options{FD: fd, MTU: mtu})
if err := s.CreateNIC(1, sniffer.New(linkID)); err != nil {
log.Fatal(err)
}
@@ -183,7 +195,7 @@
// connection from its side.
wq.EventRegister(&waitEntry, waiter.EventIn)
for {
- v, err := ep.Read(nil)
+ v, _, err := ep.Read(nil)
if err != nil {
if err == tcpip.ErrClosedForReceive {
break
diff --git a/tcpip/sample/tun_tcp_echo/main.go b/tcpip/sample/tun_tcp_echo/main.go
index 92989df..3952809 100644
--- a/tcpip/sample/tun_tcp_echo/main.go
+++ b/tcpip/sample/tun_tcp_echo/main.go
@@ -1,6 +1,18 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
// This sample creates a stack with TCP and IPv4 protocols on top of a TUN
// device, and listens on a port. Data received by the server in the accepted
@@ -8,6 +20,7 @@
package main
import (
+ "flag"
"log"
"math/rand"
"net"
@@ -20,6 +33,7 @@
"github.com/google/netstack/tcpip/link/fdbased"
"github.com/google/netstack/tcpip/link/rawfile"
"github.com/google/netstack/tcpip/link/tun"
+ "github.com/google/netstack/tcpip/network/arp"
"github.com/google/netstack/tcpip/network/ipv4"
"github.com/google/netstack/tcpip/network/ipv6"
"github.com/google/netstack/tcpip/stack"
@@ -27,6 +41,9 @@
"github.com/google/netstack/waiter"
)
+var tap = flag.Bool("tap", false, "use tap istead of tun")
+var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device")
+
func echo(wq *waiter.Queue, ep tcpip.Endpoint) {
defer ep.Close()
@@ -37,7 +54,7 @@
defer wq.EventUnregister(&waitEntry)
for {
- v, err := ep.Read(nil)
+ v, _, err := ep.Read(nil)
if err != nil {
if err == tcpip.ErrWouldBlock {
<-notifyCh
@@ -47,21 +64,28 @@
return
}
- ep.Write(v, nil)
+ ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
}
}
func main() {
- if len(os.Args) != 4 {
+ flag.Parse()
+ if len(flag.Args()) != 3 {
log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-address> <local-port>")
}
- tunName := os.Args[1]
- addrName := os.Args[2]
- portName := os.Args[3]
+ tunName := flag.Arg(0)
+ addrName := flag.Arg(1)
+ portName := flag.Arg(2)
rand.Seed(time.Now().UnixNano())
+ // Parse the mac address.
+ maddr, err := net.ParseMAC(*mac)
+ if err != nil {
+ log.Fatalf("Bad MAC address: %v", *mac)
+ }
+
// Parse the IP address. Support both ipv4 and ipv6.
parsedAddr := net.ParseIP(addrName)
if parsedAddr == nil {
@@ -87,19 +111,29 @@
// Create the stack with ip and tcp protocols, then add a tun-based
// NIC and address.
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName})
+ s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName, arp.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
mtu, err := rawfile.GetMTU(tunName)
if err != nil {
log.Fatal(err)
}
- fd, err := tun.Open(tunName)
+ var fd int
+ if *tap {
+ fd, err = tun.OpenTAP(tunName)
+ } else {
+ fd, err = tun.Open(tunName)
+ }
if err != nil {
log.Fatal(err)
}
- linkID := fdbased.New(fd, mtu, nil)
+ linkID := fdbased.New(&fdbased.Options{
+ FD: fd,
+ MTU: mtu,
+ EthernetHeader: *tap,
+ Address: tcpip.LinkAddress(maddr),
+ })
if err := s.CreateNIC(1, linkID); err != nil {
log.Fatal(err)
}
@@ -108,6 +142,10 @@
log.Fatal(err)
}
+ if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ log.Fatal(err)
+ }
+
// Add default route.
s.SetRouteTable([]tcpip.Route{
{
diff --git a/tcpip/seqnum/seqnum.go b/tcpip/seqnum/seqnum.go
index f689be9..e507d02 100644
--- a/tcpip/seqnum/seqnum.go
+++ b/tcpip/seqnum/seqnum.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package seqnum defines the types and methods for TCP sequence numbers such
// that they fit in 32-bit words and work properly when overflows occur.
diff --git a/tcpip/stack/bridge.go b/tcpip/stack/bridge.go
deleted file mode 100644
index 8aeba1c..0000000
--- a/tcpip/stack/bridge.go
+++ /dev/null
@@ -1,74 +0,0 @@
-package stack
-
-import (
- "sync/atomic"
-
- "github.com/google/netstack/tcpip"
- "github.com/google/netstack/tcpip/buffer"
-)
-
-// Bridge wraps the NetworkDispatcher of each NIC's linkEP with one
-// bridging dispatcher that delivers packets sent to one to all NICs,
-// returning the bridging dispatcher.
-func (s *Stack) Bridge(nicIDs []tcpip.NICID) (*bridge, *tcpip.Error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- nics := make([]*NIC, len(nicIDs))
- b := bridge{nics}
-
- for i, id := range nicIDs {
- nic, ok := s.nics[id]
- if !ok {
- return nil, tcpip.ErrUnknownNICID
- }
- if _, ok := nic.linkEP.(BufferWritingLinkEndpoint); !ok {
- // TODO(stijlist): port all link endpoints in netstack to implement WriteBuffer
- return nil, tcpip.ErrBadLinkEndpoint
- }
-
- nics[i] = nic
- }
-
- return &b, nil
-}
-
-// A no-frills bridge. Perlman chapter 3.
-type bridge struct {
- bridged []*NIC
-}
-
-func (b bridge) Enable() {
- for _, nic := range b.bridged {
- nic.linkEP.Attach(b)
- }
-}
-
-func (b bridge) DeliverNetworkPacket(rxEP LinkEndpoint, dstLinkAddr, srcLinkAddr tcpip.LinkAddress, protoNum tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
- for _, nic := range b.bridged {
- protocol, ok := nic.stack.networkProtocols[protoNum]
- if !ok {
- atomic.AddUint64(&nic.stack.stats.UnknownProtocolRcvdPackets, 1)
- continue
- }
- if len(vv.First()) < protocol.MinimumPacketSize() {
- atomic.AddUint64(&nic.stack.stats.MalformedRcvdPackets, 1)
- continue
- }
-
- // TODO: if nic.linkEP.LinkAddress() == dstLinkAddr || multicast
- if nic.linkEP == rxEP {
- // Packet was intended for this NIC, delegate to default NetworkDispatcher.
- nic.DeliverNetworkPacket(rxEP, dstLinkAddr, srcLinkAddr, protoNum, vv)
- } else {
- src, dst := protocol.ParseAddresses(vv.First())
- r := makeRoute(protoNum, src, dst, nic.primaryEndpoint(protoNum))
- // Preserve the link-layer source address (usually, these would be set by the link layer in WritePacket)
- r.LocalLinkAddress = srcLinkAddr
- r.RemoteLinkAddress = dstLinkAddr
- if ep, ok := nic.linkEP.(BufferWritingLinkEndpoint); ok {
- ep.WriteBuffer(&r, vv, protoNum)
- }
- }
- }
-}
diff --git a/tcpip/stack/bridge_test.go b/tcpip/stack/bridge_test.go
deleted file mode 100644
index 7cdc863..0000000
--- a/tcpip/stack/bridge_test.go
+++ /dev/null
@@ -1,219 +0,0 @@
-package stack
-
-import (
- "testing"
- "time"
-
- "github.com/google/netstack/tcpip"
- "github.com/google/netstack/tcpip/buffer"
- "github.com/google/netstack/tcpip/header"
- "github.com/google/netstack/tcpip/link/bufwritingchannel"
- "github.com/google/netstack/tcpip/link/channel"
- "github.com/google/netstack/tcpip/network/ipv4"
- "github.com/google/netstack/tcpip/stack"
- "github.com/google/netstack/tcpip/transport/tcp"
- "github.com/google/netstack/tcpip/transport/udp"
- "github.com/google/netstack/waiter"
-)
-
-func TestOneWayBridgeSeparateStacks(t *testing.T) {
- s1, s1eps, s1nics, err := newStack([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, 1)
- if err != nil {
- t.Fatalf("newStack error: %s", err)
- }
- s1EP := s1eps[0]
- s1.EnableNIC(s1nics[0])
- s1.AddAddress(s1nics[0], header.IPv4ProtocolNumber, tcpip.Parse("192.168.42.10"))
- s1.SetRouteTable([]tcpip.Route{
- {
- Destination: tcpip.Parse("10.0.0.1"),
- Mask: tcpip.Parse("255.255.255.255"),
- NIC: s1nics[0],
- },
- })
- var wq1 waiter.Queue
- txEP1, err := s1.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq1)
-
- s2, s2eps, s2nics, err := newStack([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, 1)
- s2EP := s2eps[0]
- s2.EnableNIC(s2nics[0])
- s2.AddAddress(s2nics[0], header.IPv4ProtocolNumber, tcpip.Parse("10.0.0.1"))
- var wq2 waiter.Queue
- txEP2, err := s2.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq2)
- err = txEP2.Bind(tcpip.FullAddress{Addr: tcpip.Parse("10.0.0.1"), Port: 8080}, nil)
- if err != nil {
- t.Fatalf("error in bind: %s", err)
- }
-
- sb, eps, nicIDs, err := newStack([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, 2)
- bEP1 := eps[0]
- bEP2 := eps[1]
-
- b, tcpipErr := sb.Bridge(nicIDs)
- if tcpipErr != nil {
- t.Fatalf("failed during bridge setup: %s", err)
- }
- b.Enable()
-
- // bEP1 and bEP2 don't need to be linked, since they're bridged.
- go link(s1EP, bEP1)
- go link(bEP2, s2EP)
-
- addr := tcpip.FullAddress{Addr: tcpip.Parse("10.0.0.1"), Port: 8080}
- _, err = txEP1.Write(buffer.View("hello"), &addr)
- if err != nil {
- t.Fatalf("failed to write: %s \n%+v", err, s1.Stats())
- }
-
- // TODO(stijlist): use waitqueue from txEP2 instead of sleeping
- <-time.After(10 * time.Millisecond)
-
- recvd, err := txEP2.Read(&addr)
- if err != nil {
- t.Fatalf("failed to read: %s\n%+v", err, s2.Stats())
- }
-
- payload := string(recvd)
- if payload != "hello" {
- t.Errorf("want hello, got %s", payload)
- }
-}
-
-func TestTwoWayBridgeSeparateStacks(t *testing.T) {
- s1, s1eps, s1nics, err := newStack([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, 1)
- if err != nil {
- t.Fatalf("newStack error: %s", err)
- }
- s1EP := s1eps[0]
- s1.EnableNIC(s1nics[0])
- s1.AddAddress(s1nics[0], header.IPv4ProtocolNumber, tcpip.Parse("192.168.42.10"))
- s1.SetRouteTable([]tcpip.Route{
- {
- Destination: tcpip.Parse("10.0.0.1"),
- Mask: tcpip.Parse("255.255.255.255"),
- NIC: s1nics[0],
- },
- })
-
- var wq1 waiter.Queue
- txEP1, err := s1.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq1)
- if err != nil {
- t.Fatalf("could not create endpoint in s1: %s", err)
- }
-
- s2, s2eps, s2nics, err := newStack([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, 1)
- s2EP := s2eps[0]
- s2.EnableNIC(s2nics[0])
- s2.AddAddress(s2nics[0], header.IPv4ProtocolNumber, tcpip.Parse("10.0.0.1"))
- s2.SetRouteTable([]tcpip.Route{
- {
- Destination: tcpip.Parse("192.168.42.10"),
- Mask: tcpip.Parse("255.255.255.255"),
- NIC: s2nics[0],
- },
- })
-
- var wq2 waiter.Queue
- txEP2, err := s2.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq2)
- if err != nil {
- t.Fatalf("could not create endpoint in s2: %s", err)
- }
-
- err = txEP2.Bind(tcpip.FullAddress{Addr: tcpip.Parse("10.0.0.1"), Port: 8080}, nil)
- if err != nil {
- t.Fatalf("error in bind: %s", err)
- }
-
- err = txEP2.Listen(1)
- if err != nil {
- t.Fatalf("error in listen: %s", err)
- }
-
- sb, eps, nicIDs, err := newStack([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, 2)
- bEP1 := eps[0]
- bEP2 := eps[1]
-
- b, tcpipErr := sb.Bridge(nicIDs)
- if tcpipErr != nil {
- t.Fatalf("failed during bridge setup: %s", err)
- }
- b.Enable()
-
- go link(s1EP, bEP1)
- go link(bEP1, s1EP)
-
- go link(s2EP, bEP2)
- go link(bEP2, s2EP)
-
- addr := tcpip.FullAddress{Addr: tcpip.Parse("10.0.0.1"), Port: 8080}
- err = txEP1.Connect(addr)
- if err != tcpip.ErrConnectStarted {
- t.Fatalf("failed to connect: %s \n%+v", err, s1.Stats())
- }
-
- // TODO(stijlist): use waitqueue returned from accept instead of sleeping
- <-time.After(10 * time.Millisecond)
-
- readingEP, _, err := txEP2.Accept()
-
- _, err = txEP1.Write(buffer.View("hello"), &addr)
- if err != nil {
- t.Fatalf("failed to write: %s \n%+v", err, s1.Stats())
- }
-
- <-time.After(10 * time.Millisecond)
-
- recvd, err := readingEP.Read(&addr)
- if err != nil {
- t.Fatalf("failed to read: %s\ns1 stats: %+v\n\ns2 stats: %+v", err, s1.Stats(), s2.Stats())
- }
-
- payload := string(recvd)
- if payload != "hello" {
- t.Errorf("want hello, got %s", payload)
- }
-}
-
-// Loop forever, injecting `a`'s packets into `b`.
-func link(a, b *bufwritingchannel.Endpoint) {
- for x := range a.C {
- b.Inject(unpacketInfo(x))
- }
-}
-
-var ni uint = 0
-
-func newStack(netProtos []string, transProtos []string, numEndpoints int) (s *stack.Stack, eps []*bufwritingchannel.Endpoint, nicIDs []tcpip.NICID, err *tcpip.Error) {
- s = stack.New(netProtos, transProtos)
- for i := 0; i < numEndpoints; i++ {
- id, ep := bufwritingchannel.New(1, 100, newLinkAddress())
- nicid := tcpip.NICID(ni)
- ni++
- err = s.CreateDisabledNIC(nicid, id)
- if err != nil {
- return
- }
- eps = append(eps, ep)
- nicIDs = append(nicIDs, nicid)
- }
- return
-}
-
-var li byte = 0
-
-func newLinkAddress() tcpip.LinkAddress {
- l := tcpip.LinkAddress([]byte{li, li, li, li, li, li, li})
- li++
- return l
-}
-
-func unpacketInfo(p channel.PacketInfo) (tcpip.NetworkProtocolNumber, *buffer.VectorisedView) {
- n := p.Proto
- var vv buffer.VectorisedView
- if p.Header != nil {
- vv = buffer.NewVectorisedView(len(p.Header)+len(p.Payload), []buffer.View{p.Header, p.Payload})
- } else {
- vv = buffer.NewVectorisedView(len(p.Payload), []buffer.View{p.Payload})
- }
- return n, &vv
-}
diff --git a/tcpip/stack/linkaddrcache.go b/tcpip/stack/linkaddrcache.go
index 73caf3c..c543230 100644
--- a/tcpip/stack/linkaddrcache.go
+++ b/tcpip/stack/linkaddrcache.go
@@ -1,14 +1,25 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package stack
import (
- "context"
+ "fmt"
"sync"
"time"
+ "github.com/google/netstack/sleep"
"github.com/google/netstack/tcpip"
)
@@ -17,104 +28,296 @@
// linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses.
//
// The entries are stored in a ring buffer, oldest entry replaced first.
+//
+// This struct is safe for concurrent use.
type linkAddrCache struct {
+ // ageLimit is how long a cache entry is valid for.
ageLimit time.Duration
- mu sync.RWMutex
+ // resolutionTimeout is the amount of time to wait for a link request to
+ // resolve an address.
+ resolutionTimeout time.Duration
+
+ // resolutionAttempts is the number of times an address is attempted to be
+ // resolved before failing.
+ resolutionAttempts int
+
+ mu sync.Mutex
cache map[tcpip.FullAddress]*linkAddrEntry
next int // array index of next available entry
entries [linkAddrCacheSize]linkAddrEntry
- waiters map[tcpip.FullAddress]map[chan tcpip.LinkAddress]struct{}
+}
+
+// entryState controls the state of a single entry in the cache.
+type entryState int
+
+const (
+ // incomplete means that there is an outstanding request to resolve the
+ // address. This is the initial state.
+ incomplete entryState = iota
+ // ready means that the address has been resolved and can be used.
+ ready
+ // failed means that address resolution timed out and the address
+ // could not be resolved.
+ failed
+ // expired means that the cache entry has expired and the address must be
+ // resolved again.
+ expired
+)
+
+// String implements Stringer.
+func (s entryState) String() string {
+ switch s {
+ case incomplete:
+ return "incomplete"
+ case ready:
+ return "ready"
+ case failed:
+ return "failed"
+ case expired:
+ return "expired"
+ default:
+ return fmt.Sprintf("invalid entryState: %d", s)
+ }
}
// A linkAddrEntry is an entry in the linkAddrCache.
+// This struct is thread-compatible.
type linkAddrEntry struct {
addr tcpip.FullAddress
linkAddr tcpip.LinkAddress
expiration time.Time
+ s entryState
+
+ // wakers is a set of waiters for address resolution result. Anytime
+ // state transitions out of 'incomplete' these waiters are notified.
+ wakers map[*sleep.Waker]struct{}
+
+ cancel chan struct{}
}
-func (c *linkAddrCache) valid(e *linkAddrEntry) bool {
- return time.Now().Before(e.expiration)
+func (e *linkAddrEntry) state() entryState {
+ if e.s != expired && time.Now().After(e.expiration) {
+ // Force the transition to ensure waiters are notified.
+ e.changeState(expired)
+ }
+ return e.s
+}
+
+func (e *linkAddrEntry) changeState(ns entryState) {
+ if e.s == ns {
+ return
+ }
+
+ // Validate state transition.
+ switch e.s {
+ case incomplete:
+ // All transitions are valid.
+ case ready, failed:
+ if ns != expired {
+ panic(fmt.Sprintf("invalid state transition from %v to %v", e.s, ns))
+ }
+ case expired:
+ // Terminal state.
+ panic(fmt.Sprintf("invalid state transition from %v to %v", e.s, ns))
+ default:
+ panic(fmt.Sprintf("invalid state: %v", e.s))
+ }
+
+ // Notify whoever is waiting on address resolution when transitioning
+ // out of 'incomplete'.
+ if e.s == incomplete {
+ for w := range e.wakers {
+ w.Assert()
+ }
+ e.wakers = nil
+ }
+ e.s = ns
+}
+
+func (e *linkAddrEntry) addWaker(w *sleep.Waker) {
+ e.wakers[w] = struct{}{}
+}
+
+func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
+ delete(e.wakers, w)
}
// add adds a k -> v mapping to the cache.
func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
c.mu.Lock()
defer c.mu.Unlock()
+
entry := c.cache[k]
- if entry != nil && entry.linkAddr == v && c.valid(entry) {
- return // Keep existing entry.
+ if entry != nil {
+ s := entry.state()
+ if s != expired && entry.linkAddr == v {
+ // Disregard repeated calls.
+ return
+ }
+ // Check if entry is waiting for address resolution.
+ if s == incomplete {
+ entry.linkAddr = v
+ } else {
+ // Otherwise create a new entry to replace it.
+ entry = c.makeAndAddEntry(k, v)
+ }
+ } else {
+ entry = c.makeAndAddEntry(k, v)
}
- // Take next entry.
- entry = &c.entries[c.next]
+
+ entry.changeState(ready)
+}
+
+// makeAndAddEntry is a helper function to create and add a new
+// entry to the cache map and evict older entry as needed.
+func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry {
+ // Take over the next entry.
+ entry := &c.entries[c.next]
if c.cache[entry.addr] == entry {
delete(c.cache, entry.addr)
}
+
+ // Mark the soon-to-be-replaced entry as expired, just in case there is
+ // someone waiting for address resolution on it.
+ entry.changeState(expired)
+ if entry.cancel != nil {
+ entry.cancel <- struct{}{}
+ }
+
*entry = linkAddrEntry{
addr: k,
linkAddr: v,
expiration: time.Now().Add(c.ageLimit),
+ wakers: make(map[*sleep.Waker]struct{}),
+ cancel: make(chan struct{}, 1),
}
+
c.cache[k] = entry
c.next++
if c.next == len(c.entries) {
c.next = 0
}
- for ch := range c.waiters[k] {
- ch <- v
- }
+ return entry
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.FullAddress, timeout time.Duration) (linkAddr tcpip.LinkAddress) {
- c.mu.RLock()
- if entry, found := c.cache[k]; found && c.valid(entry) {
- linkAddr = entry.linkAddr
- }
- c.mu.RUnlock()
- if linkAddr != "" || timeout == 0 {
- return linkAddr
- }
- c.mu.Lock()
- if entry, found := c.cache[k]; found && c.valid(entry) { // check again
- c.mu.Unlock()
- return entry.linkAddr
- }
- ch := make(chan tcpip.LinkAddress, 1)
- m := c.waiters[k]
- if m == nil {
- m = make(map[chan tcpip.LinkAddress]struct{})
- c.waiters[k] = m
- }
- m[ch] = struct{}{}
- c.mu.Unlock()
-
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer func() {
- cancel()
- c.mu.Lock()
- m := c.waiters[k]
- delete(m, ch)
- if len(m) == 0 {
- delete(c.waiters, k)
+func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) {
+ if linkRes != nil {
+ if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
+ return addr, nil
}
- c.mu.Unlock()
- }()
+ }
- select {
- case linkAddr := <-ch:
- return linkAddr
- case <-ctx.Done():
- return ""
+ c.mu.Lock()
+ entry := c.cache[k]
+ if entry == nil || entry.state() == expired {
+ c.mu.Unlock()
+ if linkRes == nil {
+ return "", tcpip.ErrNoLinkAddress
+ }
+ c.startAddressResolution(k, linkRes, localAddr, linkEP, waker)
+ return "", tcpip.ErrWouldBlock
+ }
+ defer c.mu.Unlock()
+
+ switch s := entry.state(); s {
+ case expired:
+ // It's possible that entry expired between state() call above and here
+ // in that case it's safe to consider it ready.
+ fallthrough
+ case ready:
+ return entry.linkAddr, nil
+ case failed:
+ return "", tcpip.ErrNoLinkAddress
+ case incomplete:
+ // Address resolution is still in progress.
+ entry.addWaker(waker)
+ return "", tcpip.ErrWouldBlock
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %d", s))
}
}
-func newLinkAddrCache(ageLimit time.Duration) *linkAddrCache {
- c := &linkAddrCache{
- ageLimit: ageLimit,
- cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize),
- waiters: make(map[tcpip.FullAddress]map[chan tcpip.LinkAddress]struct{}),
+// removeWaker removes a waker previously added through get().
+func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if entry := c.cache[k]; entry != nil {
+ entry.removeWaker(waker)
}
- return c
+}
+
+func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ // Look up again with lock held to ensure entry wasn't added by someone else.
+ if e := c.cache[k]; e != nil && e.state() != expired {
+ return
+ }
+
+ // Add 'incomplete' entry in the cache to mark that resolution is in progress.
+ e := c.makeAndAddEntry(k, "")
+ e.addWaker(waker)
+
+ go func() {
+ for i := 0; ; i++ {
+ // Send link request, then wait for the timeout limit and check
+ // whether the request succeeded.
+ linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
+ c.mu.Lock()
+ cancel := e.cancel
+ c.mu.Unlock()
+
+ select {
+ case <-time.After(c.resolutionTimeout):
+ if stop := c.checkLinkRequest(k, i); stop {
+ return
+ }
+ case <-cancel:
+ return
+ }
+ }
+ }()
+}
+
+// checkLinkRequest checks whether previous attempt to resolve address has succeeded
+// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
+// can stop, false if another request should be sent.
+func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ entry, ok := c.cache[k]
+ if !ok {
+ // Entry was evicted from the cache.
+ return true
+ }
+
+ switch s := entry.state(); s {
+ case ready, failed, expired:
+ // Entry was made ready by resolver or failed. Either way we're done.
+ return true
+ case incomplete:
+ if attempt+1 >= c.resolutionAttempts {
+ // Max number of retries reached, mark entry as failed.
+ entry.changeState(failed)
+ return true
+ }
+ // No response yet, need to send another ARP request.
+ return false
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %d", s))
+ }
+}
+
+func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
+ return &linkAddrCache{
+ ageLimit: ageLimit,
+ resolutionTimeout: resolutionTimeout,
+ resolutionAttempts: resolutionAttempts,
+ cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize),
+ }
}
diff --git a/tcpip/stack/linkaddrcache_test.go b/tcpip/stack/linkaddrcache_test.go
index 4ff71cf..bae66b6 100644
--- a/tcpip/stack/linkaddrcache_test.go
+++ b/tcpip/stack/linkaddrcache_test.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package stack
@@ -10,6 +20,7 @@
"testing"
"time"
+ "github.com/google/netstack/sleep"
"github.com/google/netstack/tcpip"
)
@@ -20,6 +31,55 @@
var testaddrs []testaddr
+type testLinkAddressResolver struct {
+ cache *linkAddrCache
+ delay time.Duration
+}
+
+func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+ go func() {
+ if r.delay > 0 {
+ time.Sleep(r.delay)
+ }
+ r.fakeRequest(addr)
+ }()
+ return nil
+}
+
+func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
+ for _, ta := range testaddrs {
+ if ta.addr.Addr == addr {
+ r.cache.add(ta.addr, ta.linkAddr)
+ break
+ }
+ }
+}
+
+func (*testLinkAddressResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == "broadcast" {
+ return "mac_broadcast", true
+ }
+ return "", false
+}
+
+func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return 1
+}
+
+func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) {
+ w := sleep.Waker{}
+ s := sleep.Sleeper{}
+ s.AddWaker(&w, 123)
+ defer s.Done()
+
+ for {
+ if got, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
+ return got, err
+ }
+ s.Fetch(true)
+ }
+}
+
func init() {
for i := 0; i < 4*linkAddrCacheSize; i++ {
addr := fmt.Sprintf("Addr%06d", i)
@@ -31,32 +91,40 @@
}
func TestCacheOverflow(t *testing.T) {
- c := newLinkAddrCache(1<<63 - 1)
+ c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
for i := len(testaddrs) - 1; i >= 0; i-- {
e := testaddrs[i]
c.add(e.addr, e.linkAddr)
- if got, want := c.get(e.addr, 0), e.linkAddr; got != want {
- t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, want)
+ got, err := c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
}
}
// Expect to find at least half of the most recent entries.
for i := 0; i < linkAddrCacheSize/2; i++ {
e := testaddrs[i]
- if got, want := c.get(e.addr, 0), e.linkAddr; got != want {
- t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, want)
+ got, err := c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
}
}
// The earliest entries should no longer be in the cache.
for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- {
e := testaddrs[i]
- if got := c.get(e.addr, 0); got != "" {
- t.Errorf("check %d, c.get(%q)=%q, want no entry", i, string(e.addr.Addr), got)
+ if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
}
}
}
func TestCacheConcurrent(t *testing.T) {
- c := newLinkAddrCache(1<<63 - 1)
+ c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
var wg sync.WaitGroup
for r := 0; r < 16; r++ {
@@ -64,7 +132,7 @@
go func() {
for _, e := range testaddrs {
c.add(e.addr, e.linkAddr)
- c.get(e.addr, 0) // make work for gotsan
+ c.get(e.addr, nil, "", nil, nil) // make work for gotsan
}
wg.Done()
}()
@@ -75,36 +143,124 @@
// can fit in the cache, so our eviction strategy requires that
// the last entry be present and the first be missing.
e := testaddrs[len(testaddrs)-1]
- if got, want := c.get(e.addr, 0), e.linkAddr; got != want {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, want)
+ got, err := c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
}
+ if got != e.linkAddr {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ }
+
e = testaddrs[0]
- if got := c.get(e.addr, 0); got != "" {
- t.Errorf("c.get(%q)=%q, want no entry", string(e.addr.Addr), got)
+ if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
}
func TestCacheAgeLimit(t *testing.T) {
- c := newLinkAddrCache(1 * time.Millisecond)
+ c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
e := testaddrs[0]
c.add(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
- if got := c.get(e.addr, 0); got != "" {
- t.Errorf("c.get(%q)=%q, want no stale entry", string(e.addr.Addr), got)
+ if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
}
func TestCacheReplace(t *testing.T) {
- c := newLinkAddrCache(1 * time.Millisecond)
+ c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
e := testaddrs[0]
l2 := e.linkAddr + "2"
c.add(e.addr, e.linkAddr)
- if got := c.get(e.addr, 0); got != e.linkAddr {
+ got, err := c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
+
c.add(e.addr, l2)
- if got := c.get(e.addr, 0); got != l2 {
+ got, err = c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != l2 {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2)
}
+}
+func TestCacheResolution(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1)
+ linkRes := &testLinkAddressResolver{cache: c}
+ for i, ta := range testaddrs {
+ got, err := getBlocking(c, ta.addr, linkRes)
+ if err != nil {
+ t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err)
+ }
+ if got != ta.linkAddr {
+ t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr)
+ }
+ }
+
+ // Check that after resolved, address stays in the cache and never returns WouldBlock.
+ for i := 0; i < 10; i++ {
+ e := testaddrs[len(testaddrs)-1]
+ got, err := c.get(e.addr, linkRes, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ }
+ }
+}
+
+func TestCacheResolutionFailed(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
+ linkRes := &testLinkAddressResolver{cache: c}
+
+ // First, sanity check that resolution is working...
+ e := testaddrs[0]
+ got, err := getBlocking(c, e.addr, linkRes)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ }
+
+ e.addr.Addr += "2"
+ if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ }
+}
+
+func TestCacheResolutionTimeout(t *testing.T) {
+ resolverDelay := 50 * time.Millisecond
+ expiration := resolverDelay / 2
+ c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
+ linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
+
+ e := testaddrs[0]
+ if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ }
+}
+
+// TestStaticResolution checks that static link addresses are resolved immediately and don't
+// send resolution requests.
+func TestStaticResolution(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, time.Millisecond, 1)
+ linkRes := &testLinkAddressResolver{cache: c, delay: time.Minute}
+
+ addr := tcpip.Address("broadcast")
+ want := tcpip.LinkAddress("mac_broadcast")
+ got, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(addr), string(got), err)
+ }
+ if got != want {
+ t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want))
+ }
}
diff --git a/tcpip/stack/nic.go b/tcpip/stack/nic.go
index 9be5cc6..e576d1f 100644
--- a/tcpip/stack/nic.go
+++ b/tcpip/stack/nic.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package stack
@@ -20,18 +30,20 @@
type NIC struct {
stack *Stack
id tcpip.NICID
+ name string
linkEP LinkEndpoint
demux *transportDemuxer
mu sync.RWMutex
+ spoofing bool
promiscuous bool
primary map[tcpip.NetworkProtocolNumber]*ilist.List
endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
subnets []tcpip.Subnet
}
-// PrimaryEndpointBehavior specifies how a new address should behave as a primary endpoint.
+// PrimaryEndpointBehavior is an enumeration of an endpoint's primacy behavior.
type PrimaryEndpointBehavior int
const (
@@ -39,19 +51,22 @@
// endpoint for new connections with no local address. This is the
// default when calling NIC.AddAddress.
CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
+
// FirstPrimaryEndpoint indicates the endpoint should be the first
// primary endpoint considered. If there are multiple endpoints with
// this behavior, the most recently-added one will be first.
FirstPrimaryEndpoint
+
// NeverPrimaryEndpoint indicates the endpoint should never be a
// primary endpoint.
NeverPrimaryEndpoint
)
-func newNIC(stack *Stack, id tcpip.NICID, ep LinkEndpoint) *NIC {
+func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC {
return &NIC{
stack: stack,
id: id,
+ name: name,
linkEP: ep,
demux: newTransportDemuxer(stack),
primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
@@ -72,18 +87,28 @@
n.mu.Unlock()
}
-// Get the primary network endpoint, if there is one; otherwise pick an arbitrary endpoint from the NIC's endpoints.
-func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet) {
+func (n *NIC) isPromiscuousMode() bool {
+ n.mu.RLock()
+ rv := n.promiscuous
+ n.mu.RUnlock()
+ return rv
+}
+
+// setSpoofing enables or disables address spoofing.
+func (n *NIC) setSpoofing(enable bool) {
+ n.mu.Lock()
+ n.spoofing = enable
+ n.mu.Unlock()
+}
+
+func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) {
n.mu.RLock()
defer n.mu.RUnlock()
- var address tcpip.Address
- var subnet tcpip.Subnet
+ var r *referencedNetworkEndpoint
// Check for a primary endpoint.
- var r *referencedNetworkEndpoint
- list := n.primary[protocol]
- if list != nil {
+ if list, ok := n.primary[protocol]; ok {
for e := list.Front(); e != nil; e = e.Next() {
ref := e.(*referencedNetworkEndpoint)
if ref.holdsInsertRef && ref.tryIncRef() {
@@ -97,28 +122,29 @@
// If no primary endpoints then check for other endpoints.
if r == nil {
for _, ref := range n.endpoints {
- if ref != nil && ref.holdsInsertRef && ref.tryIncRef() {
+ if ref.holdsInsertRef && ref.tryIncRef() {
r = ref
break
}
}
}
- if r != nil {
- address = r.ep.ID().LocalAddress
- r.decRef()
+ if r == nil {
+ return "", tcpip.Subnet{}, tcpip.ErrNoLinkAddress
}
- // Find the least-constrained matching subnet for the address, if one exists, and return it
- if address != "" {
- for _, s := range n.subnets {
- if s.Contains(address) && !subnet.Contains(s.ID()) {
- subnet = s
- }
+ address := r.ep.ID().LocalAddress
+ r.decRef()
+
+ // Find the least-constrained matching subnet for the address, if one
+ // exists, and return it.
+ var subnet tcpip.Subnet
+ for _, s := range n.subnets {
+ if s.Contains(address) && !subnet.Contains(s.ID()) {
+ subnet = s
}
}
-
- return address, subnet
+ return address, subnet, nil
}
// primaryEndpoint returns the primary endpoint of n for the given network
@@ -134,6 +160,11 @@
for e := list.Front(); e != nil; e = e.Next() {
r := e.(*referencedNetworkEndpoint)
+ // TODO: allow broadcast address when SO_BROADCAST is set.
+ switch r.ep.ID().LocalAddress {
+ case header.IPv4Broadcast, header.IPv4Any:
+ continue
+ }
if r.tryIncRef() {
return r
}
@@ -143,15 +174,33 @@
}
// findEndpoint finds the endpoint, if any, with the given address.
-func (n *NIC) findEndpoint(address tcpip.Address) *referencedNetworkEndpoint {
- n.mu.RLock()
- defer n.mu.RUnlock()
+func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
+ id := NetworkEndpointID{address}
- ref := n.endpoints[NetworkEndpointID{address}]
- if ref == nil || !ref.tryIncRef() {
- return nil
+ n.mu.RLock()
+ ref := n.endpoints[id]
+ if ref != nil && !ref.tryIncRef() {
+ ref = nil
+ }
+ spoofing := n.spoofing
+ n.mu.RUnlock()
+
+ if ref != nil || !spoofing {
+ return ref
}
+ // Try again with the lock in exclusive mode. If we still can't get the
+ // endpoint, create a new "temporary" endpoint. It will only exist while
+ // there's a route through it.
+ n.mu.Lock()
+ ref = n.endpoints[id]
+ if ref == nil || !ref.tryIncRef() {
+ ref, _ = n.addAddressLocked(protocol, address, peb, true)
+ if ref != nil {
+ ref.holdsInsertRef = false
+ }
+ }
+ n.mu.Unlock()
return ref
}
@@ -183,10 +232,12 @@
protocol: protocol,
holdsInsertRef: true,
}
- if linkRes := n.stack.linkAddrResolvers[protocol]; linkRes != nil {
- ref.linkRes = linkRes
- ref.linkCache = n.stack
- ref.linkEP = n.linkEP
+
+ // Set up cache if link address resolution exists for this protocol.
+ if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
+ if _, ok := n.stack.linkAddrResolvers[protocol]; ok {
+ ref.linkCache = n.stack
+ }
}
n.endpoints[id] = ref
@@ -213,7 +264,8 @@
return n.AddAddressWithOptions(protocol, addr, CanBePrimaryEndpoint)
}
-// AddAddressWithOptions is the same as AddAddress, but allows you to specify whether they new endpoint can be primary or not.
+// AddAddressWithOptions is the same as AddAddress, but allows you to specify
+// whether the new endpoint can be primary or not.
func (n *NIC) AddAddressWithOptions(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error {
// Add the endpoint.
n.mu.Lock()
@@ -223,6 +275,20 @@
return err
}
+// Addresses returns the addresses associated with this NIC.
+func (n *NIC) Addresses() []tcpip.ProtocolAddress {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+ addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
+ for nid, ep := range n.endpoints {
+ addrs = append(addrs, tcpip.ProtocolAddress{
+ Protocol: ep.protocol,
+ Address: nid.LocalAddress,
+ })
+ }
+ return addrs
+}
+
// AddSubnet adds a new subnet to n, so that it starts accepting packets
// targeted at the given address and network protocol.
func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
@@ -235,22 +301,21 @@
func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) {
n.mu.Lock()
- var filtered []tcpip.Subnet
+ // Use the same underlying array.
+ tmp := n.subnets[:0]
for _, sub := range n.subnets {
if sub != subnet {
- filtered = append(filtered, sub)
+ tmp = append(tmp, sub)
}
}
+ n.subnets = tmp
- n.subnets = filtered
n.mu.Unlock()
- return
}
+// ContainsSubnet reports whether this NIC contains the given subnet.
func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool {
- subnets := n.Subnets()
-
- for _, s := range subnets {
+ for _, s := range n.Subnets() {
if s == subnet {
return true
}
@@ -326,129 +391,126 @@
// Note that the ownership of the slice backing vv is retained by the caller.
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
-func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, _, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
+func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
- atomic.AddUint64(&n.stack.stats.IP.PacketsDiscarded, 1)
- atomic.AddUint64(&n.stack.stats.UnknownProtocolRcvdPackets, 1)
+ n.stack.stats.UnknownProtocolRcvdPackets.Increment()
return
}
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
- atomic.AddUint64(&n.stack.stats.IP.PacketsReceived, 1)
+ n.stack.stats.IP.PacketsReceived.Increment()
}
if len(vv.First()) < netProto.MinimumPacketSize() {
- atomic.AddUint64(&n.stack.stats.IP.PacketsDiscarded, 1)
- atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1)
+ n.stack.stats.MalformedRcvdPackets.Increment()
return
}
src, dst := netProto.ParseAddresses(vv.First())
+
+ if ref := n.getRef(protocol, dst); ref != nil {
+ r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref)
+ r.RemoteLinkAddress = remoteLinkAddr
+ ref.ep.HandlePacket(&r, vv)
+ ref.decRef()
+ return
+ }
+
+ // This NIC doesn't care about the packet. Find a NIC that cares about the
+ // packet and forward it to the NIC.
+ //
+ // TODO: Should we be forwarding the packet even if promiscuous?
+ if n.stack.Forwarding() {
+ r, err := n.stack.FindRoute(0, "", dst, protocol)
+ if err != nil {
+ n.stack.stats.IP.InvalidAddressesReceived.Increment()
+ return
+ }
+ defer r.Release()
+
+ r.LocalLinkAddress = n.linkEP.LinkAddress()
+ r.RemoteLinkAddress = remoteLinkAddr
+
+ // Found a NIC.
+ n := r.ref.nic
+ n.mu.RLock()
+ ref, ok := n.endpoints[NetworkEndpointID{dst}]
+ n.mu.RUnlock()
+ if ok && ref.tryIncRef() {
+ ref.ep.HandlePacket(&r, vv)
+ ref.decRef()
+ } else {
+ // n doesn't have a destination endpoint.
+ // Send the packet out of n.
+ hdr := buffer.NewPrependableFromView(vv.First())
+ vv.RemoveFirst()
+ n.linkEP.WritePacket(&r, hdr, vv, protocol)
+ }
+ return
+ }
+
+ n.stack.stats.IP.InvalidAddressesReceived.Increment()
+}
+
+func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
id := NetworkEndpointID{dst}
n.mu.RLock()
- ref := n.endpoints[id]
- if ref != nil && !ref.tryIncRef() {
- ref = nil
+ if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
+ n.mu.RUnlock()
+ return ref
}
+
promiscuous := n.promiscuous
- subnets := n.subnets
+ // Check if the packet is for a subnet this NIC cares about.
+ if !promiscuous {
+ for _, sn := range n.subnets {
+ if sn.Contains(dst) {
+ promiscuous = true
+ break
+ }
+ }
+ }
n.mu.RUnlock()
-
- if ref == nil {
- // Check if the packet is for a subnet this NIC cares about.
- if !promiscuous {
- for _, sn := range subnets {
- if sn.Contains(dst) {
- promiscuous = true
- break
- }
- }
- }
- if promiscuous {
- // Try again with the lock in exclusive mode. If we still can't
- // get the endpoint, create a new "temporary" one. It will only
- // exist while there's a route through it.
- n.mu.Lock()
- ref = n.endpoints[id]
- if ref == nil || !ref.tryIncRef() {
- ref, _ = n.addAddressLocked(protocol, dst, CanBePrimaryEndpoint, true)
- if ref != nil {
- ref.holdsInsertRef = false
- }
- }
+ if promiscuous {
+ // Try again with the lock in exclusive mode. If we still can't
+ // get the endpoint, create a new "temporary" one. It will only
+ // exist while there's a route through it.
+ n.mu.Lock()
+ if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
n.mu.Unlock()
+ return ref
+ }
+ ref, err := n.addAddressLocked(protocol, dst, CanBePrimaryEndpoint, true)
+ n.mu.Unlock()
+ if err == nil {
+ ref.holdsInsertRef = false
+ return ref
}
}
- if ref == nil {
- // This NIC doesn't care the packet. Find a NIC that cares about the packet and
- // forward it to the NIC.
- // TODO: Should forward the packet even if 'promiscuous' is enabled?
- if n.stack.Forwarding() {
- r, err := n.stack.FindRoute(0, "", dst, protocol)
- if err != nil {
- // Can't find a NIC.
- atomic.AddUint64(&n.stack.stats.IP.InvalidAddressesReceived, 1)
- return
- }
- defer r.Release()
-
- // Found a NIC.
- n2 := r.ref.nic
- n2.mu.RLock()
- ref := n2.endpoints[id]
- if ref != nil && !ref.tryIncRef() {
- ref = nil
- }
- n2.mu.RUnlock()
-
- r.LocalLinkAddress = n2.linkEP.LinkAddress()
- r.RemoteLinkAddress = remoteLinkAddr
-
- if ref == nil {
- // n2 doesn't have a destination endpoint.
- // Send the packet out of n2.
- if ep, ok := n2.linkEP.(BufferWritingLinkEndpoint); ok {
- ep.WriteBuffer(&r, vv, protocol)
- }
- } else {
- ref.ep.HandlePacket(&r, vv)
- ref.decRef()
- }
- return
- } else {
- atomic.AddUint64(&n.stack.stats.IP.InvalidAddressesReceived, 1)
- return
- }
- }
-
- r := makeRoute(protocol, dst, src, ref)
- r.LocalLinkAddress = linkEP.LinkAddress()
- r.RemoteLinkAddress = remoteLinkAddr
- ref.ep.HandlePacket(&r, vv)
- ref.decRef()
+ return nil
}
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView) {
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
- atomic.AddUint64(&n.stack.stats.UnknownProtocolRcvdPackets, 1)
+ n.stack.stats.UnknownProtocolRcvdPackets.Increment()
return
}
transProto := state.proto
if len(vv.First()) < transProto.MinimumPacketSize() {
- atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1)
+ n.stack.stats.MalformedRcvdPackets.Increment()
return
}
srcPort, dstPort, err := transProto.ParsePorts(vv.First())
if err != nil {
- atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1)
+ n.stack.stats.MalformedRcvdPackets.Increment()
return
}
@@ -470,7 +532,38 @@
// We could not find an appropriate destination for this packet, so
// deliver it to the global handler.
if !transProto.HandleUnknownDestinationPacket(r, id, vv) {
- atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1)
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ }
+}
+
+// DeliverTransportControlPacket delivers control packets to the appropriate
+// transport protocol endpoint.
+func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+ state, ok := n.stack.transportProtocols[trans]
+ if !ok {
+ return
+ }
+
+ transProto := state.proto
+
+ // ICMPv4 only guarantees that 8 bytes of the transport protocol will
+ // be present in the payload. We know that the ports are within the
+ // first 8 bytes for all known transport protocols.
+ if len(vv.First()) < 8 {
+ return
+ }
+
+ srcPort, dstPort, err := transProto.ParsePorts(vv.First())
+ if err != nil {
+ return
+ }
+
+ id := TransportEndpointID{srcPort, local, dstPort, remote}
+ if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ return
+ }
+ if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ return
}
}
@@ -481,13 +574,14 @@
type referencedNetworkEndpoint struct {
ilist.Entry
- refs int32
- ep NetworkEndpoint
- nic *NIC
- protocol tcpip.NetworkProtocolNumber
- linkRes LinkAddressResolver
+ refs int32
+ ep NetworkEndpoint
+ nic *NIC
+ protocol tcpip.NetworkProtocolNumber
+
+ // linkCache is set if link address resolution is enabled for this
+ // protocol. Set to nil otherwise.
linkCache LinkAddressCache
- linkEP LinkEndpoint
// holdsInsertRef is protected by the NIC's mutex. It indicates whether
// the reference count is biased by 1 due to the insertion of the
diff --git a/tcpip/stack/registration.go b/tcpip/stack/registration.go
index 6ec9d22..c6572bf 100644
--- a/tcpip/stack/registration.go
+++ b/tcpip/stack/registration.go
@@ -1,13 +1,23 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package stack
import (
"sync"
- "time"
+ "github.com/google/netstack/sleep"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/waiter"
@@ -21,6 +31,8 @@
}
// TransportEndpointID is the identifier of a transport layer protocol endpoint.
+//
+// +stateify savable
type TransportEndpointID struct {
// LocalPort is the local port associated with the endpoint.
LocalPort uint16
@@ -37,12 +49,26 @@
RemoteAddress tcpip.Address
}
+// ControlType is the type of network control message.
+type ControlType int
+
+// The following are the allowed values for ControlType values.
+const (
+ ControlPacketTooBig ControlType = iota
+ ControlPortUnreachable
+ ControlUnknown
+)
+
// TransportEndpoint is the interface that needs to be implemented by transport
// protocol (e.g., tcp, udp) endpoints that can handle packets.
type TransportEndpoint interface {
// HandlePacket is called by the stack when new packets arrive to
// this transport endpoint.
- HandlePacket(r *Route, id TransportEndpointID, vv *buffer.VectorisedView)
+ HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView)
+
+ // HandleControlPacket is called by the stack when new control (e.g.,
+ // ICMP) packets arrive to this transport endpoint.
+ HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView)
}
// TransportProtocol is the interface that needs to be implemented by transport
@@ -69,27 +95,36 @@
//
// The return value indicates whether the packet was well-formed (for
// stats purposes only).
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv *buffer.VectorisedView) bool
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) bool
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
// provided option value is invalid.
SetOption(option interface{}) *tcpip.Error
+
+ // Option allows retrieving protocol specific option values.
+ // Option returns an error if the option is not supported or the
+ // provided option value is invalid.
+ Option(option interface{}) *tcpip.Error
}
// TransportDispatcher contains the methods used by the network stack to deliver
// packets to the appropriate transport endpoint after it has been handled by
// the network layer.
type TransportDispatcher interface {
- // DeliverTransportPacket delivers the packets to the appropriate
+ // DeliverTransportPacket delivers packets to the appropriate
// transport protocol endpoint.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView)
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView)
+
+ // DeliverTransportControlPacket delivers control packets to the
+ // appropriate transport protocol endpoint.
+ DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView)
}
// NetworkEndpoint is the interface that needs to be implemented by endpoints
// of network layer protocols (e.g., ipv4, ipv6).
type NetworkEndpoint interface {
- // Default TTL is the default time-to-live value (or hop limit, in ipv6)
+ // DefaultTTL is the default time-to-live value (or hop limit, in ipv6)
// for this endpoint.
DefaultTTL() uint8
@@ -98,6 +133,10 @@
// minus the network endpoint max header length.
MTU() uint32
+ // Capabilities returns the set of capabilities supported by the
+ // underlying link-layer endpoint.
+ Capabilities() LinkEndpointCapabilities
+
// MaxHeaderLength returns the maximum size the network (and lower
// level layers combined) headers can have. Higher levels use this
// information to reserve space in the front of the packets they're
@@ -106,7 +145,7 @@
// WritePacket writes a packet to the given destination address and
// protocol.
- WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error
+ WritePacket(r *Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
@@ -116,7 +155,7 @@
// HandlePacket is called by the link layer when new packets arrive to
// this network endpoint.
- HandlePacket(r *Route, vv *buffer.VectorisedView)
+ HandlePacket(r *Route, vv buffer.VectorisedView)
// Close is called when the endpoint is reomved from a stack.
Close()
@@ -144,6 +183,11 @@
// SetOption returns an error if the option is not supported or the
// provided option value is invalid.
SetOption(option interface{}) *tcpip.Error
+
+ // Option allows retrieving protocol specific option values.
+ // Option returns an error if the option is not supported or the
+ // provided option value is invalid.
+ Option(option interface{}) *tcpip.Error
}
// NetworkDispatcher contains the methods used by the network stack to deliver
@@ -152,16 +196,21 @@
type NetworkDispatcher interface {
// DeliverNetworkPacket finds the appropriate network protocol
// endpoint and hands the packet over for further processing.
- DeliverNetworkPacket(linkEP LinkEndpoint, dstLinkAddr, srcLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView)
+ DeliverNetworkPacket(linkEP LinkEndpoint, dstLinkAddr, srcLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView)
}
-type BufferWritingLinkEndpoint interface {
- // WriteBuffer writes a packet with the given protocol through the given route.
- // It doesn't distinguish between headers and payload, since the link endpoint
- // ultimately just copies both into the transmit buffer.
- // TODO(stijlist): update WritePacket to this API, port callers
- WriteBuffer(r *Route, payload *buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error
-}
+// LinkEndpointCapabilities is the type associated with the capabilities
+// supported by a link-layer endpoint. It is a set of bitfields.
+type LinkEndpointCapabilities uint
+
+// The following are the supported link endpoint capabilities.
+const (
+ CapabilityChecksumOffload LinkEndpointCapabilities = 1 << iota
+ CapabilityResolutionRequired
+ CapabilitySaveRestore
+ CapabilityDisconnectOk
+ CapabilityLoopback
+)
// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
// ethernet, loopback, raw) and used by network layer protocols to send packets
@@ -173,6 +222,10 @@
// includes the maximum size of an IP packet.
MTU() uint32
+ // Capabilities returns the set of capabilities supported by the
+ // endpoint.
+ Capabilities() LinkEndpointCapabilities
+
// MaxHeaderLength returns the maximum size the data link (and
// lower level layers combined) headers can have. Higher levels use this
// information to reserve space in the front of the packets they're
@@ -185,11 +238,19 @@
// WritePacket writes a packet with the given protocol through the given
// route.
- WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error
+ //
+ // To participate in transparent bridging, a LinkEndpoint implementation
+ // should call eth.Encode with header.EthernetFields.SrcAddr set to
+ // r.LocalLinkAddress if it is provided.
+ WritePacket(r *Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
Attach(dispatcher NetworkDispatcher)
+
+ // IsAttached returns whether a NetworkDispatcher is attached to the
+ // endpoint.
+ IsAttached() bool
}
// A LinkAddressResolver is an extension to a NetworkProtocol that
@@ -202,6 +263,13 @@
// endpoint to call AddLinkAddress.
LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error
+ // ResolveStaticAddress attempts to resolve address without sending
+ // requests. It either resolves the name immediately or returns the
+ // empty LinkAddress.
+ //
+ // It can be used to resolve broadcast addresses for example.
+ ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool)
+
// LinkAddressProtocol returns the network protocol of the
// addresses this this resolver can resolve.
LinkAddressProtocol() tcpip.NetworkProtocolNumber
@@ -210,14 +278,21 @@
// A LinkAddressCache caches link addresses.
type LinkAddressCache interface {
// CheckLocalAddress determines if the given local address exists, and if it
- // does, returns the id of the NIC it's bound to. Returns 0 if the address
// does not exist.
- CheckLocalAddress(nicid tcpip.NICID, addr tcpip.Address) tcpip.NICID
+ CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID
// AddLinkAddress adds a link address to the cache.
AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
- GetLinkAddress(nicid tcpip.NICID, addr tcpip.Address, timeout time.Duration) tcpip.LinkAddress
+ // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC).
+ // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver
+ // registered with the network protocol, the cache attempts to resolve the address
+ // and returns ErrWouldBlock. Waker is notified when address resolution is
+ // complete (success or not).
+ GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error)
+
+ // RemoveWaker removes a waker that has been added in GetLinkAddress().
+ RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
}
// TransportProtocolFactory functions are used by the stack to instantiate
diff --git a/tcpip/stack/route.go b/tcpip/stack/route.go
index 476c2c6..5219af6 100644
--- a/tcpip/stack/route.go
+++ b/tcpip/stack/route.go
@@ -1,13 +1,21 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package stack
import (
- "sync/atomic"
- "time"
-
+ "github.com/google/netstack/sleep"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
@@ -42,12 +50,13 @@
// makeRoute initializes a new route. It takes ownership of the provided
// reference to a network endpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, ref *referencedNetworkEndpoint) Route {
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint) Route {
return Route{
- NetProto: netProto,
- LocalAddress: localAddr,
- RemoteAddress: remoteAddr,
- ref: ref,
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ LocalLinkAddress: localLinkAddr,
+ RemoteAddress: remoteAddr,
+ ref: ref,
}
}
@@ -61,10 +70,9 @@
return r.ref.ep.MaxHeaderLength()
}
-// MutableStats returns a mutable copy of the referenced endpoint's Stats
-// struct, for stats updates where we only have a reference to the Route.
-func (r *Route) MutableStats() *tcpip.Stats {
- return r.ref.nic.stack.MutableStats()
+// Stats returns a mutable copy of current stats.
+func (r *Route) Stats() tcpip.Stats {
+ return r.ref.nic.stack.Stats()
}
// PseudoHeaderChecksum forwards the call to the network endpoint's
@@ -73,30 +81,60 @@
return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress)
}
-func isLoopback(addr tcpip.Address) bool {
- return (len(addr) == 4 && addr[0] == 127) || addr == header.IPv6Loopback
+// Capabilities returns the link-layer capabilities of the route.
+func (r *Route) Capabilities() LinkEndpointCapabilities {
+ return r.ref.ep.Capabilities()
+}
+
+// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
+// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
+// notified when address resolution is complete (success or not).
+func (r *Route) Resolve(waker *sleep.Waker) *tcpip.Error {
+ if !r.IsResolutionRequired() {
+ // Nothing to do if there is no cache (which does the resolution on cache miss) or
+ // link address is already known.
+ return nil
+ }
+
+ nextAddr := r.NextHop
+ if nextAddr == "" {
+ // Local link address is already known.
+ if r.RemoteAddress == r.LocalAddress {
+ r.RemoteLinkAddress = r.LocalLinkAddress
+ return nil
+ }
+ nextAddr = r.RemoteAddress
+ }
+ linkAddr, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
+ if err != nil {
+ return err
+ }
+ r.RemoteLinkAddress = linkAddr
+ return nil
+}
+
+// RemoveWaker removes a waker that has been added in Resolve().
+func (r *Route) RemoveWaker(waker *sleep.Waker) {
+ nextAddr := r.NextHop
+ if nextAddr == "" {
+ nextAddr = r.RemoteAddress
+ }
+ r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker)
+}
+
+// IsResolutionRequired returns true if Resolve() must be called to resolve
+// the link address before the this route can be written to.
+func (r *Route) IsResolutionRequired() bool {
+ return r.ref.linkCache != nil && r.RemoteLinkAddress == ""
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
- if r.RemoteLinkAddress == "" && r.ref.linkRes != nil && !isLoopback(r.RemoteAddress) {
- nextAddr := r.NextHop
- if nextAddr == "" {
- nextAddr = r.RemoteAddress
- }
-
- nicid := r.ref.nic.ID()
- r.RemoteLinkAddress = r.ref.linkCache.GetLinkAddress(nicid, nextAddr, 0)
- if r.RemoteLinkAddress == "" {
- r.ref.linkRes.LinkAddressRequest(nextAddr, r.LocalAddress, r.ref.linkEP)
- r.RemoteLinkAddress = r.ref.linkCache.GetLinkAddress(nicid, nextAddr, 250*time.Millisecond)
- }
- if r.RemoteLinkAddress == "" {
- atomic.AddUint64(&r.MutableStats().IP.OutgoingPacketErrors, 1)
- return tcpip.ErrNoLinkAddress
- }
+func (r *Route) WritePacket(hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
+ err := r.ref.ep.WritePacket(r, hdr, payload, protocol, ttl)
+ if err == tcpip.ErrNoRoute {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
}
- return r.ref.ep.WritePacket(r, hdr, payload, protocol, ttl)
+ return err
}
// DefaultTTL returns the default TTL of the underlying network endpoint.
diff --git a/tcpip/stack/stack.go b/tcpip/stack/stack.go
index dc53ae7..bf63080 100644
--- a/tcpip/stack/stack.go
+++ b/tcpip/stack/stack.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package stack provides the glue between networking protocols and the
// consumers of the networking stack.
@@ -16,18 +26,248 @@
import (
"sync"
- "sync/atomic"
"time"
+ "github.com/google/netstack/sleep"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
+ "github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/ports"
+ "github.com/google/netstack/tcpip/seqnum"
"github.com/google/netstack/waiter"
)
+const (
+ // ageLimit is set to the same cache stale time used in Linux.
+ ageLimit = 1 * time.Minute
+ // resolutionTimeout is set to the same ARP timeout used in Linux.
+ resolutionTimeout = 1 * time.Second
+ // resolutionAttempts is set to the same ARP retries used in Linux.
+ resolutionAttempts = 3
+)
+
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(*Route, TransportEndpointID, *buffer.VectorisedView) bool
+ defaultHandler func(*Route, TransportEndpointID, buffer.VectorisedView) bool
+}
+
+// TCPProbeFunc is the expected function type for a TCP probe function to be
+// passed to stack.AddTCPProbe.
+type TCPProbeFunc func(s TCPEndpointState)
+
+// TCPCubicState is used to hold a copy of the internal cubic state when the
+// TCPProbeFunc is invoked.
+type TCPCubicState struct {
+ WLastMax float64
+ WMax float64
+ T time.Time
+ TimeSinceLastCongestion time.Duration
+ C float64
+ K float64
+ Beta float64
+ WC float64
+ WEst float64
+}
+
+// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
+type TCPEndpointID struct {
+ // LocalPort is the local port associated with the endpoint.
+ LocalPort uint16
+
+ // LocalAddress is the local [network layer] address associated with
+ // the endpoint.
+ LocalAddress tcpip.Address
+
+ // RemotePort is the remote port associated with the endpoint.
+ RemotePort uint16
+
+ // RemoteAddress it the remote [network layer] address associated with
+ // the endpoint.
+ RemoteAddress tcpip.Address
+}
+
+// TCPFastRecoveryState holds a copy of the internal fast recovery state of a
+// TCP endpoint.
+type TCPFastRecoveryState struct {
+ // Active if true indicates the endpoint is in fast recovery.
+ Active bool
+
+ // First is the first unacknowledged sequence number being recovered.
+ First seqnum.Value
+
+ // Last is the 'recover' sequence number that indicates the point at
+ // which we should exit recovery barring any timeouts etc.
+ Last seqnum.Value
+
+ // MaxCwnd is the maximum value we are permitted to grow the congestion
+ // window during recovery. This is set at the time we enter recovery.
+ MaxCwnd int
+}
+
+// TCPReceiverState holds a copy of the internal state of the receiver for
+// a given TCP endpoint.
+type TCPReceiverState struct {
+ // RcvNxt is the TCP variable RCV.NXT.
+ RcvNxt seqnum.Value
+
+ // RcvAcc is the TCP variable RCV.ACC.
+ RcvAcc seqnum.Value
+
+ // RcvWndScale is the window scaling to use for inbound segments.
+ RcvWndScale uint8
+
+ // PendingBufUsed is the number of bytes pending in the receive
+ // queue.
+ PendingBufUsed seqnum.Size
+
+ // PendingBufSize is the size of the socket receive buffer.
+ PendingBufSize seqnum.Size
+}
+
+// TCPSenderState holds a copy of the internal state of the sender for
+// a given TCP Endpoint.
+type TCPSenderState struct {
+ // LastSendTime is the time at which we sent the last segment.
+ LastSendTime time.Time
+
+ // DupAckCount is the number of Duplicate ACK's received.
+ DupAckCount int
+
+ // SndCwnd is the size of the sending congestion window in packets.
+ SndCwnd int
+
+ // Ssthresh is the slow start threshold in packets.
+ Ssthresh int
+
+ // SndCAAckCount is the number of packets consumed in congestion
+ // avoidance mode.
+ SndCAAckCount int
+
+ // Outstanding is the number of packets in flight.
+ Outstanding int
+
+ // SndWnd is the send window size in bytes.
+ SndWnd seqnum.Size
+
+ // SndUna is the next unacknowledged sequence number.
+ SndUna seqnum.Value
+
+ // SndNxt is the sequence number of the next segment to be sent.
+ SndNxt seqnum.Value
+
+ // RTTMeasureSeqNum is the sequence number being used for the latest RTT
+ // measurement.
+ RTTMeasureSeqNum seqnum.Value
+
+ // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent.
+ RTTMeasureTime time.Time
+
+ // Closed indicates that the caller has closed the endpoint for sending.
+ Closed bool
+
+ // SRTT is the smoothed round-trip time as defined in section 2 of
+ // RFC 6298.
+ SRTT time.Duration
+
+ // RTO is the retransmit timeout as defined in section of 2 of RFC 6298.
+ RTO time.Duration
+
+ // RTTVar is the round-trip time variation as defined in section 2 of
+ // RFC 6298.
+ RTTVar time.Duration
+
+ // SRTTInited if true indicates take a valid RTT measurement has been
+ // completed.
+ SRTTInited bool
+
+ // MaxPayloadSize is the maximum size of the payload of a given segment.
+ // It is initialized on demand.
+ MaxPayloadSize int
+
+ // SndWndScale is the number of bits to shift left when reading the send
+ // window size from a segment.
+ SndWndScale uint8
+
+ // MaxSentAck is the highest acknowledgement number sent till now.
+ MaxSentAck seqnum.Value
+
+ // FastRecovery holds the fast recovery state for the endpoint.
+ FastRecovery TCPFastRecoveryState
+
+ // Cubic holds the state related to CUBIC congestion control.
+ Cubic TCPCubicState
+}
+
+// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
+type TCPSACKInfo struct {
+ // Blocks is the list of SACK block currently received by the
+ // TCP endpoint.
+ Blocks []header.SACKBlock
+}
+
+// TCPEndpointState is a copy of the internal state of a TCP endpoint.
+type TCPEndpointState struct {
+ // ID is a copy of the TransportEndpointID for the endpoint.
+ ID TCPEndpointID
+
+ // SegTime denotes the absolute time when this segment was received.
+ SegTime time.Time
+
+ // RcvBufSize is the size of the receive socket buffer for the endpoint.
+ RcvBufSize int
+
+ // RcvBufUsed is the amount of bytes actually held in the receive socket
+ // buffer for the endpoint.
+ RcvBufUsed int
+
+ // RcvClosed if true, indicates the endpoint has been closed for reading.
+ RcvClosed bool
+
+ // SendTSOk is used to indicate when the TS Option has been negotiated.
+ // When sendTSOk is true every non-RST segment should carry a TS as per
+ // RFC7323#section-1.1.
+ SendTSOk bool
+
+ // RecentTS is the timestamp that should be sent in the TSEcr field of
+ // the timestamp for future segments sent by the endpoint. This field is
+ // updated if required when a new segment is received by this endpoint.
+ RecentTS uint32
+
+ // TSOffset is a randomized offset added to the value of the TSVal field
+ // in the timestamp option.
+ TSOffset uint32
+
+ // SACKPermitted is set to true if the peer sends the TCPSACKPermitted
+ // option in the SYN/SYN-ACK.
+ SACKPermitted bool
+
+ // SACK holds TCP SACK related information for this endpoint.
+ SACK TCPSACKInfo
+
+ // SndBufSize is the size of the socket send buffer.
+ SndBufSize int
+
+ // SndBufUsed is the number of bytes held in the socket send buffer.
+ SndBufUsed int
+
+ // SndClosed indicates that the endpoint has been closed for sends.
+ SndClosed bool
+
+ // SndBufInQueue is the number of bytes in the send queue.
+ SndBufInQueue seqnum.Size
+
+ // PacketTooBigCount is used to notify the main protocol routine how
+ // many times a "packet too big" control packet is received.
+ PacketTooBigCount int
+
+ // SndMTU is the smallest MTU seen in the control packets received.
+ SndMTU int
+
+ // Receiver holds variables related to the TCP receiver for the endpoint.
+ Receiver TCPReceiverState
+
+ // Sender holds state related to the TCP Sender for the endpoint.
+ Sender TCPSenderState
}
// Stack is a networking stack, with all supported protocols, NICs, and route
@@ -43,17 +283,34 @@
linkAddrCache *linkAddrCache
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*NIC
forwarding bool
- mu sync.RWMutex
- nics map[tcpip.NICID]*NIC
-
// route is the route table passed in by the user via SetRouteTable(),
// it is used by FindRoute() to build a route for a specific
// destination.
routeTable []tcpip.Route
*ports.PortManager
+
+ // If not nil, then any new endpoints will have this probe function
+ // invoked everytime they receive a TCP segment.
+ tcpProbeFunc TCPProbeFunc
+
+ // clock is used to generate user-visible times.
+ clock tcpip.Clock
+}
+
+// Options contains optional Stack configuration.
+type Options struct {
+ // Clock is an optional clock source used for timestampping packets.
+ //
+ // If no Clock is specified, the clock source will be time.Now.
+ Clock tcpip.Clock
+
+ // Stats are optional statistic counters.
+ Stats tcpip.Stats
}
// New allocates a new networking stack with only the requested networking and
@@ -63,15 +320,21 @@
// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the
// stack. Please refer to individual protocol implementations as to what options
// are supported.
-func New(network []string, transport []string) *Stack {
+func New(network []string, transport []string, opts Options) *Stack {
+ clock := opts.Clock
+ if clock == nil {
+ clock = &tcpip.StdClock{}
+ }
s := &Stack{
transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
nics: make(map[tcpip.NICID]*NIC),
- linkAddrCache: newLinkAddrCache(1 * time.Minute),
+ linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
}
// Add specified network protocols.
@@ -117,6 +380,23 @@
return netProto.SetOption(option)
}
+// NetworkProtocolOption allows retrieving individual protocol level option
+// values. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation.
+// e.g.
+// var v ipv4.MyOption
+// err := s.NetworkProtocolOption(tcpip.IPv4ProtocolNumber, &v)
+// if err != nil {
+// ...
+// }
+func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
+ netProto, ok := s.networkProtocols[network]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return netProto.Option(option)
+}
+
// SetTransportProtocolOption allows configuring individual protocol level
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
@@ -129,71 +409,59 @@
return transProtoState.proto.SetOption(option)
}
+// TransportProtocolOption allows retrieving individual protocol level option
+// values. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation.
+// var v tcp.SACKEnabled
+// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil {
+// ...
+// }
+func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
+ transProtoState, ok := s.transportProtocols[transport]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return transProtoState.proto.Option(option)
+}
+
// SetTransportProtocolHandler sets the per-stack default handler for the given
// protocol.
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *buffer.VectorisedView) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.VectorisedView) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
}
}
-// Stats returns a snapshot of the current stats.
-//
-// NOTE: The underlying stats are updated using atomic instructions as a result
-// the snapshot returned does not represent the value of all the stats at any
-// single given point of time.
-// TODO: Make stats available in sentry for debugging/diag.
-func (s *Stack) Stats() tcpip.Stats {
- return tcpip.Stats{
- UnknownProtocolRcvdPackets: atomic.LoadUint64(&s.stats.UnknownProtocolRcvdPackets),
- MalformedRcvdPackets: atomic.LoadUint64(&s.stats.MalformedRcvdPackets),
- DroppedPackets: atomic.LoadUint64(&s.stats.DroppedPackets),
- IP: tcpip.IPStats{
- PacketsReceived: atomic.LoadUint64(&s.stats.IP.PacketsReceived),
- InvalidAddressesReceived: atomic.LoadUint64(&s.stats.IP.InvalidAddressesReceived),
- PacketsDiscarded: atomic.LoadUint64(&s.stats.IP.PacketsDiscarded),
- PacketsDelivered: atomic.LoadUint64(&s.stats.IP.PacketsDelivered),
- PacketsSent: atomic.LoadUint64(&s.stats.IP.PacketsSent),
- OutgoingPacketErrors: atomic.LoadUint64(&s.stats.IP.OutgoingPacketErrors),
- },
- TCP: tcpip.TCPStats{
- ActiveConnectionOpenings: atomic.LoadUint64(&s.stats.TCP.ActiveConnectionOpenings),
- PassiveConnectionOpenings: atomic.LoadUint64(&s.stats.TCP.PassiveConnectionOpenings),
- FailedConnectionAttempts: atomic.LoadUint64(&s.stats.TCP.FailedConnectionAttempts),
- ValidSegmentsReceived: atomic.LoadUint64(&s.stats.TCP.ValidSegmentsReceived),
- InvalidSegmentsReceived: atomic.LoadUint64(&s.stats.TCP.InvalidSegmentsReceived),
- SegmentsSent: atomic.LoadUint64(&s.stats.TCP.SegmentsSent),
- ResetsSent: atomic.LoadUint64(&s.stats.TCP.ResetsSent),
- },
- UDP: tcpip.UDPStats{
- PacketsReceived: atomic.LoadUint64(&s.stats.UDP.PacketsReceived),
- UnknownPortErrors: atomic.LoadUint64(&s.stats.UDP.UnknownPortErrors),
- ReceiveBufferErrors: atomic.LoadUint64(&s.stats.UDP.ReceiveBufferErrors),
- MalformedPacketsReceived: atomic.LoadUint64(&s.stats.UDP.MalformedPacketsReceived),
- PacketsSent: atomic.LoadUint64(&s.stats.UDP.PacketsSent),
- },
- }
+// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
+func (s *Stack) NowNanoseconds() int64 {
+ return s.clock.NowNanoseconds()
}
-// MutableStats returns a mutable copy of the current stats.
+// Stats returns a mutable copy of the current stats.
//
// This is not generally exported via the public interface, but is available
// internally.
-func (s *Stack) MutableStats() *tcpip.Stats {
- return &s.stats
+func (s *Stack) Stats() tcpip.Stats {
+ return s.stats
}
// SetForwarding enables or disables the packet forwarding between NICs.
func (s *Stack) SetForwarding(enable bool) {
+ // TODO: Expose via /proc/sys/net/ipv4/ip_forward.
+ s.mu.Lock()
s.forwarding = enable
+ s.mu.Unlock()
}
// Forwarding returns if the packet forwarding between NICs is enabled.
func (s *Stack) Forwarding() bool {
+ // TODO: Expose via /proc/sys/net/ipv4/ip_forward.
+ s.mu.RLock()
+ defer s.mu.RUnlock()
return s.forwarding
}
@@ -225,7 +493,7 @@
// createNIC creates a NIC with the provided id and link-layer endpoint, and
// optionally enable it.
-func (s *Stack) createNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID, enabled bool) *tcpip.Error {
+func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled bool) *tcpip.Error {
ep := FindLinkEndpoint(linkEP)
if ep == nil {
return tcpip.ErrBadLinkEndpoint
@@ -239,7 +507,7 @@
return tcpip.ErrDuplicateNICID
}
- n := newNIC(s, id, ep)
+ n := newNIC(s, id, name, ep)
s.nics[id] = n
if enabled {
@@ -251,14 +519,26 @@
// CreateNIC creates a NIC with the provided id and link-layer endpoint.
func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, linkEP, true)
+ return s.createNIC(id, "", linkEP, true)
+}
+
+// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint,
+// and a human-readable name.
+func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, name, linkEP, true)
}
// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint,
// but leave it disable. Stack.EnableNIC must be called before the link-layer
// endpoint starts delivering packets to it.
func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, linkEP, false)
+ return s.createNIC(id, "", linkEP, false)
+}
+
+// CreateDisabledNamedNIC is a combination of CreateNamedNIC and
+// CreateDisabledNIC.
+func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, name, linkEP, false)
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
@@ -290,12 +570,65 @@
return nics
}
+// NICInfo captures the name and addresses assigned to a NIC.
+type NICInfo struct {
+ Name string
+ LinkAddress tcpip.LinkAddress
+ ProtocolAddresses []tcpip.ProtocolAddress
+
+ // Flags indicate the state of the NIC.
+ Flags NICStateFlags
+
+ // MTU is the maximum transmission unit.
+ MTU uint32
+}
+
+// NICInfo returns a map of NICIDs to their associated information.
+func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nics := make(map[tcpip.NICID]NICInfo)
+ for id, nic := range s.nics {
+ flags := NICStateFlags{
+ Up: true, // Netstack interfaces are always up.
+ Running: nic.linkEP.IsAttached(),
+ Promiscuous: nic.isPromiscuousMode(),
+ Loopback: nic.linkEP.Capabilities()&CapabilityLoopback != 0,
+ }
+ nics[id] = NICInfo{
+ Name: nic.name,
+ LinkAddress: nic.linkEP.LinkAddress(),
+ ProtocolAddresses: nic.Addresses(),
+ Flags: flags,
+ MTU: nic.linkEP.MTU(),
+ }
+ }
+ return nics
+}
+
+// NICStateFlags holds information about the state of an NIC.
+type NICStateFlags struct {
+ // Up indicates whether the interface is running.
+ Up bool
+
+ // Running indicates whether resources are allocated.
+ Running bool
+
+ // Promiscuous indicates whether the interface is in promiscuous mode.
+ Promiscuous bool
+
+ // Loopback indicates whether the interface is a loopback.
+ Loopback bool
+}
+
// AddAddress adds a new network-layer address to the specified NIC.
func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
}
-// AddAddressWithOptions is the same as AddAddress, but allows you to specify whether they new endpoint can be primary or not.
+// AddAddressWithOptions is the same as AddAddress, but allows you to specify
+// whether the new endpoint can be primary or not.
func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -313,13 +646,12 @@
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[id]
- if nic == nil {
- return tcpip.ErrUnknownNICID
+ if nic, ok := s.nics[id]; ok {
+ nic.AddSubnet(protocol, subnet)
+ return nil
}
- nic.AddSubnet(protocol, subnet)
- return nil
+ return tcpip.ErrUnknownNICID
}
// RemoveSubnet removes the subnet range from the specified NIC.
@@ -327,25 +659,25 @@
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[id]
- if nic == nil {
- return tcpip.ErrUnknownNICID
+ if nic, ok := s.nics[id]; ok {
+ nic.RemoveSubnet(subnet)
+ return nil
}
- nic.RemoveSubnet(subnet)
- return nil
+
+ return tcpip.ErrUnknownNICID
}
-// Returns true if the given subnet is present in the specified NIC..
+// ContainsSubnet reports whether the specified NIC contains the specified
+// subnet.
func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[id]
- if nic == nil {
- return false, tcpip.ErrUnknownNICID
+ if nic, ok := s.nics[id]; ok {
+ return nic.ContainsSubnet(subnet), nil
}
- return nic.ContainsSubnet(subnet), nil
+ return false, tcpip.ErrUnknownNICID
}
// RemoveAddress removes an existing network-layer address from the specified
@@ -354,30 +686,26 @@
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[id]
- if nic == nil {
- return tcpip.ErrUnknownNICID
+ if nic, ok := s.nics[id]; ok {
+ return nic.RemoveAddress(addr)
}
- return nic.RemoveAddress(addr)
+ return tcpip.ErrUnknownNICID
}
-// Returns the first primary address (and subnet that contains it) for the
-// given NIC and protocol.
+// GetMainNICAddress returns the first primary address (and the subnet that
+// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's
+// address if no primary addresses exist. Returns an error if the NIC doesn't
+// exist or has no endpoints.
func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
- var address tcpip.Address
- var subnet tcpip.Subnet
-
- nic := s.nics[id]
- if nic == nil {
- return address, subnet, tcpip.ErrUnknownNICID
+ if nic, ok := s.nics[id]; ok {
+ return nic.getMainNICAddress(protocol)
}
- address, subnet = nic.getMainNICAddress(protocol)
- return address, subnet, nil
+ return "", tcpip.Subnet{}, tcpip.ErrUnknownNICID
}
// FindRoute creates a route to the given destination address, leaving through
@@ -387,7 +715,7 @@
defer s.mu.RUnlock()
for i := range s.routeTable {
- if id != 0 && id != s.routeTable[i].NIC || !s.routeTable[i].Match(remoteAddr) {
+ if (id != 0 && id != s.routeTable[i].NIC) || (len(remoteAddr) != 0 && !s.routeTable[i].Match(remoteAddr)) {
continue
}
@@ -398,16 +726,21 @@
var ref *referencedNetworkEndpoint
if len(localAddr) != 0 {
- ref = nic.findEndpoint(localAddr)
+ ref = nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint)
} else {
ref = nic.primaryEndpoint(netProto)
}
-
if ref == nil {
continue
}
- r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, ref)
+ if len(remoteAddr) == 0 {
+ // If no remote address was provided, then the route
+ // provided will refer to the link local address.
+ remoteAddr = ref.ep.ID().LocalAddress
+ }
+
+ r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref)
r.NextHop = s.routeTable[i].Gateway
return r, nil
}
@@ -425,7 +758,7 @@
// CheckLocalAddress determines if the given local address exists, and if it
// does, returns the id of the NIC it's bound to. Returns 0 if the address
// does not exist.
-func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, addr tcpip.Address) tcpip.NICID {
+func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -436,7 +769,7 @@
return 0
}
- ref := nic.findEndpoint(addr)
+ ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
if ref == nil {
return 0
}
@@ -448,7 +781,7 @@
// Go through all the NICs.
for _, nic := range s.nics {
- ref := nic.findEndpoint(addr)
+ ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
if ref != nil {
ref.decRef()
return nic.id
@@ -473,21 +806,54 @@
return nil
}
+// SetSpoofing enables or disables address spoofing in the given NIC, allowing
+// endpoints to bind to any address in the NIC.
+func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.setSpoofing(enable)
+
+ return nil
+}
+
// AddLinkAddress adds a link address to the stack link cache.
func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
s.linkAddrCache.add(fullAddr, linkAddr)
- // TODO(crawshaw): provide a way for a
- // transport endpoint to receive a signal that AddLinkAddress
- // for a particular address has been called.
+ // TODO: provide a way for a transport endpoint to receive a signal
+ // that AddLinkAddress for a particular address has been called.
}
-func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr tcpip.Address, timeout time.Duration) tcpip.LinkAddress {
- if addr == "\xff\xff\xff\xff" {
- return "\xff\xff\xff\xff\xff\xff"
+// GetLinkAddress implements LinkAddressCache.GetLinkAddress.
+func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) {
+ s.mu.RLock()
+ nic := s.nics[nicid]
+ if nic == nil {
+ s.mu.RUnlock()
+ return "", tcpip.ErrUnknownNICID
}
+ s.mu.RUnlock()
+
fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
- return s.linkAddrCache.get(fullAddr, timeout)
+ linkRes := s.linkAddrResolvers[protocol]
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
+}
+
+// RemoveWaker implements LinkAddressCache.RemoveWaker.
+func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic := s.nics[nicid]; nic == nil {
+ fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ s.linkAddrCache.removeWaker(fullAddr, waker)
+ }
}
// RegisterTransportEndpoint registers the given endpoint with the stack
@@ -527,17 +893,6 @@
}
}
-// JoinGroup joins the given multicast group on the given NIC.
-// TODO: notify network of subscription via igmp protocol
-func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
- return s.AddAddressWithOptions(nicID, protocol, multicastAddr, NeverPrimaryEndpoint)
-}
-
-// LeaveGroup leaves the given multicast group on the given NIC.
-func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
- return s.RemoveAddress(nicID, multicastAddr)
-}
-
// NetworkProtocolInstance returns the protocol instance in the stack for the
// specified network protocol. This method is public for protocol implementers
// and tests to use.
@@ -557,3 +912,51 @@
}
return nil
}
+
+// AddTCPProbe installs a probe function that will be invoked on every segment
+// received by a given TCP endpoint. The probe function is passed a copy of the
+// TCP endpoint state.
+//
+// NOTE: TCPProbe is added only to endpoints created after this call. Endpoints
+// created prior to this call will not call the probe function.
+//
+// Further, installing two different probes back to back can result in some
+// endpoints calling the first one and some the second one. There is no
+// guarantee provided on which probe will be invoked. Ideally this should only
+// be called once per stack.
+func (s *Stack) AddTCPProbe(probe TCPProbeFunc) {
+ s.mu.Lock()
+ s.tcpProbeFunc = probe
+ s.mu.Unlock()
+}
+
+// GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil
+// otherwise.
+func (s *Stack) GetTCPProbe() TCPProbeFunc {
+ s.mu.Lock()
+ p := s.tcpProbeFunc
+ s.mu.Unlock()
+ return p
+}
+
+// RemoveTCPProbe removes an installed TCP probe.
+//
+// NOTE: This only ensures that endpoints created after this call do not
+// have a probe attached. Endpoints already created will continue to invoke
+// TCP probe.
+func (s *Stack) RemoveTCPProbe() {
+ s.mu.Lock()
+ s.tcpProbeFunc = nil
+ s.mu.Unlock()
+}
+
+// JoinGroup joins the given multicast group on the given NIC.
+func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
+ // TODO: notify network of subscription via igmp protocol.
+ return s.AddAddressWithOptions(nicID, protocol, multicastAddr, NeverPrimaryEndpoint)
+}
+
+// LeaveGroup leaves the given multicast group on the given NIC.
+func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
+ return s.RemoveAddress(nicID, multicastAddr)
+}
diff --git a/tcpip/stack/stack_test.go b/tcpip/stack/stack_test.go
index e1e2955..05c2ca2 100644
--- a/tcpip/stack/stack_test.go
+++ b/tcpip/stack/stack_test.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package stack_test contains tests for the stack. It is in its own package so
// that the tests can also validate that all definitions needed to implement
@@ -14,7 +24,6 @@
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
- "github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/link/channel"
"github.com/google/netstack/tcpip/stack"
)
@@ -23,6 +32,10 @@
fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
fakeNetHeaderLen = 12
+ // fakeControlProtocol is used for control packets that represent
+ // destination port unreachable.
+ fakeControlProtocol tcpip.TransportProtocolNumber = 2
+
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
// of loopback interfaces on linux systems.
@@ -52,15 +65,15 @@
return f.nicid
}
+func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
+ return 123
+}
+
func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
return &f.id
}
-func (f *fakeNetworkEndpoint) DefaultTTL() uint8 {
- return header.IPv4DefaultTTL
-}
-
-func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) {
+func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
// Increment the received packet count in the protocol descriptor.
f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
@@ -68,6 +81,18 @@
b := vv.First()
vv.TrimFront(fakeNetHeaderLen)
+ // Handle control packets.
+ if b[2] == uint8(fakeControlProtocol) {
+ nb := vv.First()
+ if len(nb) < fakeNetHeaderLen {
+ return
+ }
+
+ vv.TrimFront(fakeNetHeaderLen)
+ f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv)
+ return
+ }
+
// Dispatch the packet to the transport protocol.
f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), vv)
}
@@ -80,7 +105,11 @@
return 0
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber, _ uint8) *tcpip.Error {
+func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return f.linkEP.Capabilities()
+}
+
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
@@ -148,11 +177,21 @@
}
}
+func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *fakeNetGoodOption:
+ *v = fakeNetGoodOption(f.opts.good)
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
func TestNetworkReceive(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and two
// addresses attached to it: 1 & 2.
id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, nil)
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -166,15 +205,12 @@
}
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
- var views [1]buffer.View
- // Allocate the buffer containing the packet that will be injected into
- // the stack.
+
buf := buffer.NewView(30)
// Make sure packet with wrong address is not delivered.
buf[0] = 3
- vv := buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 0 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
}
@@ -184,8 +220,7 @@
// Make sure packet is delivered to first endpoint.
buf[0] = 1
- vv = buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -195,8 +230,7 @@
// Make sure packet is delivered to second endpoint.
buf[0] = 2
- vv = buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -205,8 +239,7 @@
}
// Make sure packet is not delivered if protocol number is wrong.
- vv = buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber-1, &vv)
+ linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -216,8 +249,7 @@
// Make sure packet that is too small is dropped.
buf.CapLength(2)
- vv = buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -234,8 +266,7 @@
defer r.Release()
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- err = r.WritePacket(&hdr, nil, fakeTransNumber, r.DefaultTTL())
- if err != nil {
+ if err := r.WritePacket(hdr, buffer.VectorisedView{}, fakeTransNumber, 123); err != nil {
t.Errorf("WritePacket failed: %v", err)
return
}
@@ -246,7 +277,7 @@
// address: 1. The route table sends all packets through the only
// existing nic.
id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, nil)
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("NewNIC failed: %v", err)
}
@@ -268,7 +299,7 @@
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
- s := stack.New([]string{"fakeNet"}, nil)
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id1, linkEP1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id1); err != nil {
@@ -347,7 +378,7 @@
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
- s := stack.New([]string{"fakeNet"}, nil)
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id1, _ := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id1); err != nil {
@@ -411,7 +442,7 @@
}
func TestAddressRemoval(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil)
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
@@ -422,16 +453,14 @@
t.Fatalf("AddAddress failed: %v", err)
}
- var views [1]buffer.View
- buf := buffer.NewView(30)
-
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+ buf := buffer.NewView(30)
+
// Write a packet, and check that it gets delivered.
fakeNet.packetCount[1] = 0
buf[0] = 1
- vv := buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -442,8 +471,7 @@
t.Fatalf("RemoveAddress failed: %v", err)
}
- vv = buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -455,7 +483,7 @@
}
func TestDelayedRemovalDueToRoute(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil)
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
@@ -472,14 +500,12 @@
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
- var views [1]buffer.View
buf := buffer.NewView(30)
// Write a packet, and check that it gets delivered.
fakeNet.packetCount[1] = 0
buf[0] = 1
- vv := buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -490,8 +516,7 @@
t.Fatalf("FindRoute failed: %v", err)
}
- vv = buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 2 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2)
}
@@ -502,8 +527,7 @@
t.Fatalf("RemoveAddress failed: %v", err)
}
- vv = buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 3 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
}
@@ -515,15 +539,14 @@
// Release the route, then check that packet is not deliverable anymore.
r.Release()
- vv = buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 3 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
}
}
func TestPromiscuousMode(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil)
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
@@ -536,15 +559,13 @@
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
- var views [1]buffer.View
buf := buffer.NewView(30)
// Write a packet, and check that it doesn't get delivered as we don't
// have a matching endpoint.
fakeNet.packetCount[1] = 0
buf[0] = 1
- vv := buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 0 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
}
@@ -554,8 +575,7 @@
t.Fatalf("SetPromiscuousMode failed: %v", err)
}
- vv = buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -572,16 +592,58 @@
t.Fatalf("SetPromiscuousMode failed: %v", err)
}
- vv = buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
}
+func TestAddressSpoofing(t *testing.T) {
+ srcAddr := tcpip.Address("\x01")
+ dstAddr := tcpip.Address("\x02")
+
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+
+ id, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {"\x00", "\x00", "\x00", 1},
+ })
+
+ // With address spoofing disabled, FindRoute does not permit an address
+ // that was not added to the NIC to be used as the source.
+ r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber)
+ if err == nil {
+ t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
+ }
+
+ // With address spoofing enabled, FindRoute permits any address to be used
+ // as the source.
+ if err := s.SetSpoofing(1, true); err != nil {
+ t.Fatalf("SetSpoofing failed: %v", err)
+ }
+ r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber)
+ if err != nil {
+ t.Fatalf("FindRoute failed: %v", err)
+ }
+ if r.LocalAddress != srcAddr {
+ t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
+ }
+}
+
// Set the subnet, then check that packet is delivered.
func TestSubnetAcceptsMatchingPacket(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil)
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
@@ -594,8 +656,8 @@
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
- var views [1]buffer.View
buf := buffer.NewView(30)
+
buf[0] = 1
fakeNet.packetCount[1] = 0
subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
@@ -606,8 +668,7 @@
t.Fatalf("AddSubnet failed: %v", err)
}
- vv := buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -615,7 +676,7 @@
// Set destination outside the subnet, then check it doesn't get delivered.
func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil)
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
@@ -628,8 +689,8 @@
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
- var views [1]buffer.View
buf := buffer.NewView(30)
+
buf[0] = 1
fakeNet.packetCount[1] = 0
subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
@@ -639,15 +700,14 @@
if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
t.Fatalf("AddSubnet failed: %v", err)
}
- vv := buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 0 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
}
}
-func TestSetOption(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{})
+func TestNetworkOptions(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, []string{}, stack.Options{})
// Try an unsupported network protocol.
if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol {
@@ -656,21 +716,29 @@
testCases := []struct {
option interface{}
- want *tcpip.Error
+ wantErr *tcpip.Error
verifier func(t *testing.T, p stack.NetworkProtocol)
}{
{fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) {
+ t.Helper()
fakeNet := p.(*fakeNetworkProtocol)
if fakeNet.opts.good != true {
t.Fatalf("fakeNet.opts.good = false, want = true")
}
+ var v fakeNetGoodOption
+ if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil {
+ t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err)
+ }
+ if v != true {
+ t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v)
+ }
}},
{fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
{fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
}
for _, tc := range testCases {
- if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); tc.want != got {
- t.Errorf("s.SetOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.want)
+ if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr {
+ t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr)
}
if tc.verifier != nil {
tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber))
@@ -679,7 +747,7 @@
}
func TestSubnetAddRemove(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil)
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, _ := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
@@ -687,99 +755,90 @@
addr := tcpip.Address("\x01\x01\x01\x01")
mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr)))
- subnet, err1 := tcpip.NewSubnet(addr, mask)
-
- if err1 != nil {
- t.Fatalf("NewSubnet failed: %v", err1)
+ subnet, err := tcpip.NewSubnet(addr, mask)
+ if err != nil {
+ t.Fatalf("NewSubnet failed: %v", err)
}
- if contained, err := s.ContainsSubnet(1, subnet); err != nil || contained {
- if contained {
- t.Fatalf("ContainsSubnet spuriously returns true before adding subnet.")
- }
- t.Fatalf("ContainsSubnet returned error %v", err)
+ if contained, err := s.ContainsSubnet(1, subnet); err != nil {
+ t.Fatalf("ContainsSubnet failed: %v", err)
+ } else if contained {
+ t.Fatal("got s.ContainsSubnet(...) = true, want = false")
}
if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatalf("AddSubnet failed with error: %v", err)
+ t.Fatalf("AddSubnet failed: %v", err)
}
- if contained, err := s.ContainsSubnet(1, subnet); err != nil || !contained {
- if !contained {
- t.Fatalf("ContainsSubnet spuriously returns false after adding subnet.")
- }
- t.Fatalf("ContainsSubnet returned error %v", err)
+ if contained, err := s.ContainsSubnet(1, subnet); err != nil {
+ t.Fatalf("ContainsSubnet failed: %v", err)
+ } else if !contained {
+ t.Fatal("got s.ContainsSubnet(...) = false, want = true")
}
if err := s.RemoveSubnet(1, subnet); err != nil {
- t.Fatalf("RemoveSubnet failed with error: %v", err)
+ t.Fatalf("RemoveSubnet failed: %v", err)
}
- if contained, err := s.ContainsSubnet(1, subnet); err != nil || contained {
- if contained {
- t.Fatalf("ContainsSubnet spuriously returns true after removing subnet.")
- }
- t.Fatalf("ContainsSubnet returned error %v", err)
+ if contained, err := s.ContainsSubnet(1, subnet); err != nil {
+ t.Fatalf("ContainsSubnet failed: %v", err)
+ } else if contained {
+ t.Fatal("got s.ContainsSubnet(...) = true, want = false")
}
}
func TestGetMainNICAddress(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil)
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, _ := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- addr := tcpip.Address("\x01\x01\x01\x01")
- mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr)))
- subn, _ := tcpip.NewSubnet(addr, mask)
+ for _, tc := range []struct {
+ name string
+ address tcpip.Address
+ }{
+ {"IPv4", "\x01\x01\x01\x01"},
+ {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ address := tc.address
+ mask := tcpip.AddressMask(strings.Repeat("\xff", len(address)))
+ subnet, err := tcpip.NewSubnet(address, mask)
+ if err != nil {
+ t.Fatalf("NewSubnet failed: %v", err)
+ }
- if err := s.AddAddress(1, fakeNetNumber, addr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
+ if err := s.AddAddress(1, fakeNetNumber, address); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
- if err := s.AddSubnet(1, fakeNetNumber, subn); err != nil {
- t.Fatalf("AddSubnet failed with error: %v", err)
- }
+ if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
+ t.Fatalf("AddSubnet failed: %v", err)
+ }
- // Check that we get the right initial address and subnet
- address, subnet, err := s.GetMainNICAddress(1, fakeNetNumber)
+ // Check that we get the right initial address and subnet.
+ if gotAddress, gotSubnet, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil {
+ t.Fatalf("GetMainNICAddress failed: %v", err)
+ } else if gotAddress != address {
+ t.Fatalf("got GetMainNICAddress = (%v, ...), want = (%v, ...)", gotAddress, address)
+ } else if gotSubnet != subnet {
+ t.Fatalf("got GetMainNICAddress = (..., %v), want = (..., %v)", gotSubnet, subnet)
+ }
- if err != nil {
- t.Fatalf("GetMainNICAddress failed with error: %v", err)
- }
+ if err := s.RemoveSubnet(1, subnet); err != nil {
+ t.Fatalf("RemoveSubnet failed: %v", err)
+ }
- if address != addr {
- t.Fatalf("Expecting address=%s but GetMainNICAddress returned %s", addr, address)
- }
+ if err := s.RemoveAddress(1, address); err != nil {
+ t.Fatalf("RemoveAddress failed: %v", err)
+ }
- if subnet != subn {
- t.Fatalf("Expecting subnet=%#v but GetMainNICAddress returned %#v", subn, subnet)
- }
-
- if err := s.RemoveSubnet(1, subn); err != nil {
- t.Fatalf("RemoveSubnet failed with error: %v", err)
- }
-
- if err := s.RemoveAddress(1, addr); err != nil {
- t.Fatalf("RemoveAddress failed: %v", err)
- }
-
- // Check that we get an empty address and subnet after removal
- address2, subnet2, err2 := s.GetMainNICAddress(1, fakeNetNumber)
-
- if err2 != nil {
- t.Fatalf("GetMainNICAddress failed with error: %v", err2)
- }
-
- var emptyAddr tcpip.Address
- if emptyAddr != address2 {
- t.Fatalf("Expecting address=%s but GetMainNICAddress returned %s", emptyAddr, address2)
- }
-
- var emptySubnet tcpip.Subnet
- if emptySubnet != subnet2 {
- t.Fatalf("Expecting subnet=%#v but GetMainNICAddress returned %#v", emptySubnet, subnet2)
+ // Check that we get an error after removal.
+ if _, _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %v", err, tcpip.ErrNoLinkAddress)
+ }
+ })
}
}
diff --git a/tcpip/stack/transport_demuxer.go b/tcpip/stack/transport_demuxer.go
index 6a1a3e5..b41e439 100644
--- a/tcpip/stack/transport_demuxer.go
+++ b/tcpip/stack/transport_demuxer.go
@@ -1,12 +1,21 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package stack
import (
"sync"
- "sync/atomic"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
@@ -36,7 +45,7 @@
func newTransportDemuxer(stack *Stack) *transportDemuxer {
d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
- // Add each network and and transport pair to the demuxer.
+ // Add each network and transport pair to the demuxer.
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{endpoints: make(map[TransportEndpointID]TransportEndpoint)}
@@ -91,29 +100,59 @@
// deliverPacket attempts to deliver the given packet. Returns true if it found
// an endpoint, false otherwise.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
}
eps.mu.RLock()
- b := d.deliverPacketLocked(r, eps, vv, id)
+ ep := d.findEndpointLocked(eps, vv, id)
eps.mu.RUnlock()
- // UDP packet could not be delivered to an unknown destination port
- if !b && protocol == header.UDPProtocolNumber {
- atomic.AddUint64(&r.MutableStats().UDP.UnknownPortErrors, 1)
+ // Fail if we didn't find one.
+ if ep == nil {
+ // UDP packet could not be delivered to an unknown destination port.
+ if protocol == header.UDPProtocolNumber {
+ r.Stats().UDP.UnknownPortErrors.Increment()
+ }
+ return false
}
- return b
+ // Deliver the packet.
+ ep.HandlePacket(r, id, vv)
+
+ return true
}
-func (d *transportDemuxer) deliverPacketLocked(r *Route, eps *transportEndpoints, vv *buffer.VectorisedView, id TransportEndpointID) bool {
+// deliverControlPacket attempts to deliver the given control packet. Returns
+// true if it found an endpoint, false otherwise.
+func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{net, trans}]
+ if !ok {
+ return false
+ }
+
+ // Try to find the endpoint.
+ eps.mu.RLock()
+ ep := d.findEndpointLocked(eps, vv, id)
+ eps.mu.RUnlock()
+
+ // Fail if we didn't find one.
+ if ep == nil {
+ return false
+ }
+
+ // Deliver the packet.
+ ep.HandleControlPacket(id, typ, extra, vv)
+
+ return true
+}
+
+func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
// Try to find a match with the id as provided.
if ep := eps.endpoints[id]; ep != nil {
- ep.HandlePacket(r, id, vv)
- return true
+ return ep
}
// Try to find a match with the id minus the local address.
@@ -121,8 +160,7 @@
nid.LocalAddress = ""
if ep := eps.endpoints[nid]; ep != nil {
- ep.HandlePacket(r, id, vv)
- return true
+ return ep
}
// Try to find a match with the id minus the remote part.
@@ -130,16 +168,10 @@
nid.RemoteAddress = ""
nid.RemotePort = 0
if ep := eps.endpoints[nid]; ep != nil {
- ep.HandlePacket(r, id, vv)
- return true
+ return ep
}
// Try to find a match with only the local port.
nid.LocalAddress = ""
- if ep := eps.endpoints[nid]; ep != nil {
- ep.HandlePacket(r, id, vv)
- return true
- }
-
- return false
+ return eps.endpoints[nid]
}
diff --git a/tcpip/stack/transport_test.go b/tcpip/stack/transport_test.go
index f910ff6..c44b683 100644
--- a/tcpip/stack/transport_test.go
+++ b/tcpip/stack/transport_test.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package stack_test
@@ -46,26 +56,29 @@
return mask
}
-func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, *tcpip.Error) {
- return buffer.View{}, nil
+func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ return buffer.View{}, tcpip.ControlMessages{}, nil
}
-func (f *fakeTransportEndpoint) Write(v buffer.View, _ *tcpip.FullAddress) (uintptr, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) {
if len(f.route.RemoteAddress) == 0 {
return 0, tcpip.ErrNoRoute
}
hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()))
- err := f.route.WritePacket(&hdr, v, fakeTransNumber, f.route.DefaultTTL())
+ v, err := p.Get(p.Size())
if err != nil {
return 0, err
}
+ if err := f.route.WritePacket(hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123); err != nil {
+ return 0, err
+ }
return uintptr(len(v)), nil
}
-func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, *tcpip.Error) {
- return 0, nil
+func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
}
// SetSockOpt sets a socket option. Currently not supported.
@@ -135,11 +148,16 @@
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) {
+func (f *fakeTransportEndpoint) HandlePacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) {
// Increment the number of received packets.
f.proto.packetCount++
}
+func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) {
+ // Increment the number of received control packets.
+ f.proto.controlCount++
+}
+
type fakeTransportGoodOption bool
type fakeTransportBadOption bool
@@ -153,8 +171,9 @@
// fakeTransportProtocol is a transport-layer protocol descriptor. It
// aggregates the number of packets received via endpoints of this protocol.
type fakeTransportProtocol struct {
- packetCount int
- opts fakeTransportProtocolOptions
+ packetCount int
+ controlCount int
+ opts fakeTransportProtocolOptions
}
func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
@@ -173,7 +192,7 @@
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
return true
}
@@ -189,9 +208,19 @@
}
}
+func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *fakeTransportGoodOption:
+ *v = fakeTransportGoodOption(f.opts.good)
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
func TestTransportReceive(t *testing.T) {
id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"})
+ s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -215,15 +244,13 @@
fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol)
- var views [1]buffer.View
// Create buffer that will hold the packet.
buf := buffer.NewView(30)
// Make sure packet with wrong protocol is not delivered.
buf[0] = 1
buf[2] = 0
- vv := buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeTrans.packetCount != 0 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
}
@@ -232,8 +259,7 @@
buf[0] = 1
buf[1] = 3
buf[2] = byte(fakeTransNumber)
- vv = buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeTrans.packetCount != 0 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
}
@@ -242,16 +268,77 @@
buf[0] = 1
buf[1] = 2
buf[2] = byte(fakeTransNumber)
- vv = buf.ToVectorisedView(views)
- linkEP.Inject(fakeNetNumber, &vv)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeTrans.packetCount != 1 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1)
}
}
+func TestTransportControlReceive(t *testing.T) {
+ id, linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ // Create endpoint and connect to remote address.
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
+ t.Fatalf("Connect failed: %v", err)
+ }
+
+ fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol)
+
+ // Create buffer that will hold the control packet.
+ buf := buffer.NewView(2*fakeNetHeaderLen + 30)
+
+ // Outer packet contains the control protocol number.
+ buf[0] = 1
+ buf[1] = 0xfe
+ buf[2] = uint8(fakeControlProtocol)
+
+ // Make sure packet with wrong protocol is not delivered.
+ buf[fakeNetHeaderLen+0] = 0
+ buf[fakeNetHeaderLen+1] = 1
+ buf[fakeNetHeaderLen+2] = 0
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ if fakeTrans.controlCount != 0 {
+ t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
+ }
+
+ // Make sure packet from the wrong source is not delivered.
+ buf[fakeNetHeaderLen+0] = 3
+ buf[fakeNetHeaderLen+1] = 1
+ buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ if fakeTrans.controlCount != 0 {
+ t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
+ }
+
+ // Make sure packet is delivered.
+ buf[fakeNetHeaderLen+0] = 2
+ buf[fakeNetHeaderLen+1] = 1
+ buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ if fakeTrans.controlCount != 1 {
+ t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1)
+ }
+}
+
func TestTransportSend(t *testing.T) {
id, _ := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"})
+ s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -275,7 +362,7 @@
// Create buffer that will hold the payload.
view := buffer.NewView(30)
- _, err = ep.Write(view, nil)
+ _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
if err != nil {
t.Fatalf("write failed: %v", err)
}
@@ -287,8 +374,8 @@
}
}
-func TestTransportSetOption(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"})
+func TestTransportOptions(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
// Try an unsupported transport protocol.
if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol {
@@ -297,21 +384,30 @@
testCases := []struct {
option interface{}
- want *tcpip.Error
+ wantErr *tcpip.Error
verifier func(t *testing.T, p stack.TransportProtocol)
}{
{fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) {
+ t.Helper()
fakeTrans := p.(*fakeTransportProtocol)
if fakeTrans.opts.good != true {
t.Fatalf("fakeTrans.opts.good = false, want = true")
}
+ var v fakeTransportGoodOption
+ if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil {
+ t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err)
+ }
+ if v != true {
+ t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v)
+ }
+
}},
{fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
{fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
}
for _, tc := range testCases {
- if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); tc.want != got {
- t.Errorf("s.SetOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.want)
+ if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr {
+ t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr)
}
if tc.verifier != nil {
tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber))
diff --git a/tcpip/tcpip.go b/tcpip/tcpip.go
index d014c75..ec6c5af 100644
--- a/tcpip/tcpip.go
+++ b/tcpip/tcpip.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package tcpip provides the interfaces and related types that users of the
// tcpip stack will use in order to create endpoints used to send and receive
@@ -21,8 +31,11 @@
import (
"errors"
"fmt"
+ "reflect"
"strconv"
"strings"
+ "sync"
+ "sync/atomic"
"time"
"github.com/google/netstack/tcpip/buffer"
@@ -31,22 +44,26 @@
// Error represents an error in the netstack error space. Using a special type
// ensures that errors outside of this space are not accidentally introduced.
+//
+// Note: to support save / restore, it is important that all tcpip errors have
+// distinct error messages.
type Error struct {
msg string
- // IgnoreStats determines whether this error type should be included in
- // failure counts in tcpip.Stats structs.
- IgnoreStats bool
+ ignoreStats bool
}
// String implements fmt.Stringer.String.
func (e *Error) String() string {
- if e == nil {
- return "<nil>"
- }
return e.msg
}
+// IgnoreStats indicates whether this error type should be included in failure
+// counts in tcpip.Stats structs.
+func (e *Error) IgnoreStats() bool {
+ return e.ignoreStats
+}
+
// Errors that can be returned by the network stack.
var (
ErrUnknownProtocol = &Error{msg: "unknown protocol"}
@@ -56,38 +73,63 @@
ErrDuplicateAddress = &Error{msg: "duplicate address"}
ErrNoRoute = &Error{msg: "no route"}
ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"}
- ErrAlreadyBound = &Error{msg: "endpoint already bound", IgnoreStats: true}
+ ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true}
ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"}
- ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", IgnoreStats: true}
- ErrAlreadyConnected = &Error{msg: "endpoint is already connected", IgnoreStats: true}
+ ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true}
+ ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true}
ErrNoPortAvailable = &Error{msg: "no ports are available"}
ErrPortInUse = &Error{msg: "port is in use"}
ErrBadLocalAddress = &Error{msg: "bad local address"}
ErrClosedForSend = &Error{msg: "endpoint is closed for send"}
ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"}
- ErrWouldBlock = &Error{msg: "operation would block", IgnoreStats: true}
+ ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true}
ErrConnectionRefused = &Error{msg: "connection was refused"}
ErrTimeout = &Error{msg: "operation timed out"}
ErrAborted = &Error{msg: "operation aborted"}
- ErrConnectStarted = &Error{msg: "connection attempt started", IgnoreStats: true}
+ ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true}
ErrDestinationRequired = &Error{msg: "destination address is required"}
ErrNotSupported = &Error{msg: "operation not supported"}
ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"}
ErrNotConnected = &Error{msg: "endpoint not connected"}
ErrConnectionReset = &Error{msg: "connection reset by peer"}
ErrConnectionAborted = &Error{msg: "connection aborted"}
- ErrNoLinkAddress = &Error{msg: "no remote link address"}
ErrNoSuchFile = &Error{msg: "no such file"}
ErrInvalidOptionValue = &Error{msg: "invalid option value specified"}
+ ErrNoLinkAddress = &Error{msg: "no remote link address"}
+ ErrBadAddress = &Error{msg: "bad address"}
+ ErrNetworkUnreachable = &Error{msg: "network is unreachable"}
)
// Errors related to Subnet
var (
errSubnetLengthMismatch = errors.New("subnet length of address and mask differ")
errSubnetAddressMasked = errors.New("subnet address has bits set outside the mask")
- errInvalidCIDRNotation = errors.New("CIDR notation invalid")
)
+// ErrSaveRejection indicates a failed save due to unsupported networking state.
+// This type of errors is only used for save logic.
+type ErrSaveRejection struct {
+ Err error
+}
+
+// Error returns a sensible description of the save rejection error.
+func (e ErrSaveRejection) Error() string {
+ return "save rejected due to unsupported networking state: " + e.Err.Error()
+}
+
+// A Clock provides the current time.
+//
+// Times returned by a Clock should always be used for application-visible
+// time, but never for netstack internal timekeeping.
+type Clock interface {
+ // NowNanoseconds returns the current real time as a number of
+ // nanoseconds since the Unix epoch.
+ NowNanoseconds() int64
+
+ // NowMonotonic returns a monotonic time value.
+ NowMonotonic() int64
+}
+
// Address is a byte slice cast as a string that represents the address of a
// network node. Or, in the case of unix endpoints, it may represent a path.
type Address string
@@ -101,14 +143,6 @@
mask AddressMask
}
-func (a Address) Mask(m AddressMask) Address {
- out := []byte(a)
- for i, _ := range a {
- out[i] = a[i] & m[i]
- }
- return Address(out)
-}
-
// NewSubnet creates a new Subnet, checking that the address and mask are the same length.
func NewSubnet(a Address, m AddressMask) (Subnet, error) {
if len(a) != len(m) {
@@ -169,16 +203,10 @@
}
// Mask returns the subnet mask.
-// Getter instead of exported field to avoid rename before gVisor merge.
func (s *Subnet) Mask() AddressMask {
return s.mask
}
-// String implements fmt.Stringer.String.
-func (s Subnet) String() string {
- return fmt.Sprintf("{ address=%s, mask=%s }", s.address, Address(s.mask))
-}
-
// NICID is a number that uniquely identifies a NIC.
type NICID int32
@@ -195,6 +223,8 @@
// FullAddress represents a full transport node address, as required by the
// Connect() and Bind() methods.
+//
+// +stateify savable
type FullAddress struct {
// NIC is the ID of the NIC this address refers to.
//
@@ -210,6 +240,45 @@
Port uint16
}
+// Payload provides an interface around data that is being sent to an endpoint.
+// This allows the endpoint to request the amount of data it needs based on
+// internal buffers without exposing them. 'p.Get(p.Size())' reads all the data.
+type Payload interface {
+ // Get returns a slice containing exactly 'min(size, p.Size())' bytes.
+ Get(size int) ([]byte, *Error)
+
+ // Size returns the payload size.
+ Size() int
+}
+
+// SlicePayload implements Payload on top of slices for convenience.
+type SlicePayload []byte
+
+// Get implements Payload.
+func (s SlicePayload) Get(size int) ([]byte, *Error) {
+ if size > s.Size() {
+ size = s.Size()
+ }
+ return s[:size], nil
+}
+
+// Size implements Payload.
+func (s SlicePayload) Size() int {
+ return len(s)
+}
+
+// A ControlMessages contains socket control messages for IP sockets.
+//
+// +stateify savable
+type ControlMessages struct {
+ // HasTimestamp indicates whether Timestamp is valid/set.
+ HasTimestamp bool
+
+ // Timestamp is the time (in ns) that the last packed used to create
+ // the read data was received.
+ Timestamp int64
+}
+
// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
// that exposes functionality like read, write, connect, etc. to users of the
// networking stack.
@@ -219,33 +288,44 @@
Close()
// Read reads data from the endpoint and optionally returns the sender.
- // This method does not block if there is no data pending.
- // It will also either return an error or data, never both.
- Read(*FullAddress) (buffer.View, *Error)
+ //
+ // This method does not block if there is no data pending. It will also
+ // either return an error or data, never both.
+ //
+ // A timestamp (in ns) is optionally returned. A zero value indicates
+ // that no timestamp was available.
+ Read(*FullAddress) (buffer.View, ControlMessages, *Error)
- // Write writes data to the endpoint's peer, or the provided address if
- // one is specified. This method does not block if the data cannot be
- // written.
+ // Write writes data to the endpoint's peer. This method does not block if
+ // the data cannot be written.
+ //
+ // Unlike io.Writer.Write, Endpoint.Write transfers ownership of any bytes
+ // successfully written to the Endpoint. That is, if a call to
+ // Write(SlicePayload{data}) returns (n, err), it may retain data[:n], and
+ // the caller should not use data[:n] after Write returns.
//
// Note that unlike io.Writer.Write, it is not an error for Write to
// perform a partial write.
- Write(buffer.View, *FullAddress) (uintptr, *Error)
+ Write(Payload, WriteOptions) (uintptr, *Error)
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
- Peek([][]byte) (uintptr, *Error)
+ //
+ // A timestamp (in ns) is optionally returned. A zero value indicates
+ // that no timestamp was available.
+ Peek([][]byte) (uintptr, ControlMessages, *Error)
// Connect connects the endpoint to its peer. Specifying a NIC is
// optional.
//
// There are three classes of return values:
// nil -- the attempt to connect succeeded.
- // ErrConnectStarted -- the connect attempt started but hasn't
- // completed yet. In this case, the actual result will
- // become available via GetSockOpt(ErrorOption) when
- // the endpoint becomes writable. (This mimics the
- // connect(2) syscall behavior.)
+ // ErrConnectStarted/ErrAlreadyConnecting -- the connect attempt started
+ // but hasn't completed yet. In this case, the caller must call Connect
+ // or GetSockOpt(ErrorOption) when the endpoint becomes writable to
+ // get the actual result. The first call to Connect after the socket has
+ // connected returns nil. Calling connect again results in ErrAlreadyConnected.
// Anything else -- the attempt to connect failed.
Connect(address FullAddress) *Error
@@ -291,6 +371,19 @@
GetSockOpt(opt interface{}) *Error
}
+// WriteOptions contains options for Endpoint.Write.
+type WriteOptions struct {
+ // If To is not nil, write to the given address instead of the endpoint's
+ // peer.
+ To *FullAddress
+
+ // More has the same semantics as Linux's MSG_MORE.
+ More bool
+
+ // EndOfRecord has the same semantics as Linux's MSG_EOR.
+ EndOfRecord bool
+}
+
// ErrorOption is used in GetSockOpt to specify that the last error reported by
// the endpoint should be cleared and returned.
type ErrorOption struct{}
@@ -330,12 +423,24 @@
// Only supported on Unix sockets.
type PasscredOption int
+// TimestampOption is used by SetSockOpt/GetSockOpt to specify whether
+// SO_TIMESTAMP socket control messages are enabled.
+type TimestampOption int
+
+// TCPInfoOption is used by GetSockOpt to expose TCP statistics.
+//
+// TODO: Add and populate stat fields.
+type TCPInfoOption struct {
+ RTT time.Duration
+ RTTVar time.Duration
+}
+
// KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether
// TCP keepalive is enabled for this socket.
type KeepaliveEnabledOption int
-// KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time
-// a connection must remain idle before the first TCP keepalive packet is sent.
+// KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time a
+// connection must remain idle before the first TCP keepalive packet is sent.
// Once this time is reached, KeepaliveIntervalOption is used instead.
type KeepaliveIdleOption time.Duration
@@ -343,9 +448,9 @@
// interval between sending TCP keepalive packets.
type KeepaliveIntervalOption time.Duration
-// KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the
-// number of un-ACKed TCP keepalives that will be sent before the connection
-// is closed.
+// KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number
+// of un-ACKed TCP keepalives that will be sent before the connection is
+// closed.
type KeepaliveCountOption int
// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
@@ -353,7 +458,7 @@
type MulticastTTLOption uint8
// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to
-// AddMembershipOption and RemoveMembershipOption
+// AddMembershipOption and RemoveMembershipOption.
type MembershipOption struct {
NIC NICID
InterfaceAddr Address
@@ -370,12 +475,6 @@
// the given interface address.
type RemoveMembershipOption MembershipOption
-// InfoOption is used by GetSockOpt to query various metrics about the socket.
-type InfoOption struct {
- Rtt time.Duration
- Rttvar time.Duration
-}
-
// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination adddress in the row.
@@ -419,81 +518,116 @@
// NetworkProtocolNumber is the number of a network protocol.
type NetworkProtocolNumber uint32
+// A StatCounter keeps track of a statistic.
+type StatCounter struct {
+ count uint64
+}
+
+// Increment adds one to the counter.
+func (s *StatCounter) Increment() {
+ s.IncrementBy(1)
+}
+
+// Value returns the current value of the counter.
+func (s *StatCounter) Value() uint64 {
+ return atomic.LoadUint64(&s.count)
+}
+
+// IncrementBy increments the counter by v.
+func (s *StatCounter) IncrementBy(v uint64) {
+ atomic.AddUint64(&s.count, v)
+}
+
// IPStats collects IP-specific stats (both v4 and v6).
type IPStats struct {
// PacketsReceived is the total number of IP packets received from the link
// layer in nic.DeliverNetworkPacket.
- PacketsReceived uint64
+ PacketsReceived *StatCounter
+
// InvalidAddressesReceived is the total number of IP packets received
// with an unknown or invalid destination address.
- InvalidAddressesReceived uint64
- // PacketsDiscarded is the total number of IP packets received from the link
- // layer but not delivered to the transport layer.
- PacketsDiscarded uint64
+ InvalidAddressesReceived *StatCounter
+
// PacketsDelivered is the total number of incoming IP packets that
// are successfully delivered to the transport layer via HandlePacket.
- PacketsDelivered uint64
+ PacketsDelivered *StatCounter
+
// PacketsSent is the total number of IP packets sent via WritePacket.
- PacketsSent uint64
+ PacketsSent *StatCounter
+
// OutgoingPacketErrors is the total number of IP packets which failed
// to write to a link-layer endpoint.
- OutgoingPacketErrors uint64
+ OutgoingPacketErrors *StatCounter
}
// TCPStats collects TCP-specific stats.
type TCPStats struct {
// ActiveConnectionOpenings is the number of connections opened successfully
// via Connect.
- ActiveConnectionOpenings uint64
+ ActiveConnectionOpenings *StatCounter
+
// PassiveConnectionOpenings is the number of connections opened
// successfully via Listen.
- PassiveConnectionOpenings uint64
+ PassiveConnectionOpenings *StatCounter
+
// FailedConnectionAttempts is the number of calls to Connect or Listen
// (active and passive openings, respectively) that end in an error.
- FailedConnectionAttempts uint64
+ FailedConnectionAttempts *StatCounter
+
// ValidSegmentsReceived is the number of TCP segments received that the
// transport layer successfully parsed.
- ValidSegmentsReceived uint64
+ ValidSegmentsReceived *StatCounter
+
// InvalidSegmentsReceived is the number of TCP segments received that
// the transport layer could not parse.
- InvalidSegmentsReceived uint64
+ InvalidSegmentsReceived *StatCounter
+
// SegmentsSent is the number of TCP segments sent.
- SegmentsSent uint64
+ SegmentsSent *StatCounter
+
// ResetsSent is the number of TCP resets sent.
- ResetsSent uint64
+ ResetsSent *StatCounter
+
// ResetsReceived is the number of TCP resets received.
- ResetsReceived uint64
+ ResetsReceived *StatCounter
}
// UDPStats collects UDP-specific stats.
type UDPStats struct {
- // PacketsReceived is the number of UDP datagrams received via HandlePacket.
- PacketsReceived uint64
- // UnknownPortErrors is the number of incoming UDP datagrams dropped because
- // they did not have a known destination port.
- UnknownPortErrors uint64
- // ReceiveBufferErrors is the number of incoming UDP datagrams dropped due to the
- // receiving buffer being in an invalid state.
- ReceiveBufferErrors uint64
- // MalformedPacketsReceived is the number of incoming UDP datagrams dropped due to
- // the UDP header being in a malformed state.
- MalformedPacketsReceived uint64
+ // PacketsReceived is the number of UDP datagrams received via
+ // HandlePacket.
+ PacketsReceived *StatCounter
+
+ // UnknownPortErrors is the number of incoming UDP datagrams dropped
+ // because they did not have a known destination port.
+ UnknownPortErrors *StatCounter
+
+ // ReceiveBufferErrors is the number of incoming UDP datagrams dropped
+ // due to the receiving buffer being in an invalid state.
+ ReceiveBufferErrors *StatCounter
+
+ // MalformedPacketsReceived is the number of incoming UDP datagrams
+ // dropped due to the UDP header being in a malformed state.
+ MalformedPacketsReceived *StatCounter
+
// PacketsSent is the number of UDP datagrams sent via sendUDP.
- PacketsSent uint64
+ PacketsSent *StatCounter
}
// Stats holds statistics about the networking stack.
+//
+// All fields are optional.
type Stats struct {
- // UnkownProtocolRcvdPackets is the number of packets received by the
+ // UnknownProtocolRcvdPackets is the number of packets received by the
// stack that were for an unknown or unsupported protocol.
- UnknownProtocolRcvdPackets uint64
+ UnknownProtocolRcvdPackets *StatCounter
// MalformedRcvPackets is the number of packets received by the stack
// that were deemed malformed.
- MalformedRcvdPackets uint64
+ MalformedRcvdPackets *StatCounter
// DroppedPackets is the number of packets dropped due to full queues.
- DroppedPackets uint64
+ DroppedPackets *StatCounter
// IP breaks out IP-specific stats (both v4 and v6).
IP IPStats
@@ -505,6 +639,28 @@
UDP UDPStats
}
+func fillIn(v reflect.Value) {
+ for i := 0; i < v.NumField(); i++ {
+ v := v.Field(i)
+ switch v.Kind() {
+ case reflect.Ptr:
+ if s, ok := v.Addr().Interface().(**StatCounter); ok {
+ if *s == nil {
+ *s = &StatCounter{}
+ }
+ }
+ case reflect.Struct:
+ fillIn(v)
+ }
+ }
+}
+
+// FillIn returns a copy of s with nil fields initialized to new StatCounters.
+func (s Stats) FillIn() Stats {
+ fillIn(reflect.ValueOf(&s).Elem())
+ return s
+}
+
// String implements the fmt.Stringer interface.
func (a Address) String() string {
switch len(a) {
@@ -523,39 +679,35 @@
}
}
- var b []byte
+ var b strings.Builder
for i := 0; i < len(a); i += 2 {
if i == start {
- b = append(b, "::"...)
+ b.WriteString("::")
i = end
if end >= len(a) {
break
}
} else if i > 0 {
- b = append(b, ':')
+ b.WriteByte(':')
}
v := uint16(a[i+0])<<8 | uint16(a[i+1])
- b = appendHex(b, v)
+ if v == 0 {
+ b.WriteByte('0')
+ } else {
+ const digits = "0123456789abcdef"
+ for i := uint(3); i < 4; i-- {
+ if v := v >> (i * 4); v != 0 {
+ b.WriteByte(digits[v&0xf])
+ }
+ }
+ }
}
- return string(b)
+ return b.String()
default:
return fmt.Sprintf("%x", []byte(a))
}
}
-func appendHex(b []byte, v uint16) []byte {
- if v == 0 {
- return append(b, '0')
- }
- const digits = "0123456789abcdef"
- for i := uint(3); i < 4; i-- {
- if v := v >> (i * 4); v != 0 {
- b = append(b, digits[v&0xf])
- }
- }
- return b
-}
-
// To4 converts the IPv4 address to a 4-byte representation.
// If the address is not an IPv4 address, To4 returns "".
func (a Address) To4() Address {
@@ -585,150 +737,6 @@
return true
}
-// Copied from pkg Net to avoid taking a dependency.
-func CIDRMask(ones, bits int) AddressMask {
- // header.IPv4AddressSize, header.IPv6AddressSize
- if bits != 8*4 && bits != 8*16 {
- return AddressMask("")
- }
- if ones < 0 || ones > bits {
- return AddressMask("")
- }
- l := bits / 8
- m := make([]byte, l)
- n := uint(ones)
- for i := 0; i < l; i++ {
- if n >= 8 {
- m[i] = 0xff
- n -= 8
- continue
- }
- m[i] = ^byte(0xff >> n)
- n = 0
- }
- return AddressMask(m)
-}
-
-func ParseCIDR(subnet string) (Address, Subnet, error) {
- split := strings.Split(subnet, "/")
- if len(split) != 2 {
- return Address(""), Subnet{}, errInvalidCIDRNotation
- }
- addr := Parse(split[0])
- ones, err := strconv.ParseInt(split[1], 10, 8)
-
- if err != nil {
- return Address(""), Subnet{}, err
- }
-
- mask := CIDRMask(int(ones), 8*len(addr))
- sn, err := NewSubnet(addr.Mask(mask), mask)
- return addr, sn, err
-}
-
-// Parse parses the string representation of an IPv4 or IPv6 address.
-func Parse(src string) Address {
- for i := 0; i < len(src); i++ {
- switch src[i] {
- case '.':
- return parseIP4(src)
- case ':':
- return parseIP6(src)
- }
- }
- return ""
-}
-
-func parseIP4(src string) Address {
- var addr [4]byte
- _, err := fmt.Sscanf(src, "%d.%d.%d.%d", &addr[0], &addr[1], &addr[2], &addr[3])
- if err != nil {
- return ""
- }
- return Address(addr[:])
-}
-
-func parseIP6(src string) (res Address) {
- a := make([]byte, 0, 16) // cap(a) is constant throughout
- expansion := -1 // index of '::' expansion in a
-
- if len(src) >= 2 && src[:2] == "::" {
- if len(src) == 2 {
- return Address(a[:cap(a)])
- }
- expansion = 0
- src = src[2:]
- }
-
- for len(a) < cap(a) && len(src) > 0 {
- var x uint16
- var ok bool
- x, src, ok = parseHex(src)
- if !ok {
- return ""
- }
- a = append(a, uint8(x>>8), uint8(x))
-
- if len(src) == 0 {
- break
- }
-
- // Next is either ":..." or "::[...]".
- if src[0] != ':' || len(src) == 1 {
- return ""
- }
- src = src[1:]
- if src[0] == ':' {
- if expansion >= 0 {
- return "" // only one expansion allowed
- }
- expansion = len(a)
- src = src[1:]
- }
- }
- if len(src) != 0 {
- return ""
- }
-
- if missing := cap(a) - len(a); missing > 0 {
- if expansion < 0 {
- return ""
- }
- a = a[:cap(a)]
- copy(a[expansion+missing:], a[expansion:])
- for i := 0; i < missing; i++ {
- a[i+expansion] = 0
- }
- }
-
- return Address(a)
-}
-
-func parseHex(src string) (x uint16, remaining string, ok bool) {
- if len(src) == 0 {
- return 0, src, false
- }
-loop:
- for len(src) > 0 {
- v := src[0]
- switch {
- case '0' <= v && v <= '9':
- v = v - '0'
- case 'a' <= v && v <= 'f':
- v = v - 'a' + 10
- case 'A' <= v && v <= 'F':
- v = v - 'A' + 10
- case v == ':':
- break loop
- default:
- return 0, src, false
- }
- src = src[1:]
- x = (x << 4) | uint16(v)
- }
- return x, src, true
-}
-
// LinkAddress is a byte slice cast as a string that represents a link address.
// It is typically a 6-byte MAC address.
type LinkAddress string
@@ -742,3 +750,69 @@
return fmt.Sprintf("%x", []byte(a))
}
}
+
+// ParseMACAddress parses an IEEE 802 address.
+//
+// It must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff.
+func ParseMACAddress(s string) (LinkAddress, error) {
+ parts := strings.FieldsFunc(s, func(c rune) bool {
+ return c == ':' || c == '-'
+ })
+ if len(parts) != 6 {
+ return "", fmt.Errorf("inconsistent parts: %s", s)
+ }
+ addr := make([]byte, 0, len(parts))
+ for _, part := range parts {
+ u, err := strconv.ParseUint(part, 16, 8)
+ if err != nil {
+ return "", fmt.Errorf("invalid hex digits: %s", s)
+ }
+ addr = append(addr, byte(u))
+ }
+ return LinkAddress(addr), nil
+}
+
+// ProtocolAddress is an address and the network protocol it is associated
+// with.
+type ProtocolAddress struct {
+ // Protocol is the protocol of the address.
+ Protocol NetworkProtocolNumber
+
+ // Address is a network address.
+ Address Address
+}
+
+// danglingEndpointsMu protects access to danglingEndpoints.
+var danglingEndpointsMu sync.Mutex
+
+// danglingEndpoints tracks all dangling endpoints no longer owned by the app.
+var danglingEndpoints = make(map[Endpoint]struct{})
+
+// GetDanglingEndpoints returns all dangling endpoints.
+func GetDanglingEndpoints() []Endpoint {
+ es := make([]Endpoint, 0, len(danglingEndpoints))
+ danglingEndpointsMu.Lock()
+ for e := range danglingEndpoints {
+ es = append(es, e)
+ }
+ danglingEndpointsMu.Unlock()
+ return es
+}
+
+// AddDanglingEndpoint adds a dangling endpoint.
+func AddDanglingEndpoint(e Endpoint) {
+ danglingEndpointsMu.Lock()
+ danglingEndpoints[e] = struct{}{}
+ danglingEndpointsMu.Unlock()
+}
+
+// DeleteDanglingEndpoint removes a dangling endpoint.
+func DeleteDanglingEndpoint(e Endpoint) {
+ danglingEndpointsMu.Lock()
+ delete(danglingEndpoints, e)
+ danglingEndpointsMu.Unlock()
+}
+
+// AsyncLoading is the global barrier for asynchronous endpoint loading
+// activities.
+var AsyncLoading sync.WaitGroup
diff --git a/tcpip/tcpip_test.go b/tcpip/tcpip_test.go
index 4425757..9b20c74 100644
--- a/tcpip/tcpip_test.go
+++ b/tcpip/tcpip_test.go
@@ -1,12 +1,21 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package tcpip
import (
- "reflect"
- "strings"
+ "net"
"testing"
)
@@ -111,32 +120,6 @@
}
}
-// Copied from pkg Net (ip_test.go).
-var parseCIDRTests = []struct {
- in string
- address Address
- subnet Subnet
- err error
-}{
- {"135.104.0.0/32", Parse("135.104.0.0"), Subnet{address: Parse("135.104.0.0"), mask: AddressMask(Parse("255.255.255.255"))}, nil},
- {"0.0.0.0/24", Parse("0.0.0.0"), Subnet{address: Parse("0.0.0.0"), mask: AddressMask(Parse("255.255.255.0"))}, nil},
- {"135.104.0.0/24", Parse("135.104.0.0"), Subnet{address: Parse("135.104.0.0"), mask: AddressMask(Parse("255.255.255.0"))}, nil},
- {"135.104.0.1/32", Parse("135.104.0.1"), Subnet{address: Parse("135.104.0.1"), mask: AddressMask(Parse("255.255.255.255"))}, nil},
- {"135.104.0.1/24", Parse("135.104.0.1"), Subnet{address: Parse("135.104.0.0"), mask: AddressMask(Parse("255.255.255.0"))}, nil},
-}
-
-func TestParseCIDR(t *testing.T) {
- for _, tt := range parseCIDRTests {
- address, subnet, err := ParseCIDR(tt.in)
- if !reflect.DeepEqual(err, tt.err) {
- t.Errorf("ParseCIDR(%q) = %v, %v; want %s, %+v", tt.in, address, subnet, []byte(tt.address), []byte(tt.subnet.address))
- }
- if err == nil && !(tt.address == address) || !(tt.subnet.address == subnet.address) || !reflect.DeepEqual(subnet.Mask(), tt.subnet.Mask()) {
- t.Errorf("ParseCIDR(%q) = %s, {%s, %s}; want %s, {%s, %s}", tt.in, address, subnet.address, subnet.Mask(), tt.address, tt.subnet.address, tt.subnet.Mask())
- }
- }
-}
-
func TestRouteMatch(t *testing.T) {
tests := []struct {
d Address
@@ -157,42 +140,56 @@
}
}
-func TestParse(t *testing.T) {
- tests := []struct {
- txt string
- addr Address
- }{
- {"::", Address(strings.Repeat("\x00", 16))},
- {"8::", Address("\x00\x08" + strings.Repeat("\x00", 14))},
- {"::8a", Address(strings.Repeat("\x00", 14) + "\x00\x8a")},
- {"fe80::1234:5678", "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x12\x34\x56\x78"},
- {"fe80::b097:c9ff:fe02:477", "\xfe\x80\x00\x00\x00\x00\x00\x00\xb0\x97\xc9\xff\xfe\x02\x04\x77"},
- {"a:b:c:d:1:2:3:4", "\x00\x0a\x00\x0b\x00\x0c\x00\x0d\x00\x01\x00\x02\x00\x03\x00\x04"},
- {"a:b:c::2:3:4", "\x00\x0a\x00\x0b\x00\x0c\x00\x00\x00\x00\x00\x02\x00\x03\x00\x04"},
- {"000a:000b:000c::", "\x00\x0a\x00\x0b\x00\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"0000:0000:0000::0001", Address(strings.Repeat("\x00", 15) + "\x01")},
- {"0:0::1", Address(strings.Repeat("\x00", 15) + "\x01")},
- }
-
- for _, test := range tests {
- got := Parse(test.txt)
- if got != test.addr {
- t.Errorf("Parse(%v)=%v, want %v", test.txt, got, test.addr)
- }
- }
-}
-
func TestAddressString(t *testing.T) {
- tests := []string{
- "a:b:c::2:3:4",
- "8::",
- "fe80::5054:ff:fe12:3456",
+ for _, want := range []string{
+ // Taken from stdlib.
+ "2001:db8::123:12:1",
+ "2001:db8::1",
+ "2001:db8:0:1:0:1:0:1",
+ "2001:db8:1:0:1:0:1:0",
+ "2001::1:0:0:1",
+ "2001:db8:0:0:1::",
+ "2001:db8::1:0:0:1",
+ "2001:db8::a:b:c:d",
+
+ // Leading zeros.
"::1",
- }
- for _, want := range tests {
- addr := Parse(want)
+ // Trailing zeros.
+ "8::",
+ // No zeros.
+ "1:1:1:1:1:1:1:1",
+ // Longer sequence is after other zeros, but not at the end.
+ "1:0:0:1::1",
+ // Longer sequence is at the beginning, shorter sequence is at
+ // the end.
+ "::1:1:1:0:0",
+ // Longer sequence is not at the beginning, shorter sequence is
+ // at the end.
+ "1::1:1:0:0",
+ // Longer sequence is at the beginning, shorter sequence is not
+ // at the end.
+ "::1:1:0:0:1",
+ // Neither sequence is at an end, longer is after shorter.
+ "1:0:0:1::1",
+ // Shorter sequence is at the beginning, longer sequence is not
+ // at the end.
+ "0:0:1:1::1",
+ // Shorter sequence is at the beginning, longer sequence is at
+ // the end.
+ "0:0:1:1:1::",
+ // Short sequences at both ends, longer one in the middle.
+ "0:1:1::1:1:0",
+ // Short sequences at both ends, longer one in the middle.
+ "0:1::1:0:0",
+ // Short sequences at both ends, longer one in the middle.
+ "0:0:1::1:0",
+ // Longer sequence surrounded by shorter sequences, but none at
+ // the end.
+ "1:0:1::1:0:1",
+ } {
+ addr := Address(net.ParseIP(want))
if got := addr.String(); got != want {
- t.Errorf("Address(%x).String()=%q, want %q", addr, got, want)
+ t.Errorf("Address(%x).String() = '%s', want = '%s'", addr, got, want)
}
}
}
diff --git a/tcpip/time.s b/tcpip/time.s
new file mode 100644
index 0000000..8aca31b
--- /dev/null
+++ b/tcpip/time.s
@@ -0,0 +1,15 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Empty assembly file so empty func definitions work.
diff --git a/tcpip/time_unsafe.go b/tcpip/time_unsafe.go
new file mode 100644
index 0000000..2102e96
--- /dev/null
+++ b/tcpip/time_unsafe.go
@@ -0,0 +1,43 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.9
+// +build !go1.12
+
+package tcpip
+
+import (
+ _ "time" // Used with go:linkname.
+ _ "unsafe" // Required for go:linkname.
+)
+
+// StdClock implements Clock with the time package.
+type StdClock struct{}
+
+var _ Clock = (*StdClock)(nil)
+
+//go:linkname now time.now
+func now() (sec int64, nsec int32, mono int64)
+
+// NowNanoseconds implements Clock.NowNanoseconds.
+func (*StdClock) NowNanoseconds() int64 {
+ sec, nsec, _ := now()
+ return sec*1e9 + int64(nsec)
+}
+
+// NowMonotonic implements Clock.NowMonotonic.
+func (*StdClock) NowMonotonic() int64 {
+ _, _, mono := now()
+ return mono
+}
diff --git a/tcpip/transport/ping/endpoint.go b/tcpip/transport/ping/endpoint.go
new file mode 100644
index 0000000..db7b146
--- /dev/null
+++ b/tcpip/transport/ping/endpoint.go
@@ -0,0 +1,713 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ping
+
+import (
+ "encoding/binary"
+ "sync"
+
+ "github.com/google/netstack/sleep"
+ "github.com/google/netstack/tcpip"
+ "github.com/google/netstack/tcpip/buffer"
+ "github.com/google/netstack/tcpip/header"
+ "github.com/google/netstack/tcpip/stack"
+ "github.com/google/netstack/waiter"
+)
+
+// +stateify savable
+type pingPacket struct {
+ pingPacketEntry
+ senderAddress tcpip.FullAddress
+ data buffer.VectorisedView
+ timestamp int64
+ hasTimestamp bool
+ // views is used as buffer for data when its length is large
+ // enough to store a VectorisedView.
+ views [8]buffer.View
+}
+
+type endpointState int
+
+const (
+ stateInitial endpointState = iota
+ stateBound
+ stateConnected
+ stateClosed
+)
+
+// endpoint represents a ping endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized.
+type endpoint struct {
+ // The following fields are initialized at creation time and do not
+ // change throughout the lifetime of the endpoint.
+ stack *stack.Stack
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+ waiterQueue *waiter.Queue
+
+ // The following fields are used to manage the receive queue, and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex
+ rcvReady bool
+ rcvList pingPacketList
+ rcvBufSizeMax int
+ rcvBufSize int
+ rcvClosed bool
+ rcvTimestamp bool
+
+ // The following fields are protected by the mu mutex.
+ mu sync.RWMutex
+ sndBufSize int
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+ id stack.TransportEndpointID
+ state endpointState
+ bindNICID tcpip.NICID
+ bindAddr tcpip.Address
+ regNICID tcpip.NICID
+ route stack.Route
+}
+
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+ return &endpoint{
+ stack: stack,
+ netProto: netProto,
+ transProto: transProto,
+ waiterQueue: waiterQueue,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }
+}
+
+// Close puts the endpoint in a closed state and frees all resources
+// associated with it.
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
+ switch e.state {
+ case stateBound, stateConnected:
+ e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id)
+ }
+
+ // Close the receive list and drain it.
+ e.rcvMu.Lock()
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ }
+ e.rcvMu.Unlock()
+
+ e.route.Release()
+
+ // Update the state.
+ e.state = stateClosed
+
+ e.mu.Unlock()
+
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// Read reads data from the endpoint. This method does not block if
+// there is no data pending.
+func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ e.rcvMu.Lock()
+
+ if e.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if e.rcvClosed {
+ err = tcpip.ErrClosedForReceive
+ }
+ e.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
+ ts := e.rcvTimestamp
+
+ e.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = p.senderAddress
+ }
+
+ if ts && !p.hasTimestamp {
+ // Linux uses the current time.
+ p.timestamp = e.stack.NowNanoseconds()
+ }
+
+ return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: ts, Timestamp: p.timestamp}, nil
+}
+
+// prepareForWrite prepares the endpoint for sending data. In particular, it
+// binds it if it's still in the initial state. To do so, it must first
+// reacquire the mutex in exclusive mode.
+//
+// Returns true for retry if preparation should be retried.
+func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
+ switch e.state {
+ case stateInitial:
+ case stateConnected:
+ return false, nil
+
+ case stateBound:
+ if to == nil {
+ return false, tcpip.ErrDestinationRequired
+ }
+ return false, nil
+ default:
+ return false, tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // The state changed when we released the shared locked and re-acquired
+ // it in exclusive mode. Try again.
+ if e.state != stateInitial {
+ return true, nil
+ }
+
+ // The state is still 'initial', so try to bind the endpoint.
+ if err := e.bindLocked(tcpip.FullAddress{}, nil); err != nil {
+ return false, err
+ }
+
+ return true, nil
+}
+
+// Write writes data to the endpoint's peer. This method does not block
+// if the data cannot be written.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) {
+ // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
+ if opts.More {
+ return 0, tcpip.ErrInvalidOptionValue
+ }
+
+ to := opts.To
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // If we've shutdown with SHUT_WR we are in an invalid state for sending.
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
+ return 0, tcpip.ErrClosedForSend
+ }
+
+ // Prepare for write.
+ for {
+ retry, err := e.prepareForWrite(to)
+ if err != nil {
+ return 0, err
+ }
+
+ if !retry {
+ break
+ }
+ }
+
+ var route *stack.Route
+ if to == nil {
+ route = &e.route
+
+ if route.IsResolutionRequired() {
+ // Promote lock to exclusive if using a shared route, given that it may
+ // need to change in Route.Resolve() call below.
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != stateConnected {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+ }
+ } else {
+ // Reject destination address if it goes through a different
+ // NIC than the endpoint was bound to.
+ nicid := to.NIC
+ if e.bindNICID != 0 {
+ if nicid != 0 && nicid != e.bindNICID {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ nicid = e.bindNICID
+ }
+
+ toCopy := *to
+ to = &toCopy
+ netProto, err := e.checkV4Mapped(to, true)
+ if err != nil {
+ return 0, err
+ }
+
+ // Find the enpoint.
+ r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto)
+ if err != nil {
+ return 0, err
+ }
+ defer r.Release()
+
+ route = &r
+ }
+
+ if route.IsResolutionRequired() {
+ waker := &sleep.Waker{}
+ if err := route.Resolve(waker); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ // Link address needs to be resolved. Resolution was triggered the
+ // background. Better luck next time.
+ //
+ // TODO: queue up the request and send after link address
+ // is resolved.
+ route.RemoveWaker(waker)
+ return 0, tcpip.ErrNoLinkAddress
+ }
+ return 0, err
+ }
+ }
+
+ v, err := p.Get(p.Size())
+ if err != nil {
+ return 0, err
+ }
+
+ switch e.netProto {
+ case header.IPv4ProtocolNumber:
+ err = sendPing4(route, e.id.LocalPort, v)
+
+ case header.IPv6ProtocolNumber:
+ err = sendPing6(route, e.id.LocalPort, v)
+ }
+
+ return uintptr(len(v)), err
+}
+
+// Peek only returns data from a single datagram, so do nothing here.
+func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// SetSockOpt sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.TimestampOption:
+ e.rcvMu.Lock()
+ e.rcvTimestamp = v != 0
+ e.rcvMu.Unlock()
+ }
+ return nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ *o = tcpip.SendBufferSizeOption(e.sndBufSize)
+ e.mu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
+ e.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ e.rcvMu.Lock()
+ if e.rcvList.Empty() {
+ *o = 0
+ } else {
+ p := e.rcvList.Front()
+ *o = tcpip.ReceiveQueueSizeOption(p.data.Size())
+ }
+ e.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.TimestampOption:
+ e.rcvMu.Lock()
+ *o = 0
+ if e.rcvTimestamp {
+ *o = 1
+ }
+ e.rcvMu.Unlock()
+ }
+
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
+ if len(data) < header.ICMPv4EchoMinimumSize {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Set the ident. Sequence number is provided by the user.
+ binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], ident)
+
+ hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength()))
+
+ icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize))
+ copy(icmpv4, data)
+ data = data[header.ICMPv4EchoMinimumSize:]
+
+ // Linux performs these basic checks.
+ if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ icmpv4.SetChecksum(0)
+ icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
+
+ return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL())
+}
+
+func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
+ if len(data) < header.ICMPv6EchoMinimumSize {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Set the ident. Sequence number is provided by the user.
+ binary.BigEndian.PutUint16(data[header.ICMPv6MinimumSize:], ident)
+
+ hdr := buffer.NewPrependable(header.ICMPv6EchoMinimumSize + int(r.MaxHeaderLength()))
+
+ icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ copy(icmpv6, data)
+ data = data[header.ICMPv6EchoMinimumSize:]
+
+ if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ icmpv6.SetChecksum(0)
+ icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0)))
+
+ return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, r.DefaultTTL())
+}
+
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.netProto
+ if header.IsV4MappedAddress(addr.Addr) {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ // Fail if we're bound to an address length different from the one we're
+ // checking.
+ if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ return netProto, nil
+}
+
+// Connect connects the endpoint to its peer. Specifying a NIC is optional.
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ nicid := addr.NIC
+ localPort := uint16(0)
+ switch e.state {
+ case stateBound, stateConnected:
+ localPort = e.id.LocalPort
+ if e.bindNICID == 0 {
+ break
+ }
+
+ if nicid != 0 && nicid != e.bindNICID {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ nicid = e.bindNICID
+ default:
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return err
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ id := stack.TransportEndpointID{
+ LocalAddress: r.LocalAddress,
+ LocalPort: localPort,
+ RemoteAddress: r.RemoteAddress,
+ }
+
+ // Even if we're connected, this endpoint can still be used to send
+ // packets on a different network protocol, so we register both even if
+ // v6only is set to false and this is an ipv6 endpoint.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+
+ id, err = e.registerWithStack(nicid, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ e.id = id
+ e.route = r.Clone()
+ e.regNICID = nicid
+
+ e.state = stateConnected
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// ConnectEndpoint is not supported.
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection
+// to its peer.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.shutdownFlags |= flags
+
+ if e.state != stateConnected {
+ return tcpip.ErrNotConnected
+ }
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.rcvMu.Lock()
+ wasClosed := e.rcvClosed
+ e.rcvClosed = true
+ e.rcvMu.Unlock()
+
+ if !wasClosed {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+ }
+
+ return nil
+}
+
+// Listen is not supported by UDP, it just fails.
+func (*endpoint) Listen(int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept is not supported by UDP, it just fails.
+func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+ if id.LocalPort != 0 {
+ // The endpoint already has a local port, just attempt to
+ // register it.
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e)
+ return id, err
+ }
+
+ // We need to find a port for the endpoint.
+ _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ id.LocalPort = p
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e)
+ switch err {
+ case nil:
+ return true, nil
+ case tcpip.ErrPortInUse:
+ return false, nil
+ default:
+ return false, err
+ }
+ })
+
+ return id, err
+}
+
+func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return err
+ }
+
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+
+ if len(addr.Addr) != 0 {
+ // A local address was specified, verify that it's valid.
+ if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+ }
+
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: addr.Addr,
+ }
+ id, err = e.registerWithStack(addr.NIC, netProtos, id)
+ if err != nil {
+ return err
+ }
+ if commit != nil {
+ if err := commit(); err != nil {
+ // Unregister, the commit failed.
+ e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, e.transProto, id)
+ return err
+ }
+ }
+
+ e.id = id
+ e.regNICID = addr.NIC
+
+ // Mark endpoint as bound.
+ e.state = stateBound
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// Bind binds the endpoint to a specific local address and port.
+// Specifying a NIC is optional.
+func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ err := e.bindLocked(addr, commit)
+ if err != nil {
+ return err
+ }
+
+ e.bindNICID = addr.NIC
+ e.bindAddr = addr.Addr
+
+ return nil
+}
+
+// GetLocalAddress returns the address to which the endpoint is bound.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.LocalAddress,
+ Port: e.id.LocalPort,
+ }, nil
+}
+
+// GetRemoteAddress returns the address to which the endpoint is connected.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state != stateConnected {
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ }
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.RemoteAddress,
+ Port: e.id.RemotePort,
+ }, nil
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine if the endpoint is readable if requested.
+ if (mask & waiter.EventIn) != 0 {
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() || e.rcvClosed {
+ result |= waiter.EventIn
+ }
+ e.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+ e.rcvMu.Lock()
+
+ // Drop the packet if our buffer is currently full.
+ if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
+ e.rcvMu.Unlock()
+ return
+ }
+
+ wasEmpty := e.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ pkt := &pingPacket{
+ senderAddress: tcpip.FullAddress{
+ NIC: r.NICID(),
+ Addr: id.RemoteAddress,
+ },
+ }
+ pkt.data = vv.Clone(pkt.views[:])
+ e.rcvList.PushBack(pkt)
+ e.rcvBufSize += vv.Size()
+
+ if e.rcvTimestamp {
+ pkt.timestamp = e.stack.NowNanoseconds()
+ pkt.hasTimestamp = true
+ }
+
+ e.rcvMu.Unlock()
+
+ // Notify any waiters that there's data to be read now.
+ if wasEmpty {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+}
diff --git a/tcpip/transport/ping/ping_packet_list.go b/tcpip/transport/ping/ping_packet_list.go
new file mode 100644
index 0000000..e3db86f
--- /dev/null
+++ b/tcpip/transport/ping/ping_packet_list.go
@@ -0,0 +1,154 @@
+package ping
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+type pingPacketList struct {
+ head *pingPacket
+ tail *pingPacket
+}
+
+// Reset resets list l to the empty state.
+func (l *pingPacketList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *pingPacketList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *pingPacketList) Front() *pingPacket {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *pingPacketList) Back() *pingPacket {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *pingPacketList) PushFront(e *pingPacket) {
+ e.SetNext(l.head)
+ e.SetPrev(nil)
+
+ if l.head != nil {
+ l.head.SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *pingPacketList) PushBack(e *pingPacket) {
+ e.SetNext(nil)
+ e.SetPrev(l.tail)
+
+ if l.tail != nil {
+ l.tail.SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *pingPacketList) PushBackList(m *pingPacketList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ l.tail.SetNext(m.head)
+ m.head.SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *pingPacketList) InsertAfter(b, e *pingPacket) {
+ a := b.Next()
+ e.SetNext(a)
+ e.SetPrev(b)
+ b.SetNext(e)
+
+ if a != nil {
+ a.SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *pingPacketList) InsertBefore(a, e *pingPacket) {
+ b := a.Prev()
+ e.SetNext(a)
+ e.SetPrev(b)
+ a.SetPrev(e)
+
+ if b != nil {
+ b.SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *pingPacketList) Remove(e *pingPacket) {
+ prev := e.Prev()
+ next := e.Next()
+
+ if prev != nil {
+ prev.SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ next.SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+type pingPacketEntry struct {
+ next *pingPacket
+ prev *pingPacket
+}
+
+// Next returns the entry that follows e in the list.
+func (e *pingPacketEntry) Next() *pingPacket {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *pingPacketEntry) Prev() *pingPacket {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *pingPacketEntry) SetNext(entry *pingPacket) {
+ e.next = entry
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *pingPacketEntry) SetPrev(entry *pingPacket) {
+ e.prev = entry
+}
diff --git a/tcpip/transport/ping/protocol.go b/tcpip/transport/ping/protocol.go
new file mode 100644
index 0000000..1079c31
--- /dev/null
+++ b/tcpip/transport/ping/protocol.go
@@ -0,0 +1,124 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ping contains the implementation of the ICMP and IPv6-ICMP transport
+// protocols for use in ping. To use it in the networking stack, this package
+// must be added to the project, and
+// activated on the stack by passing ping.ProtocolName (or "ping") and/or
+// ping.ProtocolName6 (or "ping6") as one of the transport protocols when
+// calling stack.New(). Then endpoints can be created by passing
+// ping.ProtocolNumber or ping.ProtocolNumber6 as the transport protocol number
+// when calling Stack.NewEndpoint().
+package ping
+
+import (
+ "encoding/binary"
+ "fmt"
+
+ "github.com/google/netstack/tcpip"
+ "github.com/google/netstack/tcpip/buffer"
+ "github.com/google/netstack/tcpip/header"
+ "github.com/google/netstack/tcpip/stack"
+ "github.com/google/netstack/waiter"
+)
+
+const (
+ // ProtocolName4 is the string representation of the ping protocol name.
+ ProtocolName4 = "ping4"
+
+ // ProtocolNumber4 is the ICMP protocol number.
+ ProtocolNumber4 = header.ICMPv4ProtocolNumber
+
+ // ProtocolName6 is the string representation of the ping protocol name.
+ ProtocolName6 = "ping6"
+
+ // ProtocolNumber6 is the IPv6-ICMP protocol number.
+ ProtocolNumber6 = header.ICMPv6ProtocolNumber
+)
+
+type protocol struct {
+ number tcpip.TransportProtocolNumber
+}
+
+// Number returns the ICMP protocol number.
+func (p *protocol) Number() tcpip.TransportProtocolNumber {
+ return p.number
+}
+
+func (p *protocol) netProto() tcpip.NetworkProtocolNumber {
+ switch p.number {
+ case ProtocolNumber4:
+ return header.IPv4ProtocolNumber
+ case ProtocolNumber6:
+ return header.IPv6ProtocolNumber
+ }
+ panic(fmt.Sprint("unknown protocol number: ", p.number))
+}
+
+// NewEndpoint creates a new ping endpoint.
+func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if netProto != p.netProto() {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+ return newEndpoint(stack, netProto, p.number, waiterQueue), nil
+}
+
+// MinimumPacketSize returns the minimum valid ping packet size.
+func (p *protocol) MinimumPacketSize() int {
+ switch p.number {
+ case ProtocolNumber4:
+ return header.ICMPv4EchoMinimumSize
+ case ProtocolNumber6:
+ return header.ICMPv6EchoMinimumSize
+ }
+ panic(fmt.Sprint("unknown protocol number: ", p.number))
+}
+
+// ParsePorts returns the source and destination ports stored in the given ping
+// packet.
+func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ switch p.number {
+ case ProtocolNumber4:
+ return 0, binary.BigEndian.Uint16(v[header.ICMPv4MinimumSize:]), nil
+ case ProtocolNumber6:
+ return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil
+ }
+ panic(fmt.Sprint("unknown protocol number: ", p.number))
+}
+
+// HandleUnknownDestinationPacket handles packets targeted at this protocol but
+// that don't match any existing endpoint.
+func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+ return true
+}
+
+// SetOption implements TransportProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements TransportProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func init() {
+ stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol {
+ return &protocol{ProtocolNumber4}
+ })
+
+ stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol {
+ return &protocol{ProtocolNumber6}
+ })
+}
diff --git a/tcpip/transport/queue/queue.go b/tcpip/transport/queue/queue.go
index b9cd3d1..5fa8798 100644
--- a/tcpip/transport/queue/queue.go
+++ b/tcpip/transport/queue/queue.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package queue provides the implementation of buffer queue
// and interface of queue entry with Length method.
@@ -23,6 +33,8 @@
}
// Queue is a buffer queue.
+//
+// +stateify savable
type Queue struct {
ReaderQueue *waiter.Queue
WriterQueue *waiter.Queue
@@ -157,6 +169,8 @@
// QueuedSize returns the number of bytes currently in the queue, that is, the
// number of readable bytes.
func (q *Queue) QueuedSize() int64 {
+ q.mu.Lock()
+ defer q.mu.Unlock()
return q.used
}
diff --git a/tcpip/transport/tcp/accept.go b/tcpip/transport/tcp/accept.go
index 28a0b9d..f1fcd13 100644
--- a/tcpip/transport/tcp/accept.go
+++ b/tcpip/transport/tcp/accept.go
@@ -1,18 +1,28 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package tcp
import (
- "crypto/rand"
- "crypto/sha1"
"encoding/binary"
"hash"
"io"
"sync"
"time"
+ "crypto/sha1"
+ "github.com/google/netstack/rand"
"github.com/google/netstack/sleep"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/header"
@@ -68,7 +78,8 @@
// to go above a threshold.
var synRcvdCount struct {
sync.Mutex
- value uint64
+ value uint64
+ pending sync.WaitGroup
}
// listenContext is used by a listening endpoint to store state used while
@@ -102,6 +113,7 @@
return false
}
+ synRcvdCount.pending.Add(1)
synRcvdCount.value++
return true
@@ -115,6 +127,7 @@
defer synRcvdCount.Unlock()
synRcvdCount.value--
+ synRcvdCount.pending.Done()
}
// newListenContext creates a new listen context.
@@ -199,6 +212,7 @@
n.rcvBufSize = int(l.rcvWnd)
n.maybeEnableTimestamp(rcvdSynOpts)
+ n.maybeEnableSACKPermitted(rcvdSynOpts)
// Register new endpoint so that packets are routed to it.
if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n); err != nil {
@@ -318,7 +332,7 @@
WS: -1,
}
// When syn cookies are in use we enable timestamp only
- // if the ack specifies the timestmap option assuming
+ // if the ack specifies the timestamp option assuming
// that the other end did in fact negotiate the
// timestamp option in the original SYN.
if s.parsedOptions.TS {
@@ -348,13 +362,17 @@
// to the endpoint.
e.mu.Lock()
e.state = stateClosed
+
+ // Do cleanup if needed.
+ e.completeWorkerLocked()
+
+ if e.drainDone != nil {
+ close(e.drainDone)
+ }
e.mu.Unlock()
// Notify waiters that the endpoint is shutdown.
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
-
- // Do cleanup if needed.
- e.completeWorker()
}()
e.mu.Lock()
@@ -373,6 +391,16 @@
if n¬ifyClose != 0 {
return nil
}
+ if n¬ifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ s := e.segmentQueue.dequeue()
+ e.handleListenSegment(ctx, s)
+ s.decRef()
+ }
+ synRcvdCount.pending.Wait()
+ close(e.drainDone)
+ <-e.undrain
+ }
case wakerForNewSegment:
// Process at most maxSegmentsPerWake segments.
diff --git a/tcpip/transport/tcp/connect.go b/tcpip/transport/tcp/connect.go
index 3368cbf..f2d3a69 100644
--- a/tcpip/transport/tcp/connect.go
+++ b/tcpip/transport/tcp/connect.go
@@ -1,14 +1,24 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package tcp
import (
- "crypto/rand"
- "sync/atomic"
+ "sync"
"time"
+ "github.com/google/netstack/rand"
"github.com/google/netstack/sleep"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
@@ -40,6 +50,12 @@
wakerForNotification = iota
wakerForNewSegment
wakerForResend
+ wakerForResolution
+)
+
+const (
+ // Maximum space available for options.
+ maxOptionSize = 40
)
// handshake holds the state used during a TCP 3-way handshake.
@@ -150,7 +166,7 @@
// incoming segment acknowledges something not yet sent. The
// connection remains in the same state.
ack := s.sequenceNumber.Add(s.logicalLen())
- h.ep.sendRaw(nil, flagRst|flagAck, s.ackNumber, ack, 0)
+ h.ep.sendRaw(buffer.VectorisedView{}, flagRst|flagAck, s.ackNumber, ack, 0)
return false
}
@@ -182,9 +198,12 @@
// Parse the SYN options.
rcvSynOpts := parseSynSegmentOptions(s)
- // Remember if the Timetstamp option was negotiated.
+ // Remember if the Timestamp option was negotiated.
h.ep.maybeEnableTimestamp(&rcvSynOpts)
+ // Remember if the SACKPermitted option was negotiated.
+ h.ep.maybeEnableSACKPermitted(&rcvSynOpts)
+
// Remember the sequence we'll ack from now on.
h.ackNum = s.sequenceNumber + 1
h.flags |= flagAck
@@ -195,7 +214,7 @@
// and the handshake is completed.
if s.flagIsSet(flagAck) {
h.state = handshakeCompleted
- h.ep.sendRaw(nil, flagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
+ h.ep.sendRaw(buffer.VectorisedView{}, flagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
return nil
}
@@ -208,6 +227,11 @@
TS: rcvSynOpts.TS,
TSVal: h.ep.timestamp(),
TSEcr: h.ep.recentTS,
+
+ // We only send SACKPermitted if the other side indicated it
+ // permits SACK. This is not explicitly defined in the RFC but
+ // this is the behaviour implemented by Linux.
+ SACKPermitted: rcvSynOpts.SACKPermitted,
}
sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
@@ -239,7 +263,7 @@
if s.flagIsSet(flagAck) {
seq = s.ackNumber
}
- h.ep.sendRaw(nil, flagRst|flagAck, seq, ack, 0)
+ h.ep.sendRaw(buffer.VectorisedView{}, flagRst|flagAck, seq, ack, 0)
if !h.active {
return tcpip.ErrInvalidEndpointState
@@ -249,10 +273,11 @@
return err
}
synOpts := header.TCPSynOptions{
- WS: h.rcvWndScale,
- TS: h.ep.sendTSOk,
- TSVal: h.ep.timestamp(),
- TSEcr: h.ep.recentTS,
+ WS: h.rcvWndScale,
+ TS: h.ep.sendTSOk,
+ TSVal: h.ep.timestamp(),
+ TSEcr: h.ep.recentTS,
+ SACKPermitted: h.ep.sackPermitted,
}
sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
return nil
@@ -266,7 +291,7 @@
// not carry a timestamp option then the segment must be dropped
// as per https://tools.ietf.org/html/rfc7323#section-3.2.
if h.ep.sendTSOk && !s.parsedOptions.TS {
- atomic.AddUint64(&h.ep.stack.MutableStats().DroppedPackets, 1)
+ h.ep.stack.Stats().DroppedPackets.Increment()
return nil
}
@@ -280,6 +305,21 @@
return nil
}
+func (h *handshake) handleSegment(s *segment) *tcpip.Error {
+ h.sndWnd = s.window
+ if !s.flagIsSet(flagSyn) && h.sndWndScale > 0 {
+ h.sndWnd <<= uint8(h.sndWndScale)
+ }
+
+ switch h.state {
+ case handshakeSynRcvd:
+ return h.synRcvdState(s)
+ case handshakeSynSent:
+ return h.synSentState(s)
+ }
+ return nil
+}
+
// processSegments goes through the segment queue and processes up to
// maxSegmentsPerWake (if they're available).
func (h *handshake) processSegments() *tcpip.Error {
@@ -289,18 +329,7 @@
return nil
}
- h.sndWnd = s.window
- if !s.flagIsSet(flagSyn) && h.sndWndScale > 0 {
- h.sndWnd <<= uint8(h.sndWndScale)
- }
-
- var err *tcpip.Error
- switch h.state {
- case handshakeSynRcvd:
- err = h.synRcvdState(s)
- case handshakeSynSent:
- err = h.synSentState(s)
- }
+ err := h.handleSegment(s)
s.decRef()
if err != nil {
return err
@@ -323,8 +352,50 @@
return nil
}
+func (h *handshake) resolveRoute() *tcpip.Error {
+ // Set up the wakers.
+ s := sleep.Sleeper{}
+ resolutionWaker := &sleep.Waker{}
+ s.AddWaker(resolutionWaker, wakerForResolution)
+ s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
+ defer s.Done()
+
+ // Initial action is to resolve route.
+ index := wakerForResolution
+ for {
+ switch index {
+ case wakerForResolution:
+ if err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
+ // Either success (err == nil) or failure.
+ return err
+ }
+ // Resolution not completed. Keep trying...
+
+ case wakerForNotification:
+ n := h.ep.fetchNotifications()
+ if n¬ifyClose != 0 {
+ h.ep.route.RemoveWaker(resolutionWaker)
+ return tcpip.ErrAborted
+ }
+ if n¬ifyDrain != 0 {
+ close(h.ep.drainDone)
+ <-h.ep.undrain
+ }
+ }
+
+ // Wait for notification.
+ index, _ = s.Fetch(true)
+ }
+}
+
// execute executes the TCP 3-way handshake.
func (h *handshake) execute() *tcpip.Error {
+ if h.ep.route.IsResolutionRequired() {
+ if err := h.resolveRoute(); err != nil {
+ return err
+ }
+ }
+
// Initialize the resend timer.
resendWaker := sleep.Waker{}
timeOut := time.Duration(time.Second)
@@ -340,19 +411,29 @@
s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
defer s.Done()
+ var sackEnabled SACKEnabled
+ if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil {
+ // If stack returned an error when checking for SACKEnabled
+ // status then just default to switching off SACK negotiation.
+ sackEnabled = false
+ }
+
// Send the initial SYN segment and loop until the handshake is
// completed.
synOpts := header.TCPSynOptions{
- WS: h.rcvWndScale,
- TS: true,
- TSVal: h.ep.timestamp(),
- TSEcr: h.ep.recentTS,
+ WS: h.rcvWndScale,
+ TS: true,
+ TSVal: h.ep.timestamp(),
+ TSEcr: h.ep.recentTS,
+ SACKPermitted: bool(sackEnabled),
}
// Execute is also called in a listen context so we want to make sure we
- // only send the TS option when we received the TS in the initial SYN.
+ // only send the TS/SACK option when we received the TS/SACK in the
+ // initial SYN.
if h.state == handshakeSynRcvd {
synOpts.TS = h.ep.sendTSOk
+ synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled)
}
sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
for h.state != handshakeCompleted {
@@ -370,6 +451,21 @@
if n¬ifyClose != 0 {
return tcpip.ErrAborted
}
+ if n¬ifyDrain != 0 {
+ for !h.ep.segmentQueue.empty() {
+ s := h.ep.segmentQueue.dequeue()
+ err := h.handleSegment(s)
+ s.decRef()
+ if err != nil {
+ return err
+ }
+ if h.state == handshakeCompleted {
+ return nil
+ }
+ }
+ close(h.ep.drainDone)
+ <-h.ep.undrain
+ }
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
@@ -390,35 +486,91 @@
return synOpts
}
-func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
- // The MSS in opts is ignored as this function is called from many
- // places and we don't want every call point being embedded with the MSS
- // calculation. So we just do it here and ignore the MSS value passed in
- // the opts.
- mss := r.MTU() - header.TCPMinimumSize
- options := []byte{
- // Initialize the MSS option.
- header.TCPOptionMSS, 4, byte(mss >> 8), byte(mss),
- }
-
- if opts.TS {
- tsOpt := header.EncodeTSOption(opts.TSVal, opts.TSEcr)
- options = append(options, tsOpt[:]...)
- }
-
- // NOTE: a WS of zero is a valid value and it indicates a scale of 1.
- if opts.WS >= 0 {
- // Initialize the WS option.
- options = append(options,
- header.TCPOptionWS, 3, uint8(opts.WS), header.TCPOptionNOP)
- }
-
- return sendTCPWithOptions(r, id, nil, r.DefaultTTL(), flags, seq, ack, rcvWnd, options)
+var optionPool = sync.Pool{
+ New: func() interface{} {
+ return make([]byte, maxOptionSize)
+ },
}
-// sendTCPWithOptions sends a TCP segment with the provided options via the
-// provided network endpoint and under the provided identity.
-func sendTCPWithOptions(r *stack.Route, id stack.TransportEndpointID, data buffer.View, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error {
+func getOptions() []byte {
+ return optionPool.Get().([]byte)
+}
+
+func putOptions(options []byte) {
+ // Reslice to full capacity.
+ optionPool.Put(options[0:cap(options)])
+}
+
+func makeSynOptions(opts header.TCPSynOptions) []byte {
+ // Emulate linux option order. This is as follows:
+ //
+ // if md5: NOP NOP MD5SIG 18 md5sig(16)
+ // if mss: MSS 4 mss(2)
+ // if ts and sack_advertise:
+ // SACK 2 TIMESTAMP 2 timestamp(8)
+ // elif ts: NOP NOP TIMESTAMP 10 timestamp(8)
+ // elif sack: NOP NOP SACK 2
+ // if wscale: NOP WINDOW 3 ws(1)
+ // if sack_blocks: NOP NOP SACK ((2 + (#blocks * 8))
+ // [for each block] start_seq(4) end_seq(4)
+ // if fastopen_cookie:
+ // if exp: EXP (4 + len(cookie)) FASTOPEN_MAGIC(2)
+ // else: FASTOPEN (2 + len(cookie))
+ // cookie(variable) [padding to four bytes]
+ //
+ options := getOptions()
+
+ // Always encode the mss.
+ offset := header.EncodeMSSOption(uint32(opts.MSS), options)
+
+ // Special ordering is required here. If both TS and SACK are enabled,
+ // then the SACK option precedes TS, with no padding. If they are
+ // enabled individually, then we see padding before the option.
+ if opts.TS && opts.SACKPermitted {
+ offset += header.EncodeSACKPermittedOption(options[offset:])
+ offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
+ } else if opts.TS {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
+ } else if opts.SACKPermitted {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeSACKPermittedOption(options[offset:])
+ }
+
+ // Initialize the WS option.
+ if opts.WS >= 0 {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeWSOption(opts.WS, options[offset:])
+ }
+
+ // Padding to the end; note that this never apply unless we add a
+ // fastopen option, we always expect the offset to remain the same.
+ if delta := header.AddTCPOptionPadding(options, offset); delta != 0 {
+ panic("unexpected option encoding")
+ }
+
+ return options[:offset]
+}
+
+func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
+ // The MSS in opts is automatically calculated as this function is
+ // called from many places and we don't want every call point being
+ // embedded with the MSS calculation.
+ if opts.MSS == 0 {
+ opts.MSS = uint16(r.MTU() - header.TCPMinimumSize)
+ }
+
+ options := makeSynOptions(opts)
+ err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options)
+ putOptions(options)
+ return err
+}
+
+// sendTCP sends a TCP segment with the provided options via the provided
+// network endpoint and under the provided identity.
+func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error {
optLen := len(opts)
// Allocate a buffer for the TCP header.
hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
@@ -440,64 +592,33 @@
})
copy(tcp[header.TCPMinimumSize:], opts)
- length := uint16(hdr.UsedLength())
- xsum := r.PseudoHeaderChecksum(ProtocolNumber)
- if data != nil {
- length += uint16(len(data))
- xsum = header.Checksum(data, xsum)
+ // Only calculate the checksum if offloading isn't supported.
+ if r.Capabilities()&stack.CapabilityChecksumOffload == 0 {
+ length := uint16(hdr.UsedLength() + data.Size())
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber)
+ for _, v := range data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+
+ tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length))
}
- tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length))
-
- atomic.AddUint64(&r.MutableStats().TCP.SegmentsSent, 1)
+ r.Stats().TCP.SegmentsSent.Increment()
if (flags & flagRst) != 0 {
- atomic.AddUint64(&r.MutableStats().TCP.ResetsSent, 1)
+ r.Stats().TCP.ResetsSent.Increment()
}
- return r.WritePacket(&hdr, data, ProtocolNumber, ttl)
+ return r.WritePacket(hdr, data, ProtocolNumber, ttl)
}
-// sendTCP sends a TCP segment via the provided network endpoint and under the
-// provided identity.
-func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.View, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
- // Allocate a buffer for the TCP header.
- hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()))
+// makeOptions makes an options slice.
+func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
+ options := getOptions()
+ offset := 0
- if rcvWnd > 0xffff {
- rcvWnd = 0xffff
- }
-
- // Initialize the header.
- tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize))
- tcp.Encode(&header.TCPFields{
- SrcPort: id.LocalPort,
- DstPort: id.RemotePort,
- SeqNum: uint32(seq),
- AckNum: uint32(ack),
- DataOffset: header.TCPMinimumSize,
- Flags: flags,
- WindowSize: uint16(rcvWnd),
- })
-
- length := uint16(hdr.UsedLength())
- xsum := r.PseudoHeaderChecksum(ProtocolNumber)
- if data != nil {
- length += uint16(len(data))
- xsum = header.Checksum(data, xsum)
- }
-
- tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length))
-
- atomic.AddUint64(&r.MutableStats().TCP.SegmentsSent, 1)
- if (flags & flagRst) != 0 {
- atomic.AddUint64(&r.MutableStats().TCP.ResetsSent, 1)
- }
-
- return r.WritePacket(&hdr, data, ProtocolNumber, ttl)
-}
-
-// sendRaw sends a TCP segment to the endpoint's peer.
-func (e *endpoint) sendRaw(data buffer.View, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
+ // N.B. the ordering here matches the ordering used by Linux internally
+ // and described in the raw makeOptions function. We don't include
+ // unnecessary cases here (post connection.)
if e.sendTSOk {
// Embed the timestamp if timestamp has been enabled.
//
@@ -511,13 +632,37 @@
// timestamp clock.
//
// Ref: https://tools.ietf.org/html/rfc7323#section-5.4.
- options := header.EncodeTSOption(e.timestamp(), uint32(e.recentTS))
- return sendTCPWithOptions(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options[:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:])
}
- return sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd)
+ if e.sackPermitted && len(sackBlocks) > 0 {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeSACKBlocks(sackBlocks, options[offset:])
+ }
+
+ // We expect the above to produce an aligned offset.
+ if delta := header.AddTCPOptionPadding(options, offset); delta != 0 {
+ panic("unexpected option encoding")
+ }
+
+ return options[:offset]
}
-func (e *endpoint) handleWrite() bool {
+// sendRaw sends a TCP segment to the endpoint's peer.
+func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
+ var sackBlocks []header.SACKBlock
+ if e.state == stateConnected && e.rcv.pendingBufSize > 0 && (flags&flagAck != 0) {
+ sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
+ }
+ options := e.makeOptions(sackBlocks)
+ err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options)
+ putOptions(options)
+ return err
+}
+
+func (e *endpoint) handleWrite() *tcpip.Error {
// Move packets from send queue to send list. The queue is accessible
// from other goroutines and protected by the send mutex, while the send
// list is only accessible from the handler goroutine, so it needs no
@@ -541,47 +686,42 @@
// Push out any new packets.
e.snd.sendData()
- return true
+ return nil
}
-func (e *endpoint) handleClose() bool {
+func (e *endpoint) handleClose() *tcpip.Error {
// Drain the send queue.
e.handleWrite()
// Mark send side as closed.
e.snd.closed = true
- return true
+ return nil
}
-// resetConnection sends a RST segment and puts the endpoint in an error state
-// with the given error code.
-// This method must only be called from the protocol goroutine.
-func (e *endpoint) resetConnection(err *tcpip.Error) {
- e.sendRaw(nil, flagAck|flagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+// resetConnectionLocked sends a RST segment and puts the endpoint in an error
+// state with the given error code. This method must only be called from the
+// protocol goroutine.
+func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
+ e.sendRaw(buffer.VectorisedView{}, flagAck|flagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
- e.mu.Lock()
e.state = stateError
e.hardError = err
- e.mu.Unlock()
}
-// completeWorker is called by the worker goroutine when it's about to exit. It
-// marks the worker as completed and performs cleanup work if requested by
-// Close().
-func (e *endpoint) completeWorker() {
- e.mu.Lock()
- defer e.mu.Unlock()
-
+// completeWorkerLocked is called by the worker goroutine when it's about to
+// exit. It marks the worker as completed and performs cleanup work if requested
+// by Close().
+func (e *endpoint) completeWorkerLocked() {
e.workerRunning = false
if e.workerCleanup {
- e.cleanup()
+ e.cleanupLocked()
}
}
// handleSegments pulls segments from the queue and processes them. It returns
-// true if the protocol loop should continue, false otherwise.
-func (e *endpoint) handleSegments() bool {
+// no error if the protocol loop should continue, an error otherwise.
+func (e *endpoint) handleSegments() *tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
s := e.segmentQueue.dequeue()
@@ -590,6 +730,11 @@
break
}
+ // Invoke the tcp probe if installed.
+ if e.probe != nil {
+ e.probe(e.completeState())
+ }
+
if s.flagIsSet(flagRst) {
if e.rcv.acceptable(s.sequenceNumber, 0) {
// RFC 793, page 37 states that "in all states
@@ -597,11 +742,7 @@
// validated by checking their SEQ-fields." So
// we only process it if it's acceptable.
s.decRef()
- e.mu.Lock()
- e.state = stateError
- e.hardError = tcpip.ErrConnectionReset
- e.mu.Unlock()
- return false
+ return tcpip.ErrConnectionReset
}
} else if s.flagIsSet(flagAck) {
// Patch the window size in the segment according to the
@@ -613,7 +754,7 @@
// must be dropped as per
// https://tools.ietf.org/html/rfc7323#section-3.2.
if e.sendTSOk && !s.parsedOptions.TS {
- atomic.AddUint64(&e.stack.MutableStats().DroppedPackets, 1)
+ e.stack.Stats().DroppedPackets.Increment()
s.decRef()
continue
}
@@ -640,66 +781,63 @@
e.resetKeepaliveTimer(true)
- return true
+ return nil
}
-// keepaliveTimerExpired is called when the keepaliveTimer fires. We send
-// TCP keepalive packets periodically when the connection is idle. If we
-// don't hear from the other side after a number of tries, we terminate
-// the connection.
-func (e *endpoint) keepaliveTimerExpired() bool {
- e.keepaliveMu.Lock()
- if !e.keepaliveEnabled || !e.keepaliveTimer.checkExpiration() {
- e.keepaliveMu.Unlock()
- return true
+// keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP
+// keepalive packets periodically when the connection is idle. If we don't hear
+// from the other side after a number of tries, we terminate the connection.
+func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
+ e.keepalive.Lock()
+ if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() {
+ e.keepalive.Unlock()
+ return nil
}
- if e.keepalivesUnacked >= e.keepaliveCount {
- e.keepaliveMu.Unlock()
- e.resetConnection(tcpip.ErrConnectionReset)
- return false
+ if e.keepalive.unacked >= e.keepalive.count {
+ e.keepalive.Unlock()
+ return tcpip.ErrConnectionReset
}
// RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with
// seg.seq = snd.nxt-1.
- e.keepalivesUnacked++
- e.keepaliveMu.Unlock()
- e.snd.sendSegment(nil, flagAck, e.snd.sndNxt-1)
+ e.keepalive.unacked++
+ e.keepalive.Unlock()
+ e.snd.sendSegment(buffer.VectorisedView{}, flagAck, e.snd.sndNxt-1)
e.resetKeepaliveTimer(false)
- return true
+ return nil
}
// resetKeepaliveTimer restarts or stops the keepalive timer, depending on
// whether it is enabled for this endpoint.
func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
- e.keepaliveMu.Lock()
- defer e.keepaliveMu.Unlock()
+ e.keepalive.Lock()
+ defer e.keepalive.Unlock()
if receivedData {
- e.keepalivesUnacked = 0
+ e.keepalive.unacked = 0
}
// Start the keepalive timer IFF it's enabled and there is no pending
// data to send.
- if e.keepaliveEnabled && e.snd != nil && e.snd.sndUna == e.snd.sndNxt {
- if e.keepalivesUnacked > 0 {
- e.keepaliveTimer.enable(e.keepaliveInterval)
- } else {
- e.keepaliveTimer.enable(e.keepaliveIdle)
- }
+ if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
+ e.keepalive.timer.disable()
+ return
+ }
+ if e.keepalive.unacked > 0 {
+ e.keepalive.timer.enable(e.keepalive.interval)
} else {
- e.keepaliveTimer.disable()
+ e.keepalive.timer.enable(e.keepalive.idle)
}
}
// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
// goroutine and is responsible for sending segments and handling received
// segments.
-func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
+func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
var closeTimer *time.Timer
var closeWaker sleep.Waker
- defer func() {
- e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
- e.completeWorker()
+ epilogue := func() {
+ // e.mu is expected to be hold upon entering this section.
if e.snd != nil {
e.snd.resendTimer.cleanup()
@@ -708,9 +846,20 @@
if closeTimer != nil {
closeTimer.Stop()
}
- }()
- if !passive {
+ e.completeWorkerLocked()
+
+ if e.drainDone != nil {
+ close(e.drainDone)
+ }
+
+ e.mu.Unlock()
+
+ // When the protocol loop exits we should wake up our waiters.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ }
+
+ if handshake {
// This is an active connection, so we must initiate the 3-way
// handshake, and then inform potential waiters about its
// completion.
@@ -726,7 +875,8 @@
e.mu.Lock()
e.state = stateError
e.hardError = err
- e.mu.Unlock()
+ // Lock released below.
+ epilogue()
return err
}
@@ -741,14 +891,18 @@
e.rcvListMu.Unlock()
}
- // Initialize the keepalive timer.
- e.keepaliveTimer.init(&e.keepaliveWaker)
- defer e.keepaliveTimer.cleanup()
+ e.keepalive.timer.init(&e.keepalive.waker)
+ defer e.keepalive.timer.cleanup()
// Tell waiters that the endpoint is connected and writable.
e.mu.Lock()
e.state = stateConnected
+ drained := e.drainDone != nil
e.mu.Unlock()
+ if drained {
+ close(e.drainDone)
+ <-e.undrain
+ }
e.waiterQueue.Notify(waiter.EventOut)
@@ -756,7 +910,7 @@
// wakes up.
funcs := []struct {
w *sleep.Waker
- f func() bool
+ f func() *tcpip.Error
}{
{
w: &e.sndWaker,
@@ -772,28 +926,26 @@
},
{
w: &closeWaker,
- f: func() bool {
- e.resetConnection(tcpip.ErrConnectionAborted)
- return false
+ f: func() *tcpip.Error {
+ return tcpip.ErrConnectionAborted
},
},
{
w: &e.snd.resendWaker,
- f: func() bool {
+ f: func() *tcpip.Error {
if !e.snd.retransmitTimerExpired() {
- e.resetConnection(tcpip.ErrTimeout)
- return false
+ return tcpip.ErrTimeout
}
- return true
+ return nil
},
},
{
- w: &e.keepaliveWaker,
+ w: &e.keepalive.waker,
f: e.keepaliveTimerExpired,
},
{
w: &e.notificationWaker,
- f: func() bool {
+ f: func() *tcpip.Error {
n := e.fetchNotifications()
if n¬ifyNonZeroReceiveWindow != 0 {
e.rcv.nonZeroWindow()
@@ -803,19 +955,44 @@
e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize())
}
+ if n¬ifyMTUChanged != 0 {
+ e.sndBufMu.Lock()
+ count := e.packetTooBigCount
+ e.packetTooBigCount = 0
+ mtu := e.sndMTU
+ e.sndBufMu.Unlock()
+
+ e.snd.updateMaxPayloadSize(mtu, count)
+ }
+
+ if n¬ifyReset != 0 {
+ e.mu.Lock()
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.mu.Unlock()
+ }
if n¬ifyClose != 0 && closeTimer == nil {
- // Reset the connection 60 seconds after the
+ // Reset the connection 3 seconds after the
// endpoint has been closed.
- closeTimer = time.AfterFunc(60*time.Second, func() {
+ closeTimer = time.AfterFunc(3*time.Second, func() {
closeWaker.Assert()
})
}
+ if n¬ifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ if err := e.handleSegments(); err != nil {
+ return err
+ }
+ }
+ close(e.drainDone)
+ <-e.undrain
+ }
+
if n¬ifyKeepaliveChanged != 0 {
e.resetKeepaliveTimer(true)
}
- return true
+ return nil
},
},
}
@@ -826,21 +1003,50 @@
s.AddWaker(funcs[i].w, i)
}
+ // The following assertions and notifications are needed for restored
+ // endpoints. Fresh newly created endpoints have empty states and should
+ // not invoke any.
+ e.segmentQueue.mu.Lock()
+ if !e.segmentQueue.list.Empty() {
+ e.newSegmentWaker.Assert()
+ }
+ e.segmentQueue.mu.Unlock()
+
+ e.rcvListMu.Lock()
+ if !e.rcvList.Empty() {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+ e.rcvListMu.Unlock()
+
+ e.mu.RLock()
+ if e.workerCleanup {
+ e.notifyProtocolGoroutine(notifyClose)
+ }
+ e.mu.RUnlock()
+
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList {
e.workMu.Unlock()
v, _ := s.Fetch(true)
e.workMu.Lock()
- if !funcs[v].f() {
+ if err := funcs[v].f(); err != nil {
+ e.mu.Lock()
+ e.resetConnectionLocked(err)
+ // Lock released below.
+ epilogue()
+
return nil
}
}
// Mark endpoint as closed.
e.mu.Lock()
- e.state = stateClosed
- e.mu.Unlock()
+ if e.state != stateError {
+ e.state = stateClosed
+ }
+ // Lock released below.
+ epilogue()
return nil
}
diff --git a/tcpip/transport/tcp/cubic.go b/tcpip/transport/tcp/cubic.go
new file mode 100644
index 0000000..8cea416
--- /dev/null
+++ b/tcpip/transport/tcp/cubic.go
@@ -0,0 +1,233 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "math"
+ "time"
+)
+
+// cubicState stores the variables related to TCP CUBIC congestion
+// control algorithm state.
+//
+// See: https://tools.ietf.org/html/rfc8312.
+type cubicState struct {
+ // wLastMax is the previous wMax value.
+ wLastMax float64
+
+ // wMax is the value of the congestion window at the
+ // time of last congestion event.
+ wMax float64
+
+ // t denotes the time when the current congestion avoidance
+ // was entered.
+ t time.Time
+
+ // numCongestionEvents tracks the number of congestion events since last
+ // RTO.
+ numCongestionEvents int
+
+ // c is the cubic constant as specified in RFC8312. It's fixed at 0.4 as
+ // per RFC.
+ c float64
+
+ // k is the time period that the above function takes to increase the
+ // current window size to W_max if there are no further congestion
+ // events and is calculated using the following equation:
+ //
+ // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2)
+ k float64
+
+ // beta is the CUBIC multiplication decrease factor. that is, when a
+ // congestion event is detected, CUBIC reduces its cwnd to
+ // W_cubic(0)=W_max*beta_cubic.
+ beta float64
+
+ // wC is window computed by CUBIC at time t. It's calculated using the
+ // formula:
+ //
+ // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1)
+ wC float64
+
+ // wEst is the window computed by CUBIC at time t+RTT i.e
+ // W_cubic(t+RTT).
+ wEst float64
+
+ s *sender
+}
+
+// newCubicCC returns a partially initialized cubic state with the constants
+// beta and c set and t set to current time.
+func newCubicCC(s *sender) *cubicState {
+ return &cubicState{
+ t: time.Now(),
+ beta: 0.7,
+ c: 0.4,
+ s: s,
+ }
+}
+
+// enterCongestionAvoidance is used to initialize cubic in cases where we exit
+// SlowStart without a real congestion event taking place. This can happen when
+// a connection goes back to slow start due to a retransmit and we exceed the
+// previously lowered ssThresh without experiencing packet loss.
+//
+// Refer: https://tools.ietf.org/html/rfc8312#section-4.8
+func (c *cubicState) enterCongestionAvoidance() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.7 &
+ // https://tools.ietf.org/html/rfc8312#section-4.8
+ if c.numCongestionEvents == 0 {
+ c.k = 0
+ c.t = time.Now()
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+ }
+}
+
+// updateSlowStart will update the congestion window as per the slow-start
+// algorithm used by NewReno. If after adjusting the congestion window we cross
+// the ssThresh then it will return the number of packets that must be consumed
+// in congestion avoidance mode.
+func (c *cubicState) updateSlowStart(packetsAcked int) int {
+ // Don't let the congestion window cross into the congestion
+ // avoidance range.
+ newcwnd := c.s.sndCwnd + packetsAcked
+ enterCA := false
+ if newcwnd >= c.s.sndSsthresh {
+ newcwnd = c.s.sndSsthresh
+ c.s.sndCAAckCount = 0
+ enterCA = true
+ }
+
+ packetsAcked -= newcwnd - c.s.sndCwnd
+ c.s.sndCwnd = newcwnd
+ if enterCA {
+ c.enterCongestionAvoidance()
+ }
+ return packetsAcked
+}
+
+// Update updates cubic's internal state variables. It must be called on every
+// ACK received.
+// Refer: https://tools.ietf.org/html/rfc8312#section-4
+func (c *cubicState) Update(packetsAcked int) {
+ if c.s.sndCwnd < c.s.sndSsthresh {
+ packetsAcked = c.updateSlowStart(packetsAcked)
+ if packetsAcked == 0 {
+ return
+ }
+ } else {
+ c.s.rtt.Lock()
+ srtt := c.s.rtt.srtt
+ c.s.rtt.Unlock()
+ c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, srtt)
+ }
+}
+
+// cubicCwnd computes the CUBIC congestion window after t seconds from last
+// congestion event.
+func (c *cubicState) cubicCwnd(t float64) float64 {
+ return c.c*math.Pow(t, 3.0) + c.wMax
+}
+
+// getCwnd returns the current congestion window as computed by CUBIC.
+// Refer: https://tools.ietf.org/html/rfc8312#section-4
+func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int {
+ elapsed := time.Since(c.t).Seconds()
+
+ // Compute the window as per Cubic after 'elapsed' time
+ // since last congestion event.
+ c.wC = c.cubicCwnd(elapsed - c.k)
+
+ // Compute the TCP friendly estimate of the congestion window.
+ c.wEst = c.wMax*c.beta + (3.0*((1.0-c.beta)/(1.0+c.beta)))*(elapsed/srtt.Seconds())
+
+ // Make sure in the TCP friendly region CUBIC performs at least
+ // as well as Reno.
+ if c.wC < c.wEst && float64(sndCwnd) < c.wEst {
+ // TCP Friendly region of cubic.
+ return int(c.wEst)
+ }
+
+ // In Concave/Convex region of CUBIC, calculate what CUBIC window
+ // will be after 1 RTT and use that to grow congestion window
+ // for every ack.
+ tEst := (time.Since(c.t) + srtt).Seconds()
+ wtRtt := c.cubicCwnd(tEst - c.k)
+ // As per 4.3 for each received ACK cwnd must be incremented
+ // by (w_cubic(t+RTT) - cwnd/cwnd.
+ cwnd := float64(sndCwnd)
+ for i := 0; i < packetsAcked; i++ {
+ // Concave/Convex regions of cubic have the same formulas.
+ // See: https://tools.ietf.org/html/rfc8312#section-4.3
+ cwnd += (wtRtt - cwnd) / cwnd
+ }
+ return int(cwnd)
+}
+
+// HandleNDupAcks implements congestionControl.HandleNDupAcks.
+func (c *cubicState) HandleNDupAcks() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.5
+ c.numCongestionEvents++
+ c.t = time.Now()
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+
+ c.fastConvergence()
+ c.reduceSlowStartThreshold()
+}
+
+// HandleRTOExpired implements congestionContrl.HandleRTOExpired.
+func (c *cubicState) HandleRTOExpired() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.6
+ c.t = time.Now()
+ c.numCongestionEvents = 0
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+
+ c.fastConvergence()
+
+ // We lost a packet, so reduce ssthresh.
+ c.reduceSlowStartThreshold()
+
+ // Reduce the congestion window to 1, i.e., enter slow-start. Per
+ // RFC 5681, page 7, we must use 1 regardless of the value of the
+ // initial congestion window.
+ c.s.sndCwnd = 1
+}
+
+// fastConvergence implements the logic for Fast Convergence algorithm as
+// described in https://tools.ietf.org/html/rfc8312#section-4.6.
+func (c *cubicState) fastConvergence() {
+ if c.wMax < c.wLastMax {
+ c.wLastMax = c.wMax
+ c.wMax = c.wMax * (1.0 + c.beta) / 2.0
+ } else {
+ c.wLastMax = c.wMax
+ }
+ // Recompute k as wMax may have changed.
+ c.k = math.Cbrt(c.wMax * (1 - c.beta) / c.c)
+}
+
+// PostRecovery implemements congestionControl.PostRecovery.
+func (c *cubicState) PostRecovery() {
+ c.t = time.Now()
+}
+
+// reduceSlowStartThreshold returns new SsThresh as described in
+// https://tools.ietf.org/html/rfc8312#section-4.7.
+func (c *cubicState) reduceSlowStartThreshold() {
+ c.s.sndSsthresh = int(math.Max(float64(c.s.sndCwnd)*c.beta, 2.0))
+}
diff --git a/tcpip/transport/tcp/dual_stack_test.go b/tcpip/transport/tcp/dual_stack_test.go
index b512d19..89bccae 100644
--- a/tcpip/transport/tcp/dual_stack_test.go
+++ b/tcpip/transport/tcp/dual_stack_test.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package tcp_test
diff --git a/tcpip/transport/tcp/endpoint.go b/tcpip/transport/tcp/endpoint.go
index d9ff04d..069c7e0 100644
--- a/tcpip/transport/tcp/endpoint.go
+++ b/tcpip/transport/tcp/endpoint.go
@@ -1,16 +1,26 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package tcp
import (
- "crypto/rand"
"math"
"sync"
"sync/atomic"
"time"
+ "github.com/google/netstack/rand"
"github.com/google/netstack/sleep"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
@@ -38,17 +48,32 @@
notifyNonZeroReceiveWindow = 1 << iota
notifyReceiveWindowChanged
notifyClose
+ notifyMTUChanged
+ notifyDrain
+ notifyReset
notifyKeepaliveChanged
)
-// DefaultBufferSize is the default size of the receive and send buffers.
-const DefaultBufferSize = 208 * 1024
+// SACKInfo holds TCP SACK related information for a given endpoint.
+//
+// +stateify savable
+type SACKInfo struct {
+ // Blocks is the maximum number of SACK blocks we track
+ // per endpoint.
+ Blocks [MaxSACKBlocks]header.SACKBlock
+
+ // NumBlocks is the number of valid SACK blocks stored in the
+ // blocks array above.
+ NumBlocks int
+}
// endpoint represents a TCP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
// synchronized. The protocol implementation, however, runs in a single
// goroutine.
+//
+// +stateify savable
type endpoint struct {
// workMu is used to arbitrate which goroutine may perform protocol
// work. Only the main protocol goroutine is expected to call Lock() on
@@ -71,8 +96,10 @@
// protocol goroutine adds ready-for-delivery segments to rcvList,
// which are returned by Read() calls to users.
//
- // Once the peer has closed the its send side, rcvClosed is set to true
+ // Once the peer has closed its send side, rcvClosed is set to true
// to indicate to users that no more data is coming.
+ //
+ // rcvListMu can be taken after the endpoint mu below.
rcvListMu sync.Mutex
rcvList segmentList
rcvClosed bool
@@ -80,14 +107,15 @@
rcvBufUsed int
// The following fields are protected by the mutex.
- mu sync.RWMutex
- id stack.TransportEndpointID
- state endpointState
- isPortReserved bool
- isRegistered bool
- boundNICID tcpip.NICID
- route stack.Route
- v6only bool
+ mu sync.RWMutex
+ id stack.TransportEndpointID
+ state endpointState
+ isPortReserved bool
+ isRegistered bool
+ boundNICID tcpip.NICID
+ route stack.Route
+ v6only bool
+ isConnectNotified bool
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@@ -99,7 +127,7 @@
// hardError is meaningful only when state is stateError, it stores the
// error to be returned when read/write syscalls are called and the
- // endpoint is in this state.
+ // endpoint is in this state. hardError is protected by mu.
hardError *tcpip.Error
// workerRunning specifies if a worker goroutine is running.
@@ -124,6 +152,16 @@
// TSVal field in the timestamp option.
tsOffset uint32
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+
+ // sackPermitted is set to true if the peer sends the TCPSACKPermitted
+ // option in the SYN/SYN-ACK.
+ sackPermitted bool
+
+ // sack holds TCP SACK related information for this endpoint.
+ sack SACKInfo
+
// The options below aren't implemented, but we remember the user
// settings because applications expect to be able to set/query these
// options.
@@ -140,15 +178,28 @@
// protocol goroutine is signaled via sndWaker.
//
// When the send side is closed, the protocol goroutine is notified via
- // sndCloseWaker, and sndBufSize is set to -1.
+ // sndCloseWaker, and sndClosed is set to true.
sndBufMu sync.Mutex
sndBufSize int
sndBufUsed int
+ sndClosed bool
sndBufInQueue seqnum.Size
sndQueue segmentList
sndWaker sleep.Waker
sndCloseWaker sleep.Waker
+ // cc stores the name of the Congestion Control algorithm to use for
+ // this endpoint.
+ cc CongestionControlOption
+
+ // The following are used when a "packet too big" control packet is
+ // received. They are protected by sndBufMu. They are used to
+ // communicate to the main protocol goroutine how many such control
+ // messages have been received since the last notification was processed
+ // and what was the smallest MTU seen.
+ packetTooBigCount int
+ sndMTU int
+
// newSegmentWaker is used to indicate to the protocol goroutine that
// it needs to wake up and handle new segments queued to it.
newSegmentWaker sleep.Waker
@@ -161,18 +212,11 @@
// goroutine what it was notified; this is only accessed atomically.
notifyFlags uint32
- // The following fields manage TCP keepalive state. When the connection
- // is idle (no data sent or received) for keepaliveIdle, we start
- // sending keepalives every keepaliveInterval. If we send
- // keepaliveCount without hearing a response, the connection is closed.
- keepaliveMu sync.Mutex
- keepaliveEnabled bool
- keepaliveIdle time.Duration
- keepaliveInterval time.Duration
- keepaliveCount int
- keepalivesUnacked int
- keepaliveTimer timer
- keepaliveWaker sleep.Waker
+ // keepalive manages TCP keepalive state. When the connection is idle
+ // (no data sent or received) for keepaliveIdle, we start sending
+ // keepalives every keepalive.interval. If we send keepalive.count
+ // without hearing a response, the connection is closed.
+ keepalive keepalive
// acceptedChan is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
@@ -183,6 +227,35 @@
// therefore don't need locks to protect them.
rcv *receiver
snd *sender
+
+ // The goroutine drain completion notification channel.
+ drainDone chan struct{}
+
+ // The goroutine undrain notification channel.
+ undrain chan struct{}
+
+ // probe if not nil is invoked on every received segment. It is passed
+ // a copy of the current state of the endpoint.
+ probe stack.TCPProbeFunc
+
+ // The following are only used to assist the restore run to re-connect.
+ bindAddress tcpip.Address
+ connectingAddress tcpip.Address
+}
+
+// keepalive is a synchronization wrapper used to appease stateify. See the
+// comment in endpoint, where it is used.
+//
+// +stateify savable
+type keepalive struct {
+ sync.Mutex
+ enabled bool
+ idle time.Duration
+ interval time.Duration
+ count int
+ unacked int
+ timer timer
+ waker sleep.Waker
}
func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
@@ -192,13 +265,36 @@
waiterQueue: waiterQueue,
rcvBufSize: DefaultBufferSize,
sndBufSize: DefaultBufferSize,
- noDelay: true,
+ sndMTU: int(math.MaxInt32),
+ noDelay: false,
reuseAddr: true,
- // Linux defaults.
- keepaliveIdle: 2 * time.Hour,
- keepaliveInterval: 75 * time.Second,
- keepaliveCount: 9,
+ keepalive: keepalive{
+ // Linux defaults.
+ idle: 2 * time.Hour,
+ interval: 75 * time.Second,
+ count: 9,
+ },
}
+
+ var ss SendBufferSizeOption
+ if err := stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ e.sndBufSize = ss.Default
+ }
+
+ var rs ReceiveBufferSizeOption
+ if err := stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ e.rcvBufSize = rs.Default
+ }
+
+ var cs CongestionControlOption
+ if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil {
+ e.cc = cs
+ }
+
+ if p := stack.GetTCPProbe(); p != nil {
+ e.probe = p
+ }
+
e.segmentQueue.setLimit(2 * e.rcvBufSize)
e.workMu.Init()
e.workMu.Lock()
@@ -234,7 +330,7 @@
// Determine if the endpoint is writable if requested.
if (mask & waiter.EventOut) != 0 {
e.sndBufMu.Lock()
- if e.sndBufSize < 0 || e.sndBufUsed <= e.sndBufSize {
+ if e.sndClosed || e.sndBufUsed < e.sndBufSize {
result |= waiter.EventOut
}
e.sndBufMu.Unlock()
@@ -285,13 +381,7 @@
// if we're connected, or stop accepting if we're listening.
e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
- // While we hold the lock, determine if the cleanup should happen
- // inline or if we should tell the worker (if any) to do the cleanup.
e.mu.Lock()
- worker := e.workerRunning
- if worker {
- e.workerCleanup = true
- }
// We always release ports inline so that they are immediately available
// for reuse after Close() is called. If also registered, it means this
@@ -307,58 +397,70 @@
}
}
- e.mu.Unlock()
-
- // Now that we don't hold the lock anymore, either perform the local
- // cleanup or kick the worker to make sure it knows it needs to cleanup.
- if !worker {
- e.cleanup()
+ // Either perform the local cleanup or kick the worker to make sure it
+ // knows it needs to cleanup.
+ tcpip.AddDanglingEndpoint(e)
+ if !e.workerRunning {
+ e.cleanupLocked()
} else {
+ e.workerCleanup = true
e.notifyProtocolGoroutine(notifyClose)
}
+
+ e.mu.Unlock()
}
-// cleanup frees all resources associated with the endpoint. It is called after
-// Close() is called and the worker goroutine (if any) is done with its work.
-func (e *endpoint) cleanup() {
+// cleanupLocked frees all resources associated with the endpoint. It is called
+// after Close() is called and the worker goroutine (if any) is done with its
+// work.
+func (e *endpoint) cleanupLocked() {
// Close all endpoints that might have been accepted by TCP but not by
// the client.
if e.acceptedChan != nil {
close(e.acceptedChan)
for n := range e.acceptedChan {
- n.resetConnection(tcpip.ErrConnectionAborted)
+ n.mu.Lock()
+ n.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ n.mu.Unlock()
n.Close()
}
+ e.acceptedChan = nil
}
+ e.workerCleanup = false
if e.isRegistered {
e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
}
e.route.Release()
+ tcpip.DeleteDanglingEndpoint(e)
}
// Read reads data from the endpoint.
-func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, *tcpip.Error) {
+func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.mu.RLock()
-
// The endpoint can be read if it's connected, or if it's already closed
- // but has some pending unread data.
- if s := e.state; s != stateConnected && s != stateClosed {
+ // but has some pending unread data. Also note that a RST being received
+ // would cause the state to become stateError so we should allow the
+ // reads to proceed before returning a ECONNRESET.
+ e.rcvListMu.Lock()
+ bufUsed := e.rcvBufUsed
+ if s := e.state; s != stateConnected && s != stateClosed && bufUsed == 0 {
+ e.rcvListMu.Unlock()
+ he := e.hardError
e.mu.RUnlock()
if s == stateError {
- return buffer.View{}, e.hardError
+ return buffer.View{}, tcpip.ControlMessages{}, he
}
- return buffer.View{}, tcpip.ErrInvalidEndpointState
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
}
- e.rcvListMu.Lock()
v, err := e.readLocked()
e.rcvListMu.Unlock()
e.mu.RUnlock()
- return v, err
+ return v, tcpip.ControlMessages{}, err
}
func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
@@ -390,9 +492,10 @@
}
// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(v buffer.View, to *tcpip.FullAddress) (uintptr, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) {
// Linux completely ignores any address passed to sendto(2) for TCP sockets
- // (without the MSG_FASTOPEN flag).
+ // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
+ // and opts.EndOfRecord are also ignored.
e.mu.RLock()
defer e.mu.RUnlock()
@@ -408,31 +511,41 @@
}
// Nothing to do if the buffer is empty.
- if len(v) == 0 {
+ if p.Size() == 0 {
return 0, nil
}
- s := newSegmentFromView(&e.route, e.id, v)
-
e.sndBufMu.Lock()
// Check if the connection has already been closed for sends.
- if e.sndBufSize < 0 {
+ if e.sndClosed {
e.sndBufMu.Unlock()
- s.decRef()
return 0, tcpip.ErrClosedForSend
}
- // Check if we're already over the limit.
- if e.sndBufUsed > e.sndBufSize {
+ // Check against the limit.
+ avail := e.sndBufSize - e.sndBufUsed
+ if avail <= 0 {
e.sndBufMu.Unlock()
- s.decRef()
return 0, tcpip.ErrWouldBlock
}
+ v, perr := p.Get(avail)
+ if perr != nil {
+ e.sndBufMu.Unlock()
+ return 0, perr
+ }
+
+ var err *tcpip.Error
+ if p.Size() > avail {
+ err = tcpip.ErrWouldBlock
+ }
+ l := len(v)
+ s := newSegmentFromView(&e.route, e.id, v)
+
// Add data to the send queue.
- e.sndBufUsed += len(v)
- e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndBufUsed += l
+ e.sndBufInQueue += seqnum.Size(l)
e.sndQueue.PushBack(s)
e.sndBufMu.Unlock()
@@ -445,14 +558,13 @@
// Let the protocol goroutine do the work.
e.sndWaker.Assert()
}
-
- return uintptr(len(v)), nil
+ return uintptr(l), err
}
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
-func (e *endpoint) Peek(vec [][]byte) (uintptr, *tcpip.Error) {
+func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
@@ -460,9 +572,9 @@
// but has some pending unread data.
if s := e.state; s != stateConnected && s != stateClosed {
if s == stateError {
- return 0, e.hardError
+ return 0, tcpip.ControlMessages{}, e.hardError
}
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
}
e.rcvListMu.Lock()
@@ -470,9 +582,9 @@
if e.rcvBufUsed == 0 {
if e.rcvClosed || e.state != stateConnected {
- return 0, tcpip.ErrClosedForReceive
+ return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
}
- return 0, tcpip.ErrWouldBlock
+ return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
}
// Make a copy of vec so we can modify the slide headers.
@@ -488,7 +600,7 @@
for len(v) > 0 {
if len(vec) == 0 {
- return num, nil
+ return num, tcpip.ControlMessages{}, nil
}
if len(vec[0]) == 0 {
vec = vec[1:]
@@ -503,7 +615,7 @@
}
}
- return num, nil
+ return num, tcpip.ControlMessages{}, nil
}
// zeroReceiveWindow checks if the receive window to be announced now would be
@@ -534,6 +646,19 @@
return nil
case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs ReceiveBufferSizeOption
+ size := int(v)
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if size < rs.Min {
+ size = rs.Min
+ }
+ if size > rs.Max {
+ size = rs.Max
+ }
+ }
+
mask := uint32(notifyReceiveWindowChanged)
e.rcvListMu.Lock()
@@ -544,27 +669,47 @@
if e.rcv != nil {
scale = e.rcv.rcvWndScale
}
- if v>>scale == 0 {
- v = 1 << scale
+ if size>>scale == 0 {
+ size = 1 << scale
}
- // Make sure 2*v doesn't overflow.
- if int(v) > math.MaxInt32/2 {
- v = math.MaxInt32 / 2
+ // Make sure 2*size doesn't overflow.
+ if size > math.MaxInt32/2 {
+ size = math.MaxInt32 / 2
}
wasZero := e.zeroReceiveWindow(scale)
- e.rcvBufSize = int(v)
+ e.rcvBufSize = size
if wasZero && !e.zeroReceiveWindow(scale) {
mask |= notifyNonZeroReceiveWindow
}
e.rcvListMu.Unlock()
- e.segmentQueue.setLimit(2 * int(v))
+ e.segmentQueue.setLimit(2 * size)
e.notifyProtocolGoroutine(mask)
return nil
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ size := int(v)
+ var ss SendBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if size < ss.Min {
+ size = ss.Min
+ }
+ if size > ss.Max {
+ size = ss.Max
+ }
+ }
+
+ e.sndBufMu.Lock()
+ e.sndBufSize = size
+ e.sndBufMu.Unlock()
+
+ return nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.netProto != header.IPv6ProtocolNumber {
@@ -582,27 +727,27 @@
e.v6only = v != 0
case tcpip.KeepaliveEnabledOption:
- e.keepaliveMu.Lock()
- e.keepaliveEnabled = v != 0
- e.keepaliveMu.Unlock()
+ e.keepalive.Lock()
+ e.keepalive.enabled = v != 0
+ e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
case tcpip.KeepaliveIdleOption:
- e.keepaliveMu.Lock()
- e.keepaliveIdle = time.Duration(v)
- e.keepaliveMu.Unlock()
+ e.keepalive.Lock()
+ e.keepalive.idle = time.Duration(v)
+ e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
case tcpip.KeepaliveIntervalOption:
- e.keepaliveMu.Lock()
- e.keepaliveInterval = time.Duration(v)
- e.keepaliveMu.Unlock()
+ e.keepalive.Lock()
+ e.keepalive.interval = time.Duration(v)
+ e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
case tcpip.KeepaliveCountOption:
- e.keepaliveMu.Lock()
- e.keepaliveCount = int(v)
- e.keepaliveMu.Unlock()
+ e.keepalive.Lock()
+ e.keepalive.count = int(v)
+ e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
}
@@ -644,7 +789,7 @@
case *tcpip.ReceiveBufferSizeOption:
e.rcvListMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize * 2)
+ *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize)
e.rcvListMu.Unlock()
return nil
@@ -695,23 +840,24 @@
}
return nil
- case *tcpip.InfoOption:
+ case *tcpip.TCPInfoOption:
+ *o = tcpip.TCPInfoOption{}
e.mu.RLock()
- if e.snd != nil {
- o.Rtt = e.snd.srtt
- o.Rttvar = e.snd.rttvar
- } else {
- o.Rtt = 0
- o.Rttvar = 0
- }
+ snd := e.snd
e.mu.RUnlock()
+ if snd != nil {
+ snd.rtt.Lock()
+ o.RTT = snd.rtt.srtt
+ o.RTTVar = snd.rtt.rttvar
+ snd.rtt.Unlock()
+ }
return nil
case *tcpip.KeepaliveEnabledOption:
- e.mu.RLock()
- v := e.keepaliveEnabled
- e.mu.RUnlock()
+ e.keepalive.Lock()
+ v := e.keepalive.enabled
+ e.keepalive.Unlock()
*o = 0
if v {
@@ -719,19 +865,19 @@
}
case *tcpip.KeepaliveIdleOption:
- e.keepaliveMu.Lock()
- *o = tcpip.KeepaliveIdleOption(e.keepaliveIdle)
- e.keepaliveMu.Unlock()
+ e.keepalive.Lock()
+ *o = tcpip.KeepaliveIdleOption(e.keepalive.idle)
+ e.keepalive.Unlock()
case *tcpip.KeepaliveIntervalOption:
- e.keepaliveMu.Lock()
- *o = tcpip.KeepaliveIntervalOption(e.keepaliveInterval)
- e.keepaliveMu.Unlock()
+ e.keepalive.Lock()
+ *o = tcpip.KeepaliveIntervalOption(e.keepalive.interval)
+ e.keepalive.Unlock()
case *tcpip.KeepaliveCountOption:
- e.keepaliveMu.Lock()
- *o = tcpip.KeepaliveCountOption(e.keepaliveCount)
- e.keepaliveMu.Unlock()
+ e.keepalive.Lock()
+ *o = tcpip.KeepaliveCountOption(e.keepalive.count)
+ e.keepalive.Unlock()
}
@@ -755,7 +901,7 @@
// Fail if we're bound to an address length different from the one we're
// checking.
- if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) {
+ if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -763,15 +909,27 @@
}
// Connect connects the endpoint to its peer.
-func (e *endpoint) Connect(addr tcpip.FullAddress) (err *tcpip.Error) {
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ return e.connect(addr, true, true)
+}
+
+// connect connects the endpoint to its peer. In the normal non-S/R case, the
+// new connection is expected to run the main goroutine and perform handshake.
+// In restore of previously connected endpoints, both ends will be passively
+// created (so no new handshaking is done); for stack-accepted connections not
+// yet accepted by the app, they are restored without running the main goroutine
+// here.
+func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
defer func() {
- if !err.IgnoreStats {
- atomic.AddUint64(&e.stack.MutableStats().TCP.FailedConnectionAttempts, 1)
+ if err != nil && !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
}
}()
+ connectingAddr := addr.Addr
+
netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
@@ -802,9 +960,17 @@
return tcpip.ErrAlreadyConnecting
case stateConnected:
- // The endpoint is already connected.
+ // The endpoint is already connected. If caller hasn't been notified yet, return success.
+ if !e.isConnectNotified {
+ e.isConnectNotified = true
+ return nil
+ }
+ // Otherwise return that it's already connected.
return tcpip.ErrAlreadyConnected
+ case stateError:
+ return e.hardError
+
default:
return tcpip.ErrInvalidEndpointState
}
@@ -820,7 +986,7 @@
netProtos := []tcpip.NetworkProtocolNumber{netProto}
e.id.LocalAddress = r.LocalAddress
- e.id.RemoteAddress = addr.Addr
+ e.id.RemoteAddress = r.RemoteAddress
e.id.RemotePort = addr.Port
if e.id.LocalPort != 0 {
@@ -831,28 +997,30 @@
}
} else {
// The endpoint doesn't have a local port yet, so try to get
- // one.
- _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- e.id.LocalPort = p
- // If the local address and the remote address are the same address, the local port
- // has to be different from the remote port.
- if e.id.LocalAddress == e.id.RemoteAddress && e.id.LocalPort == e.id.RemotePort {
+ // one. Make sure that it isn't one that will result in the same
+ // address/port for both local and remote (otherwise this
+ // endpoint would be trying to connect to itself).
+ sameAddr := e.id.LocalAddress == e.id.RemoteAddress
+ if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ if sameAddr && p == e.id.RemotePort {
return false, nil
}
- if e.stack.IsPortReserved(netProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) {
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p) {
return false, nil
}
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e)
- switch err {
+
+ id := e.id
+ id.LocalPort = p
+ switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) {
case nil:
+ e.id = id
return true, nil
case tcpip.ErrPortInUse:
return false, nil
default:
return false, err
}
- })
- if err != nil {
+ }); err != nil {
return err
}
}
@@ -870,10 +1038,29 @@
e.route = r.Clone()
e.boundNICID = nicid
e.effectiveNetProtos = netProtos
- e.workerRunning = true
+ e.connectingAddress = connectingAddr
- atomic.AddUint64(&e.stack.MutableStats().TCP.ActiveConnectionOpenings, 1)
- go e.protocolMainLoop(false)
+ // Connect in the restore phase does not perform handshake. Restore its
+ // connection setting here.
+ if !handshake {
+ e.segmentQueue.mu.Lock()
+ for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} {
+ for s := l.Front(); s != nil; s = s.Next() {
+ s.id = e.id
+ s.route = r.Clone()
+ e.sndWaker.Assert()
+ }
+ }
+ e.segmentQueue.mu.Unlock()
+ e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
+ e.state = stateConnected
+ }
+
+ if run {
+ e.workerRunning = true
+ e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
+ go e.protocolMainLoop(handshake)
+ }
return tcpip.ErrConnectStarted
}
@@ -888,14 +1075,28 @@
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
+ e.shutdownFlags |= flags
switch e.state {
case stateConnected:
// Close for write.
- if (flags & tcpip.ShutdownWrite) != 0 {
+ if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 {
+ if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
+ // We're fully closed, if we have unread data we need to abort
+ // the connection with a RST.
+ e.rcvListMu.Lock()
+ rcvBufUsed := e.rcvBufUsed
+ e.rcvListMu.Unlock()
+
+ if rcvBufUsed > 0 {
+ e.notifyProtocolGoroutine(notifyReset)
+ return nil
+ }
+ }
+
e.sndBufMu.Lock()
- if e.sndBufSize < 0 {
+ if e.sndClosed {
// Already closed.
e.sndBufMu.Unlock()
break
@@ -907,7 +1108,7 @@
e.sndBufInQueue++
// Mark endpoint as closed.
- e.sndBufSize = -1
+ e.sndClosed = true
e.sndBufMu.Unlock()
@@ -922,7 +1123,7 @@
}
default:
- return tcpip.ErrInvalidEndpointState
+ return tcpip.ErrNotConnected
}
return nil
@@ -934,8 +1135,8 @@
e.mu.Lock()
defer e.mu.Unlock()
defer func() {
- if err != nil && !err.IgnoreStats {
- atomic.AddUint64(&e.stack.MutableStats().TCP.FailedConnectionAttempts, 1)
+ if err != nil && !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
}
}()
@@ -949,6 +1150,9 @@
if len(e.acceptedChan) > backlog {
return tcpip.ErrInvalidEndpointState
}
+ if cap(e.acceptedChan) == backlog {
+ return nil
+ }
origChan := e.acceptedChan
e.acceptedChan = make(chan *endpoint, backlog)
close(origChan)
@@ -970,11 +1174,14 @@
e.isRegistered = true
e.state = stateListen
- e.acceptedChan = make(chan *endpoint, backlog)
+ if e.acceptedChan == nil {
+ e.acceptedChan = make(chan *endpoint, backlog)
+ }
e.workerRunning = true
- atomic.AddUint64(&e.stack.MutableStats().TCP.PassiveConnectionOpenings, 1)
- go e.protocolListenLoop(seqnum.Size(e.receiveBufferAvailable()))
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+ go e.protocolListenLoop(
+ seqnum.Size(e.receiveBufferAvailable()))
return nil
}
@@ -984,7 +1191,7 @@
func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
e.waiterQueue = waiterQueue
e.workerRunning = true
- go e.protocolMainLoop(true)
+ go e.protocolMainLoop(false)
}
// Accept returns a new endpoint if a peer has established a connection
@@ -1014,7 +1221,7 @@
}
// Bind binds the endpoint to a specific local port and optionally address.
-func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (retErr *tcpip.Error) {
+func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (err *tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
@@ -1025,6 +1232,7 @@
return tcpip.ErrAlreadyBound
}
+ e.bindAddress = addr.Addr
netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
@@ -1041,7 +1249,6 @@
}
}
- // Reserve the port.
port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port)
if err != nil {
return err
@@ -1053,7 +1260,7 @@
// Any failures beyond this point must remove the port registration.
defer func() {
- if retErr != nil {
+ if err != nil {
e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port)
e.isPortReserved = false
e.effectiveNetProtos = nil
@@ -1066,7 +1273,7 @@
// If an address is specified, we must ensure that it's one of our
// local addresses.
if len(addr.Addr) != 0 {
- nic := e.stack.CheckLocalAddress(addr.NIC, addr.Addr)
+ nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nic == 0 {
return tcpip.ErrBadLocalAddress
}
@@ -1119,18 +1326,18 @@
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
s := newSegment(r, id, vv)
if !s.parse() {
- atomic.AddUint64(&e.stack.MutableStats().MalformedRcvdPackets, 1)
- atomic.AddUint64(&e.stack.MutableStats().TCP.InvalidSegmentsReceived, 1)
+ e.stack.Stats().MalformedRcvdPackets.Increment()
+ e.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
s.decRef()
return
}
- atomic.AddUint64(&e.stack.MutableStats().TCP.ValidSegmentsReceived, 1)
+ e.stack.Stats().TCP.ValidSegmentsReceived.Increment()
if (s.flags & flagRst) != 0 {
- atomic.AddUint64(&e.stack.MutableStats().TCP.ResetsReceived, 1)
+ e.stack.Stats().TCP.ResetsReceived.Increment()
}
// Send packet to worker goroutine.
@@ -1138,18 +1345,36 @@
e.newSegmentWaker.Assert()
} else {
// The queue is full, so we drop the segment.
- atomic.AddUint64(&e.stack.MutableStats().DroppedPackets, 1)
+ e.stack.Stats().DroppedPackets.Increment()
s.decRef()
}
}
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+ switch typ {
+ case stack.ControlPacketTooBig:
+ e.sndBufMu.Lock()
+ e.packetTooBigCount++
+ if v := int(extra); v < e.sndMTU {
+ e.sndMTU = v
+ }
+ e.sndBufMu.Unlock()
+
+ e.notifyProtocolGoroutine(notifyMTUChanged)
+ }
+}
+
// updateSndBufferUsage is called by the protocol goroutine when room opens up
// in the send buffer. The number of newly available bytes is v.
func (e *endpoint) updateSndBufferUsage(v int) {
e.sndBufMu.Lock()
- notify := e.sndBufUsed > e.sndBufSize
+ notify := e.sndBufUsed >= e.sndBufSize>>1
e.sndBufUsed -= v
- notify = notify && e.sndBufUsed <= e.sndBufSize
+ // We only notify when there is half the sndBufSize available after
+ // a full buffer event occurs. This ensures that we don't wake up
+ // writers to queue just 1-2 segments and go back to sleep.
+ notify = notify && e.sndBufUsed < e.sndBufSize>>1
e.sndBufMu.Unlock()
if notify {
@@ -1249,3 +1474,109 @@
// randomized per connection basis. But for now this is sufficient.
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
}
+
+// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
+// if the SYN options indicate that the SACK option was negotiated and the TCP
+// stack is configured to enable TCP SACK option.
+func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
+ var v SACKEnabled
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
+ // Stack doesn't support SACK. So just return.
+ return
+ }
+ if bool(v) && synOpts.SACKPermitted {
+ e.sackPermitted = true
+ }
+}
+
+// completeState makes a full copy of the endpoint and returns it. This is used
+// before invoking the probe. The state returned may not be fully consistent if
+// there are intervening syscalls when the state is being copied.
+func (e *endpoint) completeState() stack.TCPEndpointState {
+ var s stack.TCPEndpointState
+ s.SegTime = time.Now()
+
+ // Copy EndpointID.
+ e.mu.Lock()
+ s.ID = stack.TCPEndpointID(e.id)
+ e.mu.Unlock()
+
+ // Copy endpoint rcv state.
+ e.rcvListMu.Lock()
+ s.RcvBufSize = e.rcvBufSize
+ s.RcvBufUsed = e.rcvBufUsed
+ s.RcvClosed = e.rcvClosed
+ e.rcvListMu.Unlock()
+
+ // Endpoint TCP Option state.
+ s.SendTSOk = e.sendTSOk
+ s.RecentTS = e.recentTS
+ s.TSOffset = e.tsOffset
+ s.SACKPermitted = e.sackPermitted
+ s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks)
+ copy(s.SACK.Blocks, e.sack.Blocks[:e.sack.NumBlocks])
+
+ // Copy endpoint send state.
+ e.sndBufMu.Lock()
+ s.SndBufSize = e.sndBufSize
+ s.SndBufUsed = e.sndBufUsed
+ s.SndClosed = e.sndClosed
+ s.SndBufInQueue = e.sndBufInQueue
+ s.PacketTooBigCount = e.packetTooBigCount
+ s.SndMTU = e.sndMTU
+ e.sndBufMu.Unlock()
+
+ // Copy receiver state.
+ s.Receiver = stack.TCPReceiverState{
+ RcvNxt: e.rcv.rcvNxt,
+ RcvAcc: e.rcv.rcvAcc,
+ RcvWndScale: e.rcv.rcvWndScale,
+ PendingBufUsed: e.rcv.pendingBufUsed,
+ PendingBufSize: e.rcv.pendingBufSize,
+ }
+
+ // Copy sender state.
+ s.Sender = stack.TCPSenderState{
+ LastSendTime: e.snd.lastSendTime,
+ DupAckCount: e.snd.dupAckCount,
+ FastRecovery: stack.TCPFastRecoveryState{
+ Active: e.snd.fr.active,
+ First: e.snd.fr.first,
+ Last: e.snd.fr.last,
+ MaxCwnd: e.snd.fr.maxCwnd,
+ },
+ SndCwnd: e.snd.sndCwnd,
+ Ssthresh: e.snd.sndSsthresh,
+ SndCAAckCount: e.snd.sndCAAckCount,
+ Outstanding: e.snd.outstanding,
+ SndWnd: e.snd.sndWnd,
+ SndUna: e.snd.sndUna,
+ SndNxt: e.snd.sndNxt,
+ RTTMeasureSeqNum: e.snd.rttMeasureSeqNum,
+ RTTMeasureTime: e.snd.rttMeasureTime,
+ Closed: e.snd.closed,
+ RTO: e.snd.rto,
+ SRTTInited: e.snd.srttInited,
+ MaxPayloadSize: e.snd.maxPayloadSize,
+ SndWndScale: e.snd.sndWndScale,
+ MaxSentAck: e.snd.maxSentAck,
+ }
+ e.snd.rtt.Lock()
+ s.Sender.SRTT = e.snd.rtt.srtt
+ e.snd.rtt.Unlock()
+
+ if cubic, ok := e.snd.cc.(*cubicState); ok {
+ s.Sender.Cubic = stack.TCPCubicState{
+ WMax: cubic.wMax,
+ WLastMax: cubic.wLastMax,
+ T: cubic.t,
+ TimeSinceLastCongestion: time.Since(cubic.t),
+ C: cubic.c,
+ K: cubic.k,
+ Beta: cubic.beta,
+ WC: cubic.wC,
+ WEst: cubic.wEst,
+ }
+ }
+ return s
+}
diff --git a/tcpip/transport/tcp/forwarder.go b/tcpip/transport/tcp/forwarder.go
index f9e9ea0..35296c6 100644
--- a/tcpip/transport/tcp/forwarder.go
+++ b/tcpip/transport/tcp/forwarder.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package tcp
@@ -53,7 +63,7 @@
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool {
s := newSegment(r, id, vv)
defer s.decRef()
@@ -105,7 +115,7 @@
return r.segment.id
}
-// Complete completes the request, and optinally sends a RST segment back to the
+// Complete completes the request, and optionally sends a RST segment back to the
// sender.
func (r *ForwarderRequest) Complete(sendReset bool) {
r.mu.Lock()
@@ -143,11 +153,12 @@
f := r.forwarder
ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{
- MSS: r.synOptions.MSS,
- WS: r.synOptions.WS,
- TS: r.synOptions.TS,
- TSVal: r.synOptions.TSVal,
- TSEcr: r.synOptions.TSEcr,
+ MSS: r.synOptions.MSS,
+ WS: r.synOptions.WS,
+ TS: r.synOptions.TS,
+ TSVal: r.synOptions.TSVal,
+ TSEcr: r.synOptions.TSEcr,
+ SACKPermitted: r.synOptions.SACKPermitted,
})
if err != nil {
return nil, err
diff --git a/tcpip/transport/tcp/protocol.go b/tcpip/transport/tcp/protocol.go
index d291958..09796c4 100644
--- a/tcpip/transport/tcp/protocol.go
+++ b/tcpip/transport/tcp/protocol.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package tcp contains the implementation of the TCP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
@@ -11,6 +21,9 @@
package tcp
import (
+ "strings"
+ "sync"
+
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
@@ -25,9 +38,58 @@
// ProtocolNumber is the tcp protocol number.
ProtocolNumber = header.TCPProtocolNumber
+
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ minBufferSize = 4 << 10 // 4096 bytes.
+
+ // DefaultBufferSize is the default size of the receive and send buffers.
+ DefaultBufferSize = 1 << 20 // 1MB
+
+ // MaxBufferSize is the largest size a receive and send buffer can grow to.
+ maxBufferSize = 4 << 20 // 4MB
)
-type protocol struct{}
+// SACKEnabled option can be used to enable SACK support in the TCP
+// protocol. See: https://tools.ietf.org/html/rfc2018.
+type SACKEnabled bool
+
+// SendBufferSizeOption allows the default, min and max send buffer sizes for
+// TCP endpoints to be queried or configured.
+type SendBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// ReceiveBufferSizeOption allows the default, min and max receive buffer size
+// for TCP endpoints to be queried or configured.
+type ReceiveBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+const (
+ ccReno = "reno"
+ ccCubic = "cubic"
+)
+
+// CongestionControlOption sets the current congestion control algorithm.
+type CongestionControlOption string
+
+// AvailableCongestionControlOption returns the supported congestion control
+// algorithms.
+type AvailableCongestionControlOption string
+
+type protocol struct {
+ mu sync.Mutex
+ sackEnabled bool
+ sendBufferSize SendBufferSizeOption
+ recvBufferSize ReceiveBufferSizeOption
+ congestionControl string
+ availableCongestionControl []string
+ allowedCongestionControl []string
+}
// Number returns the tcp protocol number.
func (*protocol) Number() tcpip.TransportProtocolNumber {
@@ -58,7 +120,7 @@
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) bool {
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool {
s := newSegment(r, id, vv)
defer s.decRef()
@@ -85,16 +147,93 @@
ack := s.sequenceNumber.Add(s.logicalLen())
- sendTCP(&s.route, s.id, nil, s.route.DefaultTTL(), flagRst|flagAck, seq, ack, 0)
+ sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), flagRst|flagAck, seq, ack, 0, nil)
}
// SetOption implements TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch v := option.(type) {
+ case SACKEnabled:
+ p.mu.Lock()
+ p.sackEnabled = bool(v)
+ p.mu.Unlock()
+ return nil
+
+ case SendBufferSizeOption:
+ if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.sendBufferSize = v
+ p.mu.Unlock()
+ return nil
+
+ case ReceiveBufferSizeOption:
+ if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.recvBufferSize = v
+ p.mu.Unlock()
+ return nil
+
+ case CongestionControlOption:
+ for _, c := range p.availableCongestionControl {
+ if string(v) == c {
+ p.mu.Lock()
+ p.congestionControl = string(v)
+ p.mu.Unlock()
+ return nil
+ }
+ }
+ return tcpip.ErrInvalidOptionValue
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Option implements TransportProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *SACKEnabled:
+ p.mu.Lock()
+ *v = SACKEnabled(p.sackEnabled)
+ p.mu.Unlock()
+ return nil
+
+ case *SendBufferSizeOption:
+ p.mu.Lock()
+ *v = p.sendBufferSize
+ p.mu.Unlock()
+ return nil
+
+ case *ReceiveBufferSizeOption:
+ p.mu.Lock()
+ *v = p.recvBufferSize
+ p.mu.Unlock()
+ return nil
+ case *CongestionControlOption:
+ p.mu.Lock()
+ *v = CongestionControlOption(p.congestionControl)
+ p.mu.Unlock()
+ return nil
+ case *AvailableCongestionControlOption:
+ p.mu.Lock()
+ *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
+ p.mu.Unlock()
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
func init() {
stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
- return &protocol{}
+ return &protocol{
+ sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
+ recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
+ congestionControl: ccReno,
+ availableCongestionControl: []string{ccReno, ccCubic},
+ }
})
}
diff --git a/tcpip/transport/tcp/rcv.go b/tcpip/transport/tcp/rcv.go
index 3418201..a96a57a 100644
--- a/tcpip/transport/tcp/rcv.go
+++ b/tcpip/transport/tcp/rcv.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package tcp
@@ -12,6 +22,8 @@
// receiver holds the state necessary to receive TCP segments and turn them
// into a stream of bytes.
+//
+// +stateify savable
type receiver struct {
ep *endpoint
@@ -83,7 +95,7 @@
r.ep.snd.sendAck()
}
-// consumeSegment attemps to consume a segment that was received by r. The
+// consumeSegment attempts to consume a segment that was received by r. The
// segment may have just been received or may have been received earlier but
// wasn't ready to be consumed then.
//
@@ -116,6 +128,11 @@
// Update the segment that we're expecting to consume.
r.rcvNxt = segSeq.Add(segLen)
+
+ // Trim SACK Blocks to remove any SACK information that covers
+ // sequence numbers that have been consumed.
+ TrimSACKBlockList(&r.ep.sack, r.rcvNxt)
+
if s.flagIsSet(flagFin) {
r.rcvNxt++
@@ -173,6 +190,8 @@
heap.Push(&r.pendingRcvdSegments, s)
}
+ UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
+
// Immediately send an ack so that the peer knows it may
// have to retransmit.
r.ep.snd.sendAck()
diff --git a/tcpip/transport/tcp/reno.go b/tcpip/transport/tcp/reno.go
new file mode 100644
index 0000000..feb5932
--- /dev/null
+++ b/tcpip/transport/tcp/reno.go
@@ -0,0 +1,103 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+// renoState stores the variables related to TCP New Reno congestion
+// control algorithm.
+//
+// +stateify savable
+type renoState struct {
+ s *sender
+}
+
+// newRenoCC initializes the state for the NewReno congestion control algorithm.
+func newRenoCC(s *sender) *renoState {
+ return &renoState{s: s}
+}
+
+// updateSlowStart will update the congestion window as per the slow-start
+// algorithm used by NewReno. If after adjusting the congestion window
+// we cross the SSthreshold then it will return the number of packets that
+// must be consumed in congestion avoidance mode.
+func (r *renoState) updateSlowStart(packetsAcked int) int {
+ // Don't let the congestion window cross into the congestion
+ // avoidance range.
+ newcwnd := r.s.sndCwnd + packetsAcked
+ if newcwnd >= r.s.sndSsthresh {
+ newcwnd = r.s.sndSsthresh
+ r.s.sndCAAckCount = 0
+ }
+
+ packetsAcked -= newcwnd - r.s.sndCwnd
+ r.s.sndCwnd = newcwnd
+ return packetsAcked
+}
+
+// updateCongestionAvoidance will update congestion window in congestion
+// avoidance mode as described in RFC5681 section 3.1
+func (r *renoState) updateCongestionAvoidance(packetsAcked int) {
+ // Consume the packets in congestion avoidance mode.
+ r.s.sndCAAckCount += packetsAcked
+ if r.s.sndCAAckCount >= r.s.sndCwnd {
+ r.s.sndCwnd += r.s.sndCAAckCount / r.s.sndCwnd
+ r.s.sndCAAckCount = r.s.sndCAAckCount % r.s.sndCwnd
+ }
+}
+
+// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681,
+// page 6, eq. 4. It is called when we detect congestion in the network.
+func (r *renoState) reduceSlowStartThreshold() {
+ r.s.sndSsthresh = r.s.outstanding / 2
+ if r.s.sndSsthresh < 2 {
+ r.s.sndSsthresh = 2
+ }
+
+}
+
+// Update updates the congestion state based on the number of packets that
+// were acknowledged.
+// Update implements congestionControl.Update.
+func (r *renoState) Update(packetsAcked int) {
+ if r.s.sndCwnd < r.s.sndSsthresh {
+ packetsAcked = r.updateSlowStart(packetsAcked)
+ if packetsAcked == 0 {
+ return
+ }
+ }
+ r.updateCongestionAvoidance(packetsAcked)
+}
+
+// HandleNDupAcks implements congestionControl.HandleNDupAcks.
+func (r *renoState) HandleNDupAcks() {
+ // A retransmit was triggered due to nDupAckThreshold
+ // being hit. Reduce our slow start threshold.
+ r.reduceSlowStartThreshold()
+}
+
+// HandleRTOExpired implements congestionControl.HandleRTOExpired.
+func (r *renoState) HandleRTOExpired() {
+ // We lost a packet, so reduce ssthresh.
+ r.reduceSlowStartThreshold()
+
+ // Reduce the congestion window to 1, i.e., enter slow-start. Per
+ // RFC 5681, page 7, we must use 1 regardless of the value of the
+ // initial congestion window.
+ r.s.sndCwnd = 1
+}
+
+// PostRecovery implements congestionControl.PostRecovery.
+func (r *renoState) PostRecovery() {
+ // noop.
+}
diff --git a/tcpip/transport/tcp/sack.go b/tcpip/transport/tcp/sack.go
new file mode 100644
index 0000000..9d7947e
--- /dev/null
+++ b/tcpip/transport/tcp/sack.go
@@ -0,0 +1,99 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "github.com/google/netstack/tcpip/header"
+ "github.com/google/netstack/tcpip/seqnum"
+)
+
+const (
+ // MaxSACKBlocks is the maximum number of SACK blocks stored
+ // at receiver side.
+ MaxSACKBlocks = 6
+)
+
+// UpdateSACKBlocks updates the list of SACK blocks to include the segment
+// specified by segStart->segEnd. If the segment happens to be an out of order
+// delivery then the first block in the sack.blocks always includes the
+// segment identified by segStart->segEnd.
+func UpdateSACKBlocks(sack *SACKInfo, segStart seqnum.Value, segEnd seqnum.Value, rcvNxt seqnum.Value) {
+ newSB := header.SACKBlock{Start: segStart, End: segEnd}
+ if sack.NumBlocks == 0 {
+ sack.Blocks[0] = newSB
+ sack.NumBlocks = 1
+ return
+ }
+ var n = 0
+ for i := 0; i < sack.NumBlocks; i++ {
+ start, end := sack.Blocks[i].Start, sack.Blocks[i].End
+ if end.LessThanEq(start) || start.LessThanEq(rcvNxt) {
+ // Discard any invalid blocks where end is before start
+ // and discard any sack blocks that are before rcvNxt as
+ // those have already been acked.
+ continue
+ }
+ if newSB.Start.LessThanEq(end) && start.LessThanEq(newSB.End) {
+ // Merge this SACK block into newSB and discard this SACK
+ // block.
+ if start.LessThan(newSB.Start) {
+ newSB.Start = start
+ }
+ if newSB.End.LessThan(end) {
+ newSB.End = end
+ }
+ } else {
+ // Save this block.
+ sack.Blocks[n] = sack.Blocks[i]
+ n++
+ }
+ }
+ if rcvNxt.LessThan(newSB.Start) {
+ // If this was an out of order segment then make sure that the
+ // first SACK block is the one that includes the segment.
+ //
+ // See the first bullet point in
+ // https://tools.ietf.org/html/rfc2018#section-4
+ if n == MaxSACKBlocks {
+ // If the number of SACK blocks is equal to
+ // MaxSACKBlocks then discard the last SACK block.
+ n--
+ }
+ for i := n - 1; i >= 0; i-- {
+ sack.Blocks[i+1] = sack.Blocks[i]
+ }
+ sack.Blocks[0] = newSB
+ n++
+ }
+ sack.NumBlocks = n
+}
+
+// TrimSACKBlockList updates the sack block list by removing/modifying any block
+// where start is < rcvNxt.
+func TrimSACKBlockList(sack *SACKInfo, rcvNxt seqnum.Value) {
+ n := 0
+ for i := 0; i < sack.NumBlocks; i++ {
+ if sack.Blocks[i].End.LessThanEq(rcvNxt) {
+ continue
+ }
+ if sack.Blocks[i].Start.LessThan(rcvNxt) {
+ // Shrink this SACK block.
+ sack.Blocks[i].Start = rcvNxt
+ }
+ sack.Blocks[n] = sack.Blocks[i]
+ n++
+ }
+ sack.NumBlocks = n
+}
diff --git a/tcpip/transport/tcp/segment.go b/tcpip/transport/tcp/segment.go
index b573add..be3ff99 100644
--- a/tcpip/transport/tcp/segment.go
+++ b/tcpip/transport/tcp/segment.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package tcp
@@ -26,6 +36,8 @@
// segment represents a TCP segment. It holds the payload and parsed TCP segment
// information, and can be added to intrusive lists.
// segment is mostly immutable, the only field allowed to change is viewToDeliver.
+//
+// +stateify savable
type segment struct {
segmentEntry
refCnt int32
@@ -48,7 +60,7 @@
options []byte
}
-func newSegment(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) *segment {
+func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment {
s := &segment{
refCnt: 1,
id: id,
diff --git a/tcpip/transport/tcp/segment_heap.go b/tcpip/transport/tcp/segment_heap.go
index 137ddbd..e3a3405 100644
--- a/tcpip/transport/tcp/segment_heap.go
+++ b/tcpip/transport/tcp/segment_heap.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package tcp
diff --git a/tcpip/transport/tcp/segment_queue.go b/tcpip/transport/tcp/segment_queue.go
index ecab9a6..ca4c757 100644
--- a/tcpip/transport/tcp/segment_queue.go
+++ b/tcpip/transport/tcp/segment_queue.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package tcp
@@ -11,6 +21,8 @@
)
// segmentQueue is a bounded, thread-safe queue of TCP segments.
+//
+// +stateify savable
type segmentQueue struct {
mu sync.Mutex
list segmentList
diff --git a/tcpip/transport/tcp/snd.go b/tcpip/transport/tcp/snd.go
index b4bf2dc..bbc8ccc 100644
--- a/tcpip/transport/tcp/snd.go
+++ b/tcpip/transport/tcp/snd.go
@@ -1,11 +1,22 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package tcp
import (
"math"
+ "sync"
"time"
"github.com/google/netstack/sleep"
@@ -16,14 +27,41 @@
)
const (
- // minRTO is the minimium allowed value for the retransmit timeout.
+ // minRTO is the minimum allowed value for the retransmit timeout.
minRTO = 200 * time.Millisecond
- // initalCwnd is the initial congestion window.
- initialCwnd = 10
+ // InitialCwnd is the initial congestion window.
+ InitialCwnd = 10
+
+ // nDupAckThreshold is the number of duplicate ACK's required
+ // before fast-retransmit is entered.
+ nDupAckThreshold = 3
)
+// congestionControl is an interface that must be implemented by any supported
+// congestion control algorithm.
+type congestionControl interface {
+ // HandleNDupAcks is invoked when sender.dupAckCount >= nDupAckThreshold
+ // just before entering fast retransmit.
+ HandleNDupAcks()
+
+ // HandleRTOExpired is invoked when the retransmit timer expires.
+ HandleRTOExpired()
+
+ // Update is invoked when processing inbound acks. It's passed the
+ // number of packet's that were acked by the most recent cumulative
+ // acknowledgement.
+ Update(packetsAcked int)
+
+ // PostRecovery is invoked when the sender is exiting a fast retransmit/
+ // recovery phase. This provides congestion control algorithms a way
+ // to adjust their state when exiting recovery.
+ PostRecovery()
+}
+
// sender holds the state necessary to send TCP segments.
+//
+// +stateify savable
type sender struct {
ep *endpoint
@@ -79,11 +117,10 @@
resendTimer timer
resendWaker sleep.Waker
- // srtt, rttvar & rto are the "smoothed round-trip time", "round-trip
- // time variation" and "retransmit timeout", as defined in section 2 of
- // RFC 6298.
- srtt time.Duration
- rttvar time.Duration
+ // rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time",
+ // "round-trip time variation" and "retransmit timeout", as defined in
+ // section 2 of RFC 6298.
+ rtt rtt
rto time.Duration
srttInited bool
@@ -97,9 +134,25 @@
// maxSentAck is the maxium acknowledgement actually sent.
maxSentAck seqnum.Value
+
+ // cc is the congestion control algorithm in use for this sender.
+ cc congestionControl
+}
+
+// rtt is a synchronization wrapper used to appease stateify. See the comment
+// in sender, where it is used.
+//
+// +stateify savable
+type rtt struct {
+ sync.Mutex
+
+ srtt time.Duration
+ rttvar time.Duration
}
// fastRecovery holds information related to fast recovery from a packet loss.
+//
+// +stateify savable
type fastRecovery struct {
// active whether the endpoint is in fast recovery. The following fields
// are only meaningful when active is true.
@@ -120,7 +173,7 @@
func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender {
s := &sender{
ep: ep,
- sndCwnd: initialCwnd,
+ sndCwnd: InitialCwnd,
sndSsthresh: math.MaxInt64,
sndWnd: sndWnd,
sndUna: iss + 1,
@@ -131,50 +184,143 @@
lastSendTime: time.Now(),
maxPayloadSize: int(mss),
maxSentAck: irs + 1,
+ fr: fastRecovery{
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1.
+ last: iss,
+ },
}
+ s.cc = s.initCongestionControl(ep.cc)
+
// A negative sndWndScale means that no scaling is in use, otherwise we
// store the scaling value.
if sndWndScale > 0 {
s.sndWndScale = uint8(sndWndScale)
}
- m := int(ep.route.MTU()) - header.TCPMinimumSize
- // Adjust the maxPayloadsize to account for the timestamp option.
- if ep.sendTSOk {
- m -= header.TCPTimeStampOptionSize
- }
- if m < s.maxPayloadSize {
- s.maxPayloadSize = m
- }
+ s.updateMaxPayloadSize(int(ep.route.MTU()), 0)
s.resendTimer.init(&s.resendWaker)
return s
}
+func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl {
+ switch congestionControlName {
+ case ccCubic:
+ return newCubicCC(s)
+ case ccReno:
+ fallthrough
+ default:
+ return newRenoCC(s)
+ }
+}
+
+// updateMaxPayloadSize updates the maximum payload size based on the given
+// MTU. If this is in response to "packet too big" control packets (indicated
+// by the count argument), it also reduces the number of outstanding packets and
+// attempts to retransmit the first packet above the MTU size.
+func (s *sender) updateMaxPayloadSize(mtu, count int) {
+ m := mtu - header.TCPMinimumSize
+
+ // Calculate the maximum option size.
+ var maxSackBlocks [header.TCPMaxSACKBlocks]header.SACKBlock
+ options := s.ep.makeOptions(maxSackBlocks[:])
+ m -= len(options)
+ putOptions(options)
+
+ // We don't adjust up for now.
+ if m >= s.maxPayloadSize {
+ return
+ }
+
+ // Make sure we can transmit at least one byte.
+ if m <= 0 {
+ m = 1
+ }
+
+ s.maxPayloadSize = m
+
+ s.outstanding -= count
+ if s.outstanding < 0 {
+ s.outstanding = 0
+ }
+
+ // Rewind writeNext to the first segment exceeding the MTU. Do nothing
+ // if it is already before such a packet.
+ for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
+ if seg == s.writeNext {
+ // We got to writeNext before we could find a segment
+ // exceeding the MTU.
+ break
+ }
+
+ if seg.data.Size() > m {
+ // We found a segment exceeding the MTU. Rewind
+ // writeNext and try to retransmit it.
+ s.writeNext = seg
+ break
+ }
+ }
+
+ // Since we likely reduced the number of outstanding packets, we may be
+ // ready to send some more.
+ s.sendData()
+}
+
// sendAck sends an ACK segment.
func (s *sender) sendAck() {
- s.sendSegment(nil, flagAck, s.sndNxt)
+ s.sendSegment(buffer.VectorisedView{}, flagAck, s.sndNxt)
}
// updateRTO updates the retransmit timeout when a new roud-trip time is
// available. This is done in accordance with section 2 of RFC 6298.
func (s *sender) updateRTO(rtt time.Duration) {
+ s.rtt.Lock()
if !s.srttInited {
- s.rttvar = rtt / 2
- s.srtt = rtt
+ s.rtt.rttvar = rtt / 2
+ s.rtt.srtt = rtt
s.srttInited = true
} else {
- diff := s.srtt - rtt
+ diff := s.rtt.srtt - rtt
if diff < 0 {
diff = -diff
}
- s.rttvar = (3*s.rttvar + diff) / 4
- s.srtt = (7*s.srtt + rtt) / 8
+ // Use RFC6298 standard algorithm to update rttvar and srtt when
+ // no timestamps are available.
+ if !s.ep.sendTSOk {
+ s.rtt.rttvar = (3*s.rtt.rttvar + diff) / 4
+ s.rtt.srtt = (7*s.rtt.srtt + rtt) / 8
+ } else {
+ // When we are taking RTT measurements of every ACK then
+ // we need to use a modified method as specified in
+ // https://tools.ietf.org/html/rfc7323#appendix-G
+ if s.outstanding == 0 {
+ s.rtt.Unlock()
+ return
+ }
+ // Netstack measures congestion window/inflight all in
+ // terms of packets and not bytes. This is similar to
+ // how linux also does cwnd and inflight. In practice
+ // this approximation works as expected.
+ expectedSamples := math.Ceil(float64(s.outstanding) / 2)
+
+ // alpha & beta values are the original values as recommended in
+ // https://tools.ietf.org/html/rfc6298#section-2.3.
+ const alpha = 0.125
+ const beta = 0.25
+
+ alphaPrime := alpha / expectedSamples
+ betaPrime := beta / expectedSamples
+ rttVar := (1-betaPrime)*s.rtt.rttvar.Seconds() + betaPrime*diff.Seconds()
+ srtt := (1-alphaPrime)*s.rtt.srtt.Seconds() + alphaPrime*rtt.Seconds()
+ s.rtt.rttvar = time.Duration(rttVar * float64(time.Second))
+ s.rtt.srtt = time.Duration(srtt * float64(time.Second))
+ }
}
- s.rto = s.srtt + 4*s.rttvar
+ s.rto = s.rtt.srtt + 4*s.rtt.rttvar
+ s.rtt.Unlock()
if s.rto < minRTO {
s.rto = minRTO
}
@@ -188,16 +334,7 @@
// Resend the segment.
if seg := s.writeList.Front(); seg != nil {
- s.sendSegment(&seg.data, seg.flags, seg.sequenceNumber)
- }
-}
-
-// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681,
-// page 6, eq. 4. It is called when we detect congestion in the network.
-func (s *sender) reduceSlowStartThreshold() {
- s.sndSsthresh = s.outstanding / 2
- if s.sndSsthresh < 2 {
- s.sndSsthresh = 2
+ s.sendSegment(seg.data, seg.flags, seg.sequenceNumber)
}
}
@@ -222,17 +359,18 @@
s.rto *= 2
if s.fr.active {
- // We were attempting fast recovery but were not successfull.
+ // We were attempting fast recovery but were not successful.
// Leave the state. We don't need to update ssthresh because it
// has already been updated when entered fast-recovery.
s.leaveFastRecovery()
- } else {
- // We lost a packet, so reduce ssthresh.
- s.reduceSlowStartThreshold()
}
- // Reduce the congestion window to 1, i.e., enter slow-start.
- s.sndCwnd = 1
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4.
+ // We store the highest sequence number transmitted in cases where
+ // we were not in fast recovery.
+ s.fr.last = s.sndNxt - 1
+
+ s.cc.HandleRTOExpired()
// Mark the next segment to be sent as the first unacknowledged one and
// start sending again. Set the number of outstanding packets to 0 so
@@ -257,8 +395,8 @@
// transmission if the TCP has not sent data in the interval exceeding
// the retrasmission timeout."
if !s.fr.active && time.Now().Sub(s.lastSendTime) > s.rto {
- if s.sndCwnd > initialCwnd {
- s.sndCwnd = initialCwnd
+ if s.sndCwnd > InitialCwnd {
+ s.sndCwnd = InitialCwnd
}
}
@@ -277,11 +415,17 @@
var segEnd seqnum.Value
if seg.data.Size() == 0 {
- // We're sending a FIN.
+ if s.writeList.Back() != seg {
+ panic("FIN segments must be the final segment in the write list.")
+ }
seg.flags = flagAck | flagFin
segEnd = seg.sequenceNumber.Add(1)
} else {
// We're sending a non-FIN segment.
+ if seg.flags&flagFin != 0 {
+ panic("Netstack queues FIN segments without data.")
+ }
+
if !seg.sequenceNumber.LessThan(end) {
break
}
@@ -304,7 +448,7 @@
segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
}
- s.sendSegment(&seg.data, seg.flags, seg.sequenceNumber)
+ s.sendSegment(seg.data, seg.flags, seg.sequenceNumber)
// Update sndNxt if we actually sent new data (as opposed to
// retransmitting some previously sent data).
@@ -327,26 +471,33 @@
}
func (s *sender) enterFastRecovery() {
+ s.fr.active = true
// Save state to reflect we're now in fast recovery.
- s.reduceSlowStartThreshold()
- s.sndCwnd = s.sndSsthresh
+ // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3.
+ // We inflat the cwnd by 3 to account for the 3 packets which triggered
+ // the 3 duplicate ACKs and are now not in flight.
+ s.sndCwnd = s.sndSsthresh + 3
s.fr.first = s.sndUna
s.fr.last = s.sndNxt - 1
s.fr.maxCwnd = s.sndCwnd + s.outstanding
- s.fr.active = true
}
func (s *sender) leaveFastRecovery() {
s.fr.active = false
+ s.fr.first = 0
+ s.fr.last = s.sndNxt - 1
+ s.fr.maxCwnd = 0
+ s.dupAckCount = 0
- // Deflate cwnd. It had been artifically inflated when new dups arrived.
+ // Deflate cwnd. It had been artificially inflated when new dups arrived.
s.sndCwnd = s.sndSsthresh
+ s.cc.PostRecovery()
}
// checkDuplicateAck is called when an ack is received. It manages the state
// related to duplicate acks and determines if a retransmit is needed according
// to the rules in RFC 6582 (NewReno).
-func (s *sender) checkDuplicateAck(seg *segment) bool {
+func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
ack := seg.ackNumber
if s.fr.active {
// We are in fast recovery mode. Ignore the ack if it's out of
@@ -355,7 +506,7 @@
return false
}
- // Leave fast recovery if it acknowleges all the data covered by
+ // Leave fast recovery if it acknowledges all the data covered by
// this fast recovery session.
if s.fr.last.LessThan(ack) {
s.leaveFastRecovery()
@@ -386,6 +537,7 @@
//
// N.B. The retransmit timer will be reset by the caller.
s.fr.first = ack
+ s.dupAckCount = 0
return true
}
@@ -398,50 +550,33 @@
return false
}
- // Enter fast recovery when we reach 3 dups.
s.dupAckCount++
- if s.dupAckCount != 3 {
+ // Do not enter fast recovery until we reach nDupAckThreshold.
+ if s.dupAckCount < nDupAckThreshold {
return false
}
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2
+ //
+ // We only do the check here, the incrementing of last to the highest
+ // sequence number transmitted till now is done when enterFastRecovery
+ // is invoked.
+ if !s.fr.last.LessThan(seg.ackNumber) {
+ s.dupAckCount = 0
+ return false
+ }
+
+ s.cc.HandleNDupAcks()
s.enterFastRecovery()
s.dupAckCount = 0
return true
}
-// updateCwnd updates the congestion window based on the number of packets that
-// were acknowledged.
-func (s *sender) updateCwnd(packetsAcked int) {
- if s.sndCwnd < s.sndSsthresh {
- // Don't let the congestion window cross into the congestion
- // avoidance range.
- newcwnd := s.sndCwnd + packetsAcked
- if newcwnd >= s.sndSsthresh {
- newcwnd = s.sndSsthresh
- s.sndCAAckCount = 0
- }
-
- packetsAcked -= newcwnd - s.sndCwnd
- s.sndCwnd = newcwnd
- if packetsAcked == 0 {
- // We've consumed all ack'd packets.
- return
- }
- }
-
- // Consume the packets in congestion avoidance mode.
- s.sndCAAckCount += packetsAcked
- if s.sndCAAckCount >= s.sndCwnd {
- s.sndCwnd += s.sndCAAckCount / s.sndCwnd
- s.sndCAAckCount = s.sndCAAckCount % s.sndCwnd
- }
-}
-
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
func (s *sender) handleRcvdSegment(seg *segment) {
// Check if we can extract an RTT measurement from this ack.
- if s.rttMeasureSeqNum.LessThan(seg.ackNumber) {
+ if !s.ep.sendTSOk && s.rttMeasureSeqNum.LessThan(seg.ackNumber) {
s.updateRTO(time.Now().Sub(s.rttMeasureTime))
s.rttMeasureSeqNum = s.sndNxt
}
@@ -458,16 +593,31 @@
// Ignore ack if it doesn't acknowledge any new data.
ack := seg.ackNumber
if (ack - 1).InRange(s.sndUna, s.sndNxt) {
+ s.dupAckCount = 0
// When an ack is received we must reset the timer. We stop it
// here and it will be restarted later if needed.
s.resendTimer.disable()
+ // See : https://tools.ietf.org/html/rfc1323#section-3.3.
+ // Specifically we should only update the RTO using TSEcr if the
+ // following condition holds:
+ //
+ // A TSecr value received in a segment is used to update the
+ // averaged RTT measurement only if the segment acknowledges
+ // some new data, i.e., only if it advances the left edge of
+ // the send window.
+ if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 {
+ // TSVal/Ecr values sent by Netstack are at a millisecond
+ // granularity.
+ elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond
+ s.updateRTO(elapsed)
+ }
// Remove all acknowledged data from the write list.
acked := s.sndUna.Size(ack)
s.sndUna = ack
ackLeft := acked
- originalOutsanding := s.outstanding
+ originalOutstanding := s.outstanding
for ackLeft > 0 {
// We use logicalLen here because we can have FIN
// segments (which are always at the end of list) that
@@ -492,9 +642,11 @@
// Update the send buffer usage and notify potential waiters.
s.ep.updateSndBufferUsage(int(acked))
- // Update the congestion window based on the number of
- // acknowledged packets.
- s.updateCwnd(originalOutsanding - s.outstanding)
+ // If we are not in fast recovery then update the congestion
+ // window based on the number of acknowledged packets.
+ if !s.fr.active {
+ s.cc.Update(originalOutstanding - s.outstanding)
+ }
// It is possible for s.outstanding to drop below zero if we get
// a retransmit timeout, reset outstanding to zero but later
@@ -519,7 +671,7 @@
// sendSegment sends a new segment containing the given payload, flags and
// sequence number.
-func (s *sender) sendSegment(data *buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error {
+func (s *sender) sendSegment(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error {
s.lastSendTime = time.Now()
if seq == s.rttMeasureSeqNum {
s.rttMeasureTime = s.lastSendTime
@@ -530,13 +682,5 @@
// Remember the max sent ack.
s.maxSentAck = rcvNxt
- if data == nil {
- return s.ep.sendRaw(nil, flags, seq, rcvNxt, rcvWnd)
- }
-
- if len(data.Views()) > 1 {
- panic("send path does not support views with multiple buffers")
- }
-
- return s.ep.sendRaw(data.First(), flags, seq, rcvNxt, rcvWnd)
+ return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd)
}
diff --git a/tcpip/transport/tcp/tcp_sack_test.go b/tcpip/transport/tcp/tcp_sack_test.go
new file mode 100644
index 0000000..b1ec006
--- /dev/null
+++ b/tcpip/transport/tcp/tcp_sack_test.go
@@ -0,0 +1,350 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_test
+
+import (
+ "fmt"
+ "reflect"
+ "testing"
+
+ "github.com/google/netstack/tcpip/header"
+ "github.com/google/netstack/tcpip/seqnum"
+ "github.com/google/netstack/tcpip/transport/tcp"
+ "github.com/google/netstack/tcpip/transport/tcp/testing/context"
+)
+
+// createConnectWithSACKPermittedOption creates and connects c.ep with the
+// SACKPermitted option enabled if the stack in the context has the SACK support
+// enabled.
+func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint {
+ return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
+}
+
+func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
+ t.Helper()
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil {
+ t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err)
+ }
+}
+
+// TestSackPermittedConnect establishes a connection with the SACK option
+// enabled.
+func TestSackPermittedConnect(t *testing.T) {
+ for _, sackEnabled := range []bool{false, true} {
+ t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ setStackSACKPermitted(t, c, sackEnabled)
+ rep := createConnectedWithSACKPermittedOption(c)
+ data := []byte{1, 2, 3}
+
+ rep.SendPacket(data, nil)
+ savedSeqNum := rep.NextSeqNum
+ rep.VerifyACKNoSACK()
+
+ // Make an out of order packet and send it.
+ rep.NextSeqNum += 3
+ sackBlocks := []header.SACKBlock{
+ {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))},
+ }
+ rep.SendPacket(data, nil)
+
+ // Restore the saved sequence number so that the
+ // VerifyXXX calls use the right sequence number for
+ // checking ACK numbers.
+ rep.NextSeqNum = savedSeqNum
+ if sackEnabled {
+ rep.VerifyACKHasSACK(sackBlocks)
+ } else {
+ rep.VerifyACKNoSACK()
+ }
+
+ // Send the missing segment.
+ rep.SendPacket(data, nil)
+ // The ACK should contain the cumulative ACK for all 9
+ // bytes sent and no SACK blocks.
+ rep.NextSeqNum += 3
+ // Check that no SACK block is returned in the ACK.
+ rep.VerifyACKNoSACK()
+ })
+ }
+}
+
+// TestSackDisabledConnect establishes a connection with the SACK option
+// disabled and verifies that no SACKs are sent for out of order segments.
+func TestSackDisabledConnect(t *testing.T) {
+ for _, sackEnabled := range []bool{false, true} {
+ t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ setStackSACKPermitted(t, c, sackEnabled)
+
+ rep := c.CreateConnectedWithOptions(header.TCPSynOptions{})
+
+ data := []byte{1, 2, 3}
+
+ rep.SendPacket(data, nil)
+ savedSeqNum := rep.NextSeqNum
+ rep.VerifyACKNoSACK()
+
+ // Make an out of order packet and send it.
+ rep.NextSeqNum += 3
+ rep.SendPacket(data, nil)
+
+ // The ACK should contain the older sequence number and
+ // no SACK blocks.
+ rep.NextSeqNum = savedSeqNum
+ rep.VerifyACKNoSACK()
+
+ // Send the missing segment.
+ rep.SendPacket(data, nil)
+ // The ACK should contain the cumulative ACK for all 9
+ // bytes sent and no SACK blocks.
+ rep.NextSeqNum += 3
+ // Check that no SACK block is returned in the ACK.
+ rep.VerifyACKNoSACK()
+ })
+ }
+}
+
+// TestSackPermittedAccept accepts and establishes a connection with the
+// SACKPermitted option enabled if the connection request specifies the
+// SACKPermitted option. In case of SYN cookies SACK should be disabled as we
+// don't encode the SACK information in the cookie.
+func TestSackPermittedAccept(t *testing.T) {
+ type testCase struct {
+ cookieEnabled bool
+ sackPermitted bool
+ wndScale int
+ wndSize uint16
+ }
+
+ testCases := []testCase{
+ // When cookie is used window scaling is disabled.
+ {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled.
+ {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
+ }
+ savedSynCountThreshold := tcp.SynRcvdCountThreshold
+ defer func() {
+ tcp.SynRcvdCountThreshold = savedSynCountThreshold
+ }()
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
+ if tc.cookieEnabled {
+ tcp.SynRcvdCountThreshold = 0
+ } else {
+ tcp.SynRcvdCountThreshold = savedSynCountThreshold
+ }
+ for _, sackEnabled := range []bool{false, true} {
+ t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ setStackSACKPermitted(t, c, sackEnabled)
+
+ rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
+ // Now verify no SACK blocks are
+ // received when sack is disabled.
+ data := []byte{1, 2, 3}
+ rep.SendPacket(data, nil)
+ rep.VerifyACKNoSACK()
+
+ savedSeqNum := rep.NextSeqNum
+
+ // Make an out of order packet and send
+ // it.
+ rep.NextSeqNum += 3
+ sackBlocks := []header.SACKBlock{
+ {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))},
+ }
+ rep.SendPacket(data, nil)
+
+ // The ACK should contain the older
+ // sequence number.
+ rep.NextSeqNum = savedSeqNum
+ if sackEnabled && tc.sackPermitted {
+ rep.VerifyACKHasSACK(sackBlocks)
+ } else {
+ rep.VerifyACKNoSACK()
+ }
+
+ // Send the missing segment.
+ rep.SendPacket(data, nil)
+ // The ACK should contain the cumulative
+ // ACK for all 9 bytes sent and no SACK
+ // blocks.
+ rep.NextSeqNum += 3
+ // Check that no SACK block is returned
+ // in the ACK.
+ rep.VerifyACKNoSACK()
+ })
+ }
+ })
+ }
+}
+
+// TestSackDisabledAccept accepts and establishes a connection with
+// the SACKPermitted option disabled and verifies that no SACKs are
+// sent for out of order packets.
+func TestSackDisabledAccept(t *testing.T) {
+ type testCase struct {
+ cookieEnabled bool
+ wndScale int
+ wndSize uint16
+ }
+
+ testCases := []testCase{
+ // When cookie is used window scaling is disabled.
+ {true, -1, 0xffff}, // When cookie is used window scaling is disabled.
+ {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
+ }
+ savedSynCountThreshold := tcp.SynRcvdCountThreshold
+ defer func() {
+ tcp.SynRcvdCountThreshold = savedSynCountThreshold
+ }()
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
+ if tc.cookieEnabled {
+ tcp.SynRcvdCountThreshold = 0
+ } else {
+ tcp.SynRcvdCountThreshold = savedSynCountThreshold
+ }
+ for _, sackEnabled := range []bool{false, true} {
+ t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ setStackSACKPermitted(t, c, sackEnabled)
+
+ rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
+
+ // Now verify no SACK blocks are
+ // received when sack is disabled.
+ data := []byte{1, 2, 3}
+ rep.SendPacket(data, nil)
+ rep.VerifyACKNoSACK()
+ savedSeqNum := rep.NextSeqNum
+
+ // Make an out of order packet and send
+ // it.
+ rep.NextSeqNum += 3
+ rep.SendPacket(data, nil)
+
+ // The ACK should contain the older
+ // sequence number and no SACK blocks.
+ rep.NextSeqNum = savedSeqNum
+ rep.VerifyACKNoSACK()
+
+ // Send the missing segment.
+ rep.SendPacket(data, nil)
+ // The ACK should contain the cumulative
+ // ACK for all 9 bytes sent and no SACK
+ // blocks.
+ rep.NextSeqNum += 3
+ // Check that no SACK block is returned
+ // in the ACK.
+ rep.VerifyACKNoSACK()
+ })
+ }
+ })
+ }
+}
+
+func TestUpdateSACKBlocks(t *testing.T) {
+ testCases := []struct {
+ segStart seqnum.Value
+ segEnd seqnum.Value
+ rcvNxt seqnum.Value
+ sackBlocks []header.SACKBlock
+ updated []header.SACKBlock
+ }{
+ // Trivial cases where current SACK block list is empty and we
+ // have an out of order delivery.
+ {10, 11, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 11}}},
+ {10, 12, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 12}}},
+ {10, 20, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 20}}},
+
+ // Cases where current SACK block list is not empty and we have
+ // an out of order delivery. Tests that the updated SACK block
+ // list has the first block as the one that contains the new
+ // SACK block representing the segment that was just delivered.
+ {10, 11, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 11}, {12, 20}}},
+ {24, 30, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{24, 30}, {12, 20}}},
+ {24, 30, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}}},
+
+ // Ensure that we only retain header.MaxSACKBlocks and drop the
+ // oldest one if adding a new block exceeds
+ // header.MaxSACKBlocks.
+ {24, 30, 9,
+ []header.SACKBlock{{12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}, {72, 80}},
+ []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}},
+
+ // Cases where segment extends an existing SACK block.
+ {10, 12, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 20}}},
+ {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}},
+ {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}},
+ {15, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 22}}},
+ {15, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 25}}},
+ {11, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{11, 25}}},
+ {10, 12, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 20}, {32, 40}}},
+ {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}},
+ {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}},
+ {15, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 22}, {32, 40}}},
+ {15, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 25}, {32, 40}}},
+ {11, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{11, 25}, {32, 40}}},
+
+ // Cases where segment contains rcvNxt.
+ {10, 20, 15, []header.SACKBlock{{20, 30}, {40, 50}}, []header.SACKBlock{{40, 50}}},
+ }
+
+ for _, tc := range testCases {
+ var sack tcp.SACKInfo
+ copy(sack.Blocks[:], tc.sackBlocks)
+ sack.NumBlocks = len(tc.sackBlocks)
+ tcp.UpdateSACKBlocks(&sack, tc.segStart, tc.segEnd, tc.rcvNxt)
+ if got, want := sack.Blocks[:sack.NumBlocks], tc.updated; !reflect.DeepEqual(got, want) {
+ t.Errorf("UpdateSACKBlocks(%v, %v, %v, %v), got: %v, want: %v", tc.sackBlocks, tc.segStart, tc.segEnd, tc.rcvNxt, got, want)
+ }
+
+ }
+}
+
+func TestTrimSackBlockList(t *testing.T) {
+ testCases := []struct {
+ rcvNxt seqnum.Value
+ sackBlocks []header.SACKBlock
+ trimmed []header.SACKBlock
+ }{
+ // Simple cases where we trim whole entries.
+ {2, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}},
+ {21, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{22, 30}, {32, 40}}},
+ {31, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{32, 40}}},
+ {40, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}},
+ // Cases where we need to update a block.
+ {12, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{12, 20}, {22, 30}, {32, 40}}},
+ {23, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{23, 30}, {32, 40}}},
+ {33, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{33, 40}}},
+ {41, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}},
+ }
+ for _, tc := range testCases {
+ var sack tcp.SACKInfo
+ copy(sack.Blocks[:], tc.sackBlocks)
+ sack.NumBlocks = len(tc.sackBlocks)
+ tcp.TrimSACKBlockList(&sack, tc.rcvNxt)
+ if got, want := sack.Blocks[:sack.NumBlocks], tc.trimmed; !reflect.DeepEqual(got, want) {
+ t.Errorf("TrimSackBlockList(%v, %v), got: %v, want: %v", tc.sackBlocks, tc.rcvNxt, got, want)
+ }
+ }
+}
diff --git a/tcpip/transport/tcp/tcp_segment_list.go b/tcpip/transport/tcp/tcp_segment_list.go
index 0bbd21b..22491b0 100644
--- a/tcpip/transport/tcp/tcp_segment_list.go
+++ b/tcpip/transport/tcp/tcp_segment_list.go
@@ -1,8 +1,3 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// Package ilist provides the implementation of intrusive linked lists.
package tcp
// List is an intrusive list. Entries can be added to or removed from the list
diff --git a/tcpip/transport/tcp/tcp_test.go b/tcpip/transport/tcp/tcp_test.go
index b2c27a9..d08af9d 100644
--- a/tcpip/transport/tcp/tcp_test.go
+++ b/tcpip/transport/tcp/tcp_test.go
@@ -1,11 +1,23 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package tcp_test
import (
"bytes"
+ "fmt"
+ "math"
"testing"
"time"
@@ -13,8 +25,13 @@
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/checker"
"github.com/google/netstack/tcpip/header"
+ "github.com/google/netstack/tcpip/link/loopback"
+ "github.com/google/netstack/tcpip/link/sniffer"
"github.com/google/netstack/tcpip/network/ipv4"
+ "github.com/google/netstack/tcpip/network/ipv6"
+ "github.com/google/netstack/tcpip/ports"
"github.com/google/netstack/tcpip/seqnum"
+ "github.com/google/netstack/tcpip/stack"
"github.com/google/netstack/tcpip/transport/tcp"
"github.com/google/netstack/tcpip/transport/tcp/testing/context"
"github.com/google/netstack/waiter"
@@ -38,7 +55,7 @@
var wq waiter.Queue
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NeEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %v", err)
}
// Register for notification, then start connection attempt.
@@ -46,9 +63,8 @@
wq.EventRegister(&waitEntry, waiter.EventOut)
defer wq.EventUnregister(&waitEntry)
- err = ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
}
// Close the connection, wait for completion.
@@ -56,19 +72,21 @@
// Wait for ep to become writable.
<-notifyCh
- err = ep.GetSockOpt(tcpip.ErrorOption{})
+ if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted {
+ t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted)
+ }
}
func TestConnectIncrementActiveConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- stats := c.Stack().MutableStats()
- expected := stats.TCP.ActiveConnectionOpenings + 1
+ stats := c.Stack().Stats()
+ want := stats.TCP.ActiveConnectionOpenings.Value() + 1
c.CreateConnected(789, 30000, nil)
- if actual := stats.TCP.ActiveConnectionOpenings; actual != expected {
- t.Errorf("Expected ActiveConnectionOpenings to be %d, got %d", expected, actual)
+ if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
+ t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want)
}
}
@@ -76,12 +94,12 @@
c := context.New(t, defaultMTU)
defer c.Cleanup()
- stats := c.Stack().MutableStats()
- expected := stats.TCP.FailedConnectionAttempts
+ stats := c.Stack().Stats()
+ want := stats.TCP.FailedConnectionAttempts.Value()
c.CreateConnected(789, 30000, nil)
- if actual := stats.TCP.FailedConnectionAttempts; actual != expected {
- t.Errorf("Expected FailedConnectionAttempts to be %d, got %d", expected, actual)
+ if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
+ t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want)
}
}
@@ -89,18 +107,20 @@
c := context.New(t, defaultMTU)
defer c.Cleanup()
- stats := c.Stack().MutableStats()
+ stats := c.Stack().Stats()
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
c.EP = ep
- expected := stats.TCP.FailedConnectionAttempts + 1
+ want := stats.TCP.FailedConnectionAttempts.Value() + 1
- err = c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort})
- if err != tcpip.ErrNoRoute {
- t.Errorf("Expected call to Connect after Bind to result in ErrNoRoute: %v", err)
+ if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute {
+ t.Errorf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute)
}
- if actual := stats.TCP.FailedConnectionAttempts; actual != expected {
- t.Fatalf("Expected FailedConnectionAttempts to be %d, got %d", expected, actual)
+ if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
}
}
@@ -108,21 +128,22 @@
c := context.New(t, defaultMTU)
defer c.Cleanup()
- stats := c.Stack().MutableStats()
- expected := stats.TCP.PassiveConnectionOpenings + 1
+ stats := c.Stack().Stats()
+ want := stats.TCP.PassiveConnectionOpenings.Value() + 1
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
-
- err = ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}, nil)
if err != nil {
- t.Fatalf("Err in Bind: %v", err)
- }
- err = ep.Listen(1)
- if err != nil {
- t.Fatalf("Err in Listen: %v", err)
+ t.Fatalf("NewEndpoint failed: %v", err)
}
- if actual := stats.TCP.PassiveConnectionOpenings; actual != expected {
- t.Fatalf("Expected PassiveConnectionOpenings to be %d, got %d", expected, actual)
+ if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+ if err := ep.Listen(1); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want {
+ t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %v, want = %v", got, want)
}
}
@@ -130,18 +151,20 @@
c := context.New(t, defaultMTU)
defer c.Cleanup()
- stats := c.Stack().MutableStats()
+ stats := c.Stack().Stats()
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
c.EP = ep
- expected := stats.TCP.FailedConnectionAttempts + 1
+ want := stats.TCP.FailedConnectionAttempts.Value() + 1
- err = ep.Listen(1)
- if err != tcpip.ErrInvalidEndpointState {
- t.Fatalf("Expected call to Listen without Bind to result in ErrInvalidEndpointState: %v", err)
+ if err := ep.Listen(1); err != tcpip.ErrInvalidEndpointState {
+ t.Errorf("got ep.Listen(1) = %v, want = %v", err, tcpip.ErrInvalidEndpointState)
}
- if actual := stats.TCP.FailedConnectionAttempts; actual != expected {
- t.Fatalf("Expected FailedConnectionAttempts to be %d, got %d", expected, actual)
+ if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
}
}
@@ -149,26 +172,26 @@
c := context.New(t, defaultMTU)
defer c.Cleanup()
- stats := c.Stack().MutableStats()
+ stats := c.Stack().Stats()
// SYN and ACK
- expected := stats.TCP.SegmentsSent + 2
+ want := stats.TCP.SegmentsSent.Value() + 2
c.CreateConnected(789, 30000, nil)
- if actual := stats.TCP.SegmentsSent; actual != expected {
- t.Fatalf("Expected SegmentsSent to be %d, got %d", expected, actual)
+ if got := stats.TCP.SegmentsSent.Value(); got != want {
+ t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want)
}
}
func TestTCPResetsSentIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- stats := c.Stack().MutableStats()
+ stats := c.Stack().Stats()
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
- expected := stats.TCP.SegmentsSent + 1
+ want := stats.TCP.SegmentsSent.Value() + 1
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
t.Fatalf("Bind failed: %v", err)
@@ -206,8 +229,8 @@
c.SendPacket(nil, ackHeaders)
c.GetPacket()
- if actual := stats.TCP.ResetsSent; actual != expected {
- t.Fatalf("Expected SegmentsSent to be %d, got %d", expected, actual)
+ if got := stats.TCP.ResetsSent.Value(); got != want {
+ t.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want)
}
}
@@ -215,8 +238,8 @@
c := context.New(t, defaultMTU)
defer c.Cleanup()
- stats := c.Stack().MutableStats()
- expected := stats.TCP.ResetsReceived + 1
+ stats := c.Stack().Stats()
+ want := stats.TCP.ResetsReceived.Value() + 1
ackNum := seqnum.Value(789)
rcvWnd := seqnum.Size(30000)
c.CreateConnected(ackNum, rcvWnd, nil)
@@ -230,8 +253,8 @@
Flags: header.TCPFlagRst,
})
- if actual := stats.TCP.ResetsReceived; actual != expected {
- t.Fatalf("Expected ResetsReceived to be %d, got %d", expected, actual)
+ if got := stats.TCP.ResetsReceived.Value(); got != want {
+ t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want)
}
}
@@ -318,8 +341,8 @@
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("Unexpected error from Read: %v", err)
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
}
data := []byte{1, 2, 3}
@@ -340,13 +363,13 @@
}
// Receive data.
- v, err := c.EP.Read(nil)
+ v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Unexpected error from Read: %v", err)
+ t.Fatalf("Read failed: %v", err)
}
- if bytes.Compare(data, v) != 0 {
- t.Fatalf("Data is different: expected %v, got %v", data, v)
+ if !bytes.Equal(data, v) {
+ t.Fatalf("got data = %v, want = %v", v, data)
}
// Check that ACK is received.
@@ -370,8 +393,8 @@
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("Unexpected error from Read: %v", err)
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
}
// Send second half of data first, with seqnum 3 ahead of expected.
@@ -397,8 +420,8 @@
// Wait 200ms and check that no data has been received.
time.Sleep(200 * time.Millisecond)
- if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("Unexpected error from Read: %v", err)
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
}
// Send the first 3 bytes now.
@@ -414,7 +437,7 @@
// Receive data.
read := make([]byte, 0, 6)
for len(read) < len(data) {
- v, err := c.EP.Read(nil)
+ v, _, err := c.EP.Read(nil)
if err != nil {
if err == tcpip.ErrWouldBlock {
// Wait for receive to be notified.
@@ -425,15 +448,15 @@
}
continue
}
- t.Fatalf("Unexpected error from Read: %v", err)
+ t.Fatalf("Read failed: %v", err)
}
read = append(read, v...)
}
// Check that we received the data in proper order.
- if bytes.Compare(data, read) != 0 {
- t.Fatalf("Data is different: expected %v, got %v", data, read)
+ if !bytes.Equal(data, read) {
+ t.Fatalf("got data = %v, want = %v", read, data)
}
// Check that the whole data is acknowledged.
@@ -455,8 +478,8 @@
opt := tcpip.ReceiveBufferSizeOption(10)
c.CreateConnected(789, 30000, &opt)
- if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("Unexpected error from Read: %v", err)
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
}
// Send 100 packets before the actual one that is expected.
@@ -522,6 +545,148 @@
)
}
+func TestRstOnCloseWithUnreadData(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ }
+
+ data := []byte{1, 2, 3}
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(3 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // Check that ACK is received, this happens regardless of the read.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Now that we know we have unread data, let's just close the connection
+ // and verify that netstack sends an RST rather than a FIN.
+ c.EP.Close()
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ // We shouldn't consume a sequence number on RST.
+ checker.SeqNum(uint32(c.IRS)+1),
+ ))
+
+ // This final should be ignored because an ACK on a reset doesn't
+ // mean anything.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + len(data)),
+ AckNum: c.IRS.Add(seqnum.Size(2)),
+ RcvWnd: 30000,
+ })
+}
+
+func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ }
+
+ data := []byte{1, 2, 3}
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(3 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // Check that ACK is received, this happens regardless of the read.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Cause a FIN to be generated.
+ c.EP.Shutdown(tcpip.ShutdownWrite)
+
+ // Make sure we get the FIN but DON't ACK IT.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ checker.SeqNum(uint32(c.IRS)+1),
+ ))
+
+ // Cause a RST to be generated by closing the read end now since we have
+ // unread data.
+ c.EP.Shutdown(tcpip.ShutdownRead)
+
+ // Make sure we get the RST
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ // We shouldn't consume a sequence number on RST.
+ checker.SeqNum(uint32(c.IRS)+1),
+ ))
+
+ // The ACK to the FIN should now be rejected since the connection has been
+ // closed by a RST.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + len(data)),
+ AckNum: c.IRS.Add(seqnum.Size(2)),
+ RcvWnd: 30000,
+ })
+}
+
func TestFullWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -533,9 +698,9 @@
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- _, err := c.EP.Read(nil)
+ _, _, err := c.EP.Read(nil)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("Unexpected error from Read: %v", err)
+ t.Fatalf("Read failed: %v", err)
}
// Fill up the window.
@@ -568,13 +733,13 @@
)
// Receive data and check it.
- v, err := c.EP.Read(nil)
+ v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Unexpected error from Read: %v", err)
+ t.Fatalf("Read failed: %v", err)
}
- if bytes.Compare(data, v) != 0 {
- t.Fatalf("Data is different: expected %v, got %v", data, v)
+ if !bytes.Equal(data, v) {
+ t.Fatalf("got data = %v, want = %v", v, data)
}
// Check that we get an ACK for the newly non-zero window.
@@ -606,9 +771,8 @@
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- _, err := c.EP.Read(nil)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("Unexpected error from Read: %v", err)
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
}
// Send 3 bytes, check that the peer acknowledges them.
@@ -670,16 +834,16 @@
// Receive data and check it.
read := make([]byte, 0, 10)
for len(read) < len(data) {
- v, err := c.EP.Read(nil)
+ v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Unexpected error from Read: %v", err)
+ t.Fatalf("Read failed: %v", err)
}
read = append(read, v...)
}
- if bytes.Compare(data, read) != 0 {
- t.Fatalf("Data is different: expected %v, got %v", data, read)
+ if !bytes.Equal(data, read) {
+ t.Fatalf("got data = %v, want = %v", read, data)
}
// Check that we get an ACK for the newly non-zero window, which is the
@@ -705,8 +869,8 @@
view := buffer.NewView(len(data))
copy(view, data)
- if _, err := c.EP.Write(view, nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
// Check that data is received.
@@ -721,8 +885,8 @@
),
)
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(data, p) != 0 {
- t.Fatalf("Data is different: expected %v, got %v", data, p)
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+ t.Fatalf("got data = %v, want = %v", p, data)
}
// Acknowledge the data.
@@ -746,9 +910,9 @@
view := buffer.NewView(len(data))
copy(view, data)
- _, err := c.EP.Write(view, nil)
+ _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
if err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ t.Fatalf("Write failed: %v", err)
}
// Since the window is currently zero, check that no packet is received.
@@ -776,8 +940,8 @@
),
)
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(data, p) != 0 {
- t.Fatalf("Data is different: expected %v, got %v", data, p)
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+ t.Fatalf("got data = %v, want = %v", p, data)
}
// Acknowledge the data.
@@ -807,8 +971,8 @@
view := buffer.NewView(len(data))
copy(view, data)
- if _, err := c.EP.Write(view, nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
// Check that data is received, and that advertised window is 0xbfff,
@@ -840,8 +1004,8 @@
view := buffer.NewView(len(data))
copy(view, data)
- if _, err := c.EP.Write(view, nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
// Check that data is received, and that advertised window is 0xffff,
@@ -913,8 +1077,8 @@
view := buffer.NewView(len(data))
copy(view, data)
- if _, err := c.EP.Write(view, nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
// Check that data is received, and that advertised window is 0xbfff,
@@ -986,8 +1150,8 @@
view := buffer.NewView(len(data))
copy(view, data)
- if _, err := c.EP.Write(view, nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
// Check that data is received, and that advertised window is 0xffff,
@@ -1073,9 +1237,9 @@
}
// Read some data. An ack should be sent in response to that.
- v, err := c.EP.Read(nil)
+ v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Unexpected error from Read: %v", err)
+ t.Fatalf("Read failed: %v", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -1101,8 +1265,8 @@
view := buffer.NewView(len(data))
copy(view, data)
- if _, err := c.EP.Write(view, nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
// Check that data is received in chunks.
@@ -1123,8 +1287,8 @@
)
pdata := data[bytesReceived : bytesReceived+payloadLen]
- if p := tcp.Payload(); bytes.Compare(pdata, p) != 0 {
- t.Fatalf("Data is different: expected %v, got %v", pdata, p)
+ if p := tcp.Payload(); !bytes.Equal(pdata, p) {
+ t.Fatalf("got data = %v, want = %v", p, pdata)
}
bytesReceived += payloadLen
var options []byte
@@ -1132,8 +1296,9 @@
// If timestamp option is enabled, echo back the timestamp and increment
// the TSEcr value included in the packet and send that back as the TSVal.
parsedOpts := tcp.ParsedOptions()
- tsOpt := header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal)
- options = append(options, tsOpt[:]...)
+ tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:])
+ options = tsOpt[:]
}
// Acknowledge the data.
c.SendPacket(nil, &context.Headers{
@@ -1342,9 +1507,8 @@
c.WQ.EventRegister(&we, waiter.EventOut)
defer c.WQ.EventUnregister(&we)
- err = c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
}
// Receive SYN packet.
@@ -1397,9 +1561,8 @@
// Wait for connection to be established.
select {
case <-ch:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
- t.Fatalf("Unexpected error when connecting: %v", err)
+ if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
+ t.Fatalf("GetSockOpt failed: %v", err)
}
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for connection")
@@ -1455,9 +1618,7 @@
loop:
for {
- switch _, err := c.EP.Read(nil); err {
- case nil:
- t.Fatalf("Unexpected success.")
+ switch _, _, err := c.EP.Read(nil); err {
case tcpip.ErrWouldBlock:
select {
case <-ch:
@@ -1467,7 +1628,7 @@
case tcpip.ErrConnectionReset:
break loop
default:
- t.Fatalf("Unexpected error: want %v, got %v", tcpip.ErrConnectionReset, err)
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
}
}
}
@@ -1492,9 +1653,8 @@
// Try to write.
view := buffer.NewView(10)
- _, err := c.EP.Write(view, nil)
- if err != tcpip.ErrConnectionReset {
- t.Fatalf("Unexpected error from Write: want %v, got %v", tcpip.ErrConnectionReset, err)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset {
+ t.Fatalf("got c.EP.Write(...) = %v, want = %v", err, tcpip.ErrConnectionReset)
}
}
@@ -1506,7 +1666,7 @@
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Unexpected error from Shutdown: %v", err)
+ t.Fatalf("Shutdown failed: %v", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -1549,7 +1709,7 @@
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Unexpected error from Shutdown: %v", err)
+ t.Fatalf("Shutdown failed: %v", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -1603,8 +1763,8 @@
// Write something out, and have it acknowledged.
view := buffer.NewView(10)
- if _, err := c.EP.Write(view, nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
next := uint32(c.IRS) + 1
@@ -1630,7 +1790,7 @@
// Shutdown, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Unexpected error from Shutdown: %v", err)
+ t.Fatalf("Shutdown failed: %v", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -1672,36 +1832,41 @@
c.CreateConnected(789, 30000, nil)
- // Write something out but don't ACK it yet.
+ // Write enough segments to fill the congestion window before ACK'ing
+ // any of them.
view := buffer.NewView(10)
- if _, err := c.EP.Write(view, nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ for i := tcp.InitialCwnd; i > 0; i-- {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
}
next := uint32(c.IRS) + 1
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
- next += uint32(len(view))
+ for i := tcp.InitialCwnd; i > 0; i-- {
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ next += uint32(len(view))
+ }
// Shutdown the connection, check that the FIN segment isn't sent
// because the congestion window doesn't allow it. Wait until a
// retransmit is received.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Unexpected error from Shutdown: %v", err)
+ t.Fatalf("Shutdown failed: %v", err)
}
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next-uint32(len(view))),
+ checker.SeqNum(uint32(c.IRS)+1),
checker.AckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
@@ -1757,8 +1922,8 @@
// Write something out, and acknowledge it to get cwnd to 2.
view := buffer.NewView(10)
- if _, err := c.EP.Write(view, nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
next := uint32(c.IRS) + 1
@@ -1783,8 +1948,8 @@
})
// Write new data, but don't acknowledge it.
- if _, err := c.EP.Write(view, nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -1800,7 +1965,7 @@
// Shutdown the connection, check that we do get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Unexpected error from Shutdown: %v", err)
+ t.Fatalf("Shutdown failed: %v", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -1844,8 +2009,8 @@
// Write something out, and acknowledge it to get cwnd to 2. Also send
// FIN from the test side.
view := buffer.NewView(10)
- if _, err := c.EP.Write(view, nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
next := uint32(c.IRS) + 1
@@ -1881,8 +2046,8 @@
)
// Write new data, but don't acknowledge it.
- if _, err := c.EP.Write(view, nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -1898,7 +2063,7 @@
// Shutdown the connection, check that we do get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Unexpected error from Shutdown: %v", err)
+ t.Fatalf("Shutdown failed: %v", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -1944,18 +2109,18 @@
c.CreateConnected(789, 30000, nil)
const iterations = 7
- data := buffer.NewView(maxPayload * (1 << (iterations + 1)))
+ data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(data, nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
- expected := 1
+ expected := tcp.InitialCwnd
bytesRead := 0
for i := 0; i < iterations; i++ {
// Read all packets expected on this iteration. Don't
@@ -1986,22 +2151,22 @@
c.CreateConnected(789, 30000, nil)
const iterations = 7
- data := buffer.NewView(2 * maxPayload * (1 << (iterations + 1)))
+ data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(data, nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
// Do slow start for a few iterations.
- expected := 1
+ expected := tcp.InitialCwnd
bytesRead := 0
for i := 0; i < iterations; i++ {
- expected = 1 << uint(i)
+ expected = tcp.InitialCwnd << uint(i)
if i > 0 {
// Acknowledge all the data received so far if not on
// first iteration.
@@ -2018,7 +2183,7 @@
// Check we don't receive any more packets on this iteration.
// The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond)
}
// Don't acknowledge the first packet of the last packet train. Let's
@@ -2056,7 +2221,7 @@
// Check we don't receive any more packets on this iteration.
// The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond)
// Acknowledge all the data received so far.
c.SendAck(790, bytesRead)
@@ -2067,6 +2232,130 @@
}
}
+// cubicCwnd returns an estimate of a cubic window given the
+// originalCwnd, wMax, last congestion event time and sRTT.
+func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int {
+ cwnd := float64(origCwnd)
+ // We wait 50ms between each iteration so sRTT as computed by cubic
+ // should be close to 50ms.
+ elapsed := (time.Since(congEventTime) + sRTT).Seconds()
+ k := math.Cbrt(float64(wMax) * 0.3 / 0.7)
+ wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax)
+ cwnd += (wtRTT - cwnd) / cwnd
+ return int(cwnd)
+}
+
+func TestCubicCongestionAvoidance(t *testing.T) {
+ maxPayload := 10
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ enableCUBIC(t, c)
+
+ c.CreateConnected(789, 30000, nil)
+
+ const iterations = 7
+ data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in one shot. Packets will only be written at the
+ // MTU size though.
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ // Do slow start for a few iterations.
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ expected = tcp.InitialCwnd << uint(i)
+ if i > 0 {
+ // Acknowledge all the data received so far if not on
+ // first iteration.
+ c.SendAck(790, bytesRead)
+ }
+
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond)
+ }
+
+ // Don't acknowledge the first packet of the last packet train. Let's
+ // wait for them to time out, which will trigger a restart of slow
+ // start, and initialization of ssthresh to cwnd * 0.7.
+ rtxOffset := bytesRead - maxPayload*expected
+ c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
+
+ // Acknowledge all pending data.
+ c.SendAck(790, bytesRead)
+
+ // Store away the time we sent the ACK and assuming a 200ms RTO
+ // we estimate that the sender will have an RTO 200ms from now
+ // and go back into slow start.
+ packetDropTime := time.Now().Add(200 * time.Millisecond)
+
+ // This part is tricky: when the timeout happened, we had "expected"
+ // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7.
+ // By acknowledging "expected" packets, the slow-start part will
+ // increase cwnd to expected/2 essentially putting the connection
+ // straight into congestion avoidance.
+ wMax := expected
+ // Lower expected as per cubic spec after a congestion event.
+ expected = int(float64(expected) * 0.7)
+ cwnd := expected
+ for i := 0; i < iterations; i++ {
+ // Cubic grows window independent of ACKs. Cubic Window growth
+ // is a function of time elapsed since last congestion event.
+ // As a result the congestion window does not grow
+ // deterministically in response to ACKs.
+ //
+ // We need to roughly estimate what the cwnd of the sender is
+ // based on when we sent the dupacks.
+ cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond)
+
+ packetsExpected := cwnd
+ for j := 0; j < packetsExpected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+ t.Logf("expected packets received, next trying to receive any extra packets that may come")
+
+ // If our estimate was correct there should be no more pending packets.
+ // We attempt to read a packet a few times with a short sleep in between
+ // to ensure that we don't see the sender send any unexpected packets.
+ unexpectedPackets := 0
+ for {
+ gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload)
+ if !gotPacket {
+ break
+ }
+ bytesRead += maxPayload
+ unexpectedPackets++
+ time.Sleep(1 * time.Millisecond)
+ }
+ if unexpectedPackets != 0 {
+ t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i)
+ }
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+ }
+}
+
func DisabledTestFastRecovery(t *testing.T) {
maxPayload := 10
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
@@ -2075,22 +2364,22 @@
c.CreateConnected(789, 30000, nil)
const iterations = 7
- data := buffer.NewView(2 * maxPayload * (1 << (iterations + 1)))
+ data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(data, nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
// Do slow start for a few iterations.
- expected := 1
+ expected := tcp.InitialCwnd
bytesRead := 0
for i := 0; i < iterations; i++ {
- expected = 1 << uint(i)
+ expected = tcp.InitialCwnd << uint(i)
if i > 0 {
// Acknowledge all the data received so far if not on
// first iteration.
@@ -2110,16 +2399,28 @@
c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
}
- // Send 10 duplicate acks. This should force an immediate retransmit of
- // the pending packet, and inflation of cwnd to expected/2+7.
+ // Send 3 duplicate acks. This should force an immediate retransmit of
+ // the pending packet and put the sender into fast recovery.
rtxOffset := bytesRead - maxPayload*expected
- for i := 0; i < 10; i++ {
+ for i := 0; i < 3; i++ {
c.SendAck(790, rtxOffset)
}
// Receive the retransmitted packet.
c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
+ // Now send 7 mode duplicate acks. Each of these should cause a window
+ // inflation by 1 and cause the sender to send an extra packet.
+ for i := 0; i < 7; i++ {
+ c.SendAck(790, rtxOffset)
+ }
+
+ recover := bytesRead
+
+ // Ensure no new packets arrive.
+ c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.",
+ 50*time.Millisecond)
+
// Acknowledge half of the pending data.
rtxOffset = bytesRead - expected*maxPayload/2
c.SendAck(790, rtxOffset)
@@ -2127,24 +2428,37 @@
// Receive the retransmit due to partial ack.
c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
- // This part is tricky: when the retransmit happened, we had "expected"
- // packets pending, cwnd reset to expected/2, and ssthresh set to
- // expected/2. By acknowledging expected/2 packets, 7 new packets are
- // allowed to be sent immediately.
- for j := 0; j < 7; j++ {
+ // Receive the 10 extra packets that should have been released due to
+ // the congestion window inflation in recovery.
+ for i := 0; i < 10; i++ {
c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
bytesRead += maxPayload
}
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+ // A partial ACK during recovery should reduce congestion window by the
+ // number acked. Since we had "expected" packets outstanding before sending
+ // partial ack and we acked expected/2 , the cwnd and outstanding should
+ // be expected/2 + 7. Which means the sender should not send any more packets
+ // till we ack this one.
+ c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.",
+ 50*time.Millisecond)
- // Acknowledge all pending data.
- c.SendAck(790, bytesRead)
+ // Acknowledge all pending data to recover point.
+ c.SendAck(790, recover)
- // Now the inflation is removed, so cwnd is expected/2. But since we've
- // received expected+7 packets since cwnd changed, it must now be set
- // expected/2 + 2, given that floor((expected+7)/(expected/2)) == 2.
- expected = expected/2 + 2
+ // At this point, the cwnd should reset to expected/2 and there are 10
+ // packets outstanding.
+ //
+ // NOTE: Technically netstack is incorrect in that we adjust the cwnd on
+ // the same segment that takes us out of recovery. But because of that
+ // the actual cwnd at exit of recovery will be expected/2 + 1 as we
+ // acked a cwnd worth of packets which will increase the cwnd further by
+ // 1 in congestion avoidance.
+ //
+ // Now in the first iteration since there are 10 packets outstanding.
+ // We would expect to get expected/2 +1 - 10 packets. But subsequent
+ // iterations will send us expected/2 + 1 + 1 (per iteration).
+ expected = expected/2 + 1 - 10
for i := 0; i < iterations; i++ {
// Read all packets expected on this iteration. Don't
// acknowledge any of them just yet, so that we can measure the
@@ -2156,13 +2470,19 @@
// Check we don't receive any more packets on this iteration.
// The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+ c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond)
// Acknowledge all the data received so far.
c.SendAck(790, bytesRead)
// In cogestion avoidance, the packets trains increase by 1 in
// each iteration.
+ if i == 0 {
+ // After the first iteration we expect to get the full
+ // congestion window worth of packets in every
+ // iteration.
+ expected += 10
+ }
expected++
}
}
@@ -2175,25 +2495,27 @@
c.CreateConnected(789, 30000, nil)
const iterations = 7
- data := buffer.NewView(maxPayload * (1 << (iterations + 1)))
+ data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in two shots. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(data[:len(data)/2], nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ half := data[:len(data)/2]
+ if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
- if _, err := c.EP.Write(data[len(data)/2:], nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ half = data[len(data)/2:]
+ if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
// Do slow start for a few iterations.
- expected := 1
+ expected := tcp.InitialCwnd
bytesRead := 0
for i := 0; i < iterations; i++ {
- expected = 1 << uint(i)
+ expected = tcp.InitialCwnd << uint(i)
if i > 0 {
// Acknowledge all the data received so far if not on
// first iteration.
@@ -2221,7 +2543,7 @@
rtxOffset = bytesRead - expected*maxPayload/2
c.SendAck(790, rtxOffset)
- // Receive the remaining data, making sure that acknowledge data is not
+ // Receive the remaining data, making sure that acknowledged data is not
// retransmitted.
for offset := rtxOffset; offset < len(data); offset += maxPayload {
c.ReceiveAndCheckPacket(data, offset, maxPayload)
@@ -2283,8 +2605,8 @@
// Send some data. Check that it's capped by the window size.
view := buffer.NewView(65535)
- if _, err := c.EP.Write(view, nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
// Check that only data that fits in the scaled window is sent.
@@ -2317,8 +2639,8 @@
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateConnected(789, 30000, nil)
- stats := c.Stack().MutableStats()
- expected := stats.TCP.ValidSegmentsReceived + 1
+ stats := c.Stack().Stats()
+ want := stats.TCP.ValidSegmentsReceived.Value() + 1
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -2329,8 +2651,8 @@
RcvWnd: 30000,
})
- if actual := stats.TCP.ValidSegmentsReceived; actual != expected {
- t.Fatalf("Expected ValidSegmentsReceived to be %d, got %d", expected, actual)
+ if got := stats.TCP.ValidSegmentsReceived.Value(); got != want {
+ t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want)
}
}
@@ -2338,9 +2660,9 @@
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateConnected(789, 30000, nil)
- stats := c.Stack().MutableStats()
- expected := stats.TCP.InvalidSegmentsReceived + 1
- seg := c.BuildSegment(nil, &context.Headers{
+ stats := c.Stack().Stats()
+ want := stats.TCP.InvalidSegmentsReceived.Value() + 1
+ vv := c.BuildSegment(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
@@ -2348,15 +2670,14 @@
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
- vv := &seg
- tcpbuf := vv.ByteSlice()[0][header.IPv4MinimumSize:]
- // 12 is the TCP header data offset
- tcpbuf[12] = header.TCPMinimumSize - 1
+ tcpbuf := vv.First()[header.IPv4MinimumSize:]
+ // 12 is the TCP header data offset.
+ tcpbuf[12] = ((header.TCPMinimumSize - 1) / 4) << 4
c.SendSegment(vv)
- if actual := stats.TCP.InvalidSegmentsReceived; actual != expected {
- t.Fatalf("Expected InvalidSegmentsReceived to be %d, got %d", expected, actual)
+ if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want {
+ t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want)
}
}
@@ -2420,13 +2741,13 @@
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("Unexpected error from Read: %v", err)
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
}
// Shutdown immediately for write, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Unexpected error from Shutdown: %v", err)
+ t.Fatalf("Shutdown failed: %v", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -2472,34 +2793,34 @@
// Check that peek works.
peekBuf := make([]byte, 10)
- n, err := c.EP.Peek([][]byte{peekBuf})
+ n, _, err := c.EP.Peek([][]byte{peekBuf})
if err != nil {
- t.Fatalf("Unexpected error from Peek: %v", err)
+ t.Fatalf("Peek failed: %v", err)
}
peekBuf = peekBuf[:n]
- if bytes.Compare(data, peekBuf) != 0 {
- t.Fatalf("Data is different: expected %v, got %v", data, peekBuf)
+ if !bytes.Equal(data, peekBuf) {
+ t.Fatalf("got data = %v, want = %v", peekBuf, data)
}
// Receive data.
- v, err := c.EP.Read(nil)
+ v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Unexpected error from Read: %v", err)
+ t.Fatalf("Read failed: %v", err)
}
- if bytes.Compare(data, v) != 0 {
- t.Fatalf("Data is different: expected %v, got %v", data, v)
+ if !bytes.Equal(data, v) {
+ t.Fatalf("got data = %v, want = %v", v, data)
}
// Now that we drained the queue, check that functions fail with the
// right error code.
- if _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("Unexpected return from Read: got %v, want %v", err, tcpip.ErrClosedForReceive)
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive)
}
- if _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("Unexpected return from Peek: got %v, want %v", err, tcpip.ErrClosedForReceive)
+ if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
+ t.Fatalf("got c.EP.Peek(...) = %v, want = %v", err, tcpip.ErrClosedForReceive)
}
}
@@ -2537,9 +2858,8 @@
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
t.Fatalf("Bind failed: %v", err)
}
- err = c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
}
c.EP.Close()
@@ -2560,8 +2880,7 @@
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
t.Fatalf("Bind failed: %v", err)
}
- err = c.EP.Listen(10)
- if err != nil {
+ if err := c.EP.Listen(10); err != nil {
t.Fatalf("Listen failed: %v", err)
}
c.EP.Close()
@@ -2573,12 +2892,566 @@
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
t.Fatalf("Bind failed: %v", err)
}
- err = c.EP.Listen(10)
- if err != nil {
+ if err := c.EP.Listen(10); err != nil {
t.Fatalf("Listen failed: %v", err)
}
}
+func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
+ t.Helper()
+
+ var s tcpip.ReceiveBufferSizeOption
+ if err := ep.GetSockOpt(&s); err != nil {
+ t.Fatalf("GetSockOpt failed: %v", err)
+ }
+
+ if int(s) != v {
+ t.Fatalf("got receive buffer size = %v, want = %v", s, v)
+ }
+}
+
+func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
+ t.Helper()
+
+ var s tcpip.SendBufferSizeOption
+ if err := ep.GetSockOpt(&s); err != nil {
+ t.Fatalf("GetSockOpt failed: %v", err)
+ }
+
+ if int(s) != v {
+ t.Fatalf("got send buffer size = %v, want = %v", s, v)
+ }
+}
+
+func TestDefaultBufferSizes(t *testing.T) {
+ s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+
+ // Check the default values.
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ defer func() {
+ if ep != nil {
+ ep.Close()
+ }
+ }()
+
+ checkSendBufferSize(t, ep, tcp.DefaultBufferSize)
+ checkRecvBufferSize(t, ep, tcp.DefaultBufferSize)
+
+ // Change the default send buffer size.
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultBufferSize * 2, tcp.DefaultBufferSize * 20}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ ep.Close()
+ ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+
+ checkSendBufferSize(t, ep, tcp.DefaultBufferSize*2)
+ checkRecvBufferSize(t, ep, tcp.DefaultBufferSize)
+
+ // Change the default receive buffer size.
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultBufferSize * 3, tcp.DefaultBufferSize * 30}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ ep.Close()
+ ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+
+ checkSendBufferSize(t, ep, tcp.DefaultBufferSize*2)
+ checkRecvBufferSize(t, ep, tcp.DefaultBufferSize*3)
+}
+
+func TestMinMaxBufferSizes(t *testing.T) {
+ s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+
+ // Check the default values.
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ defer ep.Close()
+
+ // Change the min/max values for send/receive
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultBufferSize * 2, tcp.DefaultBufferSize * 20}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultBufferSize * 3, tcp.DefaultBufferSize * 30}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ // Set values below the min.
+ if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(199)); err != nil {
+ t.Fatalf("GetSockOpt failed: %v", err)
+ }
+
+ checkRecvBufferSize(t, ep, 200)
+
+ if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(299)); err != nil {
+ t.Fatalf("GetSockOpt failed: %v", err)
+ }
+
+ checkSendBufferSize(t, ep, 300)
+
+ // Set values above the max.
+ if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultBufferSize*20)); err != nil {
+ t.Fatalf("GetSockOpt failed: %v", err)
+ }
+
+ checkRecvBufferSize(t, ep, tcp.DefaultBufferSize*20)
+
+ if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultBufferSize*30)); err != nil {
+ t.Fatalf("GetSockOpt failed: %v", err)
+ }
+
+ checkSendBufferSize(t, ep, tcp.DefaultBufferSize*30)
+}
+
+func makeStack() (*stack.Stack, *tcpip.Error) {
+ s := stack.New([]string{
+ ipv4.ProtocolName,
+ ipv6.ProtocolName,
+ }, []string{tcp.ProtocolName}, stack.Options{})
+
+ id := loopback.New()
+ if testing.Verbose() {
+ id = sniffer.New(id)
+ }
+
+ if err := s.CreateNIC(1, id); err != nil {
+ return nil, err
+ }
+
+ for _, ct := range []struct {
+ number tcpip.NetworkProtocolNumber
+ address tcpip.Address
+ }{
+ {ipv4.ProtocolNumber, context.StackAddr},
+ {ipv6.ProtocolNumber, context.StackV6Addr},
+ } {
+ if err := s.AddAddress(1, ct.number, ct.address); err != nil {
+ return nil, err
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: "\x00\x00\x00\x00",
+ Mask: "\x00\x00\x00\x00",
+ Gateway: "",
+ NIC: 1,
+ },
+ {
+ Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ Gateway: "",
+ NIC: 1,
+ },
+ })
+
+ return s, nil
+}
+
+func TestSelfConnect(t *testing.T) {
+ // This test ensures that intentional self-connects work. In particular,
+ // it checks that if an endpoint binds to say 127.0.0.1:1000 then
+ // connects to 127.0.0.1:1000, then it will be connected to itself, and
+ // is able to send and receive data through the same endpoint.
+ s, err := makeStack()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Register for notification, then start connection attempt.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventOut)
+ defer wq.EventUnregister(&waitEntry)
+
+ if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ }
+
+ <-notifyCh
+ if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil {
+ t.Fatalf("Connect failed: %v", err)
+ }
+
+ // Write something.
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+ if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ // Read back what was written.
+ wq.EventUnregister(&waitEntry)
+ wq.EventRegister(&waitEntry, waiter.EventIn)
+ rd, _, err := ep.Read(nil)
+ if err != nil {
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("Read failed: %v", err)
+ }
+ <-notifyCh
+ rd, _, err = ep.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+ }
+
+ if !bytes.Equal(data, rd) {
+ t.Fatalf("got data = %v, want = %v", rd, data)
+ }
+}
+
+func TestConnectAvoidsBoundPorts(t *testing.T) {
+ addressTypes := func(t *testing.T, network string) []string {
+ switch network {
+ case "ipv4":
+ return []string{"v4"}
+ case "ipv6":
+ return []string{"v6"}
+ case "dual":
+ return []string{"v6", "mapped"}
+ default:
+ t.Fatalf("unknown network: '%s'", network)
+ }
+
+ panic("unreachable")
+ }
+
+ address := func(t *testing.T, addressType string, isAny bool) tcpip.Address {
+ switch addressType {
+ case "v4":
+ if isAny {
+ return ""
+ }
+ return context.StackAddr
+ case "v6":
+ if isAny {
+ return ""
+ }
+ return context.StackV6Addr
+ case "mapped":
+ if isAny {
+ return context.V4MappedWildcardAddr
+ }
+ return context.StackV4MappedAddr
+ default:
+ t.Fatalf("unknown address type: '%s'", addressType)
+ }
+
+ panic("unreachable")
+ }
+ // This test ensures that Endpoint.Connect doesn't select already-bound ports.
+ networks := []string{"ipv4", "ipv6", "dual"}
+ for _, exhaustedNetwork := range networks {
+ t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) {
+ for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) {
+ t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) {
+ for _, isAny := range []bool{false, true} {
+ t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) {
+ for _, candidateNetwork := range networks {
+ t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) {
+ for _, candidateAddressType := range addressTypes(t, candidateNetwork) {
+ t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) {
+ s, err := makeStack()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var wq waiter.Queue
+ var eps []tcpip.Endpoint
+ defer func() {
+ for _, ep := range eps {
+ ep.Close()
+ }
+ }()
+ makeEP := func(network string) tcpip.Endpoint {
+ var networkProtocolNumber tcpip.NetworkProtocolNumber
+ switch network {
+ case "ipv4":
+ networkProtocolNumber = ipv4.ProtocolNumber
+ case "ipv6", "dual":
+ networkProtocolNumber = ipv6.ProtocolNumber
+ default:
+ t.Fatalf("unknown network: '%s'", network)
+ }
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ eps = append(eps, ep)
+ switch network {
+ case "ipv4":
+ case "ipv6":
+ if err := ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
+ t.Fatalf("SetSockOpt(V6OnlyOption(1)) failed: %v", err)
+ }
+ case "dual":
+ if err := ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
+ t.Fatalf("SetSockOpt(V6OnlyOption(0)) failed: %v", err)
+ }
+ default:
+ t.Fatalf("unknown network: '%s'", network)
+ }
+ return ep
+ }
+
+ var v4reserved, v6reserved bool
+ switch exhaustedAddressType {
+ case "v4", "mapped":
+ v4reserved = true
+ case "v6":
+ v6reserved = true
+ // Dual stack sockets bound to v6 any reserve on v4 as
+ // well.
+ if isAny {
+ switch exhaustedNetwork {
+ case "ipv6":
+ case "dual":
+ v4reserved = true
+ default:
+ t.Fatalf("unknown address type: '%s'", exhaustedNetwork)
+ }
+ }
+ default:
+ t.Fatalf("unknown address type: '%s'", exhaustedAddressType)
+ }
+ var collides bool
+ switch candidateAddressType {
+ case "v4", "mapped":
+ collides = v4reserved
+ case "v6":
+ collides = v6reserved
+ default:
+ t.Fatalf("unknown address type: '%s'", candidateAddressType)
+ }
+
+ for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ {
+ if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}, nil); err != nil {
+ t.Fatalf("Bind(%d) failed: %v", i, err)
+ }
+ }
+ want := tcpip.ErrConnectStarted
+ if collides {
+ want = tcpip.ErrNoPortAvailable
+ }
+ if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want {
+ t.Fatalf("got ep.Connect(..) = %v, want = %v", err, want)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
+
+func TestPathMTUDiscovery(t *testing.T) {
+ // This test verifies the stack retransmits packets after it receives an
+ // ICMP packet indicating that the path MTU has been exceeded.
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ // Create new connection with MSS of 1460.
+ const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize
+ c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
+ })
+
+ // Send 3200 bytes of data.
+ const writeSize = 3200
+ data := buffer.NewView(writeSize)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte {
+ var ret []byte
+ for i, size := range sizes {
+ p := c.GetPacket()
+ if i == which {
+ ret = p
+ }
+ checker.IPv4(t, p,
+ checker.PayloadLen(size+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(seqNum),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ seqNum += uint32(size)
+ }
+ return ret
+ }
+
+ // Receive three packets.
+ sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload}
+ first := receivePackets(c, sizes, 0, uint32(c.IRS)+1)
+
+ // Send "packet too big" messages back to netstack.
+ const newMTU = 1200
+ const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize
+ mtu := []byte{0, 0, newMTU / 256, newMTU % 256}
+ c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU)
+
+ // See retransmitted packets. None exceeding the new max.
+ sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload}
+ receivePackets(c, sizes, -1, uint32(c.IRS)+1)
+}
+
+func TestTCPEndpointProbe(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ invoked := make(chan struct{})
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that the endpoint ID is what we expect.
+ //
+ // We don't do an extensive validation of every field but a
+ // basic sanity test.
+ if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want {
+ t.Fatalf("got LocalAddress: %q, want: %q", got, want)
+ }
+ if got, want := state.ID.LocalPort, c.Port; got != want {
+ t.Fatalf("got LocalPort: %d, want: %d", got, want)
+ }
+ if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want {
+ t.Fatalf("got RemoteAddress: %q, want: %q", got, want)
+ }
+ if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want {
+ t.Fatalf("got RemotePort: %d, want: %d", got, want)
+ }
+
+ invoked <- struct{}{}
+ })
+
+ c.CreateConnected(789, 30000, nil)
+
+ data := []byte{1, 2, 3}
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ select {
+ case <-invoked:
+ case <-time.After(100 * time.Millisecond):
+ t.Fatalf("TCP Probe function was not called")
+ }
+}
+
+func TestSetCongestionControl(t *testing.T) {
+ testCases := []struct {
+ cc tcp.CongestionControlOption
+ mustPass bool
+ }{
+ {"reno", true},
+ {"cubic", true},
+ }
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ s := c.Stack()
+
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != nil && tc.mustPass {
+ t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want not-nil", tcp.ProtocolNumber, tc.cc, err)
+ }
+
+ var cc tcp.CongestionControlOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
+ }
+ if got, want := cc, tc.cc; got != want {
+ t.Fatalf("got congestion control: %v, want: %v", got, want)
+ }
+ })
+ }
+}
+
+func TestAvailableCongestionControl(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ s := c.Stack()
+
+ // Query permitted congestion control algorithms.
+ var aCC tcp.AvailableCongestionControlOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err)
+ }
+ if got, want := aCC, tcp.AvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want)
+ }
+}
+
+func TestSetAvailableCongestionControl(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ s := c.Stack()
+
+ // Setting AvailableCongestionControlOption should fail.
+ aCC := tcp.AvailableCongestionControlOption("xyz")
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC)
+ }
+
+ // Verify that we still get the expected list of congestion control options.
+ var cc tcp.AvailableCongestionControlOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
+ }
+ if got, want := cc, tcp.AvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want)
+ }
+}
+
+func enableCUBIC(t *testing.T, c *context.Context) {
+ t.Helper()
+ opt := tcp.CongestionControlOption("cubic")
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
+ t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err)
+ }
+}
+
func TestKeepalive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -2615,15 +3488,15 @@
}
// Check that the connection is still alive.
- if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("Unexpected error from Read: %v", err)
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
}
// Send some data and wait before ACKing it. Keepalives should be disabled
// during this period.
view := buffer.NewView(3)
- if _, err := c.EP.Write(view, nil); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
}
next := uint32(c.IRS) + 1
@@ -2637,7 +3510,8 @@
),
)
- // Wait for the packet to be retransmitted. Verify that no keepalives were sent.
+ // Wait for the packet to be retransmitted. Verify that no keepalives
+ // were sent.
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
@@ -2647,7 +3521,7 @@
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
),
)
- c.CheckNoPacket("Keepalive packet received while unACKed data is pending.")
+ c.CheckNoPacket("Keepalive packet received while unACKed data is pending")
next += uint32(len(view))
@@ -2685,7 +3559,7 @@
),
)
- if _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("Unexpected error from Read: %v", err)
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
}
}
diff --git a/tcpip/transport/tcp/tcp_timestamp_test.go b/tcpip/transport/tcp/tcp_timestamp_test.go
index 0e1a333..64aefc1 100644
--- a/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -1,3 +1,17 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package tcp_test
import (
@@ -6,6 +20,7 @@
"testing"
"time"
+ "github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/checker"
"github.com/google/netstack/tcpip/header"
@@ -90,7 +105,7 @@
// There should be 5 views to read and each of them should
// contain the same data.
for i := 0; i < 5; i++ {
- got, err := c.EP.Read(nil)
+ got, _, err := c.EP.Read(nil)
if err != nil {
t.Fatalf("Unexpected error from Read: %v", err)
}
@@ -132,7 +147,7 @@
view := buffer.NewView(len(data))
copy(view, data)
- if _, err := c.EP.Write(view, nil); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Unexpected error from Write: %v", err)
}
@@ -167,7 +182,7 @@
wndSize uint16
}{
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 2, 0xd000},
+ {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
}
for _, tc := range testCases {
timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
@@ -195,7 +210,7 @@
view := buffer.NewView(len(data))
copy(view, data)
- if _, err := c.EP.Write(view, nil); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Unexpected error from Write: %v", err)
}
@@ -224,7 +239,7 @@
wndSize uint16
}{
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 2, 0xd000},
+ {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
}
for _, tc := range testCases {
timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
@@ -252,8 +267,8 @@
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- stk := c.Stack()
- droppedPackets := stk.Stats().DroppedPackets
+ droppedPacketsStat := c.Stack().Stats().DroppedPackets
+ droppedPackets := droppedPacketsStat.Value()
data := []byte{1, 2, 3}
// Save the sequence number as we will reset it later down
// in the test.
@@ -268,11 +283,11 @@
}
// Assert that DroppedPackets was incremented by 1.
- if got, want := stk.Stats().DroppedPackets, droppedPackets+1; got != want {
+ if got, want := droppedPacketsStat.Value(), droppedPackets+1; got != want {
t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want)
}
- droppedPackets = stk.Stats().DroppedPackets
+ droppedPackets = droppedPacketsStat.Value()
// Reset the sequence number so that the other endpoint accepts
// this segment and does not treat it like an out of order delivery.
rep.NextSeqNum = savedSeqNum
@@ -286,12 +301,12 @@
}
// Assert that DroppedPackets was not incremented by 1.
- if got, want := stk.Stats().DroppedPackets, droppedPackets; got != want {
+ if got, want := droppedPacketsStat.Value(), droppedPackets; got != want {
t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want)
}
// Issue a read and we should data.
- got, err := c.EP.Read(nil)
+ got, _, err := c.EP.Read(nil)
if err != nil {
t.Fatalf("Unexpected error from Read: %v", err)
}
diff --git a/tcpip/transport/tcp/testing/context/context.go b/tcpip/transport/tcp/testing/context/context.go
index c6a0bee..4bfd237 100644
--- a/tcpip/transport/tcp/testing/context/context.go
+++ b/tcpip/transport/tcp/testing/context/context.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package context provides a test context for use in tcp tests. It also
// provides helper methods to assert/check certain behaviours.
@@ -83,7 +93,7 @@
// AckNum represents the acknowledgement number field in the TCP header.
AckNum seqnum.Value
- // Flags are the TCP flags in the the TCP header.
+ // Flags are the TCP flags in the TCP header.
Flags int
// RcvWnd is the window to be advertised in the ReceiveWindow field of
@@ -129,9 +139,20 @@
// New allocates and initializes a test context containing a new
// stack and a link-layer endpoint.
func New(t *testing.T, mtu uint32) *Context {
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName})
+ s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
- id, linkEP := channel.New(256, mtu, "")
+ // Allow minimum send/receive buffer sizes to be 1 during tests.
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultBufferSize, tcp.DefaultBufferSize * 10}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultBufferSize, tcp.DefaultBufferSize * 10}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ // Some of the congestion control tests send up to 640 packets, we so
+ // set the channel size to 1000.
+ id, linkEP := channel.New(1000, mtu, "")
if testing.Verbose() {
id = sniffer.New(id)
}
@@ -184,9 +205,11 @@
// CheckNoPacketTimeout verifies that no packet is received during the time
// specified by wait.
func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) {
+ c.t.Helper()
+
select {
case <-c.linkEP.C:
- c.t.Fatalf(errMsg)
+ c.t.Fatal(errMsg)
case <-time.After(wait):
}
@@ -221,6 +244,57 @@
return nil
}
+// GetPacketNonBlocking reads a packet from the link layer endpoint
+// and verifies that it is an IPv4 packet with the expected source
+// and destination address. If no packet is available it will return
+// nil immediately.
+func (c *Context) GetPacketNonBlocking() []byte {
+ select {
+ case p := <-c.linkEP.C:
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
+ b := make([]byte, len(p.Header)+len(p.Payload))
+ copy(b, p.Header)
+ copy(b[len(p.Header):], p.Payload)
+
+ checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
+ return b
+ default:
+ return nil
+ }
+}
+
+// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
+func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) {
+ // Allocate a buffer data and headers.
+ buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(p1) + len(p2))
+ if len(buf) > maxTotalSize {
+ buf = buf[:maxTotalSize]
+ }
+
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(len(buf)),
+ TTL: 65,
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ SrcAddr: TestAddr,
+ DstAddr: StackAddr,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ icmp := header.ICMPv4(buf[header.IPv4MinimumSize:])
+ icmp.SetType(typ)
+ icmp.SetCode(code)
+
+ copy(icmp[header.ICMPv4MinimumSize:], p1)
+ copy(icmp[header.ICMPv4MinimumSize+len(p1):], p2)
+
+ // Inject packet.
+ c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())
+}
+
// BuildSegment builds a TCP segment based on the given Headers and payload.
func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView {
// Allocate a buffer for data and headers.
@@ -263,23 +337,19 @@
t.SetChecksum(^t.CalculateChecksum(xsum, length))
// Inject packet.
- var views [1]buffer.View
- vv := buf.ToVectorisedView(views)
-
- return vv
+ return buf.ToVectorisedView()
}
// SendSegment sends a TCP segment that has already been built and written to a
// buffer.VectorisedView.
-func (c *Context) SendSegment(s *buffer.VectorisedView) {
+func (c *Context) SendSegment(s buffer.VectorisedView) {
c.linkEP.Inject(ipv4.ProtocolNumber, s)
}
// SendPacket builds and sends a TCP segment(with the provided payload & TCP
// headers) in an IPv4 packet via the link layer endpoint.
func (c *Context) SendPacket(payload []byte, h *Headers) {
- vv := c.BuildSegment(payload, h)
- c.linkEP.Inject(ipv4.ProtocolNumber, &vv)
+ c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h))
}
// SendAck sends an ACK packet.
@@ -315,6 +385,32 @@
}
}
+// ReceiveNonBlockingAndCheckPacket reads a packet from the link layer endpoint
+// and verifies that the packet packet payload of packet matches the slice of
+// data indicated by offset & size. It returns true if a packet was received and
+// processed.
+func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool {
+ b := c.GetPacketNonBlocking()
+ if b == nil {
+ return false
+ }
+ checker.IPv4(c.t, b,
+ checker.PayloadLen(size+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(TestPort),
+ checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
+ checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ pdata := data[offset:][:size]
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 {
+ c.t.Fatalf("Data is different: expected %v, got %v", pdata, p)
+ }
+ return true
+}
+
// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only
// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6
// only endpoint instead of a default dual stack socket.
@@ -396,9 +492,7 @@
t.SetChecksum(^t.CalculateChecksum(xsum, length))
// Inject packet.
- var views [1]buffer.View
- vv := buf.ToVectorisedView(views)
- c.linkEP.Inject(ipv6.ProtocolNumber, &vv)
+ c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
}
// CreateConnected creates a connected TCP endpoint.
@@ -494,15 +588,17 @@
WndSize seqnum.Size
RecentTS uint32 // Stores the latest timestamp to echo back.
TSVal uint32 // TSVal stores the last timestamp sent by this endpoint.
+
+ // SackPermitted is true if SACKPermitted option was negotiated for this endpoint.
+ SACKPermitted bool
}
// SendPacketWithTS embeds the provided tsVal in the Timestamp option
// for the packet to be sent out.
func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) {
r.TSVal = tsVal
- // Increment TSVal by 1 from the value sent in the SYN and echo the
- // TSVal in the SYN-ACK in the TSEcr field.
- tsOpt := header.EncodeTSOption(r.TSVal, r.RecentTS)
+ tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:])
r.SendPacket(payload, tsOpt[:])
}
@@ -541,6 +637,27 @@
r.RecentTS = opts.TSVal
}
+// VerifyACKNoSACK verifies that the ACK does not contain a SACK block.
+func (r *RawEndpoint) VerifyACKNoSACK() {
+ r.VerifyACKHasSACK(nil)
+}
+
+// VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks.
+func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
+ // Read ACK and verify that the TCP options in the segment do
+ // not contain a SACK block.
+ ackPacket := r.C.GetPacket()
+ checker.IPv4(r.C.t, ackPacket,
+ checker.TCP(
+ checker.DstPort(r.SrcPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(r.AckNum)),
+ checker.AckNum(uint32(r.NextSeqNum)),
+ checker.TCPSACKBlockChecker(sackBlocks),
+ ),
+ )
+}
+
// CreateConnectedWithOptions creates and connects c.ep with the specified TCP
// options enabled and returns a RawEndpoint which represents the other end of
// the connection.
@@ -573,9 +690,10 @@
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagSyn),
checker.TCPSynOptions(header.TCPSynOptions{
- MSS: uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize),
- TS: true,
- WS: defaultWindowScale,
+ MSS: uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize),
+ TS: true,
+ WS: defaultWindowScale,
+ SACKPermitted: c.SACKEnabled(),
}),
),
)
@@ -583,11 +701,16 @@
synOptions := header.ParseSynOptions(tcpSeg.Options(), false)
// Build options w/ tsVal to be sent in the SYN-ACK.
- var synAckOptions []byte
+ synAckOptions := make([]byte, 40)
+ offset := 0
if wantOptions.TS {
- tsOpt := header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal)
- synAckOptions = append(synAckOptions, tsOpt[:]...)
+ offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:])
}
+ if wantOptions.SACKPermitted {
+ offset += header.EncodeSACKPermittedOption(synAckOptions[offset:])
+ }
+
+ offset += header.AddTCPOptionPadding(synAckOptions, offset)
// Build SYN-ACK.
c.IRS = seqnum.Value(tcpSeg.SequenceNumber())
@@ -599,7 +722,7 @@
SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
- TCPOpts: synAckOptions[:],
+ TCPOpts: synAckOptions[:offset],
})
// Read ACK.
@@ -645,22 +768,26 @@
c.TimeStampEnabled = true
return &RawEndpoint{
- C: c,
- SrcPort: tcpSeg.DestinationPort(),
- DstPort: tcpSeg.SourcePort(),
- Flags: header.TCPFlagAck | header.TCPFlagPsh,
- NextSeqNum: iss + 1,
- AckNum: c.IRS.Add(1),
- WndSize: 30000,
- RecentTS: ackOptions.TSVal,
- TSVal: wantOptions.TSVal,
+ C: c,
+ SrcPort: tcpSeg.DestinationPort(),
+ DstPort: tcpSeg.SourcePort(),
+ Flags: header.TCPFlagAck | header.TCPFlagPsh,
+ NextSeqNum: iss + 1,
+ AckNum: c.IRS.Add(1),
+ WndSize: 30000,
+ RecentTS: ackOptions.TSVal,
+ TSVal: wantOptions.TSVal,
+ SACKPermitted: wantOptions.SACKPermitted,
}
}
// AcceptWithOptions initializes a listening endpoint and connects to it with the
// provided options enabled. It also verifies that the SYN-ACK has the expected
// values for the provided options.
-func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) {
+//
+// The function returns a RawEndpoint representing the other end of the accepted
+// endpoint.
+func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
// Create EP and start listening.
wq := &waiter.Queue{}
ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
@@ -677,7 +804,7 @@
c.t.Fatalf("Listen failed: %v", err)
}
- c.PassiveConnectWithOptions(100, wndScale, synOptions)
+ rep := c.PassiveConnectWithOptions(100, wndScale, synOptions)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -698,6 +825,7 @@
c.t.Fatalf("Timed out waiting for accept")
}
}
+ return rep
}
// PassiveConnect just disables WindowScaling and delegates the call to
@@ -720,21 +848,30 @@
// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the
// value of the window scaling option to be sent in the SYN. If synOptions.WS >
// 0 then we send the WindowScale option.
-func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) {
- opts := []byte{
- header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
- }
+func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+ opts := make([]byte, 40)
+ offset := 0
+ offset += header.EncodeMSSOption(uint32(maxPayload), opts)
if synOptions.WS >= 0 {
- opts = append(opts, []byte{
- header.TCPOptionWS, 3, byte(synOptions.WS), header.TCPOptionNOP,
- }...)
+ offset += header.EncodeWSOption(3, opts[offset:])
}
if synOptions.TS {
- tsOpt := header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr)
- opts = append(opts, tsOpt[:]...)
+ offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:])
}
+ if synOptions.SACKPermitted {
+ offset += header.EncodeSACKPermittedOption(opts[offset:])
+ }
+
+ paddingToAdd := 4 - offset%4
+ // Now add any padding bytes that might be required to quad align the
+ // options.
+ for i := offset; i < offset+paddingToAdd; i++ {
+ opts[i] = header.TCPOptionNOP
+ }
+ offset += paddingToAdd
+
// Send a SYN request.
iss := seqnum.Value(testInitialSequenceNumber)
c.SendPacket(nil, &Headers{
@@ -743,10 +880,11 @@
Flags: header.TCPFlagSyn,
SeqNum: iss,
RcvWnd: 30000,
- TCPOpts: opts,
+ TCPOpts: opts[:offset],
})
- // Receive the SYN-ACK reply. Make sure MSS is present.
+ // Receive the SYN-ACK reply. Make sure MSS and other expected options
+ // are present.
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
c.IRS = seqnum.Value(tcp.SequenceNumber())
@@ -756,7 +894,7 @@
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
checker.AckNum(uint32(iss) + 1),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale}),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}),
}
// If TS option was enabled in the original SYN then add a checker to
@@ -784,18 +922,43 @@
ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS)
}
+ parsedOpts := tcp.ParsedOptions()
if synOptions.TS {
// Echo the tsVal back to the peer in the tsEcr field of the
// timestamp option.
- opts := tcp.ParsedOptions()
// Increment TSVal by 1 from the value sent in the SYN and echo
// the TSVal in the SYN-ACK in the TSEcr field.
- tsOpt := header.EncodeTSOption(synOptions.TSVal+1, opts.TSVal)
- ackHeaders.TCPOpts = tsOpt[:]
+ opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:])
+ ackHeaders.TCPOpts = opts[:]
}
// Send ACK.
c.SendPacket(nil, ackHeaders)
c.Port = StackPort
+
+ return &RawEndpoint{
+ C: c,
+ SrcPort: TestPort,
+ DstPort: StackPort,
+ Flags: header.TCPFlagPsh | header.TCPFlagAck,
+ NextSeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ WndSize: rcvWnd,
+ SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(),
+ RecentTS: parsedOpts.TSVal,
+ TSVal: synOptions.TSVal + 1,
+ }
+}
+
+// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true
+// for the Stack in the context.
+func (c *Context) SACKEnabled() bool {
+ var v tcp.SACKEnabled
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil {
+ // Stack doesn't support SACK. So just return.
+ return false
+ }
+ return bool(v)
}
diff --git a/tcpip/transport/tcp/timer.go b/tcpip/transport/tcp/timer.go
index b6cf990..7ee7ed0 100644
--- a/tcpip/transport/tcp/timer.go
+++ b/tcpip/transport/tcp/timer.go
@@ -1,6 +1,16 @@
-// Copyright 2017 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package tcp
@@ -112,7 +122,7 @@
}
}
-// enabled returns true if the timer is currenlty enabled, false otherwise.
+// enabled returns true if the timer is currently enabled, false otherwise.
func (t *timer) enabled() bool {
return t.state == timerStateEnabled
}
diff --git a/tcpip/transport/tcpconntrack/tcp_conntrack.go b/tcpip/transport/tcpconntrack/tcp_conntrack.go
index 955a708..0208479 100644
--- a/tcpip/transport/tcpconntrack/tcp_conntrack.go
+++ b/tcpip/transport/tcpconntrack/tcp_conntrack.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package tcpconntrack implements a TCP connection tracking object. It allows
// users with access to a segment stream to figure out when a connection is
@@ -51,9 +61,12 @@
// firstFin holds a pointer to the first stream to send a FIN.
firstFin *stream
+
+ // state is the current state of the stream.
+ state Result
}
-// Init initalizes the state of the TCB according to the initial SYN.
+// Init initializes the state of the TCB according to the initial SYN.
func (t *TCB) Init(initialSyn header.TCP) {
t.handlerInbound = synSentStateInbound
t.handlerOutbound = synSentStateOutbound
@@ -69,18 +82,43 @@
t.inbound.una = 0
t.inbound.nxt = 0
t.inbound.end = seqnum.Value(initialSyn.WindowSize())
+ t.state = ResultConnecting
}
// UpdateStateInbound updates the state of the TCB based on the supplied inbound
// segment.
func (t *TCB) UpdateStateInbound(tcp header.TCP) Result {
- return t.handlerInbound(t, tcp)
+ st := t.handlerInbound(t, tcp)
+ if st != ResultDrop {
+ t.state = st
+ }
+ return st
}
// UpdateStateOutbound updates the state of the TCB based on the supplied
// outbound segment.
func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result {
- return t.handlerOutbound(t, tcp)
+ st := t.handlerOutbound(t, tcp)
+ if st != ResultDrop {
+ t.state = st
+ }
+ return st
+}
+
+// IsAlive returns true as long as the connection is established(Alive)
+// or connecting state.
+func (t *TCB) IsAlive() bool {
+ return !t.inbound.rstSeen && !t.outbound.rstSeen && (!t.inbound.closed() || !t.outbound.closed())
+}
+
+// OutboundSendSequenceNumber returns the snd.NXT for the outbound stream.
+func (t *TCB) OutboundSendSequenceNumber() seqnum.Value {
+ return t.outbound.nxt
+}
+
+// InboundSendSequenceNumber returns the snd.NXT for the inbound stream.
+func (t *TCB) InboundSendSequenceNumber() seqnum.Value {
+ return t.inbound.nxt
}
// adapResult modifies the supplied "Result" according to the state of the TCB;
@@ -118,6 +156,7 @@
// implicitly acceptable).
if flags&header.TCPFlagRst != 0 {
if ackPresent {
+ t.inbound.rstSeen = true
return ResultReset
}
return ResultConnecting
@@ -185,6 +224,7 @@
flags := tcp.Flags()
if flags&header.TCPFlagRst != 0 {
+ inbound.rstSeen = true
return ResultReset
}
@@ -260,6 +300,9 @@
// fin is the sequence number of the FIN. It is only valid after finSeen
// is set to true.
fin seqnum.Value
+
+ // rstSeen indicates if a RST has already been sent on this stream.
+ rstSeen bool
}
// acceptable determines if the segment with the given sequence number and data
diff --git a/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
index bb6d9df..f77c617 100644
--- a/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
+++ b/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package tcpconntrack_test
diff --git a/tcpip/transport/udp/endpoint.go b/tcpip/transport/udp/endpoint.go
index 68ca28b..a45f780 100644
--- a/tcpip/transport/udp/endpoint.go
+++ b/tcpip/transport/udp/endpoint.go
@@ -1,13 +1,23 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package udp
import (
"sync"
- "sync/atomic"
+ "github.com/google/netstack/sleep"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
@@ -15,10 +25,13 @@
"github.com/google/netstack/waiter"
)
+// +stateify savable
type udpPacket struct {
udpPacketEntry
senderAddress tcpip.FullAddress
data buffer.VectorisedView
+ timestamp int64
+ hasTimestamp bool
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
views [8]buffer.View
@@ -37,6 +50,8 @@
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
// synchronized.
+//
+// +stateify savable
type endpoint struct {
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
@@ -52,6 +67,7 @@
rcvBufSizeMax int
rcvBufSize int
rcvClosed bool
+ rcvTimestamp bool
// The following fields are protected by the mu mutex.
mu sync.RWMutex
@@ -59,15 +75,17 @@
id stack.TransportEndpointID
state endpointState
bindNICID tcpip.NICID
- bindAddr tcpip.Address
regNICID tcpip.NICID
route stack.Route
dstPort uint16
v6only bool
multicastTTL uint8
- // A list of multicast memberships that we need to remove when the endpoint
- // is closed. Protected by the mu mutex.
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+
+ // multicastMemberships that need to be remvoed when the endpoint is
+ // closed. Protected by the mu mutex.
multicastMemberships []multicastMembership
// effectiveNetProtos contains the network protocols actually in use. In
@@ -85,12 +103,22 @@
}
func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
- // TODO: Use the send buffer size initialized here.
return &endpoint{
- stack: stack,
- netProto: netProto,
- waiterQueue: waiterQueue,
- v6only: true,
+ stack: stack,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ // RFC 1075 section 5.4 recommends a TTL of 1 for membership
+ // requests.
+ //
+ // RFC 5135 4.2.1 appears to assume that IGMP messages have a
+ // TTL of 1.
+ //
+ // RFC 5135 Appendix A defines TTL=1: A multicast source that
+ // wants its traffic to not traverse a router (e.g., leave a
+ // home network) may find it useful to send traffic with IP
+ // TTL=1.
+ //
+ // Linux defaults to TTL=1.
multicastTTL: 1,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
@@ -118,18 +146,11 @@
return ep, nil
}
-func (e *endpoint) isPortReserved() bool {
- return e.id.LocalPort != 0
-}
-
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
- // Shutdown the endpoint so that we notify waiters that the endpoint is closed.
- e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
-
e.mu.Lock()
- defer e.mu.Unlock()
+ e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
@@ -156,11 +177,15 @@
// Update the state.
e.state = stateClosed
+
+ e.mu.Unlock()
+
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, *tcpip.Error) {
+func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.rcvMu.Lock()
if e.rcvList.Empty() {
@@ -169,12 +194,13 @@
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, err
+ return buffer.View{}, tcpip.ControlMessages{}, err
}
p := e.rcvList.Front()
e.rcvList.Remove(p)
e.rcvBufSize -= p.data.Size()
+ ts := e.rcvTimestamp
e.rcvMu.Unlock()
@@ -182,7 +208,12 @@
*addr = p.senderAddress
}
- return p.data.ToView(), nil
+ if ts && !p.hasTimestamp {
+ // Linux uses the current time.
+ p.timestamp = e.stack.NowNanoseconds()
+ }
+
+ return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: ts, Timestamp: p.timestamp}, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -227,10 +258,22 @@
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(v buffer.View, to *tcpip.FullAddress) (uintptr, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) {
+ // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
+ if opts.More {
+ return 0, tcpip.ErrInvalidOptionValue
+ }
+
+ to := opts.To
+
e.mu.RLock()
defer e.mu.RUnlock()
+ // If we've shutdown with SHUT_WR we are in an invalid state for sending.
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
+ return 0, tcpip.ErrClosedForSend
+ }
+
// Prepare for write.
for {
retry, err := e.prepareForWrite(to)
@@ -243,9 +286,27 @@
}
}
- route := &e.route
- dstPort := e.dstPort
- if to != nil {
+ var route *stack.Route
+ var dstPort uint16
+ if to == nil {
+ route = &e.route
+ dstPort = e.dstPort
+
+ if route.IsResolutionRequired() {
+ // Promote lock to exclusive if using a shared route, given that it may need to
+ // change in Route.Resolve() call below.
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != stateConnected {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+ }
+ } else {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicid := to.NIC
@@ -259,13 +320,13 @@
toCopy := *to
to = &toCopy
- netProto, err := e.checkV4Mapped(to, true)
+ netProto, err := e.checkV4Mapped(to, false)
if err != nil {
return 0, err
}
// Find the enpoint.
- r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto)
+ r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, to.Addr, netProto)
if err != nil {
return 0, err
}
@@ -275,21 +336,41 @@
dstPort = to.Port
}
+ if route.IsResolutionRequired() {
+ waker := &sleep.Waker{}
+ if err := route.Resolve(waker); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ // Link address needs to be resolved. Resolution was triggered the background.
+ // Better luck next time.
+ //
+ // TODO: queue up the request and send after link address
+ // is resolved.
+ route.RemoveWaker(waker)
+ return 0, tcpip.ErrNoLinkAddress
+ }
+ return 0, err
+ }
+ }
+
+ v, err := p.Get(p.Size())
+ if err != nil {
+ return 0, err
+ }
+
ttl := route.DefaultTTL()
if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
ttl = e.multicastTTL
}
- err := sendUDP(route, v, e.id.LocalPort, dstPort, ttl)
- if err != nil {
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil {
return 0, err
}
return uintptr(len(v)), nil
}
// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (uintptr, *tcpip.Error) {
- return 0, nil
+func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
}
// SetSockOpt sets a socket option. Currently not supported.
@@ -311,21 +392,28 @@
}
e.v6only = v != 0
+
+ case tcpip.TimestampOption:
+ e.rcvMu.Lock()
+ e.rcvTimestamp = v != 0
+ e.rcvMu.Unlock()
+
case tcpip.MulticastTTLOption:
e.mu.Lock()
defer e.mu.Unlock()
e.multicastTTL = uint8(v)
+
case tcpip.AddMembershipOption:
- nicID := v.NIC;
- if (v.InterfaceAddr != header.IPv4Any) {
- nicID = e.stack.CheckLocalAddress(nicID, v.InterfaceAddr)
+ nicID := v.NIC
+ if v.InterfaceAddr != header.IPv4Any {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
}
if nicID == 0 {
return tcpip.ErrNoRoute
}
- err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr)
- if err != nil {
+ // TODO: check that v.MulticastAddr is a multicast address.
+ if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
return err
}
@@ -333,17 +421,18 @@
defer e.mu.Unlock()
e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr})
+
case tcpip.RemoveMembershipOption:
- nicID := v.NIC;
- if (v.InterfaceAddr != header.IPv4Any) {
- nicID = e.stack.CheckLocalAddress(nicID, v.InterfaceAddr)
+ nicID := v.NIC
+ if v.InterfaceAddr != header.IPv4Any {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
}
if nicID == 0 {
return tcpip.ErrNoRoute
}
- err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr)
- if err != nil {
+ // TODO: check that v.MulticastAddr is a multicast address.
+ if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
return err
}
@@ -396,12 +485,6 @@
}
return nil
- case *tcpip.MulticastTTLOption:
- e.mu.Lock()
- *o = tcpip.MulticastTTLOption(e.multicastTTL)
- e.mu.Unlock()
- return nil
-
case *tcpip.ReceiveQueueSizeOption:
e.rcvMu.Lock()
if e.rcvList.Empty() {
@@ -412,6 +495,20 @@
}
e.rcvMu.Unlock()
return nil
+
+ case *tcpip.TimestampOption:
+ e.rcvMu.Lock()
+ *o = 0
+ if e.rcvTimestamp {
+ *o = 1
+ }
+ e.rcvMu.Unlock()
+
+ case *tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ *o = tcpip.MulticastTTLOption(e.multicastTTL)
+ e.mu.Unlock()
+ return nil
}
return tcpip.ErrUnknownProtocolOption
@@ -419,32 +516,33 @@
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.View, localPort, remotePort uint16, ttl uint8) *tcpip.Error {
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error {
// Allocate a buffer for the UDP header.
hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
// Initialize the header.
udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
- length := uint16(hdr.UsedLength())
- xsum := r.PseudoHeaderChecksum(ProtocolNumber)
- if data != nil {
- length += uint16(len(data))
- xsum = header.Checksum(data, xsum)
- }
-
+ length := uint16(hdr.UsedLength() + data.Size())
udp.Encode(&header.UDPFields{
SrcPort: localPort,
DstPort: remotePort,
Length: length,
})
- udp.SetChecksum(^udp.CalculateChecksum(xsum, length))
+ // Only calculate the checksum if offloading isn't supported.
+ if r.Capabilities()&stack.CapabilityChecksumOffload == 0 {
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber)
+ for _, v := range data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ udp.SetChecksum(^udp.CalculateChecksum(xsum, length))
+ }
// Track count of packets sent.
- atomic.AddUint64(&r.MutableStats().UDP.PacketsSent, 1)
+ r.Stats().UDP.PacketsSent.Increment()
- return r.WritePacket(&hdr, data, ProtocolNumber, ttl)
+ return r.WritePacket(hdr, data, ProtocolNumber, ttl)
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
@@ -460,11 +558,16 @@
if addr.Addr == "\x00\x00\x00\x00" {
addr.Addr = ""
}
+
+ // Fail if we are bound to an IPv6 address.
+ if !allowMismatch && len(e.id.LocalAddress) == 16 {
+ return 0, tcpip.ErrNetworkUnreachable
+ }
}
// Fail if we're bound to an address length different from the one we're
// checking.
- if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) {
+ if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -482,7 +585,7 @@
defer e.mu.Unlock()
nicid := addr.NIC
- localPort := uint16(0)
+ var localPort uint16
switch e.state {
case stateInitial:
case stateBound, stateConnected:
@@ -506,7 +609,7 @@
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto)
+ r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto)
if err != nil {
return err
}
@@ -516,14 +619,14 @@
LocalAddress: r.LocalAddress,
LocalPort: localPort,
RemotePort: addr.Port,
- RemoteAddress: addr.Addr,
+ RemoteAddress: r.RemoteAddress,
}
// Even if we're connected, this endpoint can still be used to send
// packets on a different network protocol, so we register both even if
// v6only is set to false and this is an ipv6 endpoint.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if e.netProto == header.IPv6ProtocolNumber && !e.v6only {
+ if netProto == header.IPv6ProtocolNumber && !e.v6only {
netProtos = []tcpip.NetworkProtocolNumber{
header.IPv4ProtocolNumber,
header.IPv6ProtocolNumber,
@@ -536,7 +639,7 @@
}
// Remove the old registration.
- if e.isPortReserved() {
+ if e.id.LocalPort != 0 {
e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
}
@@ -563,8 +666,8 @@
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.mu.Lock()
+ defer e.mu.Unlock()
// A socket in the bound state can still receive multicast messages,
// so we need to notify waiters on shutdown.
@@ -572,6 +675,8 @@
return tcpip.ErrNotConnected
}
+ e.shutdownFlags |= flags
+
if flags&tcpip.ShutdownRead != 0 {
e.rcvMu.Lock()
wasClosed := e.rcvClosed
@@ -597,13 +702,11 @@
}
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
- // Reserve the port.
- if !e.isPortReserved() {
+ if e.id.LocalPort == 0 {
port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
if err != nil {
return id, err
}
-
id.LocalPort = port
}
@@ -621,7 +724,7 @@
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr, false)
+ netProto, err := e.checkV4Mapped(&addr, true)
if err != nil {
return err
}
@@ -639,7 +742,7 @@
if len(addr.Addr) != 0 {
// A local address was specified, verify that it's valid.
- if e.stack.CheckLocalAddress(addr.NIC, addr.Addr) == 0 {
+ if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
return tcpip.ErrBadLocalAddress
}
}
@@ -648,12 +751,10 @@
LocalPort: addr.Port,
LocalAddress: addr.Addr,
}
-
id, err = e.registerWithStack(addr.NIC, netProtos, id)
if err != nil {
return err
}
-
if commit != nil {
if err := commit(); err != nil {
// Unregister, the commit failed.
@@ -689,7 +790,6 @@
}
e.bindNICID = addr.NIC
- e.bindAddr = addr.Addr
return nil
}
@@ -742,23 +842,23 @@
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
// Get the header then trim it from the view.
hdr := header.UDP(vv.First())
if int(hdr.Length()) > vv.Size() {
// Malformed packet.
- atomic.AddUint64(&e.stack.MutableStats().UDP.MalformedPacketsReceived, 1)
+ e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
return
}
vv.TrimFront(header.UDPMinimumSize)
e.rcvMu.Lock()
- atomic.AddUint64(&e.stack.MutableStats().UDP.PacketsReceived, 1)
+ e.stack.Stats().UDP.PacketsReceived.Increment()
// Drop the packet if our buffer is currently full.
if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
- atomic.AddUint64(&e.stack.MutableStats().UDP.ReceiveBufferErrors, 1)
+ e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
e.rcvMu.Unlock()
return
}
@@ -777,6 +877,11 @@
e.rcvList.PushBack(pkt)
e.rcvBufSize += vv.Size()
+ if e.rcvTimestamp {
+ pkt.timestamp = e.stack.NowNanoseconds()
+ pkt.hasTimestamp = true
+ }
+
e.rcvMu.Unlock()
// Notify any waiters that there's data to be read now.
@@ -784,3 +889,7 @@
e.waiterQueue.Notify(waiter.EventIn)
}
}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+}
diff --git a/tcpip/transport/udp/protocol.go b/tcpip/transport/udp/protocol.go
index cb4a9d5..7ac890c 100644
--- a/tcpip/transport/udp/protocol.go
+++ b/tcpip/transport/udp/protocol.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package udp contains the implementation of the UDP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
@@ -52,7 +62,7 @@
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool {
+func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
return true
}
@@ -61,6 +71,11 @@
return tcpip.ErrUnknownProtocolOption
}
+// Option implements TransportProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
func init() {
stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
return &protocol{}
diff --git a/tcpip/transport/udp/udp_packet_list.go b/tcpip/transport/udp/udp_packet_list.go
index 4525a3a..37e2acf 100644
--- a/tcpip/transport/udp/udp_packet_list.go
+++ b/tcpip/transport/udp/udp_packet_list.go
@@ -1,8 +1,3 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// Package ilist provides the implementation of intrusive linked lists.
package udp
// List is an intrusive list. Entries can be added to or removed from the list
@@ -14,6 +9,8 @@
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
+//
+// +stateify savable
type udpPacketList struct {
head *udpPacket
tail *udpPacket
@@ -133,6 +130,8 @@
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
+//
+// +stateify savable
type udpPacketEntry struct {
next *udpPacket
prev *udpPacket
diff --git a/tcpip/transport/udp/udp_test.go b/tcpip/transport/udp/udp_test.go
index 3bc29ff..1c70272 100644
--- a/tcpip/transport/udp/udp_test.go
+++ b/tcpip/transport/udp/udp_test.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package udp_test
@@ -24,7 +34,6 @@
)
const (
- testLinkAddr = "\x00\x00\x00\x00\x00\x02"
stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr
@@ -32,12 +41,13 @@
multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr
V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
- stackAddr = "\x0a\x00\x00\x01"
- stackPort = 1234
- testAddr = "\x0a\x00\x00\x02"
- testPort = 4096
- multicastAddr = "\xe8\x2b\xd3\xea"
- multicastPort = 1234
+ stackAddr = "\x0a\x00\x00\x01"
+ stackPort = 1234
+ testAddr = "\x0a\x00\x00\x02"
+ testPort = 4096
+ multicastAddr = "\xe8\x2b\xd3\xea"
+ multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ multicastPort = 1234
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -60,7 +70,7 @@
}
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName})
+ s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
id, linkEP := channel.New(256, mtu, "")
if testing.Verbose() {
@@ -93,10 +103,6 @@
},
})
- // Add test IP -> MAC mappings to LinkResolverCache
- s.AddLinkAddress(1, testV6Addr, testLinkAddr)
- s.AddLinkAddress(1, testV4MappedAddr, testLinkAddr)
-
return &testContext{
t: t,
s: s,
@@ -110,7 +116,7 @@
}
}
-func (c *testContext) createV6Endpoint(v4only bool) {
+func (c *testContext) createV6Endpoint(v6only bool) {
var err *tcpip.Error
c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
if err != nil {
@@ -118,7 +124,7 @@
}
var v tcpip.V6OnlyOption
- if v4only {
+ if v6only {
v = 1
}
if err := c.ep.SetSockOpt(v); err != nil {
@@ -126,57 +132,35 @@
}
}
-func (c *testContext) getV6Packet() []byte {
+func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte {
select {
case p := <-c.linkEP.C:
- if p.Proto != ipv6.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
+ if p.Proto != protocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber)
}
b := make([]byte, len(p.Header)+len(p.Payload))
copy(b, p.Header)
copy(b[len(p.Header):], p.Payload)
- checker.IPv6(c.t, b, checker.SrcAddr(stackV6Addr), checker.DstAddr(testV6Addr))
- return b
-
- case <-time.After(2 * time.Second):
- c.t.Fatalf("Packet wasn't written out")
- }
-
- return nil
-}
-
-func (c *testContext) getPacket() []byte {
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker)
+ var srcAddr, dstAddr tcpip.Address
+ switch protocolNumber {
+ case ipv4.ProtocolNumber:
+ checkerFn = checker.IPv4
+ srcAddr, dstAddr = stackAddr, testAddr
+ if multicast {
+ dstAddr = multicastAddr
+ }
+ case ipv6.ProtocolNumber:
+ checkerFn = checker.IPv6
+ srcAddr, dstAddr = stackV6Addr, testV6Addr
+ if multicast {
+ dstAddr = multicastV6Addr
+ }
+ default:
+ c.t.Fatalf("unknown protocol %d", protocolNumber)
}
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
-
- checker.IPv4(c.t, b, checker.SrcAddr(stackAddr), checker.DstAddr(testAddr))
- return b
-
- case <-time.After(2 * time.Second):
- c.t.Fatalf("Packet wasn't written out")
- }
-
- return nil
-}
-
-func (c *testContext) getMCPacket() []byte {
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
- }
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
-
- checker.IPv4(c.t, b, checker.SrcAddr(stackAddr), checker.DstAddr(multicastAddr))
+ checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr))
return b
case <-time.After(2 * time.Second):
@@ -220,9 +204,7 @@
u.SetChecksum(^u.CalculateChecksum(xsum, length))
// Inject packet.
- var views [1]buffer.View
- vv := buf.ToVectorisedView(views)
- c.linkEP.Inject(ipv6.ProtocolNumber, &vv)
+ c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
}
func (c *testContext) sendPacket(payload []byte, h *headers) {
@@ -261,9 +243,7 @@
u.SetChecksum(^u.CalculateChecksum(xsum, length))
// Inject packet.
- var views [1]buffer.View
- vv := buf.ToVectorisedView(views)
- c.linkEP.Inject(ipv4.ProtocolNumber, &vv)
+ c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())
}
func newPayload() []byte {
@@ -288,12 +268,12 @@
defer c.wq.EventUnregister(&we)
var addr tcpip.FullAddress
- v, err := c.ep.Read(&addr)
+ v, _, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
- v, err = c.ep.Read(&addr)
+ v, _, err = c.ep.Read(&addr)
if err != nil {
c.t.Fatalf("Read failed: %v", err)
}
@@ -314,6 +294,76 @@
}
}
+func TestBindEphemeralPort(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ if err := c.ep.Bind(tcpip.FullAddress{}, nil); err != nil {
+ t.Fatalf("ep.Bind(...) failed: %v", err)
+ }
+}
+
+func TestBindReservedPort(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
+ }
+
+ addr, err := c.ep.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("GetLocalAddress failed: %v", err)
+ }
+
+ // We can't bind the address reserved by the connected endpoint above.
+ {
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+ if got, want := ep.Bind(addr, nil), tcpip.ErrPortInUse; got != want {
+ t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want)
+ }
+ }
+
+ func() {
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+ // We can't bind ipv4-any on the port reserved by the connected endpoint
+ // above, since the endpoint is dual-stack.
+ if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}, nil), tcpip.ErrPortInUse; got != want {
+ t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want)
+ }
+ // We can bind an ipv4 address on this port, though.
+ if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}, nil); err != nil {
+ t.Fatalf("ep.Bind(...) failed: %v", err)
+ }
+ }()
+
+ // Once the connected endpoint releases its port reservation, we are able to
+ // bind ipv4-any once again.
+ c.ep.Close()
+ func() {
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+ if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}, nil); err != nil {
+ t.Fatalf("ep.Bind(...) failed: %v", err)
+ }
+ }()
+}
+
func TestV4ReadOnV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -350,7 +400,7 @@
c.createV6Endpoint(false)
- // Bind to local adress.
+ // Bind to local address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
@@ -383,12 +433,12 @@
defer c.wq.EventUnregister(&we)
var addr tcpip.FullAddress
- v, err := c.ep.Read(&addr)
+ v, _, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
- v, err = c.ep.Read(&addr)
+ v, _, err = c.ep.Read(&addr)
if err != nil {
c.t.Fatalf("Read failed: %v", err)
}
@@ -429,10 +479,12 @@
testV4Read(c)
}
-func testDualWrite(c *testContext) uint16 {
+func testV4Write(c *testContext) uint16 {
// Write to V4 mapped address.
payload := buffer.View(newPayload())
- n, err := c.ep.Write(payload, &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort})
+ n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
+ })
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
@@ -441,7 +493,7 @@
}
// Check that we received the packet.
- b := c.getPacket()
+ b := c.getPacket(ipv4.ProtocolNumber, false)
udp := header.UDP(header.IPv4(b).Payload())
checker.IPv4(c.t, b,
checker.UDP(
@@ -449,16 +501,20 @@
),
)
- port := udp.SourcePort()
-
// Check the payload.
if !bytes.Equal(payload, udp.Payload()) {
c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
}
+ return udp.SourcePort()
+}
+
+func testV6Write(c *testContext) uint16 {
// Write to v6 address.
- payload = buffer.View(newPayload())
- n, err = c.ep.Write(payload, &tcpip.FullAddress{Addr: testV6Addr, Port: testPort})
+ payload := buffer.View(newPayload())
+ n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
+ })
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
@@ -466,14 +522,12 @@
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
- // Check that we received the packet, and that the source port is the
- // same as the one used in ipv4.
- b = c.getV6Packet()
- udp = header.UDP(header.IPv6(b).Payload())
+ // Check that we received the packet.
+ b := c.getPacket(ipv6.ProtocolNumber, false)
+ udp := header.UDP(header.IPv6(b).Payload())
checker.IPv6(c.t, b,
checker.UDP(
checker.DstPort(testPort),
- checker.SrcPort(port),
),
)
@@ -482,7 +536,17 @@
c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
}
- return port
+ return udp.SourcePort()
+}
+
+func testDualWrite(c *testContext) uint16 {
+ v4Port := testV4Write(c)
+ v6Port := testV6Write(c)
+ if v4Port != v6Port {
+ c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port)
+ }
+
+ return v4Port
}
func TestDualWriteUnbound(t *testing.T) {
@@ -522,7 +586,16 @@
c.t.Fatalf("Bind failed: %v", err)
}
- testDualWrite(c)
+ testV6Write(c)
+
+ // Write to V4 mapped address.
+ payload := buffer.View(newPayload())
+ _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
+ })
+ if err != tcpip.ErrNetworkUnreachable {
+ c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNetworkUnreachable)
+ }
}
func TestDualWriteConnectedToV4Mapped(t *testing.T) {
@@ -536,7 +609,16 @@
c.t.Fatalf("Bind failed: %v", err)
}
- testDualWrite(c)
+ testV4Write(c)
+
+ // Write to v6 address.
+ payload := buffer.View(newPayload())
+ _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
+ })
+ if err != tcpip.ErrInvalidEndpointState {
+ c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState)
+ }
}
func TestV4WriteOnV6Only(t *testing.T) {
@@ -547,7 +629,9 @@
// Write to V4 mapped address.
payload := buffer.View(newPayload())
- _, err := c.ep.Write(payload, &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort})
+ _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
+ })
if err != tcpip.ErrNoRoute {
c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute)
}
@@ -566,9 +650,11 @@
// Write to v6 address.
payload := buffer.View(newPayload())
- _, err := c.ep.Write(payload, &tcpip.FullAddress{Addr: testV6Addr, Port: testPort})
- if err != tcpip.ErrNoRoute {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute)
+ _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
+ })
+ if err != tcpip.ErrInvalidEndpointState {
+ c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState)
}
}
@@ -585,7 +671,7 @@
// Write without destination.
payload := buffer.View(newPayload())
- n, err := c.ep.Write(payload, nil)
+ n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
@@ -594,7 +680,7 @@
}
// Check that we received the packet.
- b := c.getV6Packet()
+ b := c.getPacket(ipv6.ProtocolNumber, false)
udp := header.UDP(header.IPv6(b).Payload())
checker.IPv6(c.t, b,
checker.UDP(
@@ -621,7 +707,7 @@
// Write without destination.
payload := buffer.View(newPayload())
- n, err := c.ep.Write(payload, nil)
+ n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
@@ -630,7 +716,7 @@
}
// Check that we received the packet.
- b := c.getPacket()
+ b := c.getPacket(ipv4.ProtocolNumber, false)
udp := header.UDP(header.IPv4(b).Payload())
checker.IPv4(c.t, b,
checker.UDP(
@@ -644,54 +730,6 @@
}
}
-func TestMulticastTTL(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createV6Endpoint(false)
- c.ep.SetSockOpt(tcpip.MulticastTTLOption(42))
-
- payload := buffer.View(newPayload())
- // Write a multicast packet. Its TTL value should be the above multicast value.
- {
- n, err := c.ep.Write(payload, &tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: multicastPort})
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
- }
- if n != uintptr(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
- }
-
- // Check that we received the packet and that it has the multicastTTL value.
- b := c.getMCPacket()
- checker.IPv4(c.t, b,
- checker.TTL(42),
- checker.UDP(
- checker.DstPort(multicastPort),
- ),
- )
- }
-
- // Write a regular packet. Its TTL value should be the default.
- {
- n, err := c.ep.Write(payload, &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort})
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
- }
- if n != uintptr(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
- }
-
- b := c.getPacket()
- checker.IPv4(c.t, b,
- checker.TTL(header.IPv4DefaultTTL),
- checker.UDP(
- checker.DstPort(testPort),
- ),
- )
- }
-}
-
func TestReadIncrementsPacketsReceived(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -711,7 +749,7 @@
testV4Read(c)
var want uint64 = 1
- if got := c.s.Stats().UDP.PacketsReceived; got != want {
+ if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want {
c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want)
}
}
@@ -725,7 +763,166 @@
testDualWrite(c)
var want uint64 = 2
- if got := c.s.Stats().UDP.PacketsSent; got != want {
+ if got := c.s.Stats().UDP.PacketsSent.Value(); got != want {
c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want)
}
}
+
+func TestTTL(t *testing.T) {
+ payload := tcpip.SlicePayload(buffer.View(newPayload()))
+
+ for _, name := range []string{"v4", "v6", "dual"} {
+ t.Run(name, func(t *testing.T) {
+ var networkProtocolNumber tcpip.NetworkProtocolNumber
+ switch name {
+ case "v4":
+ networkProtocolNumber = ipv4.ProtocolNumber
+ case "v6", "dual":
+ networkProtocolNumber = ipv6.ProtocolNumber
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ var variants []string
+ switch name {
+ case "v4":
+ variants = []string{"v4"}
+ case "v6":
+ variants = []string{"v6"}
+ case "dual":
+ variants = []string{"v6", "mapped"}
+ }
+
+ for _, variant := range variants {
+ t.Run(variant, func(t *testing.T) {
+ for _, typ := range []string{"unicast", "multicast"} {
+ t.Run(typ, func(t *testing.T) {
+ var addr tcpip.Address
+ var port uint16
+ switch typ {
+ case "unicast":
+ port = testPort
+ switch variant {
+ case "v4":
+ addr = testAddr
+ case "mapped":
+ addr = testV4MappedAddr
+ case "v6":
+ addr = testV6Addr
+ default:
+ t.Fatal("unknown test variant")
+ }
+ case "multicast":
+ port = multicastPort
+ switch variant {
+ case "v4":
+ addr = multicastAddr
+ case "mapped":
+ addr = multicastV4MappedAddr
+ case "v6":
+ addr = multicastV6Addr
+ default:
+ t.Fatal("unknown test variant")
+ }
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ switch name {
+ case "v4":
+ case "v6":
+ if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+ case "dual":
+ if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ const multicastTTL = 42
+ if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+
+ n, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}})
+ if err != nil {
+ c.t.Fatalf("Write failed: %v", err)
+ }
+ if n != uintptr(len(payload)) {
+ c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload))
+ }
+
+ checkerFn := checker.IPv4
+ switch variant {
+ case "v4", "mapped":
+ case "v6":
+ checkerFn = checker.IPv6
+ default:
+ t.Fatal("unknown test variant")
+ }
+ var wantTTL uint8
+ var multicast bool
+ switch typ {
+ case "unicast":
+ multicast = false
+ switch variant {
+ case "v4", "mapped":
+ ep, err := ipv4.NewProtocol().NewEndpoint(0, "", nil, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ wantTTL = ep.DefaultTTL()
+ ep.Close()
+ case "v6":
+ ep, err := ipv6.NewProtocol().NewEndpoint(0, "", nil, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ wantTTL = ep.DefaultTTL()
+ ep.Close()
+ default:
+ t.Fatal("unknown test variant")
+ }
+ case "multicast":
+ wantTTL = multicastTTL
+ multicast = true
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ var networkProtocolNumber tcpip.NetworkProtocolNumber
+ switch variant {
+ case "v4", "mapped":
+ networkProtocolNumber = ipv4.ProtocolNumber
+ case "v6":
+ networkProtocolNumber = ipv6.ProtocolNumber
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ b := c.getPacket(networkProtocolNumber, multicast)
+ checkerFn(c.t, b,
+ checker.TTL(wantTTL),
+ checker.UDP(
+ checker.DstPort(port),
+ ),
+ )
+ })
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/tcpip/transport/unix/connectioned.go b/tcpip/transport/unix/connectioned.go
index 34dd051..38892dd 100644
--- a/tcpip/transport/unix/connectioned.go
+++ b/tcpip/transport/unix/connectioned.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package unix
@@ -75,6 +85,8 @@
// path != "" && acceptedChan != nil => bound and listening.
//
// Only one of these will be true at any moment.
+//
+// +stateify savable
type connectionedEndpoint struct {
baseEndpoint
@@ -394,6 +406,17 @@
return nil
}
+// SendMsg writes data and a control message to the endpoint's peer.
+// This method does not block if the data cannot be written.
+func (e *connectionedEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *tcpip.Error) {
+ // Stream sockets do not support specifying the endpoint. Seqpacket
+ // sockets ignore the passed endpoint.
+ if e.stype == SockStream && to != nil {
+ return 0, tcpip.ErrNotSupported
+ }
+ return e.baseEndpoint.SendMsg(data, c, to)
+}
+
// Readiness returns the current readiness of the connectionedEndpoint. For
// example, if waiter.EventIn is set, the connectionedEndpoint is immediately
// readable.
diff --git a/tcpip/transport/unix/connectionless.go b/tcpip/transport/unix/connectionless.go
index 1b37fed..f6b93ab 100644
--- a/tcpip/transport/unix/connectionless.go
+++ b/tcpip/transport/unix/connectionless.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package unix
@@ -15,6 +25,8 @@
//
// Specifically, this means datagram unix sockets not created with
// socketpair(2).
+//
+// +stateify savable
type connectionlessEndpoint struct {
baseEndpoint
}
@@ -65,9 +77,15 @@
// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
func (e *connectionlessEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *tcpip.Error) {
+ e.Lock()
+ r := e.receiver
+ e.Unlock()
+ if r == nil {
+ return nil, tcpip.ErrConnectionRefused
+ }
return &connectedEndpoint{
endpoint: e,
- writeQueue: e.receiver.(*queueReceiver).readQueue,
+ writeQueue: r.(*queueReceiver).readQueue,
}, nil
}
diff --git a/tcpip/transport/unix/unix.go b/tcpip/transport/unix/unix.go
index 4f08989..e9e388a 100644
--- a/tcpip/transport/unix/unix.go
+++ b/tcpip/transport/unix/unix.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package unix contains the implementation of Unix endpoints.
package unix
@@ -28,6 +38,8 @@
SockStream SockType = 1
// SockDgram corresponds to syscall.SOCK_DGRAM.
SockDgram SockType = 2
+ // SockRaw corresponds to syscall.SOCK_RAW.
+ SockRaw SockType = 3
// SockSeqpacket corresponds to syscall.SOCK_SEQPACKET.
SockSeqpacket SockType = 5
)
@@ -48,6 +60,8 @@
}
// A ControlMessages represents a collection of socket control messages.
+//
+// +stateify savable
type ControlMessages struct {
// Rights is a control message containing FDs.
Rights RightsControlMessage
@@ -112,7 +126,12 @@
// If peek is true, no data should be consumed from the Endpoint. Any and
// all data returned from a peek should be available in the next call to
// RecvMsg.
- RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, ControlMessages, *tcpip.Error)
+ //
+ // recvLen is the number of bytes copied into data.
+ //
+ // msgLen is the length of the read message consumed for datagram Endpoints.
+ // msgLen is always the same as recvLen for stream Endpoints.
+ RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, err *tcpip.Error)
// SendMsg writes data and a control message to the endpoint's peer.
// This method does not block if the data cannot be written.
@@ -205,7 +224,11 @@
// type that isn't SockStream or SockSeqpacket.
BidirectionalConnect(ep ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *tcpip.Error
- // UnidirectionalConnect establishes a write-only connection to a unix endpoint.
+ // UnidirectionalConnect establishes a write-only connection to a unix
+ // endpoint.
+ //
+ // An endpoint which calls UnidirectionalConnect and supports it itself must
+ // not hold its own lock when calling UnidirectionalConnect.
//
// This method will return tcpip.ErrConnectionRefused on a non-SockDgram
// endpoint.
@@ -218,6 +241,8 @@
}
// message represents a message passed over a Unix domain socket.
+//
+// +stateify savable
type message struct {
ilist.Entry
@@ -256,7 +281,7 @@
// See Endpoint.RecvMsg for documentation on shared arguments.
//
// notify indicates if RecvNotify should be called.
- Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (n uintptr, cm ControlMessages, source tcpip.FullAddress, notify bool, err *tcpip.Error)
+ Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (recvLen, msgLen uintptr, cm ControlMessages, source tcpip.FullAddress, notify bool, err *tcpip.Error)
// RecvNotify notifies the Receiver of a successful Recv. This must not be
// called while holding any endpoint locks.
@@ -289,12 +314,14 @@
}
// queueReceiver implements Receiver for datagram sockets.
+//
+// +stateify savable
type queueReceiver struct {
readQueue *queue.Queue
}
// Recv implements Receiver.Recv.
-func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) {
+func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) {
var m queue.Entry
var notify bool
var err *tcpip.Error
@@ -304,7 +331,7 @@
m, notify, err = q.readQueue.Dequeue()
}
if err != nil {
- return 0, ControlMessages{}, tcpip.FullAddress{}, false, err
+ return 0, 0, ControlMessages{}, tcpip.FullAddress{}, false, err
}
msg := m.(*message)
src := []byte(msg.Data)
@@ -314,7 +341,7 @@
copied += uintptr(n)
src = src[n:]
}
- return copied, msg.Control, msg.Address, notify, nil
+ return copied, uintptr(len(msg.Data)), msg.Control, msg.Address, notify, nil
}
// RecvNotify implements Receiver.RecvNotify.
@@ -343,7 +370,7 @@
return q.readQueue.QueuedSize()
}
-// RecvMaxQueueSize implements ConnectedEndpoint.RecvMaxQueueSize.
+// RecvMaxQueueSize implements Receiver.RecvMaxQueueSize.
func (q *queueReceiver) RecvMaxQueueSize() int64 {
return q.readQueue.MaxQueueSize()
}
@@ -352,6 +379,8 @@
func (*queueReceiver) Release() {}
// streamQueueReceiver implements Receiver for stream sockets.
+//
+// +stateify savable
type streamQueueReceiver struct {
queueReceiver
@@ -375,8 +404,35 @@
return copied, data, buf
}
+// Readable implements Receiver.Readable.
+func (q *streamQueueReceiver) Readable() bool {
+ q.mu.Lock()
+ bl := len(q.buffer)
+ r := q.readQueue.IsReadable()
+ q.mu.Unlock()
+ // We're readable if we have data in our buffer or if the queue receiver is
+ // readable.
+ return bl > 0 || r
+}
+
+// RecvQueuedSize implements Receiver.RecvQueuedSize.
+func (q *streamQueueReceiver) RecvQueuedSize() int64 {
+ q.mu.Lock()
+ bl := len(q.buffer)
+ qs := q.readQueue.QueuedSize()
+ q.mu.Unlock()
+ return int64(bl) + qs
+}
+
+// RecvMaxQueueSize implements Receiver.RecvMaxQueueSize.
+func (q *streamQueueReceiver) RecvMaxQueueSize() int64 {
+ // The RecvMaxQueueSize() is the readQueue's MaxQueueSize() plus the largest
+ // message we can buffer which is also the largest message we can receive.
+ return 2 * q.readQueue.MaxQueueSize()
+}
+
// Recv implements Receiver.Recv.
-func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) {
+func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) {
q.mu.Lock()
defer q.mu.Unlock()
@@ -389,7 +445,7 @@
// the next time Recv() is called.
m, n, err := q.readQueue.Dequeue()
if err != nil {
- return 0, ControlMessages{}, tcpip.FullAddress{}, false, err
+ return 0, 0, ControlMessages{}, tcpip.FullAddress{}, false, err
}
notify = n
msg := m.(*message)
@@ -406,7 +462,7 @@
// Don't consume data since we are peeking.
copied, data, _ = vecCopy(data, q.buffer)
- return copied, c, q.addr, notify, nil
+ return copied, copied, c, q.addr, notify, nil
}
// Consume data and control message since we are not peeking.
@@ -484,7 +540,7 @@
q.control.Rights = nil
}
}
- return copied, c, q.addr, notify, nil
+ return copied, copied, c, q.addr, notify, nil
}
// A ConnectedEndpoint is an Endpoint that can be used to send Messages.
@@ -535,6 +591,7 @@
Release()
}
+// +stateify savable
type connectedEndpoint struct {
// endpoint represents the subset of the Endpoint functionality needed by
// the connectedEndpoint. It is implemented by both connectionedEndpoint
@@ -627,6 +684,8 @@
// unix domain socket Endpoint implementations.
//
// Not to be used on its own.
+//
+// +stateify savable
type baseEndpoint struct {
*waiter.Queue
@@ -651,8 +710,8 @@
// EventRegister implements waiter.Waitable.EventRegister.
func (e *baseEndpoint) EventRegister(we *waiter.Entry, mask waiter.EventMask) {
- e.Lock()
e.Queue.EventRegister(we, mask)
+ e.Lock()
if e.connected != nil {
e.connected.EventUpdate()
}
@@ -661,8 +720,8 @@
// EventUnregister implements waiter.Waitable.EventUnregister.
func (e *baseEndpoint) EventUnregister(we *waiter.Entry) {
- e.Lock()
e.Queue.EventUnregister(we)
+ e.Lock()
if e.connected != nil {
e.connected.EventUpdate()
}
@@ -695,18 +754,18 @@
}
// RecvMsg reads data and a control message from the endpoint.
-func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, ControlMessages, *tcpip.Error) {
+func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, *tcpip.Error) {
e.Lock()
if e.receiver == nil {
e.Unlock()
- return 0, ControlMessages{}, tcpip.ErrNotConnected
+ return 0, 0, ControlMessages{}, tcpip.ErrNotConnected
}
- n, cms, a, notify, err := e.receiver.Recv(data, creds, numRights, peek)
+ recvLen, msgLen, cms, a, notify, err := e.receiver.Recv(data, creds, numRights, peek)
e.Unlock()
if err != nil {
- return 0, ControlMessages{}, err
+ return 0, 0, ControlMessages{}, err
}
if notify {
@@ -716,7 +775,7 @@
if addr != nil {
*addr = a
}
- return n, cms, nil
+ return recvLen, msgLen, cms, nil
}
// SendMsg writes data and a control message to the endpoint's peer.
diff --git a/tmutex/tmutex.go b/tmutex/tmutex.go
index 6177965..bd5c681 100644
--- a/tmutex/tmutex.go
+++ b/tmutex/tmutex.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package tmutex provides the implementation of a mutex that implements an
// efficient TryLock function in addition to Lock and Unlock.
diff --git a/tmutex/tmutex_test.go b/tmutex/tmutex_test.go
index 16d60a2..a9dc997 100644
--- a/tmutex/tmutex_test.go
+++ b/tmutex/tmutex_test.go
@@ -1,10 +1,22 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package tmutex
import (
+ "fmt"
+ "runtime"
"sync"
"sync/atomic"
"testing"
@@ -133,6 +145,7 @@
const gr = 1000
const iters = 100000
total := int64(gr * iters)
+ var tryTotal int64
v := int64(0)
var wg sync.WaitGroup
for i := 0; i < gr; i++ {
@@ -154,14 +167,91 @@
local++
}
}
- atomic.AddInt64(&total, local)
+ atomic.AddInt64(&tryTotal, local)
wg.Done()
}()
}
wg.Wait()
+ t.Logf("tryTotal = %d", tryTotal)
+ total += tryTotal
+
if v != total {
t.Fatalf("Bad count: got %v, want %v", v, total)
}
}
+
+// BenchmarkTmutex is equivalent to TestMutualExclusion, with the following
+// differences:
+//
+// - The number of goroutines is variable, with the maximum value depending on
+// GOMAXPROCS.
+//
+// - The number of iterations per benchmark is controlled by the benchmarking
+// framework.
+//
+// - Care is taken to ensure that all goroutines participating in the benchmark
+// have been created before the benchmark begins.
+func BenchmarkTmutex(b *testing.B) {
+ for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var m Mutex
+ m.Init()
+
+ var ready sync.WaitGroup
+ begin := make(chan struct{})
+ var end sync.WaitGroup
+ for i := 0; i < n; i++ {
+ ready.Add(1)
+ end.Add(1)
+ go func() {
+ ready.Done()
+ <-begin
+ for j := 0; j < b.N; j++ {
+ m.Lock()
+ m.Unlock()
+ }
+ end.Done()
+ }()
+ }
+
+ ready.Wait()
+ b.ResetTimer()
+ close(begin)
+ end.Wait()
+ })
+ }
+}
+
+// BenchmarkSyncMutex is equivalent to BenchmarkTmutex, but uses sync.Mutex as
+// a comparison point.
+func BenchmarkSyncMutex(b *testing.B) {
+ for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var m sync.Mutex
+
+ var ready sync.WaitGroup
+ begin := make(chan struct{})
+ var end sync.WaitGroup
+ for i := 0; i < n; i++ {
+ ready.Add(1)
+ end.Add(1)
+ go func() {
+ ready.Done()
+ <-begin
+ for j := 0; j < b.N; j++ {
+ m.Lock()
+ m.Unlock()
+ }
+ end.Done()
+ }()
+ }
+
+ ready.Wait()
+ b.ResetTimer()
+ close(begin)
+ end.Wait()
+ })
+ }
+}
diff --git a/waiter/waiter.go b/waiter/waiter.go
index 268e1dd..207cf30 100644
--- a/waiter/waiter.go
+++ b/waiter/waiter.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
// Package waiter provides the implementation of a wait queue, where waiters can
// be enqueued to be notified when an event of interest happens.
@@ -147,6 +157,8 @@
// notifiers can notify them when events happen.
//
// The zero value for waiter.Queue is an empty queue ready for use.
+//
+// +stateify savable
type Queue struct {
list ilist.List
mu sync.RWMutex
@@ -174,7 +186,7 @@
q.mu.RLock()
for it := q.list.Front(); it != nil; it = it.Next() {
e := it.(*Entry)
- if (mask & e.mask) != 0 {
+ if mask&e.mask != 0 {
e.Callback.Callback(e)
}
}
diff --git a/waiter/waiter_test.go b/waiter/waiter_test.go
index 1a20335..c45f228 100644
--- a/waiter/waiter_test.go
+++ b/waiter/waiter_test.go
@@ -1,6 +1,16 @@
-// Copyright 2016 The Netstack Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
package waiter