Migrate the RsaSsaPkcs1{Sign,Verify}KeyManager to KeyTypeManagers.

PiperOrigin-RevId: 268166210
diff --git a/java/src/main/java/com/google/crypto/tink/signature/RsaSsaPkcs1SignKeyManager.java b/java/src/main/java/com/google/crypto/tink/signature/RsaSsaPkcs1SignKeyManager.java
index ef5cfe6..ea724f1 100644
--- a/java/src/main/java/com/google/crypto/tink/signature/RsaSsaPkcs1SignKeyManager.java
+++ b/java/src/main/java/com/google/crypto/tink/signature/RsaSsaPkcs1SignKeyManager.java
@@ -16,10 +16,8 @@
 
 package com.google.crypto.tink.signature;
 
-import com.google.crypto.tink.KeyManagerBase;
-import com.google.crypto.tink.PrivateKeyManager;
+import com.google.crypto.tink.PrivateKeyTypeManager;
 import com.google.crypto.tink.PublicKeySign;
-import com.google.crypto.tink.proto.KeyData;
 import com.google.crypto.tink.proto.KeyData.KeyMaterialType;
 import com.google.crypto.tink.proto.RsaSsaPkcs1KeyFormat;
 import com.google.crypto.tink.proto.RsaSsaPkcs1Params;
@@ -48,150 +46,145 @@
  * {@code RsaSsaPkcs1SignJce}.
  */
 class RsaSsaPkcs1SignKeyManager
