[release-branch.go1.16] net: filter bad names from Lookup functions instead of hard failing

Instead of hard failing on a single bad record, filter the bad records
and return anything valid. This only applies to the methods which can
return multiple records, LookupMX, LookupNS, LookupSRV, and LookupAddr.

When bad results are filtered out, also return an error, indicating
that this filtering has happened.

Updates #46241
Updates #46979
Fixes #46999

Change-Id: I6493e0002beaf89f5a9795333a93605abd30d171
Reviewed-on: https://go-review.googlesource.com/c/go/+/332549
Trust: Roland Shoemaker <roland@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
(cherry picked from commit 296ddf2a936a30866303a64d49bc0e3e034730a8)
Reviewed-on: https://go-review.googlesource.com/c/go/+/333330
Run-TryBot: Roland Shoemaker <roland@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go
index 8cd383c..e7f7621 100644
--- a/src/net/dnsclient_unix_test.go
+++ b/src/net/dnsclient_unix_test.go
@@ -1845,6 +1845,17 @@
 							Target: dnsmessage.MustNewName("<html>.golang.org."),
 						},
 					},
+					dnsmessage.Resource{
+						Header: dnsmessage.ResourceHeader{
+							Name:   n,
+							Type:   dnsmessage.TypeSRV,
+							Class:  dnsmessage.ClassINET,
+							Length: 4,
+						},
+						Body: &dnsmessage.SRVResource{
+							Target: dnsmessage.MustNewName("good.golang.org."),
+						},
+					},
 				)
 			case dnsmessage.TypeMX:
 				r.Answers = append(r.Answers,
@@ -1859,6 +1870,17 @@
 							MX: dnsmessage.MustNewName("<html>.golang.org."),
 						},
 					},
+					dnsmessage.Resource{
+						Header: dnsmessage.ResourceHeader{
+							Name:   dnsmessage.MustNewName("good.golang.org."),
+							Type:   dnsmessage.TypeMX,
+							Class:  dnsmessage.ClassINET,
+							Length: 4,
+						},
+						Body: &dnsmessage.MXResource{
+							MX: dnsmessage.MustNewName("good.golang.org."),
+						},
+					},
 				)
 			case dnsmessage.TypeNS:
 				r.Answers = append(r.Answers,
@@ -1873,6 +1895,17 @@
 							NS: dnsmessage.MustNewName("<html>.golang.org."),
 						},
 					},
+					dnsmessage.Resource{
+						Header: dnsmessage.ResourceHeader{
+							Name:   dnsmessage.MustNewName("good.golang.org."),
+							Type:   dnsmessage.TypeNS,
+							Class:  dnsmessage.ClassINET,
+							Length: 4,
+						},
+						Body: &dnsmessage.NSResource{
+							NS: dnsmessage.MustNewName("good.golang.org."),
+						},
+					},
 				)
 			case dnsmessage.TypePTR:
 				r.Answers = append(r.Answers,
@@ -1887,6 +1920,17 @@
 							PTR: dnsmessage.MustNewName("<html>.golang.org."),
 						},
 					},
+					dnsmessage.Resource{
+						Header: dnsmessage.ResourceHeader{
+							Name:   dnsmessage.MustNewName("good.golang.org."),
+							Type:   dnsmessage.TypePTR,
+							Class:  dnsmessage.ClassINET,
+							Length: 4,
+						},
+						Body: &dnsmessage.PTRResource{
+							PTR: dnsmessage.MustNewName("good.golang.org."),
+						},
+					},
 				)
 			}
 			return r, nil
@@ -1902,58 +1946,137 @@
 	defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
 	testHookHostsPath = "testdata/hosts"
 
