ipv4: fix error values on header manipulation
This change makes header manipulation methods and functions return a nil
header error if the receiver or wire header is nil, a header too short
error if the header length field value of receiver or wire header is
short, and a extension header error if the wire extnsion header is
short.
Also replaces syscall.EWINDOWS or syscall.EPLAN9 with more descriptive,
platform independent error values.
Change-Id: I923fb60b1d68857cffc9df20f3f6cb2babbcdb1f
Reviewed-on: https://go-review.googlesource.com/c/net/+/129136
Reviewed-by: Matt Layher <mdlayher@gmail.com>
diff --git a/ipv4/control_windows.go b/ipv4/control_windows.go
index ce55c66..82c6306 100644
--- a/ipv4/control_windows.go
+++ b/ipv4/control_windows.go
@@ -4,13 +4,9 @@
package ipv4
-import (
- "syscall"
-
- "golang.org/x/net/internal/socket"
-)
+import "golang.org/x/net/internal/socket"
func setControlMessage(c *socket.Conn, opt *rawOpt, cf ControlFlags, on bool) error {
// TODO(mikio): implement this
- return syscall.EWINDOWS
+ return errNotImplemented
}
diff --git a/ipv4/header.go b/ipv4/header.go
index a8c8f7a..701bd4b 100644
--- a/ipv4/header.go
+++ b/ipv4/header.go
@@ -57,7 +57,7 @@
// This may differ from the wire format, depending on the system.
func (h *Header) Marshal() ([]byte, error) {
if h == nil {
- return nil, errInvalidConn
+ return nil, errNilHeader
}
if h.Len < HeaderLen {
return nil, errHeaderTooShort
@@ -107,12 +107,15 @@
// local system.
// This may differ from the wire format, depending on the system.
func (h *Header) Parse(b []byte) error {
- if h == nil || len(b) < HeaderLen {
+ if h == nil || b == nil {
+ return errNilHeader
+ }
+ if len(b) < HeaderLen {
return errHeaderTooShort
}
hdrlen := int(b[0]&0x0f) << 2
- if hdrlen > len(b) {
- return errBufferTooShort
+ if len(b) < hdrlen {
+ return errExtHeaderTooShort
}
h.Version = int(b[0] >> 4)
h.Len = hdrlen
diff --git a/ipv4/header_test.go b/ipv4/header_test.go
index a246aee..2211605 100644
--- a/ipv4/header_test.go
+++ b/ipv4/header_test.go
@@ -157,10 +157,21 @@
}
func TestMarshalHeader(t *testing.T) {
+ for i, tt := range []struct {
+ h *Header
+ err error
+ }{
+ {nil, errNilHeader},
+ {&Header{Len: HeaderLen - 1}, errHeaderTooShort},
+ } {
+ if _, err := tt.h.Marshal(); err != tt.err {
+ t.Errorf("#%d: got %v; want %v", i, err, tt.err)
+ }
+ }
+
if socket.NativeEndian != binary.LittleEndian {
t.Skip("no test for non-little endian machine yet")
}
-
for _, tt := range headerLittleEndianTests {
b, err := tt.Header.Marshal()
if err != nil {
@@ -189,10 +200,30 @@
}
func TestParseHeader(t *testing.T) {
+ for i, tt := range []struct {
+ h *Header
+ wh []byte
+ err error
+ }{
+ {nil, nil, errNilHeader},
+ {&Header{}, nil, errNilHeader},
+ {&Header{}, make([]byte, HeaderLen-1), errHeaderTooShort},
+ {&Header{}, []byte{
+ 0x46, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ }, errExtHeaderTooShort},
+ } {
+ if err := tt.h.Parse(tt.wh); err != tt.err {
+ t.Fatalf("#%d: got %v; want %v", i, err, tt.err)
+ }
+ }
+
if socket.NativeEndian != binary.LittleEndian {
t.Skip("no test for big endian machine yet")
}
-
for _, tt := range headerLittleEndianTests {
var wh []byte
switch runtime.GOOS {
diff --git a/ipv4/helper.go b/ipv4/helper.go
index 8d8ff98..ffedd4e 100644
--- a/ipv4/helper.go
+++ b/ipv4/helper.go
@@ -7,18 +7,21 @@
import (
"errors"
"net"
+ "runtime"
)
var (
errInvalidConn = errors.New("invalid connection")
errMissingAddress = errors.New("missing address")
errMissingHeader = errors.New("missing header")
+ errNilHeader = errors.New("nil header")
errHeaderTooShort = errors.New("header too short")
- errBufferTooShort = errors.New("buffer too short")
+ errExtHeaderTooShort = errors.New("extension header too short")
errInvalidConnType = errors.New("invalid conn type")
errOpNoSupport = errors.New("operation not supported")
errNoSuchInterface = errors.New("no such interface")
errNoSuchMulticastInterface = errors.New("no such multicast interface")
+ errNotImplemented = errors.New("not implemented on " + runtime.GOOS + "/" + runtime.GOARCH)
// See http://www.freebsd.org/doc/en/books/porters-handbook/freebsd-versions.html.
freebsdVersion uint32