| // Copyright 2020 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| //////////////////////////////////////////////////////////////////////////////// |
| package com.google.crypto.tink.jwt; |
| |
| import static java.nio.charset.StandardCharsets.US_ASCII; |
| |
| import com.google.crypto.tink.KeyTemplate; |
| import com.google.crypto.tink.Registry; |
| import com.google.crypto.tink.internal.KeyTypeManager; |
| import com.google.crypto.tink.internal.PrimitiveFactory; |
| import com.google.crypto.tink.internal.PrivateKeyTypeManager; |
| import com.google.crypto.tink.proto.JwtRsaSsaPssAlgorithm; |
| import com.google.crypto.tink.proto.JwtRsaSsaPssKeyFormat; |
| import com.google.crypto.tink.proto.JwtRsaSsaPssPrivateKey; |
| import com.google.crypto.tink.proto.JwtRsaSsaPssPublicKey; |
| import com.google.crypto.tink.proto.KeyData.KeyMaterialType; |
| import com.google.crypto.tink.subtle.EngineFactory; |
| import com.google.crypto.tink.subtle.Enums; |
| import com.google.crypto.tink.subtle.RsaSsaPssSignJce; |
| import com.google.crypto.tink.subtle.SelfKeyTestValidators; |
| import com.google.crypto.tink.subtle.Validators; |
| import com.google.protobuf.ByteString; |
| import com.google.protobuf.ExtensionRegistryLite; |
| import com.google.protobuf.InvalidProtocolBufferException; |
| import java.io.InputStream; |
| import java.math.BigInteger; |
| import java.security.GeneralSecurityException; |
| import java.security.KeyPair; |
| import java.security.KeyPairGenerator; |
| import java.security.interfaces.RSAPrivateCrtKey; |
| import java.security.interfaces.RSAPublicKey; |
| import java.security.spec.RSAKeyGenParameterSpec; |
| import java.security.spec.RSAPrivateCrtKeySpec; |
| import java.security.spec.RSAPublicKeySpec; |
| import java.util.Collections; |
| import java.util.HashMap; |
| import java.util.Map; |
| import java.util.Optional; |
| |
| /** |
| * This key manager generates new {@code JwtRsaSsaPssPrivateKey} keys and produces new instances of |
| * {@code JwtPublicKeySign}. |
| */ |
| public final class JwtRsaSsaPssSignKeyManager |
| extends PrivateKeyTypeManager<JwtRsaSsaPssPrivateKey, JwtRsaSsaPssPublicKey> { |
| |
| private static final void selfTestKey( |
| RSAPrivateCrtKey privateKey, JwtRsaSsaPssPrivateKey keyProto) |
| throws GeneralSecurityException { |
| java.security.KeyFactory factory = EngineFactory.KEY_FACTORY.getInstance("RSA"); |
| RSAPublicKey publicKey = |
| (RSAPublicKey) |
| factory.generatePublic( |
| new RSAPublicKeySpec( |
| new BigInteger(1, keyProto.getPublicKey().getN().toByteArray()), |
| new BigInteger(1, keyProto.getPublicKey().getE().toByteArray()))); |
| // Sign and verify a test message to make sure that the key is correct. |
| JwtRsaSsaPssAlgorithm algorithm = keyProto.getPublicKey().getAlgorithm(); |
| Enums.HashType hash = JwtRsaSsaPssVerifyKeyManager.hashForPssAlgorithm(algorithm); |
| int saltLength = JwtRsaSsaPssVerifyKeyManager.saltLengthForPssAlgorithm(algorithm); |
| SelfKeyTestValidators.validateRsaSsaPss(privateKey, publicKey, hash, hash, saltLength); |
| } |
| |
| private static final RSAPrivateCrtKey createPrivateKey(JwtRsaSsaPssPrivateKey keyProto) |
| throws GeneralSecurityException { |
| java.security.KeyFactory factory = EngineFactory.KEY_FACTORY.getInstance("RSA"); |
| return (RSAPrivateCrtKey) |
| factory.generatePrivate( |
| new RSAPrivateCrtKeySpec( |
| new BigInteger(1, keyProto.getPublicKey().getN().toByteArray()), |
| new BigInteger(1, keyProto.getPublicKey().getE().toByteArray()), |
| new BigInteger(1, keyProto.getD().toByteArray()), |
| new BigInteger(1, keyProto.getP().toByteArray()), |
| new BigInteger(1, keyProto.getQ().toByteArray()), |
| new BigInteger(1, keyProto.getDp().toByteArray()), |
| new BigInteger(1, keyProto.getDq().toByteArray()), |
| new BigInteger(1, keyProto.getCrt().toByteArray()))); |
| } |
| |
| private static class JwtPublicKeySignFactory |
| extends PrimitiveFactory<JwtPublicKeySignInternal, JwtRsaSsaPssPrivateKey> { |
| public JwtPublicKeySignFactory() { |
| super(JwtPublicKeySignInternal.class); |
| } |
| |
| @Override |
| public JwtPublicKeySignInternal getPrimitive(JwtRsaSsaPssPrivateKey keyProto) |
| throws GeneralSecurityException { |
| RSAPrivateCrtKey privateKey = createPrivateKey(keyProto); |
| selfTestKey(privateKey, keyProto); |
| JwtRsaSsaPssAlgorithm algorithm = keyProto.getPublicKey().getAlgorithm(); |
| Enums.HashType hash = JwtRsaSsaPssVerifyKeyManager.hashForPssAlgorithm(algorithm); |
| int saltLength = JwtRsaSsaPssVerifyKeyManager.saltLengthForPssAlgorithm(algorithm); |
| final RsaSsaPssSignJce signer = new RsaSsaPssSignJce(privateKey, hash, hash, saltLength); |
| final String algorithmName = algorithm.name(); |
| final Optional<String> customKid = |
| keyProto.getPublicKey().hasCustomKid() |
| ? Optional.of(keyProto.getPublicKey().getCustomKid().getValue()) |
| : Optional.empty(); |
| |
| return new JwtPublicKeySignInternal() { |
| @Override |
| public String signAndEncodeWithKid(RawJwt rawJwt, Optional<String> kid) |
| throws GeneralSecurityException { |
| if (customKid.isPresent()) { |
| if (kid.isPresent()) { |
| throw new JwtInvalidException("custom_kid can only be set for RAW keys."); |
| } |
| kid = customKid; |
| } |
| String unsignedCompact = JwtFormat.createUnsignedCompact(algorithmName, kid, rawJwt); |
| return JwtFormat.createSignedCompact( |
| unsignedCompact, signer.sign(unsignedCompact.getBytes(US_ASCII))); |
| } |
| }; |
| } |
| } |
| |
| JwtRsaSsaPssSignKeyManager() { |
| super(JwtRsaSsaPssPrivateKey.class, JwtRsaSsaPssPublicKey.class, new JwtPublicKeySignFactory()); |
| } |
| |
| @Override |
| public String getKeyType() { |
| return "type.googleapis.com/google.crypto.tink.JwtRsaSsaPssPrivateKey"; |
| } |
| |
| @Override |
| public int getVersion() { |
| return 0; |
| } |
| |
| @Override |
| public JwtRsaSsaPssPublicKey getPublicKey(JwtRsaSsaPssPrivateKey privKeyProto) { |
| return privKeyProto.getPublicKey(); |
| } |
| |
| @Override |
| public KeyMaterialType keyMaterialType() { |
| return KeyMaterialType.ASYMMETRIC_PRIVATE; |
| } |
| |
| @Override |
| public JwtRsaSsaPssPrivateKey parseKey(ByteString byteString) |
| throws InvalidProtocolBufferException { |
| return JwtRsaSsaPssPrivateKey.parseFrom(byteString, ExtensionRegistryLite.getEmptyRegistry()); |
| } |
| |
| @Override |
| public void validateKey(JwtRsaSsaPssPrivateKey privKey) throws GeneralSecurityException { |
| Validators.validateVersion(privKey.getVersion(), getVersion()); |
| Validators.validateRsaModulusSize( |
| new BigInteger(1, privKey.getPublicKey().getN().toByteArray()).bitLength()); |
| Validators.validateRsaPublicExponent( |
| new BigInteger(1, privKey.getPublicKey().getE().toByteArray())); |
| } |
| |
| @Override |
| public KeyTypeManager.KeyFactory<JwtRsaSsaPssKeyFormat, JwtRsaSsaPssPrivateKey> keyFactory() { |
| return new KeyTypeManager.KeyFactory<JwtRsaSsaPssKeyFormat, JwtRsaSsaPssPrivateKey>( |
| JwtRsaSsaPssKeyFormat.class) { |
| @Override |
| public void validateKeyFormat(JwtRsaSsaPssKeyFormat keyFormat) |
| throws GeneralSecurityException { |
| Validators.validateRsaModulusSize(keyFormat.getModulusSizeInBits()); |
| Validators.validateRsaPublicExponent( |
| new BigInteger(1, keyFormat.getPublicExponent().toByteArray())); |
| } |
| |
| @Override |
| public JwtRsaSsaPssKeyFormat parseKeyFormat(ByteString byteString) |
| throws InvalidProtocolBufferException { |
| return JwtRsaSsaPssKeyFormat.parseFrom( |
| byteString, ExtensionRegistryLite.getEmptyRegistry()); |
| } |
| |
| @Override |
| public JwtRsaSsaPssPrivateKey deriveKey( |
| JwtRsaSsaPssKeyFormat format, InputStream inputStream) { |
| throw new UnsupportedOperationException(); |
| } |
| |
| @Override |
| public JwtRsaSsaPssPrivateKey createKey(JwtRsaSsaPssKeyFormat format) |
| throws GeneralSecurityException { |
| JwtRsaSsaPssAlgorithm algorithm = format.getAlgorithm(); |
| KeyPairGenerator keyGen = EngineFactory.KEY_PAIR_GENERATOR.getInstance("RSA"); |
| RSAKeyGenParameterSpec spec = |
| new RSAKeyGenParameterSpec( |
| format.getModulusSizeInBits(), |
| new BigInteger(1, format.getPublicExponent().toByteArray())); |
| keyGen.initialize(spec); |
| KeyPair keyPair = keyGen.generateKeyPair(); |
| RSAPublicKey pubKey = (RSAPublicKey) keyPair.getPublic(); |
| RSAPrivateCrtKey privKey = (RSAPrivateCrtKey) keyPair.getPrivate(); |
| // Creates JwtRsaSsaPssPublicKey. |
| JwtRsaSsaPssPublicKey pssPubKey = |
| JwtRsaSsaPssPublicKey.newBuilder() |
| .setVersion(getVersion()) |
| .setAlgorithm(algorithm) |
| .setE(ByteString.copyFrom(pubKey.getPublicExponent().toByteArray())) |
| .setN(ByteString.copyFrom(pubKey.getModulus().toByteArray())) |
| .build(); |
| // Creates JwtRsaSsaPssPrivateKey. |
| return JwtRsaSsaPssPrivateKey.newBuilder() |
| .setVersion(getVersion()) |
| .setPublicKey(pssPubKey) |
| .setD(ByteString.copyFrom(privKey.getPrivateExponent().toByteArray())) |
| .setP(ByteString.copyFrom(privKey.getPrimeP().toByteArray())) |
| .setQ(ByteString.copyFrom(privKey.getPrimeQ().toByteArray())) |
| .setDp(ByteString.copyFrom(privKey.getPrimeExponentP().toByteArray())) |
| .setDq(ByteString.copyFrom(privKey.getPrimeExponentQ().toByteArray())) |
| .setCrt(ByteString.copyFrom(privKey.getCrtCoefficient().toByteArray())) |
| .build(); |
| } |
| |
| /** |
| * List of default templates to generate tokens with algorithms "PS256", "PS384" or "PS512". |
| * Use the template with the "_RAW" suffix if you want to generate tokens without a "kid" |
| * header. |
| */ |
| @Override |
| public Map<String, KeyFactory.KeyFormat<JwtRsaSsaPssKeyFormat>> keyFormats() { |
| Map<String, KeyFactory.KeyFormat<JwtRsaSsaPssKeyFormat>> result = new HashMap<>(); |
| result.put( |
| "JWT_PS256_2048_F4_RAW", |
| createKeyFormat( |
| JwtRsaSsaPssAlgorithm.PS256, |
| 2048, |
| RSAKeyGenParameterSpec.F4, |
| KeyTemplate.OutputPrefixType.RAW)); |
| result.put( |
| "JWT_PS256_2048_F4", |
| createKeyFormat( |
| JwtRsaSsaPssAlgorithm.PS256, |
| 2048, |
| RSAKeyGenParameterSpec.F4, |
| KeyTemplate.OutputPrefixType.TINK)); |
| result.put( |
| "JWT_PS256_3072_F4_RAW", |
| createKeyFormat( |
| JwtRsaSsaPssAlgorithm.PS256, |
| 3072, |
| RSAKeyGenParameterSpec.F4, |
| KeyTemplate.OutputPrefixType.RAW)); |
| result.put( |
| "JWT_PS256_3072_F4", |
| createKeyFormat( |
| JwtRsaSsaPssAlgorithm.PS256, |
| 3072, |
| RSAKeyGenParameterSpec.F4, |
| KeyTemplate.OutputPrefixType.TINK)); |
| result.put( |
| "JWT_PS384_3072_F4_RAW", |
| createKeyFormat( |
| JwtRsaSsaPssAlgorithm.PS384, |
| 3072, |
| RSAKeyGenParameterSpec.F4, |
| KeyTemplate.OutputPrefixType.RAW)); |
| result.put( |
| "JWT_PS384_3072_F4", |
| createKeyFormat( |
| JwtRsaSsaPssAlgorithm.PS384, |
| 3072, |
| RSAKeyGenParameterSpec.F4, |
| KeyTemplate.OutputPrefixType.TINK)); |
| result.put( |
| "JWT_PS512_4096_F4_RAW", |
| createKeyFormat( |
| JwtRsaSsaPssAlgorithm.PS512, |
| 4096, |
| RSAKeyGenParameterSpec.F4, |
| KeyTemplate.OutputPrefixType.RAW)); |
| result.put( |
| "JWT_PS512_4096_F4", |
| createKeyFormat( |
| JwtRsaSsaPssAlgorithm.PS512, |
| 4096, |
| RSAKeyGenParameterSpec.F4, |
| KeyTemplate.OutputPrefixType.TINK)); |
| return Collections.unmodifiableMap(result); |
| } |
| }; |
| } |
| |
| /** |
| * Registers the {@link RsaSsaPssSignKeyManager} and the {@link RsaSsaPssVerifyKeyManager} with |
| * the registry, so that the the RsaSsaPss-Keys can be used with Tink. |
| */ |
| public static void registerPair(boolean newKeyAllowed) throws GeneralSecurityException { |
| Registry.registerAsymmetricKeyManagers( |
| new JwtRsaSsaPssSignKeyManager(), new JwtRsaSsaPssVerifyKeyManager(), newKeyAllowed); |
| } |
| |
| private static KeyFactory.KeyFormat<JwtRsaSsaPssKeyFormat> createKeyFormat( |
| JwtRsaSsaPssAlgorithm algorithm, |
| int modulusSize, |
| BigInteger publicExponent, |
| KeyTemplate.OutputPrefixType prefixType) { |
| JwtRsaSsaPssKeyFormat format = |
| JwtRsaSsaPssKeyFormat.newBuilder() |
| .setAlgorithm(algorithm) |
| .setModulusSizeInBits(modulusSize) |
| .setPublicExponent(ByteString.copyFrom(publicExponent.toByteArray())) |
| .build(); |
| return new KeyFactory.KeyFormat<>(format, prefixType); |
| } |
| } |