blob: ef25ba2f3920ee8e30dedc6d77dd30ae1fc7ec00 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
////////////////////////////////////////////////////////////////////////////////
package com.google.crypto.tink.jwt;
import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.assertThrows;
import static org.junit.Assume.assumeFalse;
import com.google.crypto.tink.CleartextKeysetHandle;
import com.google.crypto.tink.KeyTemplate;
import com.google.crypto.tink.KeyTypeManager;
import com.google.crypto.tink.KeysetHandle;
import com.google.crypto.tink.proto.JwtEcdsaAlgorithm;
import com.google.crypto.tink.proto.JwtEcdsaKeyFormat;
import com.google.crypto.tink.proto.JwtEcdsaPrivateKey;
import com.google.crypto.tink.proto.JwtEcdsaPublicKey;
import com.google.crypto.tink.proto.KeyData.KeyMaterialType;
import com.google.crypto.tink.proto.Keyset;
import com.google.crypto.tink.subtle.Base64;
import com.google.crypto.tink.subtle.EcdsaSignJce;
import com.google.crypto.tink.subtle.EllipticCurves;
import com.google.crypto.tink.subtle.EllipticCurves.EcdsaEncoding;
import com.google.crypto.tink.subtle.Enums;
import com.google.crypto.tink.subtle.Random;
import com.google.crypto.tink.testing.TestUtil;
import com.google.gson.JsonObject;
import com.google.protobuf.ByteString;
import com.google.protobuf.ExtensionRegistryLite;
import java.io.ByteArrayInputStream;
import java.security.GeneralSecurityException;
import java.security.interfaces.ECPrivateKey;
import java.util.Arrays;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
/** Unit tests for JwtEcdsaSignKeyManager. */
@RunWith(JUnitParamsRunner.class)
public class JwtEcdsaSignKeyManagerTest {
@BeforeClass
public static void setUp() throws Exception {
JwtSignatureConfig.register();
}
private final JwtEcdsaSignKeyManager manager = new JwtEcdsaSignKeyManager();
private final KeyTypeManager.KeyFactory<JwtEcdsaKeyFormat, JwtEcdsaPrivateKey> factory =
manager.keyFactory();
private static JwtEcdsaKeyFormat createKeyFormat(JwtEcdsaAlgorithm algorithm) {
return JwtEcdsaKeyFormat.newBuilder().setAlgorithm(algorithm).build();
}
private static Object[] parametersAlgos() {
return new Object[] {JwtEcdsaAlgorithm.ES256, JwtEcdsaAlgorithm.ES384, JwtEcdsaAlgorithm.ES512};
}
private static Object[] templates() {
return new Object[] {
JwtEcdsaSignKeyManager.jwtES256Template(),
JwtEcdsaSignKeyManager.jwtES384Template(),
JwtEcdsaSignKeyManager.jwtES512Template()
};
}
@Test
public void basics() throws Exception {
assertThat(manager.getKeyType())
.isEqualTo("type.googleapis.com/google.crypto.tink.JwtEcdsaPrivateKey");
assertThat(manager.getVersion()).isEqualTo(0);
assertThat(manager.keyMaterialType()).isEqualTo(KeyMaterialType.ASYMMETRIC_PRIVATE);
}
@Test
public void validateKeyFormat_empty_throw() throws Exception {
assertThrows(
GeneralSecurityException.class,
() -> factory.validateKeyFormat(JwtEcdsaKeyFormat.getDefaultInstance()));
}
@Test
@Parameters(method = "parametersAlgos")
public void validateKeyFormat_ok(JwtEcdsaAlgorithm algorithm) throws GeneralSecurityException {
JwtEcdsaKeyFormat format = createKeyFormat(algorithm);
factory.validateKeyFormat(format);
}
private static void checkConsistency(JwtEcdsaPrivateKey privateKey, JwtEcdsaKeyFormat keyFormat) {
assertThat(privateKey.getPublicKey().getAlgorithm()).isEqualTo(keyFormat.getAlgorithm());
}
@Test
@Parameters(method = "parametersAlgos")
public void createKeys_ok(JwtEcdsaAlgorithm algorithm) throws Exception {
JwtEcdsaKeyFormat format = createKeyFormat(algorithm);
JwtEcdsaPrivateKey key = factory.createKey(format);
checkConsistency(key, format);
}
@Test
@Parameters(method = "parametersAlgos")
public void createKey_alwaysNewElement_ok(JwtEcdsaAlgorithm algorithm) throws Exception {
JwtEcdsaKeyFormat format = createKeyFormat(algorithm);
Set<String> keys = new TreeSet<>();
// Calls newKey multiple times and make sure that they generate different keys -- takes about a
// second per key.
int numTests = 5;
for (int i = 0; i < numTests; i++) {
JwtEcdsaPrivateKey key = factory.createKey(format);
keys.add(TestUtil.hexEncode(key.getKeyValue().toByteArray()));
}
assertThat(keys).hasSize(numTests);
}
@Test
@Parameters(method = "parametersAlgos")
public void getPublicKey_checkValues(JwtEcdsaAlgorithm algorithm) throws Exception {
JwtEcdsaPrivateKey privateKey = factory.createKey(createKeyFormat(algorithm));
JwtEcdsaPublicKey publicKey = manager.getPublicKey(privateKey);
assertThat(publicKey).isEqualTo(privateKey.getPublicKey());
}
@Test
@Parameters(method = "parametersAlgos")
public void createCorruptedPublicKeyPrimitive_throws(JwtEcdsaAlgorithm algorithm)
throws Exception {
JwtEcdsaKeyFormat format = createKeyFormat(algorithm);
JwtEcdsaPrivateKey originalKey = factory.createKey(format);
byte[] originalPubX = originalKey.getPublicKey().getX().toByteArray();
byte[] originalPubY = originalKey.getPublicKey().getY().toByteArray();
originalPubX[0] = (byte) (originalPubX[0] ^ 0x01);
ByteString corruptedPubX = ByteString.copyFrom(originalPubX);
JwtEcdsaPublicKey corruptedPub =
JwtEcdsaPublicKey.newBuilder()
.setVersion(originalKey.getPublicKey().getVersion())
.setAlgorithm(algorithm)
.setX(corruptedPubX)
.setY(ByteString.copyFrom(originalPubY))
.build();
JwtEcdsaPrivateKey corruptedKey =
JwtEcdsaPrivateKey.newBuilder()
.setVersion(originalKey.getVersion())
.setPublicKey(corruptedPub)
.setKeyValue(originalKey.getKeyValue())
.build();
assertThrows(
GeneralSecurityException.class,
() -> manager.getPrimitive(corruptedKey, JwtPublicKeySign.class));
}
@Test
public void testDeriveKey_throw() throws Exception {
assertThrows(
UnsupportedOperationException.class,
() ->
factory.deriveKey(
JwtEcdsaKeyFormat.getDefaultInstance(),
new ByteArrayInputStream(Random.randBytes(100))));
}
private static void checkTemplate(KeyTemplate template, JwtEcdsaAlgorithm algorithm)
throws Exception {
assertThat(template.getTypeUrl()).isEqualTo(new JwtEcdsaSignKeyManager().getKeyType());
assertThat(template.getOutputPrefixType()).isEqualTo(KeyTemplate.OutputPrefixType.RAW);
JwtEcdsaKeyFormat format =
JwtEcdsaKeyFormat.parseFrom(template.getValue(), ExtensionRegistryLite.getEmptyRegistry());
assertThat(format.getAlgorithm()).isEqualTo(algorithm);
}
@Test
public void testJwtES256Template_ok() throws Exception {
KeyTemplate template = JwtEcdsaSignKeyManager.jwtES256Template();
checkTemplate(template, JwtEcdsaAlgorithm.ES256);
}
@Test
public void testJwtES384Template_ok() throws Exception {
KeyTemplate template = JwtEcdsaSignKeyManager.jwtES384Template();
checkTemplate(template, JwtEcdsaAlgorithm.ES384);
}
@Test
public void testJwtES512Template_ok() throws Exception {
KeyTemplate template = JwtEcdsaSignKeyManager.jwtES512Template();
checkTemplate(template, JwtEcdsaAlgorithm.ES512);
}
@Test
@Parameters(method = "templates")
public void createSignVerify_success(KeyTemplate template) throws Exception {
assumeFalse(TestUtil.isTsan()); // KeysetHandle.generateNew is too slow in Tsan.
KeysetHandle handle = KeysetHandle.generateNew(template);
JwtPublicKeySign signer = handle.getPrimitive(JwtPublicKeySign.class);
JwtPublicKeyVerify verifier =
handle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);
JwtValidator validator = new JwtValidator.Builder().build();
RawJwt rawToken = new RawJwt.Builder().setIssuer("issuer").build();
String signedCompact = signer.signAndEncode(rawToken);
VerifiedJwt verifiedToken = verifier.verifyAndDecode(signedCompact, validator);
assertThat(verifiedToken.getIssuer()).isEqualTo("issuer");
assertThat(verifiedToken.hasTypeHeader()).isFalse();
RawJwt rawTokenWithType = new RawJwt.Builder().setTypeHeader("typeHeader").build();
String signedCompactWithType = signer.signAndEncode(rawTokenWithType);
VerifiedJwt verifiedTokenWithType = verifier.verifyAndDecode(signedCompactWithType, validator);
assertThat(verifiedTokenWithType.getTypeHeader()).isEqualTo("typeHeader");
}
@Test
@Parameters(method = "templates")
public void createSignVerifyDifferentKey_throw(KeyTemplate template) throws Exception {
assumeFalse(TestUtil.isTsan()); // KeysetHandle.generateNew is too slow in Tsan.
KeysetHandle handle = KeysetHandle.generateNew(template);
JwtPublicKeySign signer = handle.getPrimitive(JwtPublicKeySign.class);
RawJwt rawToken = new RawJwt.Builder().setIssuer("issuer").build();
String signedCompact = signer.signAndEncode(rawToken);
KeysetHandle otherHandle = KeysetHandle.generateNew(template);
JwtPublicKeyVerify otherVerifier =
otherHandle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);
JwtValidator validator = new JwtValidator.Builder().build();
assertThrows(
GeneralSecurityException.class,
() -> otherVerifier.verifyAndDecode(signedCompact, validator));
}
@Test
@Parameters(method = "templates")
public void createSignVerify_header_modification_throw(KeyTemplate template) throws Exception {
assumeFalse(TestUtil.isTsan()); // KeysetHandle.generateNew is too slow in Tsan.
KeysetHandle handle = KeysetHandle.generateNew(template);
JwtPublicKeySign signer = handle.getPrimitive(JwtPublicKeySign.class);
JwtPublicKeyVerify verifier =
handle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);
RawJwt rawToken = new RawJwt.Builder().setIssuer("issuer").build();
String signedCompact = signer.signAndEncode(rawToken);
// Modify the header by adding a space at the end.
String[] parts = signedCompact.split("\\.", -1);
String header = new String(Base64.urlSafeDecode(parts[0]), UTF_8);
String headerBase64 = Base64.urlSafeEncode((header + " ").getBytes(UTF_8));
String modifiedCompact = headerBase64 + "." + parts[1] + "." + parts[2];
JwtValidator validator = new JwtValidator.Builder().build();
assertThrows(
GeneralSecurityException.class, () -> verifier.verifyAndDecode(modifiedCompact, validator));
}
@Test
@Parameters(method = "templates")
public void createSignVerify_payload_modification_throw(KeyTemplate template) throws Exception {
assumeFalse(TestUtil.isTsan()); // KeysetHandle.generateNew is too slow in Tsan.
KeysetHandle handle = KeysetHandle.generateNew(template);
JwtPublicKeySign signer = handle.getPrimitive(JwtPublicKeySign.class);
JwtPublicKeyVerify verifier =
handle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);
RawJwt rawToken = new RawJwt.Builder().setIssuer("issuer").build();
String signedCompact = signer.signAndEncode(rawToken);
// Modify the payload by adding a space at the end.
String[] parts = signedCompact.split("\\.", -1);
String payload = new String(Base64.urlSafeDecode(parts[1]), UTF_8);
String payloadBase64 = Base64.urlSafeEncode((payload + " ").getBytes(UTF_8));
String modifiedCompact = parts[0] + "." + payloadBase64 + "." + parts[2];
JwtValidator validator = new JwtValidator.Builder().build();
assertThrows(
GeneralSecurityException.class, () -> verifier.verifyAndDecode(modifiedCompact, validator));
}
@Test
@Parameters(method = "templates")
public void createSignVerify_bitFlipped_throw(KeyTemplate template) throws Exception {
assumeFalse(TestUtil.isTsan()); // KeysetHandle.generateNew is too slow in Tsan.
KeysetHandle handle = KeysetHandle.generateNew(template);
JwtPublicKeySign signer = handle.getPrimitive(JwtPublicKeySign.class);
JwtPublicKeyVerify verifier =
handle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);
RawJwt rawToken = new RawJwt.Builder().setIssuer("issuer").build();
String result = signer.signAndEncode(rawToken);
JwtValidator validator = new JwtValidator.Builder().build();
char[] validJwt = new char[result.length()];
for (int j = 0; j < result.length(); j++) {
validJwt[j] = result.charAt(j);
}
// We ignore the last byte because the bas64 decoder ignores some of the bits.
for (int i = 0; i < result.length() - 1; ++i) {
// Flip every bit of i-th byte.
for (int b = 0; b < 8; ++b) {
char[] invalidJwt = Arrays.copyOf(validJwt, result.length());
invalidJwt[i] = (char) (validJwt[i] ^ (1 << b));
assertThrows(
GeneralSecurityException.class,
() -> verifier.verifyAndDecode(new String(invalidJwt), validator));
}
}
}
private static String generateSignedCompact(
EcdsaSignJce rawSigner, JsonObject header, JsonObject payload)
throws GeneralSecurityException {
String payloadBase64 = Base64.urlSafeEncode(payload.toString().getBytes(UTF_8));
String headerBase64 = Base64.urlSafeEncode(header.toString().getBytes(UTF_8));
String unsignedCompact = headerBase64 + "." + payloadBase64;
String signature =
Base64.urlSafeEncode(rawSigner.sign(unsignedCompact.getBytes(UTF_8)));
return unsignedCompact + "." + signature;
}
@Test
public void createSignVerify_withDifferentHeaders() throws Exception {
assumeFalse(TestUtil.isTsan()); // KeysetHandle.generateNew is too slow in Tsan.
KeyTemplate template = JwtEcdsaSignKeyManager.jwtES256Template();
KeysetHandle handle = KeysetHandle.generateNew(template);
Keyset keyset = CleartextKeysetHandle.getKeyset(handle);
JwtEcdsaPrivateKey keyProto =
JwtEcdsaPrivateKey.parseFrom(
keyset.getKey(0).getKeyData().getValue(), ExtensionRegistryLite.getEmptyRegistry());
ECPrivateKey privateKey =
EllipticCurves.getEcPrivateKey(
JwtEcdsaVerifyKeyManager.getCurve(keyProto.getPublicKey().getAlgorithm()),
keyProto.getKeyValue().toByteArray());
JwtEcdsaAlgorithm algorithm = keyProto.getPublicKey().getAlgorithm();
Enums.HashType hash = JwtEcdsaVerifyKeyManager.hashForEcdsaAlgorithm(algorithm);
EcdsaSignJce rawSigner = new EcdsaSignJce(privateKey, hash, EcdsaEncoding.IEEE_P1363);
JsonObject payload = new JsonObject();
payload.addProperty(JwtNames.CLAIM_ISSUER, "issuer");
JwtValidator validator = new JwtValidator.Builder().build();
JwtPublicKeyVerify verifier =
handle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);
// Normal, valid signed compact.
String unsignedCompact =
JwtFormat.createUnsignedCompact(algorithm.name(), Optional.empty(), payload.toString());
String normalSignedCompact =
JwtFormat.createSignedCompact(
unsignedCompact, rawSigner.sign(unsignedCompact.getBytes(US_ASCII)));
verifier.verifyAndDecode(normalSignedCompact, validator);
// valid token, with "typ" set in the header
JsonObject goodHeader = new JsonObject();
goodHeader.addProperty(JwtNames.HEADER_ALGORITHM, "ES256");
goodHeader.addProperty("typ", "JWT");
String goodSignedCompact = generateSignedCompact(rawSigner, goodHeader, payload);
verifier.verifyAndDecode(goodSignedCompact, validator);
// invalid token with an empty header
JsonObject emptyHeader = new JsonObject();
String emptyHeaderSignedCompact = generateSignedCompact(rawSigner, emptyHeader, payload);
assertThrows(
GeneralSecurityException.class,
() -> verifier.verifyAndDecode(emptyHeaderSignedCompact, validator));
// invalid token with a valid but incorrect algorithm in the header
JsonObject badAlgoHeader = new JsonObject();
badAlgoHeader.addProperty(JwtNames.HEADER_ALGORITHM, "RS256");
String badAlgoSignedCompact = generateSignedCompact(rawSigner, badAlgoHeader, payload);
assertThrows(
GeneralSecurityException.class,
() -> verifier.verifyAndDecode(badAlgoSignedCompact, validator));
// token with an unknown "typ" in the header is valid
JsonObject unknownTypeHeader = new JsonObject();
unknownTypeHeader.addProperty(JwtNames.HEADER_ALGORITHM, "ES256");
unknownTypeHeader.addProperty("typ", "unknown");
String unknownTypeSignedCompact = generateSignedCompact(rawSigner, unknownTypeHeader, payload);
verifier.verifyAndDecode(unknownTypeSignedCompact, validator);
}
}