# Copyright 2019 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module defines KeysetHandle."""

from __future__ import absolute_import
from __future__ import division
from __future__ import google_type_annotations
from __future__ import print_function

import random

from typing import Type, TypeVar

from google.protobuf import message
from tink.proto import tink_pb2
from tink.python.aead import aead
from tink.python.core import keyset_reader as reader
from tink.python.core import keyset_writer as writer
from tink.python.core import primitive_set
from tink.python.core import registry
from tink.python.core import tink_error

P = TypeVar('P')

MAX_INT32 = 2147483647  # = 2^31 - 1


class KeysetHandle(object):
  """A KeysetHandle provides abstracted access to Keyset.

  KeysetHandle limits the exposure of actual protocol buffers that hold
  sensitive key material. This class allows reading and writing encrypted
  keysets.
  """

  def __new__(cls):
    raise tink_error.TinkError(
        ('KeysetHandle cannot be instantiated directly.'))

  def __init__(self, keyset: tink_pb2.Keyset):
    self._keyset = keyset

  @classmethod
  def generate_new(cls, key_template: tink_pb2.KeyTemplate) -> 'KeysetHandle':
    """Return a new KeysetHandle.

    It contains a single fresh key generated according to key_template.

    Args:
      key_template: A tink_pb2.KeyTemplate object.

    Returns:
      A new KeysetHandle.
    """
    keyset = tink_pb2.Keyset()
    key_data = registry.Registry.new_key_data(key_template)
    key_id = _generate_unused_key_id(keyset)
    key = keyset.key.add()
    key.key_data.CopyFrom(key_data)
    key.status = tink_pb2.ENABLED
    key.key_id = key_id
    key.output_prefix_type = key_template.output_prefix_type
    keyset.primary_key_id = key_id
    return cls._create(keyset)

  @classmethod
  def read(cls, keyset_reader: reader.KeysetReader,
           master_key_aead: aead.Aead) -> 'KeysetHandle':
    """Tries to create a KeysetHandle from an encrypted keyset."""
    encrypted_keyset = keyset_reader.read_encrypted()
    _assert_enough_encrypted_key_material(encrypted_keyset)
    return cls._create(_decrypt(encrypted_keyset, master_key_aead))

  @classmethod
  def _create(cls, keyset: tink_pb2.Keyset):
    o = object.__new__(cls)
    o.__init__(keyset)
    return o

  def keyset_info(self) -> tink_pb2.KeysetInfo:
    """Returns the KeysetInfo that doesn't contain actual key material."""
    return _keyset_info(self._keyset)

  def write(self, keyset_writer: writer.KeysetWriter,
            master_key_primitive: aead.Aead) -> None:
    """Serializes, encrypts with master_key_primitive and writes the keyset."""
    encrypted_keyset = _encrypt(self._keyset, master_key_primitive)
    keyset_writer.write_encrypted(encrypted_keyset)

  def public_keyset_handle(self) -> 'KeysetHandle':
    """Returns a new KeysetHandle for the corresponding public keys."""
    public_keyset = tink_pb2.Keyset()
    for key in self._keyset.key:
      public_key = public_keyset.key.add()
      public_key.CopyFrom(key)
      public_key.key_data.CopyFrom(
          registry.Registry.public_key_data(key.key_data))
      _validate_key(public_key)
    public_keyset.primary_key_id = self._keyset.primary_key_id
    return self._create(public_keyset)

  def primitive(self, primitive_class: Type[P]) -> P:
    """Returns a wrapped primitive from this KeysetHandle.

    Uses the KeyManager and the PrimitiveWrapper objects in the global
    registry.Registry
    to create the primitive. This function is the most common way of creating a
    primitive.

    Args:
      primitive_class: The class of the primitive.

    Returns:
      The primitive.
    Raises:
      tink.TinkError if creation of the primitive fails, for example if
      primitive_class cannot be used with this KeysetHandle.
    """
    _validate_keyset(self._keyset)
    pset = primitive_set.PrimitiveSet(primitive_class)
    for key in self._keyset.key:
      if key.status == tink_pb2.ENABLED:
        primitive = registry.Registry.primitive(key.key_data, primitive_class)
        entry = pset.add_primitive(primitive, key)
        if key.key_id == self._keyset.primary_key_id:
          pset.set_primary(entry)
    return registry.Registry.wrap(pset)