-	_, err := r.LookupCNAME(context.Background(), "golang.org")
-	if expected := "lookup golang.org: CNAME target is invalid"; err == nil || err.Error() != expected {
-		t.Errorf("Resolver.LookupCNAME returned unexpected error, got %q, want %q", err, expected)
-	}
-	_, err = LookupCNAME("golang.org")
-	if expected := "lookup golang.org: CNAME target is invalid"; err == nil || err.Error() != expected {
-		t.Errorf("LookupCNAME returned unexpected error, got %q, want %q", err, expected)
+	tests := []struct {
+		name string
+		f    func(*testing.T)
+	}{
+		{
+			name: "CNAME",
+			f: func(t *testing.T) {
+				expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
+				_, err := r.LookupCNAME(context.Background(), "golang.org")
+				if err.Error() != expectedErr.Error() {
+					t.Fatalf("unexpected error: %s", err)
+				}
+				_, err = LookupCNAME("golang.org")
+				if err.Error() != expectedErr.Error() {
+					t.Fatalf("unexpected error: %s", err)
+				}
+			},
+		},
+		{
+			name: "SRV (bad record)",
+			f: func(t *testing.T) {
+				expected := []*SRV{
+					{
+						Target: "good.golang.org.",
+					},
+				}
+				expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
+				_, records, err := r.LookupSRV(context.Background(), "target", "tcp", "golang.org")
+				if err.Error() != expectedErr.Error() {
+					t.Fatalf("unexpected error: %s", err)
+				}
+				if !reflect.DeepEqual(records, expected) {
+					t.Error("Unexpected record set")
+				}
+				_, records, err = LookupSRV("target", "tcp", "golang.org")
+				if err.Error() != expectedErr.Error() {
+					t.Errorf("unexpected error: %s", err)
+				}
+				if !reflect.DeepEqual(records, expected) {
+					t.Error("Unexpected record set")
+				}
+			},
+		},
+		{
+			name: "SRV (bad header)",
+			f: func(t *testing.T) {
+				_, _, err := r.LookupSRV(context.Background(), "hdr", "tcp", "golang.org.")
+				if expected := "lookup golang.org.: SRV header name is invalid"; err == nil || err.Error() != expected {
+					t.Errorf("Resolver.LookupSRV returned unexpected error, got %q, want %q", err, expected)
+				}
+				_, _, err = LookupSRV("hdr", "tcp", "golang.org.")
+				if expected := "lookup golang.org.: SRV header name is invalid"; err == nil || err.Error() != expected {
+					t.Errorf("LookupSRV returned unexpected error, got %q, want %q", err, expected)
+				}
+			},
+		},
+		{
+			name: "MX",
+			f: func(t *testing.T) {
+				expected := []*MX{
+					{
+						Host: "good.golang.org.",
+					},
+				}
+				expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
+				records, err := r.LookupMX(context.Background(), "golang.org")
+				if err.Error() != expectedErr.Error() {
+					t.Fatalf("unexpected error: %s", err)
+				}
+				if !reflect.DeepEqual(records, expected) {
+					t.Error("Unexpected record set")
+				}
+				records, err = LookupMX("golang.org")
+				if err.Error() != expectedErr.Error() {
+					t.Fatalf("unexpected error: %s", err)
+				}
+				if !reflect.DeepEqual(records, expected) {
+					t.Error("Unexpected record set")
+				}
+			},
+		},
+		{
+			name: "NS",
+			f: func(t *testing.T) {
+				expected := []*NS{
+					{
+						Host: "good.golang.org.",
+					},
+				}
+				expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
+				records, err := r.LookupNS(context.Background(), "golang.org")
+				if err.Error() != expectedErr.Error() {
+					t.Fatalf("unexpected error: %s", err)
+				}
+				if !reflect.DeepEqual(records, expected) {
+					t.Error("Unexpected record set")
+				}
+				records, err = LookupNS("golang.org")
+				if err.Error() != expectedErr.Error() {
+					t.Fatalf("unexpected error: %s", err)
+				}
+				if !reflect.DeepEqual(records, expected) {
+					t.Error("Unexpected record set")
+				}
+			},
+		},
+		{
+			name: "Addr",
+			f: func(t *testing.T) {
+				expected := []string{"good.golang.org."}
+				expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "192.0.2.42"}
+				records, err := r.LookupAddr(context.Background(), "192.0.2.42")
+				if err.Error() != expectedErr.Error() {
+					t.Fatalf("unexpected error: %s", err)
+				}
+				if !reflect.DeepEqual(records, expected) {
+					t.Error("Unexpected record set")
+				}
+				records, err = LookupAddr("192.0.2.42")
+				if err.Error() != expectedErr.Error() {
+					t.Fatalf("unexpected error: %s", err)
+				}
+				if !reflect.DeepEqual(records, expected) {
+					t.Error("Unexpected record set")
+				}
+			},
+		},
 	}
 
