Add type header to JWTs in C++.

PiperOrigin-RevId: 374206249
diff --git a/cc/jwt/internal/BUILD.bazel b/cc/jwt/internal/BUILD.bazel
index b1d51d2..f0f14de 100644
--- a/cc/jwt/internal/BUILD.bazel
+++ b/cc/jwt/internal/BUILD.bazel
@@ -87,6 +87,7 @@
         "//util:status",
         "//util:statusor",
         "@com_google_absl//absl/strings",
+        "@com_google_protobuf//:protobuf",
     ],
 )
 
@@ -94,6 +95,7 @@
     name = "jwt_format_test",
     srcs = ["jwt_format_test.cc"],
     deps = [
+        ":json_util",
         ":jwt_format",
         "//util:test_matchers",
         "//util:test_util",
@@ -528,6 +530,7 @@
     hdrs = ["jwt_public_key_verify_impl.h"],
     include_prefix = "tink/jwt/internal",
     deps = [
+        ":json_util",
         ":jwt_format",
         "//:public_key_verify",
         "//jwt:jwt_public_key_verify",
diff --git a/cc/jwt/internal/CMakeLists.txt b/cc/jwt/internal/CMakeLists.txt
index 7cd94d6..b9f0f10 100644
--- a/cc/jwt/internal/CMakeLists.txt
+++ b/cc/jwt/internal/CMakeLists.txt
@@ -54,6 +54,7 @@
     jwt_format.cc
     jwt_format.h
   DEPS
+    protobuf::libprotobuf
     tink::jwt::internal::json_util
     tink::util::status
     tink::util::statusor
@@ -64,6 +65,7 @@
   NAME jwt_format_test
   SRCS jwt_format_test.cc
   DEPS
+    tink::jwt::internal::json_util
     tink::jwt::internal::jwt_format
     tink::util::test_matchers
     tink::util::test_util
@@ -488,6 +490,7 @@
     jwt_public_key_verify_impl.cc
     jwt_public_key_verify_impl.h
   DEPS
+    tink::jwt::internal::json_util
     tink::jwt::internal::jwt_format
     tink::core::public_key_verify
     tink::jwt::jwt_public_key_verify
diff --git a/cc/jwt/internal/jwt_format.cc b/cc/jwt/internal/jwt_format.cc
index e6bc55a..663ca78 100644
--- a/cc/jwt/internal/jwt_format.cc
+++ b/cc/jwt/internal/jwt_format.cc
@@ -51,22 +51,29 @@
   return StrictWebSafeBase64Unescape(header, json_header);
 }
 
-std::string CreateHeader(absl::string_view algorithm) {
-  std::string header = absl::StrCat(R"({"alg":")", algorithm, R"("})");
-  return EncodeHeader(header);
+std::string CreateHeader(absl::string_view algorithm,
+                         absl::optional<absl::string_view> type_header) {
+  google::protobuf::Struct header;
+  auto fields = header.mutable_fields();
+  if (type_header.has_value()) {
+    google::protobuf::Value type_value;
+    type_value.set_string_value(std::string(type_header.value()));
+    (*fields)["typ"] = type_value;
+  }
+  google::protobuf::Value alg_value;
+  alg_value.set_string_value(std::string(algorithm));
+  (*fields)["alg"] = alg_value;
+  util::StatusOr<std::string> json_or =
+      jwt_internal::ProtoStructToJsonString(header);
+  if (!json_or.ok()) {
+    // do something
+  }
+  return EncodeHeader(json_or.ValueOrDie());
 }
 
-util::Status ValidateHeader(absl::string_view encoded_header,
+util::Status ValidateHeader(const google::protobuf::Struct& header,
                             absl::string_view algorithm) {
-  std::string json_header;
-  if (!DecodeHeader(encoded_header, &json_header)) {
-    return util::Status(util::error::INVALID_ARGUMENT, "invalid header");
-  }
-  auto proto_or = JsonStringToProtoStruct(json_header);
-  if (!proto_or.ok()) {
-    return proto_or.status();
-  }
-  auto fields = proto_or.ValueOrDie().fields();
+  auto fields = header.fields();
   auto it = fields.find("alg");
   if (it == fields.end()) {
     return util::Status(util::error::INVALID_ARGUMENT, "header is missing alg");
@@ -85,6 +92,19 @@
   return util::OkStatus();
 }
 
+absl::optional<std::string> GetTypeHeader(
+    const google::protobuf::Struct& header) {
+  auto it = header.fields().find("typ");
+  if (it == header.fields().end()) {
+    return absl::nullopt;
+  }
+  const auto& value = it->second;
+  if (value.kind_case() != google::protobuf::Value::kStringValue) {
+    return absl::nullopt;
+  }
+  return value.string_value();
+}
+
 std::string EncodePayload(absl::string_view json_payload) {
   return absl::WebSafeBase64Escape(json_payload);
 }
diff --git a/cc/jwt/internal/jwt_format.h b/cc/jwt/internal/jwt_format.h
index bf4c746..0ba56e1 100644
--- a/cc/jwt/internal/jwt_format.h
+++ b/cc/jwt/internal/jwt_format.h
@@ -17,6 +17,7 @@
 #ifndef TINK_JWT_INTERNAL_JWT_FORMAT_H_
 #define TINK_JWT_INTERNAL_JWT_FORMAT_H_
 
+#include "google/protobuf/struct.pb.h"
 #include "tink/util/status.h"
 #include "tink/util/statusor.h"
 
@@ -27,9 +28,12 @@
 std::string EncodeHeader(absl::string_view json_header);
 bool DecodeHeader(absl::string_view header, std::string* json_header);
 
-std::string CreateHeader(absl::string_view algorithm);
-util::Status ValidateHeader(absl::string_view encoded_header,
+std::string CreateHeader(absl::string_view algorithm,
+                         absl::optional<absl::string_view> type_header);
+util::Status ValidateHeader(const google::protobuf::Struct& header,
                             absl::string_view algorithm);
+absl::optional<std::string> GetTypeHeader(
+    const google::protobuf::Struct& header);
 
 std::string EncodePayload(absl::string_view json_payload);
 bool DecodePayload(absl::string_view payload, std::string* json_payload);
diff --git a/cc/jwt/internal/jwt_format_test.cc b/cc/jwt/internal/jwt_format_test.cc
index f68302e..8d17b5b 100644
--- a/cc/jwt/internal/jwt_format_test.cc
+++ b/cc/jwt/internal/jwt_format_test.cc
@@ -18,6 +18,7 @@
 
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
+#include "tink/jwt/internal/json_util.h"
 #include "tink/util/test_matchers.h"
 #include "tink/util/test_util.h"
 
@@ -79,72 +80,105 @@
   // Example from https://tools.ietf.org/html/rfc7515#appendix-A.1
   std::string encoded_header = "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9";
 
-  std::string output;
-  ASSERT_TRUE(DecodeHeader(encoded_header, &output));
-  EXPECT_THAT(output, Eq("{\"typ\":\"JWT\",\r\n \"alg\":\"HS256\"}"));
+  std::string json_header;
+  ASSERT_TRUE(DecodeHeader(encoded_header, &json_header));
+  EXPECT_THAT(json_header, Eq("{\"typ\":\"JWT\",\r\n \"alg\":\"HS256\"}"));
 
-  EXPECT_THAT(ValidateHeader(encoded_header, "HS256"), IsOk());
-  EXPECT_FALSE(ValidateHeader(encoded_header, "RS256").ok());
+  util::StatusOr<google::protobuf::Struct> header_or =
+      JsonStringToProtoStruct(json_header);
+  EXPECT_THAT(header_or.status(), IsOk());
+
+  EXPECT_THAT(ValidateHeader(header_or.ValueOrDie(), "HS256"), IsOk());
+  EXPECT_FALSE(ValidateHeader(header_or.ValueOrDie(), "RS256").ok());
 }
 
 TEST(JwtFormat, DecodeAndValidateFixedHeaderRS256) {
   // Example from https://tools.ietf.org/html/rfc7515#appendix-A.2
   std::string encoded_header = "eyJhbGciOiJSUzI1NiJ9";
 
-  std::string output;
-  ASSERT_TRUE(DecodeHeader(encoded_header, &output));
-  EXPECT_THAT(output, Eq(R"({"alg":"RS256"})"));
+  std::string json_header;
+  ASSERT_TRUE(DecodeHeader(encoded_header, &json_header));
+  EXPECT_THAT(json_header, Eq(R"({"alg":"RS256"})"));
 
-  EXPECT_THAT(ValidateHeader(encoded_header, "RS256"), IsOk());
-  EXPECT_FALSE(ValidateHeader(encoded_header, "HS256").ok());
+  util::StatusOr<google::protobuf::Struct> header_or =
+      JsonStringToProtoStruct(json_header);
+  EXPECT_THAT(header_or.status(), IsOk());
+
+  EXPECT_THAT(ValidateHeader(header_or.ValueOrDie(), "RS256"), IsOk());
+  EXPECT_FALSE(ValidateHeader(header_or.ValueOrDie(), "HS256").ok());
 }
 
 TEST(JwtFormat, CreateValidateHeader) {
-  std::string encoded_header = CreateHeader("PS384");
-  EXPECT_THAT(ValidateHeader(encoded_header, "PS384"), IsOk());
-  EXPECT_FALSE(ValidateHeader(encoded_header, "HS256").ok());
+  std::string encoded_header = CreateHeader("PS384", absl::nullopt);
+
+  std::string json_header;
+  ASSERT_TRUE(DecodeHeader(encoded_header, &json_header));
+
+  util::StatusOr<google::protobuf::Struct> header_or =
+      JsonStringToProtoStruct(json_header);
+  EXPECT_THAT(header_or.status(), IsOk());
+
+  EXPECT_THAT(ValidateHeader(header_or.ValueOrDie(), "PS384"), IsOk());
+  EXPECT_FALSE(ValidateHeader(header_or.ValueOrDie(), "HS256").ok());
+}
+
+TEST(JwtFormat, CreateValidateHeaderWithType) {
+  std::string encoded_header = CreateHeader("PS384", "JWT");
+
+  std::string json_header;
+  ASSERT_TRUE(DecodeHeader(encoded_header, &json_header));
+
+  util::StatusOr<google::protobuf::Struct> header_or =
+      JsonStringToProtoStruct(json_header);
+  EXPECT_THAT(header_or.status(), IsOk());
+
+  EXPECT_THAT(ValidateHeader(header_or.ValueOrDie(), "PS384"), IsOk());
+  EXPECT_FALSE(ValidateHeader(header_or.ValueOrDie(), "HS256").ok());
 }
 
 TEST(JwtFormat, ValidateEmptyHeaderFails) {
-  std::string header = "{}";
-  EXPECT_FALSE(ValidateHeader(EncodeHeader(header), "HS256").ok());
+  google::protobuf::Struct empty_header;
+  EXPECT_FALSE(ValidateHeader(empty_header, "HS256").ok());
 }
 
-TEST(JwtFormat, ValidateInvalidEncodedHeaderFails) {
-  EXPECT_FALSE(
-      ValidateHeader("eyJ0eXAiOiJKV1Q?LA0KICJhbGciOiJIUzI1NiJ9", "HS256").ok());
-}
+TEST(JwtFormat, ValidateHeaderWithUnknownTypeOk) {
+  std::string json_header = R"({"alg":"HS256","typ":"unknown"})";
+  util::StatusOr<google::protobuf::Struct> header_or =
+      JsonStringToProtoStruct(json_header);
+  EXPECT_THAT(header_or.status(), IsOk());
 
-TEST(JwtFormat, ValidateInvalidJsonHeaderFails) {
-  std::string header = R"({"alg":"HS256")";  // missing }
-  EXPECT_FALSE(ValidateHeader(EncodeHeader(header), "HS256").ok());
-}
-
-TEST(JwtFormat, ValidateHeaderIgnoresTyp) {
-  std::string header = R"({"alg":"HS256","typ":"unknown"})";
-  EXPECT_THAT(ValidateHeader(EncodeHeader(header), "HS256"), IsOk());
+  EXPECT_THAT(ValidateHeader(header_or.ValueOrDie(), "HS256"), IsOk());
 }
 
 TEST(JwtFormat, ValidateHeaderRejectsCrit) {
-  std::string header =
+  std::string json_header =
       R"({"alg":"HS256","crit":["http://example.invalid/UNDEFINED"],)"
       R"("http://example.invalid/UNDEFINED":true})";
-  EXPECT_FALSE(ValidateHeader(EncodeHeader(header), "HS256").ok());
+  util::StatusOr<google::protobuf::Struct> header_or =
+      JsonStringToProtoStruct(json_header);
+  EXPECT_THAT(header_or.status(), IsOk());
+  EXPECT_FALSE(ValidateHeader(header_or.ValueOrDie(), "HS256").ok());
 }
 
 TEST(JwtFormat, ValidateHeaderWithUnknownEntry) {
-  std::string header = R"({"alg":"HS256","unknown":"header"})";
-  EXPECT_THAT(ValidateHeader(EncodeHeader(header), "HS256"), IsOk());
+  std::string json_header = R"({"alg":"HS256","unknown":"header"})";
+  util::StatusOr<google::protobuf::Struct> header_or =
+      JsonStringToProtoStruct(json_header);
+  EXPECT_THAT(header_or.status(), IsOk());
+  EXPECT_THAT(ValidateHeader(header_or.ValueOrDie(), "HS256"), IsOk());
 }
 
 TEST(JwtFormat, ValidateHeaderWithInvalidAlgTypFails) {
-  std::string header = R"({"alg":true})";
-  EXPECT_FALSE(ValidateHeader(EncodeHeader(header), "HS256").ok());
+  std::string json_header = R"({"alg":true})";
+  util::StatusOr<google::protobuf::Struct> header_or =
+      JsonStringToProtoStruct(json_header);
+  EXPECT_THAT(header_or.status(), IsOk());
+  EXPECT_FALSE(ValidateHeader(header_or.ValueOrDie(), "HS256").ok());
 }
 
 TEST(JwtFormat, DecodeFixedPayload) {
   // Example from https://tools.ietf.org/html/rfc7519#section-3.1
-  std::string encoded_header =
+  std::string encoded_payload =
       "eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0"
       "dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ";
 
@@ -152,7 +186,7 @@
       "{\"iss\":\"joe\",\r\n \"exp\":1300819380,\r\n "
       "\"http://example.com/is_root\":true}";
   std::string output;
-  ASSERT_TRUE(DecodePayload(encoded_header, &output));
+  ASSERT_TRUE(DecodePayload(encoded_payload, &output));
   EXPECT_THAT(output, Eq(expected));
 }
 
diff --git a/cc/jwt/internal/jwt_mac_impl.cc b/cc/jwt/internal/jwt_mac_impl.cc
index f7ad3a5..8d4a8f5 100644
--- a/cc/jwt/internal/jwt_mac_impl.cc
+++ b/cc/jwt/internal/jwt_mac_impl.cc
@@ -27,8 +27,16 @@
 
 util::StatusOr<std::string> JwtMacImpl::ComputeMacAndEncode(
     const RawJwt& token) const {
-  std::string encoded_header = CreateHeader(algorithm_);
-  util::StatusOr<std::string> payload_or = token.ToString();
+  absl::optional<std::string> type_header;
+  if (token.HasTypeHeader()) {
+    util::StatusOr<std::string> type_or = token.GetTypeHeader();
+    if (!type_or.ok()) {
+      return type_or.status();
+    }
+    type_header = type_or.ValueOrDie();
+  }
+  std::string encoded_header = CreateHeader(algorithm_, type_header);
+  util::StatusOr<std::string> payload_or = token.GetJsonPayload();
   if (!payload_or.ok()) {
     return payload_or.status();
   }
@@ -64,7 +72,16 @@
         util::error::INVALID_ARGUMENT,
         "only tokens in JWS compact serialization format are supported");
   }
-  util::Status validate_header_result = ValidateHeader(parts[0], algorithm_);
+  std::string json_header;
+  if (!DecodeHeader(parts[0], &json_header)) {
+    return util::Status(util::error::INVALID_ARGUMENT, "invalid header");
+  }
+  auto header_or = JsonStringToProtoStruct(json_header);
+  if (!header_or.ok()) {
+    return header_or.status();
+  }
+  util::Status validate_header_result =
+      ValidateHeader(header_or.ValueOrDie(), algorithm_);
   if (!validate_header_result.ok()) {
     return validate_header_result;
   }
@@ -72,7 +89,8 @@
   if (!DecodePayload(parts[1], &json_payload)) {
     return util::Status(util::error::INVALID_ARGUMENT, "invalid JWT payload");
   }
-  auto raw_jwt_or = RawJwt::FromString(json_payload);
+  auto raw_jwt_or =
+      RawJwt::FromJson(GetTypeHeader(header_or.ValueOrDie()), json_payload);
   if (!raw_jwt_or.ok()) {
     return raw_jwt_or.status();
   }
diff --git a/cc/jwt/internal/jwt_mac_impl_test.cc b/cc/jwt/internal/jwt_mac_impl_test.cc
index 754f80e..31b0cb2 100644
--- a/cc/jwt/internal/jwt_mac_impl_test.cc
+++ b/cc/jwt/internal/jwt_mac_impl_test.cc
@@ -34,6 +34,7 @@
 #include "tink/util/test_util.h"
 
 using ::crypto::tink::test::IsOk;
+using ::crypto::tink::test::IsOkAndHolds;
 
 namespace crypto {
 namespace tink {
@@ -67,13 +68,16 @@
   std::unique_ptr<JwtMac> jwt_mac = std::move(jwt_mac_or.ValueOrDie());
 
   absl::Time now = absl::Now();
-  auto builder = RawJwtBuilder().SetIssuer("issuer");
+  auto builder =
+      RawJwtBuilder().SetTypeHeader("typeHeader").SetIssuer("issuer");
   ASSERT_THAT(builder.SetNotBefore(now - absl::Seconds(300)), IsOk());
   ASSERT_THAT(builder.SetIssuedAt(now), IsOk());
   ASSERT_THAT(builder.SetExpiration(now + absl::Seconds(300)), IsOk());
   auto raw_jwt_or = builder.Build();
   ASSERT_THAT(raw_jwt_or.status(), IsOk());
   RawJwt raw_jwt = raw_jwt_or.ValueOrDie();
+  EXPECT_TRUE(raw_jwt.HasTypeHeader());
+  EXPECT_THAT(raw_jwt.GetTypeHeader(), IsOkAndHolds("typeHeader"));
 
   util::StatusOr<std::string> compact_or =
       jwt_mac->ComputeMacAndEncode(raw_jwt);
@@ -86,7 +90,8 @@
       jwt_mac->VerifyMacAndDecode(compact, validator);
   ASSERT_THAT(verified_jwt_or.status(), IsOk());
   auto verified_jwt = verified_jwt_or.ValueOrDie();
-  EXPECT_THAT(verified_jwt.GetIssuer(), test::IsOkAndHolds("issuer"));
+  EXPECT_THAT(verified_jwt.GetTypeHeader(), IsOkAndHolds("typeHeader"));
+  EXPECT_THAT(verified_jwt.GetIssuer(), IsOkAndHolds("issuer"));
 
   JwtValidator validator2 = JwtValidatorBuilder().SetIssuer("unknown").Build();
   EXPECT_FALSE(jwt_mac->VerifyMacAndDecode(compact, validator2).ok());
@@ -110,9 +115,9 @@
       jwt_mac->VerifyMacAndDecode(compact, validator_1970);
   ASSERT_THAT(verified_jwt_or.status(), IsOk());
   auto verified_jwt = verified_jwt_or.ValueOrDie();
-  EXPECT_THAT(verified_jwt.GetIssuer(), test::IsOkAndHolds("joe"));
+  EXPECT_THAT(verified_jwt.GetIssuer(), IsOkAndHolds("joe"));
   EXPECT_THAT(verified_jwt.GetBooleanClaim("http://example.com/is_root"),
-              test::IsOkAndHolds(true));
+              IsOkAndHolds(true));
 
   // verification fails because token is expired
   JwtValidator validator_now = JwtValidatorBuilder().Build();
diff --git a/cc/jwt/internal/jwt_public_key_sign_impl.cc b/cc/jwt/internal/jwt_public_key_sign_impl.cc
index 9207298..ef1fda6 100644
--- a/cc/jwt/internal/jwt_public_key_sign_impl.cc
+++ b/cc/jwt/internal/jwt_public_key_sign_impl.cc
@@ -26,8 +26,16 @@
 
 util::StatusOr<std::string> JwtPublicKeySignImpl::SignAndEncode(
     const RawJwt& token) const {
-  std::string encoded_header = CreateHeader(algorithm_);
-  util::StatusOr<std::string> payload_or = token.ToString();
+  absl::optional<std::string> type_header;
+  if (token.HasTypeHeader()) {
+    util::StatusOr<std::string> type_or = token.GetTypeHeader();
+    if (!type_or.ok()) {
+      return type_or.status();
+    }
+    type_header = type_or.ValueOrDie();
+  }
+  std::string encoded_header = CreateHeader(algorithm_, type_header);
+  util::StatusOr<std::string> payload_or = token.GetJsonPayload();
   if (!payload_or.ok()) {
     return payload_or.status();
   }
diff --git a/cc/jwt/internal/jwt_public_key_sign_verify_impl_test.cc b/cc/jwt/internal/jwt_public_key_sign_verify_impl_test.cc
index 7faf810..ab87374 100644
--- a/cc/jwt/internal/jwt_public_key_sign_verify_impl_test.cc
+++ b/cc/jwt/internal/jwt_public_key_sign_verify_impl_test.cc
@@ -66,7 +66,8 @@
 
 TEST_F(JwtSignatureImplTest, CreateAndValidateToken) {
   absl::Time now = absl::Now();
-  auto builder = RawJwtBuilder().SetIssuer("issuer");
+  auto builder =
+      RawJwtBuilder().SetTypeHeader("typeHeader").SetIssuer("issuer");
   ASSERT_THAT(builder.SetNotBefore(now - absl::Seconds(300)), IsOk());
   ASSERT_THAT(builder.SetIssuedAt(now), IsOk());
   ASSERT_THAT(builder.SetExpiration(now + absl::Seconds(300)), IsOk());
@@ -86,6 +87,7 @@
       jwt_verify_->VerifyAndDecode(compact, validator);
   ASSERT_THAT(verified_jwt_or.status(), IsOk());
   auto verified_jwt = verified_jwt_or.ValueOrDie();
+  EXPECT_THAT(verified_jwt.GetTypeHeader(), test::IsOkAndHolds("typeHeader"));
   EXPECT_THAT(verified_jwt.GetIssuer(), test::IsOkAndHolds("issuer"));
 
   // Fails with wrong issuer
diff --git a/cc/jwt/internal/jwt_public_key_verify_impl.cc b/cc/jwt/internal/jwt_public_key_verify_impl.cc
index 9b5efd7..4d1f5c5 100644
--- a/cc/jwt/internal/jwt_public_key_verify_impl.cc
+++ b/cc/jwt/internal/jwt_public_key_verify_impl.cc
@@ -18,6 +18,7 @@
 
 #include "absl/strings/escaping.h"
 #include "absl/strings/str_split.h"
+#include "tink/jwt/internal/json_util.h"
 #include "tink/jwt/internal/jwt_format.h"
 
 namespace crypto {
@@ -46,7 +47,16 @@
         util::error::INVALID_ARGUMENT,
         "only tokens in JWS compact serialization format are supported");
   }
-  util::Status validate_header_result = ValidateHeader(parts[0], algorithm_);
+  std::string json_header;
+  if (!DecodeHeader(parts[0], &json_header)) {
+    return util::Status(util::error::INVALID_ARGUMENT, "invalid header");
+  }
+  auto header_or = JsonStringToProtoStruct(json_header);
+  if (!header_or.ok()) {
+    return header_or.status();
+  }
+  util::Status validate_header_result =
+      ValidateHeader(header_or.ValueOrDie(), algorithm_);
   if (!validate_header_result.ok()) {
     return validate_header_result;
   }
@@ -54,7 +64,8 @@
   if (!DecodePayload(parts[1], &json_payload)) {
     return util::Status(util::error::INVALID_ARGUMENT, "invalid JWT payload");
   }
-  auto raw_jwt_or = RawJwt::FromString(json_payload);
+  auto raw_jwt_or =
+      RawJwt::FromJson(GetTypeHeader(header_or.ValueOrDie()), json_payload);
   if (!raw_jwt_or.ok()) {
     return raw_jwt_or.status();
   }
diff --git a/cc/jwt/raw_jwt.cc b/cc/jwt/raw_jwt.cc
index ed8dcf5..a89972f 100644
--- a/cc/jwt/raw_jwt.cc
+++ b/cc/jwt/raw_jwt.cc
@@ -148,8 +148,9 @@
 
 }  // namespace
 
-util::StatusOr<RawJwt> RawJwt::FromString(absl::string_view json_string) {
-  auto proto_or = jwt_internal::JsonStringToProtoStruct(json_string);
+util::StatusOr<RawJwt> RawJwt::FromJson(absl::optional<std::string> type_header,
+                                        absl::string_view json_payload) {
+  auto proto_or = jwt_internal::JsonStringToProtoStruct(json_payload);
   if (!proto_or.ok()) {
     return proto_or.status();
   }
@@ -166,20 +167,31 @@
   if (!audStatus.ok()) {
     return audStatus;
   }
-  RawJwt token(proto);
+  RawJwt token(type_header, proto);
   return token;
 }
 
-util::StatusOr<std::string> RawJwt::ToString() const {
+util::StatusOr<std::string> RawJwt::GetJsonPayload() const {
   return jwt_internal::ProtoStructToJsonString(json_proto_);
 }
 
 RawJwt::RawJwt() {}
 
-RawJwt::RawJwt(google::protobuf::Struct json_proto) {
+RawJwt::RawJwt(absl::optional<std::string> type_header,
+               google::protobuf::Struct json_proto) {
+  type_header_ = type_header;
   json_proto_ = json_proto;
 }
 
+bool RawJwt::HasTypeHeader() const { return type_header_.has_value(); }
+
+util::StatusOr<std::string> RawJwt::GetTypeHeader() const {
+  if (!type_header_.has_value()) {
+    return util::Status(util::error::INVALID_ARGUMENT, "No type header found");
+  }
+  return *type_header_;
+}
+
 bool RawJwt::HasIssuer() const {
   return json_proto_.fields().contains(std::string(kJwtClaimIssuer));
 }
@@ -457,6 +469,11 @@
 
 RawJwtBuilder::RawJwtBuilder() {}
 
+RawJwtBuilder& RawJwtBuilder::SetTypeHeader(absl::string_view type_header) {
+  type_header_ = std::string(type_header);
+  return *this;
+}
+
 RawJwtBuilder& RawJwtBuilder::SetIssuer(absl::string_view issuer) {
   auto fields = json_proto_.mutable_fields();
   google::protobuf::Value value;
@@ -612,7 +629,7 @@
 }
 
 util::StatusOr<RawJwt> RawJwtBuilder::Build() {
-  RawJwt token(json_proto_);
+  RawJwt token(type_header_, json_proto_);
   return token;
 }
 
diff --git a/cc/jwt/raw_jwt.h b/cc/jwt/raw_jwt.h
index 73bc855..8ccca58 100644
--- a/cc/jwt/raw_jwt.h
+++ b/cc/jwt/raw_jwt.h
@@ -36,6 +36,8 @@
  public:
   RawJwt();
 
+  bool HasTypeHeader() const;
+  util::StatusOr<std::string> GetTypeHeader() const;
   bool HasIssuer() const;
   util::StatusOr<std::string> GetIssuer() const;
   bool HasSubject() const;
@@ -63,8 +65,9 @@
   util::StatusOr<std::string> GetJsonArrayClaim(absl::string_view name) const;
   std::vector<std::string> CustomClaimNames() const;
 
-  static util::StatusOr<RawJwt> FromString(absl::string_view json_string);
-  util::StatusOr<std::string> ToString() const;
+  static util::StatusOr<RawJwt> FromJson(
+      absl::optional<std::string> type_header, absl::string_view json_payload);
+  util::StatusOr<std::string> GetJsonPayload() const;
 
   // RawJwt objects are copiable and movable.
   RawJwt(const RawJwt&) = default;
@@ -73,8 +76,10 @@
   RawJwt& operator=(RawJwt&& other) = default;
 
  private:
-  explicit RawJwt(google::protobuf::Struct json_proto);
+  explicit RawJwt(absl::optional<std::string> type_header,
+                  google::protobuf::Struct json_proto);
   friend class RawJwtBuilder;
+  absl::optional<std::string> type_header_;
   google::protobuf::Struct json_proto_;
 };
 
@@ -82,6 +87,7 @@
  public:
   RawJwtBuilder();
 
+  RawJwtBuilder& SetTypeHeader(absl::string_view type_header);
   RawJwtBuilder& SetIssuer(absl::string_view issuer);
   RawJwtBuilder& SetSubject(absl::string_view subject);
   RawJwtBuilder& AddAudience(absl::string_view audience);
@@ -107,6 +113,7 @@
   RawJwtBuilder& operator=(RawJwtBuilder&& other) = default;
 
  private:
+  absl::optional<std::string> type_header_;
   google::protobuf::Struct json_proto_;
 };
 
diff --git a/cc/jwt/raw_jwt_test.cc b/cc/jwt/raw_jwt_test.cc
index 2253038..5b3daac 100644
--- a/cc/jwt/raw_jwt_test.cc
+++ b/cc/jwt/raw_jwt_test.cc
@@ -31,8 +31,9 @@
 namespace crypto {
 namespace tink {
 
-TEST(RawJwt, GetIssuerSubjectJwtIdOK) {
+TEST(RawJwt, GetTypeHeaderIssuerSubjectJwtIdOK) {
   auto jwt_or = RawJwtBuilder()
+                    .SetTypeHeader("typeHeader")
                     .SetIssuer("issuer")
                     .SetSubject("subject")
                     .SetJwtId("jwt_id")
@@ -40,6 +41,8 @@
   ASSERT_THAT(jwt_or.status(), IsOk());
   auto jwt = jwt_or.ValueOrDie();
 
+  EXPECT_TRUE(jwt.HasTypeHeader());
+  EXPECT_THAT(jwt.GetTypeHeader(), IsOkAndHolds("typeHeader"));
   EXPECT_TRUE(jwt.HasIssuer());
   EXPECT_THAT(jwt.GetIssuer(), IsOkAndHolds("issuer"));
   EXPECT_TRUE(jwt.HasSubject());
@@ -255,6 +258,7 @@
   ASSERT_THAT(jwt_or.status(), IsOk());
   auto jwt = jwt_or.ValueOrDie();
 
+  EXPECT_FALSE(jwt.HasTypeHeader());
   EXPECT_FALSE(jwt.HasIssuer());
   EXPECT_FALSE(jwt.HasSubject());
   EXPECT_FALSE(jwt.HasAudiences());
@@ -275,6 +279,7 @@
   ASSERT_THAT(jwt_or.status(), IsOk());
   auto jwt = jwt_or.ValueOrDie();
 
+  EXPECT_FALSE(jwt.GetTypeHeader().ok());
   EXPECT_FALSE(jwt.GetIssuer().ok());
   EXPECT_FALSE(jwt.GetSubject().ok());
   EXPECT_FALSE(jwt.GetAudiences().ok());
@@ -307,12 +312,14 @@
   EXPECT_THAT(jwt2.GetSubject(), IsOkAndHolds("subject2"));
 }
 
-TEST(RawJwt, FromString) {
-  auto jwt_or = RawJwt::FromString(
+TEST(RawJwt, FromJson) {
+  auto jwt_or = RawJwt::FromJson(
+      absl::nullopt,
       R"({"iss":"issuer", "sub":"subject", "exp":123, "aud":["a1", "a2"]})");
   ASSERT_THAT(jwt_or.status(), IsOk());
   RawJwt jwt = jwt_or.ValueOrDie();
 
+  EXPECT_FALSE(jwt.HasTypeHeader());
   EXPECT_THAT(jwt.GetIssuer(), IsOkAndHolds("issuer"));
   EXPECT_THAT(jwt.GetSubject(), IsOkAndHolds("subject"));
   EXPECT_THAT(jwt.GetExpiration(), IsOkAndHolds(absl::FromUnixSeconds(123)));
@@ -320,8 +327,17 @@
   EXPECT_THAT(jwt.GetAudiences(), IsOkAndHolds(expected_audiences));
 }
 
-TEST(RawJwt, FromStringExpExpiration) {
-  auto jwt_or = RawJwt::FromString(R"({"exp":1e10})");
+TEST(RawJwt, FromJsonWithTypeHeader) {
+  auto jwt_or = RawJwt::FromJson("typeHeader", R"({"iss":"issuer"})");
+  ASSERT_THAT(jwt_or.status(), IsOk());
+  RawJwt jwt = jwt_or.ValueOrDie();
+
+  EXPECT_THAT(jwt.GetTypeHeader(), IsOkAndHolds("typeHeader"));
+  EXPECT_THAT(jwt.GetIssuer(), IsOkAndHolds("issuer"));
+}
+
+TEST(RawJwt, FromJsonExpExpiration) {
+  auto jwt_or = RawJwt::FromJson(absl::nullopt, R"({"exp":1e10})");
   ASSERT_THAT(jwt_or.status(), IsOk());
   RawJwt jwt = jwt_or.ValueOrDie();
 
@@ -329,18 +345,18 @@
               IsOkAndHolds(absl::FromUnixSeconds(10000000000)));
 }
 
-TEST(RawJwt, FromStringExpirationTooLarge) {
-  auto jwt_or = RawJwt::FromString(R"({"exp":1e30})");
+TEST(RawJwt, FromJsonExpirationTooLarge) {
+  auto jwt_or = RawJwt::FromJson(absl::nullopt, R"({"exp":1e30})");
   EXPECT_FALSE(jwt_or.ok());
 }
 
-TEST(RawJwt, FromStringNegativeExpirationAreInvalid) {
-  auto jwt_or = RawJwt::FromString(R"({"exp":-1})");
+TEST(RawJwt, FromJsonNegativeExpirationAreInvalid) {
+  auto jwt_or = RawJwt::FromJson(absl::nullopt, R"({"exp":-1})");
   EXPECT_FALSE(jwt_or.ok());
 }
 
-TEST(RawJwt, FromStringConvertsStringAudIntoListOfStrings) {
-  auto jwt_or = RawJwt::FromString(R"({"aud":"audience"})");
+TEST(RawJwt, FromJsonConvertsStringAudIntoListOfStrings) {
+  auto jwt_or = RawJwt::FromJson(absl::nullopt, R"({"aud":"audience"})");
   ASSERT_THAT(jwt_or.status(), IsOk());
   RawJwt jwt = jwt_or.ValueOrDie();
 
@@ -349,23 +365,23 @@
   EXPECT_THAT(jwt.GetAudiences(), IsOkAndHolds(expected));
 }
 
-TEST(RawJwt, FromStringWithBadRegisteredTypes) {
-  EXPECT_FALSE(RawJwt::FromString(R"({"iss":123})").ok());
-  EXPECT_FALSE(RawJwt::FromString(R"({"sub":123})").ok());
-  EXPECT_FALSE(RawJwt::FromString(R"({"aud":123})").ok());
-  EXPECT_FALSE(RawJwt::FromString(R"({"aud":[]})").ok());
-  EXPECT_FALSE(RawJwt::FromString(R"({"aud":["abc",123]})").ok());
-  EXPECT_FALSE(RawJwt::FromString(R"({"exp":"abc"})").ok());
-  EXPECT_FALSE(RawJwt::FromString(R"({"nbf":"abc"})").ok());
-  EXPECT_FALSE(RawJwt::FromString(R"({"iat":"abc"})").ok());
+TEST(RawJwt, FromJsonWithBadRegisteredTypes) {
+  EXPECT_FALSE(RawJwt::FromJson(absl::nullopt, R"({"iss":123})").ok());
+  EXPECT_FALSE(RawJwt::FromJson(absl::nullopt, R"({"sub":123})").ok());
+  EXPECT_FALSE(RawJwt::FromJson(absl::nullopt, R"({"aud":123})").ok());
+  EXPECT_FALSE(RawJwt::FromJson(absl::nullopt, R"({"aud":[]})").ok());
+  EXPECT_FALSE(RawJwt::FromJson(absl::nullopt, R"({"aud":["abc",123]})").ok());
+  EXPECT_FALSE(RawJwt::FromJson(absl::nullopt, R"({"exp":"abc"})").ok());
+  EXPECT_FALSE(RawJwt::FromJson(absl::nullopt, R"({"nbf":"abc"})").ok());
+  EXPECT_FALSE(RawJwt::FromJson(absl::nullopt, R"({"iat":"abc"})").ok());
 }
 
-TEST(RawJwt, ToString) {
+TEST(RawJwt, GetJsonPayload) {
   auto jwt_or = RawJwtBuilder().SetIssuer("issuer").Build();
   ASSERT_THAT(jwt_or.status(), IsOk());
   auto jwt = jwt_or.ValueOrDie();
 
-  ASSERT_THAT(jwt.ToString(), IsOkAndHolds(R"({"iss":"issuer"})"));
+  ASSERT_THAT(jwt.GetJsonPayload(), IsOkAndHolds(R"({"iss":"issuer"})"));
 }
 
 }  // namespace tink
diff --git a/cc/jwt/verified_jwt.cc b/cc/jwt/verified_jwt.cc
index 708230a..6db74d3 100644
--- a/cc/jwt/verified_jwt.cc
+++ b/cc/jwt/verified_jwt.cc
@@ -30,6 +30,12 @@
   raw_jwt_ = raw_jwt;
 }
 
+bool VerifiedJwt::HasTypeHeader() const { return raw_jwt_.HasTypeHeader(); }
+
+util::StatusOr<std::string> VerifiedJwt::GetTypeHeader() const {
+  return raw_jwt_.GetTypeHeader();
+}
+
 bool VerifiedJwt::HasIssuer() const {
   return raw_jwt_.HasIssuer();
 }
@@ -139,8 +145,8 @@
   return raw_jwt_.CustomClaimNames();
 }
 
-util::StatusOr<std::string> VerifiedJwt::ToString() {
-  return raw_jwt_.ToString();
+util::StatusOr<std::string> VerifiedJwt::GetJsonPayload() {
+  return raw_jwt_.GetJsonPayload();
 }
 
 }  // namespace tink
diff --git a/cc/jwt/verified_jwt.h b/cc/jwt/verified_jwt.h
index 67d187c..bf9c42d 100644
--- a/cc/jwt/verified_jwt.h
+++ b/cc/jwt/verified_jwt.h
@@ -48,6 +48,8 @@
   VerifiedJwt(const VerifiedJwt&) = default;
   VerifiedJwt& operator=(const VerifiedJwt&) = default;
 
+  bool HasTypeHeader() const;
+  util::StatusOr<std::string> GetTypeHeader() const;
   bool HasIssuer() const;
   util::StatusOr<std::string> GetIssuer() const;
   bool HasSubject() const;
@@ -76,7 +78,7 @@
   util::StatusOr<std::string> GetJsonArrayClaim(absl::string_view name) const;
   std::vector<std::string> CustomClaimNames() const;
 
-  util::StatusOr<std::string> ToString();
+  util::StatusOr<std::string> GetJsonPayload();
 
  private:
   VerifiedJwt();
diff --git a/cc/jwt/verified_jwt_test.cc b/cc/jwt/verified_jwt_test.cc
index 7bfe604..57e0e93 100644
--- a/cc/jwt/verified_jwt_test.cc
+++ b/cc/jwt/verified_jwt_test.cc
@@ -76,8 +76,9 @@
                                      validator_builder.Build());
 }
 
-TEST(VerifiedJwt, GetIssuerSubjectJwtIdOK) {
+TEST(VerifiedJwt, GetTypeIssuerSubjectJwtIdOK) {
   auto raw_jwt_or = RawJwtBuilder()
+                        .SetTypeHeader("typeHeader")
                         .SetIssuer("issuer")
                         .SetSubject("subject")
                         .SetJwtId("jwt_id")
@@ -87,6 +88,8 @@
   ASSERT_THAT(verified_jwt_or.status(), IsOk());
   VerifiedJwt jwt = verified_jwt_or.ValueOrDie();
 
+  EXPECT_TRUE(jwt.HasTypeHeader());
+  EXPECT_THAT(jwt.GetTypeHeader(), IsOkAndHolds("typeHeader"));
   EXPECT_TRUE(jwt.HasIssuer());
   EXPECT_THAT(jwt.GetIssuer(), IsOkAndHolds("issuer"));
   EXPECT_TRUE(jwt.HasSubject());
@@ -255,6 +258,7 @@
   ASSERT_THAT(verified_jwt_or.status(), IsOk());
   VerifiedJwt jwt = verified_jwt_or.ValueOrDie();
 
+  EXPECT_FALSE(jwt.HasTypeHeader());
   EXPECT_FALSE(jwt.HasIssuer());
   EXPECT_FALSE(jwt.HasSubject());
   EXPECT_FALSE(jwt.HasAudiences());
@@ -277,6 +281,7 @@
   ASSERT_THAT(verified_jwt_or.status(), IsOk());
   VerifiedJwt jwt = verified_jwt_or.ValueOrDie();
 
+  EXPECT_FALSE(jwt.GetTypeHeader().ok());
   EXPECT_FALSE(jwt.GetIssuer().ok());
   EXPECT_FALSE(jwt.GetSubject().ok());
   EXPECT_FALSE(jwt.GetAudiences().ok());
@@ -292,14 +297,14 @@
   EXPECT_FALSE(jwt.GetJsonArrayClaim("array_claim").ok());
 }
 
-TEST(VerifiedJwt, ToString) {
+TEST(VerifiedJwt, GetJsonPayload) {
   auto raw_jwt_or = RawJwtBuilder().SetIssuer("issuer").Build();
   ASSERT_THAT(raw_jwt_or.status(), IsOk());
   auto verified_jwt_or = CreateVerifiedJwt(raw_jwt_or.ValueOrDie());
   ASSERT_THAT(verified_jwt_or.status(), IsOk());
   VerifiedJwt jwt = verified_jwt_or.ValueOrDie();
 
-  EXPECT_THAT(jwt.ToString(), IsOkAndHolds(R"({"iss":"issuer"})"));
+  EXPECT_THAT(jwt.GetJsonPayload(), IsOkAndHolds(R"({"iss":"issuer"})"));
 }
 
 TEST(VerifiedJwt, MoveMakesCopy) {