curve25519: implement new X25519 API and deprecate ScalarMult
const ScalarSize = 32
const PointSize = 32
var Basepoint []byte
func X25519(scalar, point []byte) ([]byte, error)
Fixes golang/go#32670
Change-Id: I6b08932e4123949355610b16b6053559b399516c
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/205157
Reviewed-by: Katie Hockman <katie@golang.org>
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/curve25519/curve25519.go b/curve25519/curve25519.go
index 723d9d0..4b9a655 100644
--- a/curve25519/curve25519.go
+++ b/curve25519/curve25519.go
@@ -2,22 +2,94 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package curve25519 provides an implementation of scalar multiplication on
-// the elliptic curve known as curve25519. See https://cr.yp.to/ecdh.html.
+// Package curve25519 provides an implementation of the X25519 function, which
+// performs scalar multiplication on the elliptic curve known as Curve25519.
+// See RFC 7748.
package curve25519 // import "golang.org/x/crypto/curve25519"
-// ScalarMult sets dst to the product in*base where dst and base are the x
-// coordinates of group points and all values are in little-endian form.
-func ScalarMult(dst, in, base *[32]byte) {
- scalarMult(dst, in, base)
+import (
+ "crypto/subtle"
+ "fmt"
+)
+
+// ScalarMult sets dst to the product scalar * point.
+//
+// Deprecated: when provided a low-order point, ScalarMult will set dst to all
+// zeroes, irrespective of the scalar. Instead, use the X25519 function, which
+// will return an error.
+func ScalarMult(dst, scalar, point *[32]byte) {
+ scalarMult(dst, scalar, point)
}
-// ScalarBaseMult sets dst to the product in*base where dst and base are the x
-// coordinates of group points, base is the standard generator and all values
-// are in little-endian form.
-func ScalarBaseMult(dst, in *[32]byte) {
- ScalarMult(dst, in, &basePoint)
+// ScalarBaseMult sets dst to the product scalar * base where base is the
+// standard generator.
+//
+// It is recommended to use the X25519 function with Basepoint instead, as
+// copying into fixed size arrays can lead to unexpected bugs.
+func ScalarBaseMult(dst, scalar *[32]byte) {
+ ScalarMult(dst, scalar, &basePoint)
}
-// basePoint is the x coordinate of the generator of the curve.
+const (
+ // ScalarSize is the size of the scalar input to X25519.
+ ScalarSize = 32
+ // PointSize is the size of the point input to X25519.
+ PointSize = 32
+)
+
+// Basepoint is the canonical Curve25519 generator.
+var Basepoint []byte
+
var basePoint = [32]byte{9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
+
+func init() { Basepoint = basePoint[:] }
+
+func checkBasepoint() {
+ if subtle.ConstantTimeCompare(Basepoint, []byte{
+ 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ }) != 1 {
+ panic("curve25519: global Basepoint value was modified")
+ }
+}
+
+// X25519 returns the result of the scalar multiplication (scalar * point),
+// according to RFC 7748, Section 5. scalar, point and the return value are
+// slices of 32 bytes.
+//
+// scalar can be generated at random, for example with crypto/rand. point should
+// be either Basepoint or the output of another X25519 call.
+//
+// If point is Basepoint (but not if it's a different slice with the same
+// contents) a precomputed implementation might be used for performance.
+func X25519(scalar, point []byte) ([]byte, error) {
+ // Outline the body of function, to let the allocation be inlined in the
+ // caller, and possibly avoid escaping to the heap.
+ var dst [32]byte
+ return x25519(&dst, scalar, point)
+}
+
+func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
+ var in [32]byte
+ if l := len(scalar); l != 32 {
+ return nil, fmt.Errorf("bad scalar length: %d, expected %d", l, 32)
+ }
+ if l := len(point); l != 32 {
+ return nil, fmt.Errorf("bad point length: %d, expected %d", l, 32)
+ }
+ copy(in[:], scalar)
+ if &point[0] == &Basepoint[0] {
+ checkBasepoint()
+ ScalarBaseMult(dst, &in)
+ } else {
+ var base, zero [32]byte
+ copy(base[:], point)
+ ScalarMult(dst, &in, &base)
+ if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 {
+ return nil, fmt.Errorf("bad input point: low order point")
+ }
+ }
+ return dst[:], nil
+}
diff --git a/curve25519/curve25519_test.go b/curve25519/curve25519_test.go
index 56ef73b..aca7695 100644
--- a/curve25519/curve25519_test.go
+++ b/curve25519/curve25519_test.go
@@ -13,27 +13,58 @@
const expectedHex = "89161fde887b2b53de549af483940106ecc114d6982daa98256de23bdf77661a"
-func TestBaseScalarMult(t *testing.T) {
- var a, b [32]byte
- in := &a
- out := &b
- a[0] = 1
+func TestX25519Basepoint(t *testing.T) {
+ x := make([]byte, 32)
+ x[0] = 1
for i := 0; i < 200; i++ {
- ScalarBaseMult(out, in)
- in, out = out, in
+ var err error
+ x, err = X25519(x, Basepoint)
+ if err != nil {
+ t.Fatal(err)
+ }
}
- result := fmt.Sprintf("%x", in[:])
+ result := fmt.Sprintf("%x", x)
if result != expectedHex {
t.Errorf("incorrect result: got %s, want %s", result, expectedHex)
}
}
+func TestLowOrderPoints(t *testing.T) {
+ scalar := make([]byte, ScalarSize)
+ if _, err := rand.Read(scalar); err != nil {
+ t.Fatal(err)
+ }
+ for i, p := range lowOrderPoints {
+ out, err := X25519(scalar, p)
+ if err == nil {
+ t.Errorf("%d: expected error, got nil", i)
+ }
+ if out != nil {
+ t.Errorf("%d: expected nil output, got %x", i, out)
+ }
+ }
+}
+
func TestTestVectors(t *testing.T) {
+ t.Run("Generic", func(t *testing.T) { testTestVectors(t, scalarMultGeneric) })
+ t.Run("Native", func(t *testing.T) { testTestVectors(t, ScalarMult) })
+ t.Run("X25519", func(t *testing.T) {
+ testTestVectors(t, func(dst, scalar, point *[32]byte) {
+ out, err := X25519(scalar[:], point[:])
+ if err != nil {
+ t.Fatal(err)
+ }
+ copy(dst[:], out)
+ })
+ })
+}
+
+func testTestVectors(t *testing.T, scalarMult func(dst, scalar, point *[32]byte)) {
for _, tv := range testVectors {
var got [32]byte
- ScalarMult(&got, &tv.In, &tv.Base)
+ scalarMult(&got, &tv.In, &tv.Base)
if !bytes.Equal(got[:], tv.Expect[:]) {
t.Logf(" in = %x", tv.In)
t.Logf(" base = %x", tv.Base)
diff --git a/curve25519/vectors_test.go b/curve25519/vectors_test.go
index 79b7a09..946e9a8 100644
--- a/curve25519/vectors_test.go
+++ b/curve25519/vectors_test.go
@@ -4,6 +4,18 @@
package curve25519
+// lowOrderPoints from libsodium.
+// https://github.com/jedisct1/libsodium/blob/65621a1059a37d/src/libsodium/crypto_scalarmult/curve25519/ref10/x25519_ref10.c#L11-L70
+var lowOrderPoints = [][]byte{
+ {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ {0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ {0xe0, 0xeb, 0x7a, 0x7c, 0x3b, 0x41, 0xb8, 0xae, 0x16, 0x56, 0xe3, 0xfa, 0xf1, 0x9f, 0xc4, 0x6a, 0xda, 0x09, 0x8d, 0xeb, 0x9c, 0x32, 0xb1, 0xfd, 0x86, 0x62, 0x05, 0x16, 0x5f, 0x49, 0xb8, 0x00},
+ {0x5f, 0x9c, 0x95, 0xbc, 0xa3, 0x50, 0x8c, 0x24, 0xb1, 0xd0, 0xb1, 0x55, 0x9c, 0x83, 0xef, 0x5b, 0x04, 0x44, 0x5c, 0xc4, 0x58, 0x1c, 0x8e, 0x86, 0xd8, 0x22, 0x4e, 0xdd, 0xd0, 0x9f, 0x11, 0x57},
+ {0xec, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
+ {0xed, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
+ {0xee, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
+}
+
// testVectors generated with BoringSSL.
var testVectors = []struct {
In [32]byte