-	_, _, err = r.LookupSRV(context.Background(), "target", "tcp", "golang.org")
-	if expected := "lookup golang.org: SRV target is invalid"; err == nil || err.Error() != expected {
-		t.Errorf("Resolver.LookupSRV returned unexpected error, got %q, want %q", err, expected)
-	}
-	_, _, err = LookupSRV("target", "tcp", "golang.org")
-	if expected := "lookup golang.org: SRV target is invalid"; err == nil || err.Error() != expected {
-		t.Errorf("LookupSRV returned unexpected error, got %q, want %q", err, expected)
-	}
-
-	_, _, err = r.LookupSRV(context.Background(), "hdr", "tcp", "golang.org")
-	if expected := "lookup golang.org: SRV header name is invalid"; err == nil || err.Error() != expected {
-		t.Errorf("Resolver.LookupSRV returned unexpected error, got %q, want %q", err, expected)
-	}
-	_, _, err = LookupSRV("hdr", "tcp", "golang.org")
-	if expected := "lookup golang.org: SRV header name is invalid"; err == nil || err.Error() != expected {
-		t.Errorf("LookupSRV returned unexpected error, got %q, want %q", err, expected)
-	}
-
-	_, err = r.LookupMX(context.Background(), "golang.org")
-	if expected := "lookup golang.org: MX target is invalid"; err == nil || err.Error() != expected {
-		t.Errorf("Resolver.LookupMX returned unexpected error, got %q, want %q", err, expected)
-	}
-	_, err = LookupMX("golang.org")
-	if expected := "lookup golang.org: MX target is invalid"; err == nil || err.Error() != expected {
-		t.Errorf("LookupMX returned unexpected error, got %q, want %q", err, expected)
-	}
-
-	_, err = r.LookupNS(context.Background(), "golang.org")
-	if expected := "lookup golang.org: NS target is invalid"; err == nil || err.Error() != expected {
-		t.Errorf("Resolver.LookupNS returned unexpected error, got %q, want %q", err, expected)
-	}
-	_, err = LookupNS("golang.org")
-	if expected := "lookup golang.org: NS target is invalid"; err == nil || err.Error() != expected {
-		t.Errorf("LookupNS returned unexpected error, got %q, want %q", err, expected)
-	}
-
-	_, err = r.LookupAddr(context.Background(), "192.0.2.42")
-	if expected := "lookup 192.0.2.42: PTR target is invalid"; err == nil || err.Error() != expected {
-		t.Errorf("Resolver.LookupAddr returned unexpected error, got %q, want %q", err, expected)
-	}
-	_, err = LookupAddr("192.0.2.42")
-	if expected := "lookup 192.0.2.42: PTR target is invalid"; err == nil || err.Error() != expected {
-		t.Errorf("LookupAddr returned unexpected error, got %q, want %q", err, expected)
+	for _, tc := range tests {
+		t.Run(tc.name, tc.f)
 	}
 }
 
diff --git a/src/net/lookup.go b/src/net/lookup.go
index 01c81db..4135122 100644
--- a/src/net/lookup.go
+++ b/src/net/lookup.go
@@ -415,7 +415,7 @@
 		return "", err
 	}
 	if !isDomainName(cname) {
-		return "", &DNSError{Err: "CNAME target is invalid", Name: host}
+		return "", &DNSError{Err: errMalformedDNSRecordsDetail, Name: host}
 	}
 	return cname, nil
 }
@@ -431,7 +431,9 @@
 // and proto are empty strings, LookupSRV looks up name directly.
 //
 // The returned service names are validated to be properly
-// formatted presentation-format domain names.
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the the remaining results, if any.
 func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) {
 	return DefaultResolver.LookupSRV(context.Background(), service, proto, name)
 }
@@ -447,7 +449,9 @@
 // and proto are empty strings, LookupSRV looks up name directly.
 //
 // The returned service names are validated to be properly
