blob: 7c1d78c44593d0147fbb8ae6794845ca7c2074e0 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
////////////////////////////////////////////////////////////////////////////////
package com.google.crypto.tink.jwt;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.crypto.tink.subtle.Base64;
import com.google.gson.JsonObject;
import java.nio.ByteBuffer;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.CharsetDecoder;
import java.security.InvalidAlgorithmParameterException;
import java.util.Optional;
final class JwtFormat {
static class Parts {
String unsignedCompact;
byte[] signatureOrMac;
String header;
String payload;
Parts(
String unsignedCompact, byte[] signatureOrMac, String header, String payload) {
this.unsignedCompact = unsignedCompact;
this.signatureOrMac = signatureOrMac;
this.header = header;
this.payload = payload;
}
}
private JwtFormat() {}
static boolean isValidUrlsafeBase64Char(char c) {
return (((c >= 'a') && (c <= 'z'))
|| ((c >= 'A') && (c <= 'Z'))
|| ((c >= '0') && (c <= '9'))
|| ((c == '-') || (c == '_')));
}
// We need this validation, since String(data, UTF_8) ignores invalid characters.
static void validateUtf8(byte[] data) throws JwtInvalidException {
CharsetDecoder decoder = UTF_8.newDecoder();
try {
decoder.decode(ByteBuffer.wrap(data));
} catch (CharacterCodingException ex) {
throw new JwtInvalidException(ex.getMessage());
}
}
static byte[] strictUrlSafeDecode(String encodedData) throws JwtInvalidException {
for (char c : encodedData.toCharArray()) {
if (!isValidUrlsafeBase64Char(c)) {
throw new JwtInvalidException("invalid encoding");
}
}
try {
return Base64.urlSafeDecode(encodedData);
} catch (IllegalArgumentException ex) {
throw new JwtInvalidException("invalid encoding: " + ex);
}
}
private static String validateAlgorithm(String algo) throws InvalidAlgorithmParameterException {
switch (algo) {
case "HS256":
case "HS384":
case "HS512":
case "ES256":
case "ES384":
case "ES512":
case "RS256":
case "RS384":
case "RS512":
case "PS256":
case "PS384":
case "PS512":
return algo;
default:
throw new InvalidAlgorithmParameterException("invalid algorithm: " + algo);
}
}
static String createHeader(String algorithm, Optional<String> typeHeader)
throws InvalidAlgorithmParameterException {
validateAlgorithm(algorithm);
JsonObject header = new JsonObject();
header.addProperty(JwtNames.HEADER_ALGORITHM, algorithm);
if (typeHeader.isPresent()) {
header.addProperty(JwtNames.HEADER_TYPE, typeHeader.get());
}
return Base64.urlSafeEncode(header.toString().getBytes(UTF_8));
}
static void validateHeader(String expectedAlgorithm, JsonObject parsedHeader)
throws InvalidAlgorithmParameterException, JwtInvalidException {
validateAlgorithm(expectedAlgorithm);
if (!parsedHeader.has(JwtNames.HEADER_ALGORITHM)) {
throw new JwtInvalidException("missing algorithm in header");
}
for (String name : parsedHeader.keySet()) {
if (name.equals(JwtNames.HEADER_ALGORITHM)) {
String algorithm = getStringHeader(parsedHeader, JwtNames.HEADER_ALGORITHM);
if (!algorithm.equals(expectedAlgorithm)) {
throw new InvalidAlgorithmParameterException(
String.format(
"invalid algorithm; expected %s, got %s", expectedAlgorithm, algorithm));
}
} else if (name.equals(JwtNames.HEADER_CRITICAL)) {
throw new JwtInvalidException(
"all tokens with crit headers are rejected");
}
// Ignore all other headers
}
}
static Optional<String> getTypeHeader(JsonObject header) throws JwtInvalidException {
if (header.has(JwtNames.HEADER_TYPE)) {
return Optional.of(getStringHeader(header, JwtNames.HEADER_TYPE));
}
return Optional.empty();
}
private static String getStringHeader(JsonObject header, String name) throws JwtInvalidException {
if (!header.has(name)) {
throw new JwtInvalidException("header " + name + " does not exist");
}
if (!header.get(name).isJsonPrimitive() || !header.get(name).getAsJsonPrimitive().isString()) {
throw new JwtInvalidException("header " + name + " is not a string");
}
return header.get(name).getAsString();
}
static String decodeHeader(String headerStr) throws JwtInvalidException {
byte[] data = strictUrlSafeDecode(headerStr);
validateUtf8(data);
return new String(data, UTF_8);
}
static String encodePayload(String jsonPayload) {
return Base64.urlSafeEncode(jsonPayload.getBytes(UTF_8));
}
static String decodePayload(String payloadStr) throws JwtInvalidException {
byte[] data = strictUrlSafeDecode(payloadStr);
validateUtf8(data);
return new String(data, UTF_8);
}
static String encodeSignature(byte[] signature) {
return Base64.urlSafeEncode(signature);
}
static byte[] decodeSignature(String signatureStr) throws JwtInvalidException {
return strictUrlSafeDecode(signatureStr);
}
static Parts splitSignedCompact(String signedCompact) throws JwtInvalidException {
validateASCII(signedCompact);
int sigPos = signedCompact.lastIndexOf('.');
if (sigPos < 0) {
throw new JwtInvalidException(
"only tokens in JWS compact serialization format are supported");
}
String unsignedCompact = signedCompact.substring(0, sigPos);
String encodedMac = signedCompact.substring(sigPos + 1);
byte[] mac = decodeSignature(encodedMac);
int payloadPos = unsignedCompact.indexOf('.');
if (payloadPos < 0) {
throw new JwtInvalidException(
"only tokens in JWS compact serialization format are supported");
}
String encodedHeader = unsignedCompact.substring(0, payloadPos);
String encodedPayload = unsignedCompact.substring(payloadPos + 1);
if (encodedPayload.indexOf('.') > 0) {
throw new JwtInvalidException(
"only tokens in JWS compact serialization format are supported");
}
String header = decodeHeader(encodedHeader);
String payload = decodePayload(encodedPayload);
return new Parts(unsignedCompact, mac, header, payload);
}
// TODO(juerg): Refactor this to createUnsignedCompact(algorithm, rawJwt)
static String createUnsignedCompact(
String algorithm, Optional<String> typeHeader, String jsonPayload)
throws InvalidAlgorithmParameterException {
return createHeader(algorithm, typeHeader) + "." + encodePayload(jsonPayload);
}
static String createSignedCompact(String unsignedCompact, byte[] signature) {
return unsignedCompact + "." + encodeSignature(signature);
}
static void validateASCII(String data) throws JwtInvalidException {
for (char c : data.toCharArray()) {
if ((c & 0x80) > 0) {
throw new JwtInvalidException("Non ascii character");
}
}
}
}