-    extends KeyManagerBase<PublicKeySign, RsaSsaPkcs1PrivateKey, RsaSsaPkcs1KeyFormat>
-    implements PrivateKeyManager<PublicKeySign> {
+    extends PrivateKeyTypeManager<RsaSsaPkcs1PrivateKey, RsaSsaPkcs1PublicKey> {
+  private static final byte[] TEST_MESSAGE =
+      "Tink and Wycheproof.".getBytes(Charset.forName("UTF-8"));
+
   public RsaSsaPkcs1SignKeyManager() {
-    super(PublicKeySign.class, RsaSsaPkcs1PrivateKey.class, RsaSsaPkcs1KeyFormat.class, TYPE_URL);
+    super(
+        RsaSsaPkcs1PrivateKey.class,
+        RsaSsaPkcs1PublicKey.class,
+        new PrimitiveFactory<PublicKeySign, RsaSsaPkcs1PrivateKey>(PublicKeySign.class) {
+          @Override
+          public PublicKeySign getPrimitive(RsaSsaPkcs1PrivateKey keyProto)
+              throws GeneralSecurityException {
+            java.security.KeyFactory kf = EngineFactory.KEY_FACTORY.getInstance("RSA");
+            RSAPrivateCrtKey privateKey =
+                (RSAPrivateCrtKey)
+                    kf.generatePrivate(
+                        new RSAPrivateCrtKeySpec(
+                            new BigInteger(1, keyProto.getPublicKey().getN().toByteArray()),
+                            new BigInteger(1, keyProto.getPublicKey().getE().toByteArray()),
+                            new BigInteger(1, keyProto.getD().toByteArray()),
+                            new BigInteger(1, keyProto.getP().toByteArray()),
+                            new BigInteger(1, keyProto.getQ().toByteArray()),
+                            new BigInteger(1, keyProto.getDp().toByteArray()),
+                            new BigInteger(1, keyProto.getDq().toByteArray()),
+                            new BigInteger(1, keyProto.getCrt().toByteArray())));
+            // Sign and verify a test message to make sure that the key is correct.
+            RsaSsaPkcs1SignJce signer =
+                new RsaSsaPkcs1SignJce(
+                    privateKey,
+                    SigUtil.toHashType(keyProto.getPublicKey().getParams().getHashType()));
+            RSAPublicKey publicKey =
+                (RSAPublicKey)
+                    kf.generatePublic(
+                        new RSAPublicKeySpec(
+                            new BigInteger(1, keyProto.getPublicKey().getN().toByteArray()),
+                            new BigInteger(1, keyProto.getPublicKey().getE().toByteArray())));
+            RsaSsaPkcs1VerifyJce verifier =
+                new RsaSsaPkcs1VerifyJce(
+                    publicKey,
+                    SigUtil.toHashType(keyProto.getPublicKey().getParams().getHashType()));
+            try {
+              verifier.verify(signer.sign(TEST_MESSAGE), TEST_MESSAGE);
+            } catch (GeneralSecurityException e) {
+              throw new RuntimeException(
+                  "Security bug: signing with private key followed by verifying with public key"
+                      + " failed"
+                      + e);
+            }
+            return signer;
+          }
+        });
   }
 
-  public static final String TYPE_URL =
-      "type.googleapis.com/google.crypto.tink.RsaSsaPkcs1PrivateKey";
-
-  private static final int VERSION = 0;
-
-  private static final Charset UTF_8 = Charset.forName("UTF-8");
-
-  /** Test message. */
-  private static final byte[] TEST_MESSAGE = "Tink and Wycheproof.".getBytes(UTF_8);
+  @Override
+  public String getKeyType() {
+    return "type.googleapis.com/google.crypto.tink.RsaSsaPkcs1PrivateKey";
+  }
 
   @Override
-  public PublicKeySign getPrimitiveFromKey(RsaSsaPkcs1PrivateKey keyProto)
+  public int getVersion() {
+    return 0;
+  }
+
+  @Override
+  public RsaSsaPkcs1PublicKey getPublicKey(RsaSsaPkcs1PrivateKey privKeyProto)
       throws GeneralSecurityException {
-    validateKey(keyProto);
-    KeyFactory kf = EngineFactory.KEY_FACTORY.getInstance("RSA");
-    RSAPrivateCrtKey privateKey =
-        (RSAPrivateCrtKey)
-            kf.generatePrivate(
-                new RSAPrivateCrtKeySpec(
-                    new BigInteger(1, keyProto.getPublicKey().getN().toByteArray()),
-                    new BigInteger(1, keyProto.getPublicKey().getE().toByteArray()),
-                    new BigInteger(1, keyProto.getD().toByteArray()),
-                    new BigInteger(1, keyProto.getP().toByteArray()),
-                    new BigInteger(1, keyProto.getQ().toByteArray()),
-                    new BigInteger(1, keyProto.getDp().toByteArray()),
-                    new BigInteger(1, keyProto.getDq().toByteArray()),
-                    new BigInteger(1, keyProto.getCrt().toByteArray())));
-    // Sign and verify a test message to make sure that the key is correct.
-    RsaSsaPkcs1SignJce signer =
-        new RsaSsaPkcs1SignJce(
-            privateKey, SigUtil.toHashType(keyProto.getPublicKey().getParams().getHashType()));
-    RSAPublicKey publicKey =
-        (RSAPublicKey)
-            kf.generatePublic(
-                new RSAPublicKeySpec(
-                    new BigInteger(1, keyProto.getPublicKey().getN().toByteArray()),
-                    new BigInteger(1, keyProto.getPublicKey().getE().toByteArray())));
-    RsaSsaPkcs1VerifyJce verifier =
-        new RsaSsaPkcs1VerifyJce(
-            publicKey, SigUtil.toHashType(keyProto.getPublicKey().getParams().getHashType()));
-    try {
-      verifier.verify(signer.sign(TEST_MESSAGE), TEST_MESSAGE);
-    } catch (GeneralSecurityException e) {
-      throw new RuntimeException(
-          "Security bug: signing with private key followed by verifying with public key failed"
-              + e);
-    }
-    return signer;
-  }
-
-  /**
-   * @param serializedKeyFormat serialized {@code RsaSsaPkcs1KeyFormat} proto
-   * @return new {@code RsaSsaPkcs1PrivateKey} proto
-   */
-  @Override
-  public RsaSsaPkcs1PrivateKey newKeyFromFormat(RsaSsaPkcs1KeyFormat format)
-      throws GeneralSecurityException {
-    validateKeyFormat(format);
-    RsaSsaPkcs1Params params = format.getParams();
-    KeyPairGenerator keyGen = EngineFactory.KEY_PAIR_GENERATOR.getInstance("RSA");
-    RSAKeyGenParameterSpec spec =
-        new RSAKeyGenParameterSpec(
-            format.getModulusSizeInBits(),
-            new BigInteger(1, format.getPublicExponent().toByteArray()));
-    keyGen.initialize(spec);
-    KeyPair keyPair = keyGen.generateKeyPair();
-    RSAPublicKey pubKey = (RSAPublicKey) keyPair.getPublic();
-    RSAPrivateCrtKey privKey = (RSAPrivateCrtKey) keyPair.getPrivate();
-
-    // Creates RsaSsaPkcs1PublicKey.
-    RsaSsaPkcs1PublicKey pkcs1PubKey =
-        RsaSsaPkcs1PublicKey.newBuilder()
-            .setVersion(VERSION)
-            .setParams(params)
-            .setE(ByteString.copyFrom(pubKey.getPublicExponent().toByteArray()))
-            .setN(ByteString.copyFrom(pubKey.getModulus().toByteArray()))
-            .build();
-
-    // Creates RsaSsaPkcs1PrivateKey.
-    return RsaSsaPkcs1PrivateKey.newBuilder()
-        .setVersion(VERSION)
-        .setPublicKey(pkcs1PubKey)
-        .setD(ByteString.copyFrom(privKey.getPrivateExponent().toByteArray()))
-        .setP(ByteString.copyFrom(privKey.getPrimeP().toByteArray()))
-        .setQ(ByteString.copyFrom(privKey.getPrimeQ().toByteArray()))
-        .setDp(ByteString.copyFrom(privKey.getPrimeExponentP().toByteArray()))
-        .setDq(ByteString.copyFrom(privKey.getPrimeExponentQ().toByteArray()))
-        .setCrt(ByteString.copyFrom(privKey.getCrtCoefficient().toByteArray()))
-        .build();
+    return privKeyProto.getPublicKey();
   }
 
   @Override
-  protected KeyMaterialType keyMaterialType() {
+  public KeyMaterialType keyMaterialType() {
     return KeyMaterialType.ASYMMETRIC_PRIVATE;
   }
 
   @Override
-  protected RsaSsaPkcs1PrivateKey parseKeyProto(ByteString byteString)
+  public RsaSsaPkcs1PrivateKey parseKey(ByteString byteString)
       throws InvalidProtocolBufferException {
     return RsaSsaPkcs1PrivateKey.parseFrom(byteString);
   }
 
   @Override
-  protected RsaSsaPkcs1KeyFormat parseKeyFormatProto(ByteString byteString)
-      throws InvalidProtocolBufferException {
-    return RsaSsaPkcs1KeyFormat.parseFrom(byteString);
-  }
-
-  @Override
-  public KeyData getPublicKeyData(ByteString serializedKey) throws GeneralSecurityException {
-    try {
-      RsaSsaPkcs1PrivateKey privKeyProto = RsaSsaPkcs1PrivateKey.parseFrom(serializedKey);
-      return KeyData.newBuilder()
-          .setTypeUrl(RsaSsaPkcs1VerifyKeyManager.TYPE_URL)
-          .setValue(privKeyProto.getPublicKey().toByteString())
-          .setKeyMaterialType(KeyData.KeyMaterialType.ASYMMETRIC_PUBLIC)
-          .build();
-    } catch (InvalidProtocolBufferException e) {
-      throw new GeneralSecurityException("expected serialized RsaSsaPkcs1PrivateKey proto", e);
-    }
-  }
-
-  @Override
-  public int getVersion() {
-    return VERSION;
-  }
-
-  @Override
-  protected void validateKeyFormat(RsaSsaPkcs1KeyFormat keyFormat) throws GeneralSecurityException {
-    SigUtil.validateRsaSsaPkcs1Params(keyFormat.getParams());
-    Validators.validateRsaModulusSize(keyFormat.getModulusSizeInBits());
-  }
-
-  @Override
-  protected void validateKey(RsaSsaPkcs1PrivateKey privKey) throws GeneralSecurityException {
-    Validators.validateVersion(privKey.getVersion(), VERSION);
+  public void validateKey(RsaSsaPkcs1PrivateKey privKey) throws GeneralSecurityException {
+    Validators.validateVersion(privKey.getVersion(), getVersion());
     Validators.validateRsaModulusSize(
-        (new BigInteger(1, privKey.getPublicKey().getN().toByteArray())).bitLength());
+        new BigInteger(1, privKey.getPublicKey().getN().toByteArray()).bitLength());
     SigUtil.validateRsaSsaPkcs1Params(privKey.getPublicKey().getParams());
   }
+
+  @Override
+  public KeyFactory<RsaSsaPkcs1KeyFormat, RsaSsaPkcs1PrivateKey> keyFactory() {
+    return new KeyFactory<RsaSsaPkcs1KeyFormat, RsaSsaPkcs1PrivateKey>(RsaSsaPkcs1KeyFormat.class) {
+      @Override
+      public void validateKeyFormat(RsaSsaPkcs1KeyFormat keyFormat)
+          throws GeneralSecurityException {
+        SigUtil.validateRsaSsaPkcs1Params(keyFormat.getParams());
+        Validators.validateRsaModulusSize(keyFormat.getModulusSizeInBits());
+      }
+
+      @Override
+      public RsaSsaPkcs1KeyFormat parseKeyFormat(ByteString byteString)
+          throws InvalidProtocolBufferException {
+        return RsaSsaPkcs1KeyFormat.parseFrom(byteString);
+      }
+
+      @Override
+      public RsaSsaPkcs1PrivateKey createKey(RsaSsaPkcs1KeyFormat format)
+          throws GeneralSecurityException {
+        RsaSsaPkcs1Params params = format.getParams();
+        KeyPairGenerator keyGen = EngineFactory.KEY_PAIR_GENERATOR.getInstance("RSA");
+        RSAKeyGenParameterSpec spec =
+            new RSAKeyGenParameterSpec(
+                format.getModulusSizeInBits(),
+                new BigInteger(1, format.getPublicExponent().toByteArray()));
+        keyGen.initialize(spec);
+        KeyPair keyPair = keyGen.generateKeyPair();
+        RSAPublicKey pubKey = (RSAPublicKey) keyPair.getPublic();
+        RSAPrivateCrtKey privKey = (RSAPrivateCrtKey) keyPair.getPrivate();
+
+        // Creates RsaSsaPkcs1PublicKey.
+        RsaSsaPkcs1PublicKey pkcs1PubKey =
+            RsaSsaPkcs1PublicKey.newBuilder()
+                .setVersion(getVersion())
+                .setParams(params)
+                .setE(ByteString.copyFrom(pubKey.getPublicExponent().toByteArray()))
+                .setN(ByteString.copyFrom(pubKey.getModulus().toByteArray()))
+                .build();
+
+        // Creates RsaSsaPkcs1PrivateKey.
+        return RsaSsaPkcs1PrivateKey.newBuilder()
+            .setVersion(getVersion())
+            .setPublicKey(pkcs1PubKey)
+            .setD(ByteString.copyFrom(privKey.getPrivateExponent().toByteArray()))
+            .setP(ByteString.copyFrom(privKey.getPrimeP().toByteArray()))
+            .setQ(ByteString.copyFrom(privKey.getPrimeQ().toByteArray()))
+            .setDp(ByteString.copyFrom(privKey.getPrimeExponentP().toByteArray()))
+            .setDq(ByteString.copyFrom(privKey.getPrimeExponentQ().toByteArray()))
+            .setCrt(ByteString.copyFrom(privKey.getCrtCoefficient().toByteArray()))
+            .build();
+      }
+    };
+  }
 }
diff --git a/java/src/main/java/com/google/crypto/tink/signature/RsaSsaPkcs1VerifyKeyManager.java b/java/src/main/java/com/google/crypto/tink/signature/RsaSsaPkcs1VerifyKeyManager.java
index 31c68f1..b320ac0 100644
--- a/java/src/main/java/com/google/crypto/tink/signature/RsaSsaPkcs1VerifyKeyManager.java
+++ b/java/src/main/java/com/google/crypto/tink/signature/RsaSsaPkcs1VerifyKeyManager.java
@@ -16,9 +16,8 @@
 
 package com.google.crypto.tink.signature;
 
-import com.google.crypto.tink.KeyManagerBase;
+import com.google.crypto.tink.KeyTypeManager;
 import com.google.crypto.tink.PublicKeyVerify;
-import com.google.crypto.tink.proto.Empty;
 import com.google.crypto.tink.proto.KeyData.KeyMaterialType;
 import com.google.crypto.tink.proto.RsaSsaPkcs1PublicKey;
 import com.google.crypto.tink.subtle.EngineFactory;
@@ -28,7 +27,6 @@
 import com.google.protobuf.InvalidProtocolBufferException;
 import java.math.BigInteger;
 import java.security.GeneralSecurityException;
-import java.security.KeyFactory;
 import java.security.interfaces.RSAPublicKey;
 import java.security.spec.RSAPublicKeySpec;
 
@@ -36,64 +34,50 @@
  * This key manager produces new instances of {@code RsaSsaPkcs1VerifyJce}. It doesn't support key
  * generation.
  */
-class RsaSsaPkcs1VerifyKeyManager
-    extends KeyManagerBase<PublicKeyVerify, RsaSsaPkcs1PublicKey, Empty> {
+class RsaSsaPkcs1VerifyKeyManager extends KeyTypeManager<RsaSsaPkcs1PublicKey> {
   public RsaSsaPkcs1VerifyKeyManager() {
-    super(PublicKeyVerify.class, RsaSsaPkcs1PublicKey.class, Empty.class, TYPE_URL);
-  }
-
-  protected static final int VERSION = 0;
-  public static final String TYPE_URL =
-      "type.googleapis.com/google.crypto.tink.RsaSsaPkcs1PublicKey";
-
-  @Override
-  public PublicKeyVerify getPrimitiveFromKey(RsaSsaPkcs1PublicKey keyProto)
-      throws GeneralSecurityException {
-    validateKey(keyProto);
-    KeyFactory kf = EngineFactory.KEY_FACTORY.getInstance("RSA");
-    BigInteger modulus = new BigInteger(1, keyProto.getN().toByteArray());
-    BigInteger exponent = new BigInteger(1, keyProto.getE().toByteArray());
-    RSAPublicKey publicKey =
-        (RSAPublicKey) kf.generatePublic(new RSAPublicKeySpec(modulus, exponent));
-    return new RsaSsaPkcs1VerifyJce(
-        publicKey, SigUtil.toHashType(keyProto.getParams().getHashType()));
+    super(
+        RsaSsaPkcs1PublicKey.class,
+        new PrimitiveFactory<PublicKeyVerify, RsaSsaPkcs1PublicKey>(PublicKeyVerify.class) {
+          @Override
+          public PublicKeyVerify getPrimitive(RsaSsaPkcs1PublicKey keyProto)
+              throws GeneralSecurityException {
+            java.security.KeyFactory kf = EngineFactory.KEY_FACTORY.getInstance("RSA");
+            BigInteger modulus = new BigInteger(1, keyProto.getN().toByteArray());
+            BigInteger exponent = new BigInteger(1, keyProto.getE().toByteArray());
+            RSAPublicKey publicKey =
+                (RSAPublicKey) kf.generatePublic(new RSAPublicKeySpec(modulus, exponent));
+            return new RsaSsaPkcs1VerifyJce(
+                publicKey, SigUtil.toHashType(keyProto.getParams().getHashType()));
+          }
+        });
   }
 
   @Override
-  public RsaSsaPkcs1PublicKey newKeyFromFormat(Empty serializedKeyFormat)
-      throws GeneralSecurityException {
-    throw new GeneralSecurityException("Not implemented");
+  public String getKeyType() {
+    return "type.googleapis.com/google.crypto.tink.RsaSsaPkcs1PublicKey";
   }
 
   @Override
   public int getVersion() {
-    return VERSION;
+    return 0;
   }
 
   @Override
-  protected KeyMaterialType keyMaterialType() {
+  public KeyMaterialType keyMaterialType() {
     return KeyMaterialType.ASYMMETRIC_PUBLIC;
   }
 
   @Override
-  protected RsaSsaPkcs1PublicKey parseKeyProto(ByteString byteString)
+  public RsaSsaPkcs1PublicKey parseKey(ByteString byteString)
       throws InvalidProtocolBufferException {
     return RsaSsaPkcs1PublicKey.parseFrom(byteString);
   }
 
   @Override
-  protected Empty parseKeyFormatProto(ByteString byteString)
-      throws InvalidProtocolBufferException {
-    return Empty.parseFrom(byteString);
-  }
-
-  @Override
-  protected void validateKey(RsaSsaPkcs1PublicKey pubKey) throws GeneralSecurityException {
-    Validators.validateVersion(pubKey.getVersion(), VERSION);
+  public void validateKey(RsaSsaPkcs1PublicKey pubKey) throws GeneralSecurityException {
+    Validators.validateVersion(pubKey.getVersion(), getVersion());
     Validators.validateRsaModulusSize((new BigInteger(1, pubKey.getN().toByteArray())).bitLength());
     SigUtil.validateRsaSsaPkcs1Params(pubKey.getParams());
   }
-
-  @Override
-  protected void validateKeyFormat(Empty unused) {}
 }
diff --git a/java/src/main/java/com/google/crypto/tink/signature/SignatureConfig.java b/java/src/main/java/com/google/crypto/tink/signature/SignatureConfig.java
index 5bd5eb2..7b3e027 100644
--- a/java/src/main/java/com/google/crypto/tink/signature/SignatureConfig.java
+++ b/java/src/main/java/com/google/crypto/tink/signature/SignatureConfig.java
@@ -92,8 +92,8 @@
     Registry.registerAsymmetricKeyManagers(
         new Ed25519PrivateKeyManager(), new Ed25519PublicKeyManager(), true);
 
-    Registry.registerKeyManager(new RsaSsaPkcs1SignKeyManager());
-    Registry.registerKeyManager(new RsaSsaPkcs1VerifyKeyManager());
+    Registry.registerAsymmetricKeyManagers(
+        new RsaSsaPkcs1SignKeyManager(), new RsaSsaPkcs1VerifyKeyManager(), true);
 
     Registry.registerAsymmetricKeyManagers(
         new RsaSsaPssSignKeyManager(), new RsaSsaPssVerifyKeyManager(), true);
diff --git a/java/src/main/java/com/google/crypto/tink/signature/SignatureKeyTemplates.java b/java/src/main/java/com/google/crypto/tink/signature/SignatureKeyTemplates.java
index c1a48b3..a634bc9 100644
--- a/java/src/main/java/com/google/crypto/tink/signature/SignatureKeyTemplates.java
+++ b/java/src/main/java/com/google/crypto/tink/signature/SignatureKeyTemplates.java
@@ -206,7 +206,7 @@
             .build();
     return KeyTemplate.newBuilder()
         .setValue(format.toByteString())
-        .setTypeUrl(RsaSsaPkcs1SignKeyManager.TYPE_URL)
+        .setTypeUrl(new RsaSsaPkcs1SignKeyManager().getKeyType())
         .setOutputPrefixType(OutputPrefixType.TINK)
         .build();
   }
diff --git a/java/src/main/java/com/google/crypto/tink/signature/SignaturePemKeysetReader.java b/java/src/main/java/com/google/crypto/tink/signature/SignaturePemKeysetReader.java
index bcc6b3c..fb6a4fd 100644
--- a/java/src/main/java/com/google/crypto/tink/signature/SignaturePemKeysetReader.java
+++ b/java/src/main/java/com/google/crypto/tink/signature/SignaturePemKeysetReader.java
@@ -158,13 +158,13 @@
           RsaSsaPkcs1Params.newBuilder().setHashType(getHashType(pemKeyType)).build();
       RsaSsaPkcs1PublicKey pkcs1PubKey =
           RsaSsaPkcs1PublicKey.newBuilder()
-              .setVersion(RsaSsaPkcs1VerifyKeyManager.VERSION)
+              .setVersion(new RsaSsaPkcs1VerifyKeyManager().getVersion())
               .setParams(params)
               .setE(ByteString.copyFrom(key.getPublicExponent().toByteArray()))
               .setN(ByteString.copyFrom(key.getModulus().toByteArray()))
               .build();
       return KeyData.newBuilder()
-          .setTypeUrl(RsaSsaPkcs1VerifyKeyManager.TYPE_URL)
+          .setTypeUrl(new RsaSsaPkcs1VerifyKeyManager().getKeyType())
           .setValue(pkcs1PubKey.toByteString())
           .setKeyMaterialType(KeyData.KeyMaterialType.ASYMMETRIC_PUBLIC)
           .build();
diff --git a/java/src/test/java/com/google/crypto/tink/signature/RsaSsaPkcs1SignKeyManagerTest.java b/java/src/test/java/com/google/crypto/tink/signature/RsaSsaPkcs1SignKeyManagerTest.java
index 7371320..e7fdf81 100644
--- a/java/src/test/java/com/google/crypto/tink/signature/RsaSsaPkcs1SignKeyManagerTest.java
+++ b/java/src/test/java/com/google/crypto/tink/signature/RsaSsaPkcs1SignKeyManagerTest.java
@@ -20,7 +20,11 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.fail;
 
+import com.google.crypto.tink.KeyManager;
+import com.google.crypto.tink.KeyManagerImpl;
 import com.google.crypto.tink.KeysetHandle;
+import com.google.crypto.tink.PrivateKeyManager;
+import com.google.crypto.tink.PrivateKeyManagerImpl;
 import com.google.crypto.tink.PublicKeySign;
 import com.google.crypto.tink.PublicKeyVerify;
 import com.google.crypto.tink.TestUtil;
@@ -67,7 +71,11 @@
     // Call newKey multiple times and make sure that it generates different keys.
     int numTests = 3;
     RsaSsaPkcs1PrivateKey[] privKeys = new RsaSsaPkcs1PrivateKey[numTests];
-    RsaSsaPkcs1SignKeyManager signManager = new RsaSsaPkcs1SignKeyManager();
+    PrivateKeyManager<PublicKeySign> signManager =
+        new PrivateKeyManagerImpl<>(
+            new RsaSsaPkcs1SignKeyManager(),
+            new RsaSsaPkcs1VerifyKeyManager(),
+            PublicKeySign.class);
     Set<String> keys = new TreeSet<String>();
 
     privKeys[0] =
@@ -89,7 +97,8 @@
     }
 
     // Test whether signer works correctly with the corresponding verifier.
-    RsaSsaPkcs1VerifyKeyManager verifyManager = new RsaSsaPkcs1VerifyKeyManager();
+    KeyManager<PublicKeyVerify> verifyManager =
+        new KeyManagerImpl<>(new RsaSsaPkcs1VerifyKeyManager(), PublicKeyVerify.class);
     for (int j = 0; j < numTests; j++) {
       PublicKeySign signer = signManager.getPrimitive(privKeys[j]);
       byte[] signature = signer.sign(msg);
@@ -128,10 +137,15 @@
     ByteString serialized = ByteString.copyFrom(new byte[128]);
     KeyTemplate keyTemplate =
         KeyTemplate.newBuilder()
-            .setTypeUrl(RsaSsaPkcs1SignKeyManager.TYPE_URL)
+            .setTypeUrl(new RsaSsaPkcs1SignKeyManager().getKeyType())
             .setValue(serialized)
             .build();
-    RsaSsaPkcs1SignKeyManager keyManager = new RsaSsaPkcs1SignKeyManager();
+
+    PrivateKeyManager<PublicKeySign> keyManager =
+        new PrivateKeyManagerImpl<>(
+            new RsaSsaPkcs1SignKeyManager(),
+            new RsaSsaPkcs1VerifyKeyManager(),
+            PublicKeySign.class);
     try {
       keyManager.newKey(serialized);
       fail("Corrupted format, should have thrown exception");
@@ -156,14 +170,19 @@
     KeysetHandle privateHandle =
         KeysetHandle.generateNew(SignatureKeyTemplates.RSA_SSA_PKCS1_3072_SHA256_F4);
     KeyData privateKeyData = TestUtil.getKeyset(privateHandle).getKey(0).getKeyData();
-    RsaSsaPkcs1SignKeyManager privateManager = new RsaSsaPkcs1SignKeyManager();
+    PrivateKeyManager<PublicKeySign> privateManager =
+        new PrivateKeyManagerImpl<>(
+            new RsaSsaPkcs1SignKeyManager(),
+            new RsaSsaPkcs1VerifyKeyManager(),
+            PublicKeySign.class);
     KeyData publicKeyData = privateManager.getPublicKeyData(privateKeyData.getValue());
-    assertEquals(RsaSsaPkcs1VerifyKeyManager.TYPE_URL, publicKeyData.getTypeUrl());
+    assertEquals(new RsaSsaPkcs1VerifyKeyManager().getKeyType(), publicKeyData.getTypeUrl());
     assertEquals(KeyData.KeyMaterialType.ASYMMETRIC_PUBLIC, publicKeyData.getKeyMaterialType());
     RsaSsaPkcs1PrivateKey privateKey = RsaSsaPkcs1PrivateKey.parseFrom(privateKeyData.getValue());
     assertArrayEquals(
         privateKey.getPublicKey().toByteArray(), publicKeyData.getValue().toByteArray());
-    RsaSsaPkcs1VerifyKeyManager publicManager = new RsaSsaPkcs1VerifyKeyManager();
+    KeyManager<PublicKeyVerify> publicManager =
+        new KeyManagerImpl<>(new RsaSsaPkcs1VerifyKeyManager(), PublicKeyVerify.class);
     PublicKeySign signer = privateManager.getPrimitive(privateKeyData.getValue());
     PublicKeyVerify verifier = publicManager.getPrimitive(publicKeyData.getValue());
     byte[] message = Random.randBytes(20);
diff --git a/java/src/test/java/com/google/crypto/tink/signature/RsaSsaPkcs1VerifyKeyManagerTest.java b/java/src/test/java/com/google/crypto/tink/signature/RsaSsaPkcs1VerifyKeyManagerTest.java
index d62a656..204f58e 100644
--- a/java/src/test/java/com/google/crypto/tink/signature/RsaSsaPkcs1VerifyKeyManagerTest.java
+++ b/java/src/test/java/com/google/crypto/tink/signature/RsaSsaPkcs1VerifyKeyManagerTest.java
@@ -18,6 +18,8 @@
 import static com.google.crypto.tink.TestUtil.assertExceptionContains;
 import static org.junit.Assert.fail;
 
+import com.google.crypto.tink.KeyManager;
+import com.google.crypto.tink.KeyManagerImpl;
 import com.google.crypto.tink.PublicKeyVerify;
 import com.google.crypto.tink.TestUtil;
 import com.google.crypto.tink.TestUtil.BytesMutation;
@@ -76,7 +78,8 @@
     for (NistTestVector t : nistTestVectors) {
       RsaSsaPkcs1PublicKey pubKey =
           TestUtil.createRsaSsaPkcs1PubKey(t.modulus, t.exponent, t.hashType);
-      RsaSsaPkcs1VerifyKeyManager keyManager = new RsaSsaPkcs1VerifyKeyManager();
+      KeyManager<PublicKeyVerify> keyManager =
+          new KeyManagerImpl<>(new RsaSsaPkcs1VerifyKeyManager(), PublicKeyVerify.class);
       PublicKeyVerify verifier = keyManager.getPrimitive(pubKey);
       try {
         verifier.verify(t.sig, t.msg);
@@ -104,7 +107,8 @@
       RsaSsaPkcs1PublicKey pubKey =
           TestUtil.createRsaSsaPkcs1PubKey(
               TestUtil.hexDecode("23"), TestUtil.hexDecode("03"), HashType.SHA256);
-      RsaSsaPkcs1VerifyKeyManager keyManager = new RsaSsaPkcs1VerifyKeyManager();
+      KeyManager<PublicKeyVerify> keyManager =
+          new KeyManagerImpl<>(new RsaSsaPkcs1VerifyKeyManager(), PublicKeyVerify.class);
       keyManager.getPrimitive(pubKey);
       fail("Invalid modulus, should have thrown exception");
     } catch (GeneralSecurityException e) {
@@ -118,7 +122,8 @@
     try {
       RsaSsaPkcs1PublicKey pubKey =
           TestUtil.createRsaSsaPkcs1PubKey(t.modulus, t.exponent, HashType.SHA1);
-      RsaSsaPkcs1VerifyKeyManager keyManager = new RsaSsaPkcs1VerifyKeyManager();
+      KeyManager<PublicKeyVerify> keyManager =
+          new KeyManagerImpl<>(new RsaSsaPkcs1VerifyKeyManager(), PublicKeyVerify.class);
       keyManager.getPrimitive(pubKey);
       fail("Invalid hash, should have thrown exception");
     } catch (GeneralSecurityException e) {
diff --git a/java/src/test/java/com/google/crypto/tink/signature/SignatureKeyTemplatesTest.java b/java/src/test/java/com/google/crypto/tink/signature/SignatureKeyTemplatesTest.java
index c695b36..06b9f1a 100644
--- a/java/src/test/java/com/google/crypto/tink/signature/SignatureKeyTemplatesTest.java
+++ b/java/src/test/java/com/google/crypto/tink/signature/SignatureKeyTemplatesTest.java
@@ -128,7 +128,7 @@
   @Test
   public void testRSA_SSA_PKCS1_3072_SHA256_F4() throws Exception {
     KeyTemplate template = SignatureKeyTemplates.RSA_SSA_PKCS1_3072_SHA256_F4;
-    assertEquals(RsaSsaPkcs1SignKeyManager.TYPE_URL, template.getTypeUrl());
+    assertEquals(new RsaSsaPkcs1SignKeyManager().getKeyType(), template.getTypeUrl());
     assertEquals(OutputPrefixType.TINK, template.getOutputPrefixType());
     RsaSsaPkcs1KeyFormat format = RsaSsaPkcs1KeyFormat.parseFrom(template.getValue());
 
@@ -142,7 +142,7 @@
   @Test
   public void testRSA_SSA_PKCS1_4096_SHA512_F4() throws Exception {
     KeyTemplate template = SignatureKeyTemplates.RSA_SSA_PKCS1_4096_SHA512_F4;
-    assertEquals(RsaSsaPkcs1SignKeyManager.TYPE_URL, template.getTypeUrl());
+    assertEquals(new RsaSsaPkcs1SignKeyManager().getKeyType(), template.getTypeUrl());
     assertEquals(OutputPrefixType.TINK, template.getOutputPrefixType());
     RsaSsaPkcs1KeyFormat format = RsaSsaPkcs1KeyFormat.parseFrom(template.getValue());