-// formatted presentation-format domain names.
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the the remaining results, if any.
 func (r *Resolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
 	cname, addrs, err := r.lookupSRV(ctx, service, proto, name)
 	if err != nil {
@@ -456,21 +460,31 @@
 	if cname != "" && !isDomainName(cname) {
 		return "", nil, &DNSError{Err: "SRV header name is invalid", Name: name}
 	}
+	filteredAddrs := make([]*SRV, 0, len(addrs))
 	for _, addr := range addrs {
 		if addr == nil {
 			continue
 		}
 		if !isDomainName(addr.Target) {
-			return "", nil, &DNSError{Err: "SRV target is invalid", Name: name}
+			continue
 		}
+		filteredAddrs = append(filteredAddrs, addr)
 	}
-	return cname, addrs, nil
+	if len(addrs) != len(filteredAddrs) {
+		return cname, filteredAddrs, &DNSError{Err: errMalformedDNSRecordsDetail, Name: name}
+	}
+	return cname, filteredAddrs, nil
 }
 
 // LookupMX returns the DNS MX records for the given domain name sorted by preference.
 //
 // The returned mail server names are validated to be properly
-// formatted presentation-format domain names.
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the the remaining results, if any.
+//
+// LookupMX uses context.Background internally; to specify the context, use
+// Resolver.LookupMX.
 func LookupMX(name string) ([]*MX, error) {
 	return DefaultResolver.LookupMX(context.Background(), name)
 }
@@ -478,12 +492,15 @@
 // LookupMX returns the DNS MX records for the given domain name sorted by preference.
 //
 // The returned mail server names are validated to be properly
-// formatted presentation-format domain names.
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the the remaining results, if any.
 func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) {
 	records, err := r.lookupMX(ctx, name)
 	if err != nil {
 		return nil, err
 	}
+	filteredMX := make([]*MX, 0, len(records))
 	for _, mx := range records {
 		if mx == nil {
 			continue
@@ -491,16 +508,25 @@
 		// Bypass the hostname validity check for targets which contain only a dot,
 		// as this is used to represent a 'Null' MX record.
 		if mx.Host != "." && !isDomainName(mx.Host) {
-			return nil, &DNSError{Err: "MX target is invalid", Name: name}
+			continue
 		}
+		filteredMX = append(filteredMX, mx)
 	}
-	return records, nil
+	if len(records) != len(filteredMX) {
+		return filteredMX, &DNSError{Err: errMalformedDNSRecordsDetail, Name: name}
+	}
+	return filteredMX, nil
 }
 
 // LookupNS returns the DNS NS records for the given domain name.
 //
 // The returned name server names are validated to be properly
-// formatted presentation-format domain names.
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the the remaining results, if any.
+//
+// LookupNS uses context.Background internally; to specify the context, use
+// Resolver.LookupNS.
 func LookupNS(name string) ([]*NS, error) {
 	return DefaultResolver.LookupNS(context.Background(), name)
 }
@@ -508,21 +534,28 @@
 // LookupNS returns the DNS NS records for the given domain name.
 //
 // The returned name server names are validated to be properly
-// formatted presentation-format domain names.
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the the remaining results, if any.
 func (r *Resolver) LookupNS(ctx context.Context, name string) ([]*NS, error) {
 	records, err := r.lookupNS(ctx, name)
 	if err != nil {
 		return nil, err
 	}
+	filteredNS := make([]*NS, 0, len(records))
 	for _, ns := range records {
 		if ns == nil {
 			continue
 		}
 		if !isDomainName(ns.Host) {
-			return nil, &DNSError{Err: "NS target is invalid", Name: name}
+			continue
 		}
+		filteredNS = append(filteredNS, ns)
 	}
-	return records, nil
+	if len(records) != len(filteredNS) {
+		return filteredNS, &DNSError{Err: errMalformedDNSRecordsDetail, Name: name}
+	}
+	return filteredNS, nil
 }
 
 // LookupTXT returns the DNS TXT records for the given domain name.
@@ -539,7 +572,8 @@
 // of names mapping to that address.
 //
 // The returned names are validated to be properly formatted presentation-format
-// domain names.
+// domain names. If the response contains invalid names, those records are filtered
+// out and an error will be returned alongside the the remaining results, if any.
 //
 // When using the host C library resolver, at most one result will be
 // returned. To bypass the host resolver, use a custom Resolver.
@@ -551,16 +585,26 @@
 // of names mapping to that address.
 //
 // The returned names are validated to be properly formatted presentation-format
-// domain names.
+// domain names. If the response contains invalid names, those records are filtered
+// out and an error will be returned alongside the the remaining results, if any.
 func (r *Resolver) LookupAddr(ctx context.Context, addr string) ([]string, error) {
 	names, err := r.lookupAddr(ctx, addr)
 	if err != nil {
 		return nil, err
 	}
+	filteredNames := make([]string, 0, len(names))
 	for _, name := range names {
-		if !isDomainName(name) {
-			return nil, &DNSError{Err: "PTR target is invalid", Name: addr}
+		if isDomainName(name) {
+			filteredNames = append(filteredNames, name)
 		}
 	}
-	return names, nil
+	if len(names) != len(filteredNames) {
+		return filteredNames, &DNSError{Err: errMalformedDNSRecordsDetail, Name: addr}
+	}
+	return filteredNames, nil
 }
+
+// errMalformedDNSRecordsDetail is the DNSError detail which is returned when a Resolver.Lookup...
+// method recieves DNS records which contain invalid DNS names. This may be returned alongside
+// results which have had the malformed records filtered out.
+var errMalformedDNSRecordsDetail = "DNS response contained records which contain invalid names"