curve25519: implement new X25519 API and deprecate ScalarMult

    const ScalarSize = 32
    const PointSize = 32
    var Basepoint []byte
    func X25519(scalar, point []byte) ([]byte, error)

Fixes golang/go#32670

Change-Id: I6b08932e4123949355610b16b6053559b399516c
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/205157
Reviewed-by: Katie Hockman <katie@golang.org>
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/curve25519/curve25519.go b/curve25519/curve25519.go
index 723d9d0..4b9a655 100644
--- a/curve25519/curve25519.go
+++ b/curve25519/curve25519.go
@@ -2,22 +2,94 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// Package curve25519 provides an implementation of scalar multiplication on
-// the elliptic curve known as curve25519. See https://cr.yp.to/ecdh.html.
+// Package curve25519 provides an implementation of the X25519 function, which
+// performs scalar multiplication on the elliptic curve known as Curve25519.
+// See RFC 7748.
 package curve25519 // import "golang.org/x/crypto/curve25519"
 
-// ScalarMult sets dst to the product in*base where dst and base are the x
-// coordinates of group points and all values are in little-endian form.
-func ScalarMult(dst, in, base *[32]byte) {
-	scalarMult(dst, in, base)
+import (
+	"crypto/subtle"
+	"fmt"
+)
+
+// ScalarMult sets dst to the product scalar * point.
+//
+// Deprecated: when provided a low-order point, ScalarMult will set dst to all
+// zeroes, irrespective of the scalar. Instead, use the X25519 function, which
+// will return an error.
+func ScalarMult(dst, scalar, point *[32]byte) {
+	scalarMult(dst, scalar, point)
 }
 
-// ScalarBaseMult sets dst to the product in*base where dst and base are the x
-// coordinates of group points, base is the standard generator and all values
-// are in little-endian form.
-func ScalarBaseMult(dst, in *[32]byte) {
-	ScalarMult(dst, in, &basePoint)
+// ScalarBaseMult sets dst to the product scalar * base where base is the
+// standard generator.
+//
+// It is recommended to use the X25519 function with Basepoint instead, as
+// copying into fixed size arrays can lead to unexpected bugs.
+func ScalarBaseMult(dst, scalar *[32]byte) {
+	ScalarMult(dst, scalar, &basePoint)
 }
 
-// basePoint is the x coordinate of the generator of the curve.
+const (
+	// ScalarSize is the size of the scalar input to X25519.
+	ScalarSize = 32
+	// PointSize is the size of the point input to X25519.
+	PointSize = 32
+)
+
+// Basepoint is the canonical Curve25519 generator.
+var Basepoint []byte
+
 var basePoint = [32]byte{9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
+
+func init() { Basepoint = basePoint[:] }
+
+func checkBasepoint() {
+	if subtle.ConstantTimeCompare(Basepoint, []byte{
+		0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+	}) != 1 {
+		panic("curve25519: global Basepoint value was modified")
+	}
+}
+
+// X25519 returns the result of the scalar multiplication (scalar * point),
+// according to RFC 7748, Section 5. scalar, point and the return value are
+// slices of 32 bytes.
+//
+// scalar can be generated at random, for example with crypto/rand. point should
+// be either Basepoint or the output of another X25519 call.
+//
+// If point is Basepoint (but not if it's a different slice with the same
+// contents) a precomputed implementation might be used for performance.
+func X25519(scalar, point []byte) ([]byte, error) {
+	// Outline the body of function, to let the allocation be inlined in the
+	// caller, and possibly avoid escaping to the heap.
+	var dst [32]byte
+	return x25519(&dst, scalar, point)
+}
+
+func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
+	var in [32]byte
+	if l := len(scalar); l != 32 {
+		return nil, fmt.Errorf("bad scalar length: %d, expected %d", l, 32)
+	}
+	if l := len(point); l != 32 {
+		return nil, fmt.Errorf("bad point length: %d, expected %d", l, 32)
+	}
+	copy(in[:], scalar)
+	if &point[0] == &Basepoint[0] {
+		checkBasepoint()
+		ScalarBaseMult(dst, &in)
+	} else {
+		var base, zero [32]byte
+		copy(base[:], point)
+		ScalarMult(dst, &in, &base)
+		if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 {
+			return nil, fmt.Errorf("bad input point: low order point")
+		}
+	}
+	return dst[:], nil
+}
diff --git a/curve25519/curve25519_test.go b/curve25519/curve25519_test.go
index 56ef73b..aca7695 100644
--- a/curve25519/curve25519_test.go
+++ b/curve25519/curve25519_test.go
@@ -13,27 +13,58 @@
 
 const expectedHex = "89161fde887b2b53de549af483940106ecc114d6982daa98256de23bdf77661a"
 
-func TestBaseScalarMult(t *testing.T) {
-	var a, b [32]byte
-	in := &a
-	out := &b
-	a[0] = 1
+func TestX25519Basepoint(t *testing.T) {
+	x := make([]byte, 32)
+	x[0] = 1
 
 	for i := 0; i < 200; i++ {
-		ScalarBaseMult(out, in)
-		in, out = out, in
+		var err error
+		x, err = X25519(x, Basepoint)
+		if err != nil {
+			t.Fatal(err)
+		}
 	}
 
-	result := fmt.Sprintf("%x", in[:])
+	result := fmt.Sprintf("%x", x)
 	if result != expectedHex {
 		t.Errorf("incorrect result: got %s, want %s", result, expectedHex)
 	}
 }
 
+func TestLowOrderPoints(t *testing.T) {
+	scalar := make([]byte, ScalarSize)
+	if _, err := rand.Read(scalar); err != nil {
+		t.Fatal(err)
+	}
+	for i, p := range lowOrderPoints {
+		out, err := X25519(scalar, p)
+		if err == nil {
+			t.Errorf("%d: expected error, got nil", i)
+		}
+		if out != nil {
+			t.Errorf("%d: expected nil output, got %x", i, out)
+		}
+	}
+}
+
 func TestTestVectors(t *testing.T) {
+	t.Run("Generic", func(t *testing.T) { testTestVectors(t, scalarMultGeneric) })
+	t.Run("Native", func(t *testing.T) { testTestVectors(t, ScalarMult) })
+	t.Run("X25519", func(t *testing.T) {
+		testTestVectors(t, func(dst, scalar, point *[32]byte) {
+			out, err := X25519(scalar[:], point[:])
+			if err != nil {
+				t.Fatal(err)
+			}
+			copy(dst[:], out)
+		})
+	})
+}
+
+func testTestVectors(t *testing.T, scalarMult func(dst, scalar, point *[32]byte)) {
 	for _, tv := range testVectors {
 		var got [32]byte
-		ScalarMult(&got, &tv.In, &tv.Base)
+		scalarMult(&got, &tv.In, &tv.Base)
 		if !bytes.Equal(got[:], tv.Expect[:]) {
 			t.Logf("    in = %x", tv.In)
 			t.Logf("  base = %x", tv.Base)
diff --git a/curve25519/vectors_test.go b/curve25519/vectors_test.go
index 79b7a09..946e9a8 100644
--- a/curve25519/vectors_test.go
+++ b/curve25519/vectors_test.go
@@ -4,6 +4,18 @@
 
 package curve25519
 
+// lowOrderPoints from libsodium.
+// https://github.com/jedisct1/libsodium/blob/65621a1059a37d/src/libsodium/crypto_scalarmult/curve25519/ref10/x25519_ref10.c#L11-L70
+var lowOrderPoints = [][]byte{
+	{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+	{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+	{0xe0, 0xeb, 0x7a, 0x7c, 0x3b, 0x41, 0xb8, 0xae, 0x16, 0x56, 0xe3, 0xfa, 0xf1, 0x9f, 0xc4, 0x6a, 0xda, 0x09, 0x8d, 0xeb, 0x9c, 0x32, 0xb1, 0xfd, 0x86, 0x62, 0x05, 0x16, 0x5f, 0x49, 0xb8, 0x00},
+	{0x5f, 0x9c, 0x95, 0xbc, 0xa3, 0x50, 0x8c, 0x24, 0xb1, 0xd0, 0xb1, 0x55, 0x9c, 0x83, 0xef, 0x5b, 0x04, 0x44, 0x5c, 0xc4, 0x58, 0x1c, 0x8e, 0x86, 0xd8, 0x22, 0x4e, 0xdd, 0xd0, 0x9f, 0x11, 0x57},
+	{0xec, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
+	{0xed, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
+	{0xee, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
+}
+
 // testVectors generated with BoringSSL.
 var testVectors = []struct {
 	In     [32]byte