def _keyset_info(keyset: tink_pb2.Keyset) -> tink_pb2.KeysetInfo:
  keyset_info = tink_pb2.KeysetInfo(primary_key_id=keyset.primary_key_id)
  for key in keyset.key:
    key_info = keyset_info.key_info.add()
    key_info.type_url = key.key_data.type_url
    key_info.status = key.status
    key_info.output_prefix_type = key.output_prefix_type
    key_info.key_id = key.key_id
  return keyset_info


def _encrypt(keyset: tink_pb2.Keyset,
             master_key_primitive: aead.Aead) -> tink_pb2.EncryptedKeyset:
  """Encrypts a Keyset and returns an EncryptedKeyset."""
  encrypted_keyset = master_key_primitive.encrypt(keyset.SerializeToString(),
                                                  b'')
  # Check if we can decrypt, to detect errors
  try:
    keyset2 = tink_pb2.Keyset.FromString(
        master_key_primitive.decrypt(encrypted_keyset, b''))
    if keyset != keyset2:
      raise tink_error.TinkError('cannot encrypt keyset: %s != %s' %
                                 (keyset, keyset2))
  except message.DecodeError:
    raise tink_error.TinkError('invalid keyset, corrupted key material')
  return tink_pb2.EncryptedKeyset(
      encrypted_keyset=encrypted_keyset, keyset_info=_keyset_info(keyset))


def _decrypt(encrypted_keyset: tink_pb2.EncryptedKeyset,
             master_key_aead: aead.Aead) -> tink_pb2.Keyset:
  """Decrypts an EncryptedKeyset and returns a Keyset."""
  try:
    keyset = tink_pb2.Keyset.FromString(
        master_key_aead.decrypt(encrypted_keyset.encrypted_keyset, b''))
    # Check emptiness here too, in case the encrypted keys unwrapped to nothing?
    _assert_enough_key_material(keyset)
    return keyset
  except message.DecodeError:
    raise tink_error.TinkError('invalid keyset, corrupted key material')


def _validate_keyset(keyset: tink_pb2.Keyset):
  """Raises tink_error.TinkError if keyset is not valid."""
  for key in keyset.key:
    if key.status != tink_pb2.DESTROYED:
      _validate_key(key)
  num_non_destroyed_keys = sum(
      1 for key in keyset.key if key.status != tink_pb2.DESTROYED)
  num_non_public_key_material = sum(
      1 for key in keyset.key
      if key.key_data.key_material_type != tink_pb2.KeyData.ASYMMETRIC_PUBLIC)
  num_primary_keys = sum(
      1 for key in keyset.key
      if key.status == tink_pb2.ENABLED and key.key_id == keyset.primary_key_id)
  if num_non_destroyed_keys == 0:
    raise tink_error.TinkError('empty keyset')
  if num_primary_keys > 1:
    raise tink_error.TinkError('keyset contains multiple primary keys')
  if num_primary_keys == 0 and num_non_public_key_material > 0:
    raise tink_error.TinkError('keyset does not contain a valid primary key')


def _validate_key(key: tink_pb2.Keyset.Key):
  """Raises tink_error.TinkError if key is not valid."""
  if not key.HasField('key_data'):
    raise tink_error.TinkError('key {} has no key data'.format(key.key_id))
  if key.output_prefix_type == tink_pb2.UNKNOWN_PREFIX:
    raise tink_error.TinkError('key {} has unknown prefix'.format(key.key_id))
  if key.status == tink_pb2.UNKNOWN_STATUS:
    raise tink_error.TinkError('key {} has unknown status'.format(key.key_id))


def _assert_no_secret_key_material(keyset: tink_pb2.Keyset):
  for key in keyset.key:
    if (key.key_data.key_material_type == tink_pb2.KeyData.UNKNOWN_KEYMATERIAL
        or key.key_data.key_material_type == tink_pb2.KeyData.SYMMETRIC or
        key.key_data.key_material_type == tink_pb2.KeyData.ASYMMETRIC_PRIVATE):
      raise tink_error.TinkError('keyset contains secret key material')


def _assert_enough_key_material(keyset: tink_pb2.Keyset):
  if not keyset or not keyset.key:
    raise tink_error.TinkError('empty keyset')


def _assert_enough_encrypted_key_material(
    encrypted_keyset: tink_pb2.EncryptedKeyset):
  if not encrypted_keyset or not encrypted_keyset.encrypted_keyset:
    raise tink_error.TinkError('empty keyset')


def _generate_unused_key_id(keyset: tink_pb2.Keyset) -> int:
  while True:
    key_id = random.randint(1, MAX_INT32)
    if key_id not in {key.key_id for key in keyset.key}:
      return key_id
