blob: 99e8b1942b13cfdf0daf006f9be3f17c491f9915 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
////////////////////////////////////////////////////////////////////////////////
package com.google.crypto.tink.jwt;
import static com.google.crypto.tink.internal.Util.UTF_8;
import com.google.crypto.tink.proto.OutputPrefixType;
import com.google.crypto.tink.subtle.Base64;
import com.google.gson.JsonObject;
import java.nio.ByteBuffer;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.CharsetDecoder;
import java.security.GeneralSecurityException;
import java.security.InvalidAlgorithmParameterException;
import java.util.Optional;
final class JwtFormat {
static class Parts {
String unsignedCompact;
byte[] signatureOrMac;
String header;
String payload;
Parts(
String unsignedCompact, byte[] signatureOrMac, String header, String payload) {
this.unsignedCompact = unsignedCompact;
this.signatureOrMac = signatureOrMac;
this.header = header;
this.payload = payload;
}
}
private JwtFormat() {}
static boolean isValidUrlsafeBase64Char(char c) {
return (((c >= 'a') && (c <= 'z'))
|| ((c >= 'A') && (c <= 'Z'))
|| ((c >= '0') && (c <= '9'))
|| ((c == '-') || (c == '_')));
}
// We need this validation, since String(data, UTF_8) ignores invalid characters.
static void validateUtf8(byte[] data) throws JwtInvalidException {
CharsetDecoder decoder = UTF_8.newDecoder();
try {
decoder.decode(ByteBuffer.wrap(data));
} catch (CharacterCodingException ex) {
throw new JwtInvalidException(ex.getMessage());
}
}
static byte[] strictUrlSafeDecode(String encodedData) throws JwtInvalidException {
for (int i = 0; i < encodedData.length(); i++) {
char c = encodedData.charAt(i);
if (!isValidUrlsafeBase64Char(c)) {
throw new JwtInvalidException("invalid encoding");
}
}
try {
return Base64.urlSafeDecode(encodedData);
} catch (IllegalArgumentException ex) {
throw new JwtInvalidException("invalid encoding: " + ex);
}
}
private static void validateAlgorithm(String algo) throws InvalidAlgorithmParameterException {
switch (algo) {
case "HS256":
case "HS384":
case "HS512":
case "ES256":
case "ES384":
case "ES512":
case "RS256":
case "RS384":
case "RS512":
case "PS256":
case "PS384":
case "PS512":
return;
default:
throw new InvalidAlgorithmParameterException("invalid algorithm: " + algo);
}
}
static String createHeader(String algorithm, Optional<String> typeHeader, Optional<String> kid)
throws InvalidAlgorithmParameterException {
validateAlgorithm(algorithm);
JsonObject header = new JsonObject();
if (kid.isPresent()) {
header.addProperty(JwtNames.HEADER_KEY_ID, kid.get());
}
header.addProperty(JwtNames.HEADER_ALGORITHM, algorithm);
if (typeHeader.isPresent()) {
header.addProperty(JwtNames.HEADER_TYPE, typeHeader.get());
}
return Base64.urlSafeEncode(header.toString().getBytes(UTF_8));
}
private static void validateKidInHeader(String expectedKid, JsonObject parsedHeader)
throws JwtInvalidException {
String kid = getStringHeader(parsedHeader, JwtNames.HEADER_KEY_ID);
if (!kid.equals(expectedKid)) {
throw new JwtInvalidException("invalid kid in header");
}
}
static void validateHeader(
JsonObject parsedHeader,
String algorithmFromKey,
Optional<String> kidFromKey,
boolean allowKidAbsent)
throws GeneralSecurityException {
String receivedAlgorithm = JwtFormat.getStringHeader(parsedHeader, JwtNames.HEADER_ALGORITHM);
if (!receivedAlgorithm.equals(algorithmFromKey)) {
throw new InvalidAlgorithmParameterException(
String.format(
"invalid algorithm; expected %s, got %s", algorithmFromKey, receivedAlgorithm));
}
if (parsedHeader.has(JwtNames.HEADER_CRITICAL)) {
throw new JwtInvalidException("all tokens with crit headers are rejected");
}
boolean headerHasKid = parsedHeader.has(JwtNames.HEADER_KEY_ID);
if (!headerHasKid && allowKidAbsent) {
return;
}
if (!headerHasKid && !allowKidAbsent) {
throw new JwtInvalidException("missing kid in header");
}
// Header is guaranteed to have a kid at this point.
if (!kidFromKey.isPresent()) {
// We allow the header to have a kid when the key does not have one (which implies that
// KidStrategy = IGNORED)
return;
}
String kid = JwtFormat.getStringHeader(parsedHeader, JwtNames.HEADER_KEY_ID);
if (!kid.equals(kidFromKey.get())) {
throw new JwtInvalidException("invalid kid in header");
}
}
/**
* Validates the parsed header.
*
* tinkKid should only be set for keys with output prefix type TINK. customKid should only
* be set for keys with output prefix type RAW. They should not be set at the same time.
*/
static void validateHeader(
String expectedAlgorithm,
Optional<String> tinkKid,
Optional<String> customKid,
JsonObject parsedHeader)
throws InvalidAlgorithmParameterException, JwtInvalidException {
validateAlgorithm(expectedAlgorithm);
String algorithm = getStringHeader(parsedHeader, JwtNames.HEADER_ALGORITHM);
if (!algorithm.equals(expectedAlgorithm)) {
throw new InvalidAlgorithmParameterException(
String.format(
"invalid algorithm; expected %s, got %s", expectedAlgorithm, algorithm));
}
if (parsedHeader.has(JwtNames.HEADER_CRITICAL)) {
throw new JwtInvalidException("all tokens with crit headers are rejected");
}
if (tinkKid.isPresent() && customKid.isPresent()) {
throw new JwtInvalidException("custom_kid can only be set for RAW keys.");
}
boolean headerHasKid = parsedHeader.has(JwtNames.HEADER_KEY_ID);
if (tinkKid.isPresent()) {
if (!headerHasKid) {
// for output prefix type TINK, the kid header is required.
throw new JwtInvalidException("missing kid in header");
}
validateKidInHeader(tinkKid.get(), parsedHeader);
}
if (customKid.isPresent() && headerHasKid) {
// for output prefix type RAW, the kid header is not required, even if custom kid is set.
validateKidInHeader(customKid.get(), parsedHeader);
}
// Ignore all other headers
}
static Optional<String> getTypeHeader(JsonObject header) throws JwtInvalidException {
if (header.has(JwtNames.HEADER_TYPE)) {
return Optional.of(getStringHeader(header, JwtNames.HEADER_TYPE));
}
return Optional.empty();
}
static String getStringHeader(JsonObject header, String name) throws JwtInvalidException {
if (!header.has(name)) {
throw new JwtInvalidException("header " + name + " does not exist");
}
if (!header.get(name).isJsonPrimitive() || !header.get(name).getAsJsonPrimitive().isString()) {
throw new JwtInvalidException("header " + name + " is not a string");
}
return header.get(name).getAsString();
}
static String decodeHeader(String headerStr) throws JwtInvalidException {
byte[] data = strictUrlSafeDecode(headerStr);
validateUtf8(data);
return new String(data, UTF_8);
}
static String encodePayload(String jsonPayload) {
return Base64.urlSafeEncode(jsonPayload.getBytes(UTF_8));
}
static String decodePayload(String payloadStr) throws JwtInvalidException {
byte[] data = strictUrlSafeDecode(payloadStr);
validateUtf8(data);
return new String(data, UTF_8);
}
static String encodeSignature(byte[] signature) {
return Base64.urlSafeEncode(signature);
}
static byte[] decodeSignature(String signatureStr) throws JwtInvalidException {
return strictUrlSafeDecode(signatureStr);
}
static Optional<String> getKid(int keyId, OutputPrefixType prefix) throws JwtInvalidException {
if (prefix == OutputPrefixType.RAW) {
return Optional.empty();
}
if (prefix == OutputPrefixType.TINK) {
byte[] bigEndianKeyId = ByteBuffer.allocate(4).putInt(keyId).array();
return Optional.of(Base64.urlSafeEncode(bigEndianKeyId));
}
throw new JwtInvalidException("unsupported output prefix type");
}
static Optional<Integer> getKeyId(String kid) {
byte[] encodedKeyId = Base64.urlSafeDecode(kid);
if (encodedKeyId.length != 4) {
return Optional.empty();
}
return Optional.of(ByteBuffer.wrap(encodedKeyId).getInt());
}
static Parts splitSignedCompact(String signedCompact) throws JwtInvalidException {
validateASCII(signedCompact);
int sigPos = signedCompact.lastIndexOf('.');
if (sigPos < 0) {
throw new JwtInvalidException(
"only tokens in JWS compact serialization format are supported");
}
String unsignedCompact = signedCompact.substring(0, sigPos);
String encodedMac = signedCompact.substring(sigPos + 1);
byte[] mac = decodeSignature(encodedMac);
int payloadPos = unsignedCompact.indexOf('.');
if (payloadPos < 0) {
throw new JwtInvalidException(
"only tokens in JWS compact serialization format are supported");
}
String encodedHeader = unsignedCompact.substring(0, payloadPos);
String encodedPayload = unsignedCompact.substring(payloadPos + 1);
if (encodedPayload.indexOf('.') > 0) {
throw new JwtInvalidException(
"only tokens in JWS compact serialization format are supported");
}
String header = decodeHeader(encodedHeader);
String payload = decodePayload(encodedPayload);
return new Parts(unsignedCompact, mac, header, payload);
}
static String createUnsignedCompact(String algorithm, Optional<String> kid, RawJwt rawJwt)
throws InvalidAlgorithmParameterException, JwtInvalidException {
String jsonPayload = rawJwt.getJsonPayload();
Optional<String> typeHeader =
rawJwt.hasTypeHeader() ? Optional.of(rawJwt.getTypeHeader()) : Optional.empty();
return createHeader(algorithm, typeHeader, kid) + "." + encodePayload(jsonPayload);
}
static String createSignedCompact(String unsignedCompact, byte[] signature) {
return unsignedCompact + "." + encodeSignature(signature);
}
static void validateASCII(String data) throws JwtInvalidException {
for (int i = 0; i < data.length(); i++) {
char c = data.charAt(i);
if ((c & 0x80) > 0) {
throw new JwtInvalidException("Non ascii character");
}
}
}
}