//===--- DerivedConformanceVectorProtocol.cpp -----------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file implements explicit derivation of the VectorProtocol protocol for
// struct types.
//
//===----------------------------------------------------------------------===//

#include "CodeSynthesis.h"
#include "TypeChecker.h"
#include "swift/AST/Decl.h"
#include "swift/AST/Expr.h"
#include "swift/AST/GenericSignature.h"
#include "swift/AST/Module.h"
#include "swift/AST/ParameterList.h"
#include "swift/AST/Pattern.h"
#include "swift/AST/ProtocolConformance.h"
#include "swift/AST/Stmt.h"
#include "swift/AST/Types.h"
#include "DerivedConformances.h"

using namespace swift;

// Return the protocol requirement with the specified name.
// TODO: Move function to shared place for use with other derived conformances.
static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
  auto lookup = proto->lookupDirect(name);
  // Erase declarations that are not protocol requirements.
  // This is important for removing default implementations of the same name.
  llvm::erase_if(lookup, [](ValueDecl *v) {
    return !isa<ProtocolDecl>(v->getDeclContext()) ||
           !v->isProtocolRequirement();
  });
  assert(lookup.size() == 1 && "Ambiguous protocol requirement");
  return lookup.front();
}

// Return true if given nominal type has a `let` stored with an initial value.
// TODO: Move function to shared place for use with other derived conformances.
static bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) {
  return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) {
    return v->isLet() && v->hasInitialValue();
  });
}

// Return the `VectorSpaceScalar` associated type for the given `ValueDecl` if
// it conforms to `VectorProtocol` in the given context. Otherwise, return
// `nullptr`.
static Type getVectorProtocolVectorSpaceScalarAssocType(
    VarDecl *varDecl, DeclContext *DC) {
  auto &C = varDecl->getASTContext();
  auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol);
  if (!varDecl->hasInterfaceType())
    C.getLazyResolver()->resolveDeclSignature(varDecl);
  if (!varDecl->hasInterfaceType())
    return nullptr;
  auto varType = DC->mapTypeIntoContext(varDecl->getValueInterfaceType());
  auto conf = TypeChecker::conformsToProtocol(varType, vectorProto, DC, None);
  if (!conf)
    return nullptr;
  return conf->getTypeWitnessByName(varType, C.Id_VectorSpaceScalar);
}

// Return the `VectorSpaceScalar` associated type for the given nominal type in
// the given context, or `nullptr` if `VectorSpaceScalar` cannot be derived.
static Type deriveVectorProtocol_VectorSpaceScalar(NominalTypeDecl *nominal,
                                                   DeclContext *DC) {
  auto &C = DC->getASTContext();
  // Nominal type must be a struct. (Zero stored properties is okay.)
  if (!isa<StructDecl>(nominal))
    return nullptr;
  // If all stored properties conform to `VectorProtocol` and have the same
  // `VectorSpaceScalar` associated type, return that `VectorSpaceScalar`
  // associated type. Otherwise, the `VectorSpaceScalar` type cannot be derived.
  Type sameScalarType;
  for (auto member : nominal->getStoredProperties()) {
    if (!member->hasInterfaceType())
      C.getLazyResolver()->resolveDeclSignature(member);
    if (!member->hasInterfaceType())
      return nullptr;
    auto scalarType = getVectorProtocolVectorSpaceScalarAssocType(member, DC);
    // If stored property does not conform to `VectorProtocol`, return nullptr.
    if (!scalarType)
      return nullptr;
    // If same `VectorSpaceScalar` type has not been set, set it for the first
    // time.
    if (!sameScalarType) {
      sameScalarType = scalarType;
      continue;
    }
    // If stored property `VectorSpaceScalar` types do not match, return
    // nullptr.
    if (!scalarType->isEqual(sameScalarType))
      return nullptr;
  }
  return sameScalarType;
}

bool DerivedConformance::canDeriveVectorProtocol(NominalTypeDecl *nominal,
                                                 DeclContext *DC) {
  // Must not have any `let` stored properties with an initial value.
  // - This restriction may be lifted later with support for "true" memberwise
  //   initializers that initialize all stored properties, including initial
  //   value information.
  if (hasLetStoredPropertyWithInitialValue(nominal))
    return false;
  // Must be able to derive `VectorSpaceScalar` associated type.
  return bool(deriveVectorProtocol_VectorSpaceScalar(nominal, DC));
}

// Synthesize body for a `VectorProtocol` method requirement.
static std::pair<BraceStmt *, bool>
deriveBodyVectorProtocol_method(AbstractFunctionDecl *funcDecl,
                                Identifier methodName,
                                Identifier methodParamLabel) {
  auto *parentDC = funcDecl->getParent();
  auto *nominal = parentDC->getSelfNominalTypeDecl();
  auto &C = nominal->getASTContext();

  // Create memberwise initializer: `Nominal.init(...)`.
  auto *memberwiseInitDecl = nominal->getEffectiveMemberwiseInitializer();
  assert(memberwiseInitDecl && "Memberwise initializer must exist");
  auto *initDRE =
      new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true);
  initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
  auto *nominalTypeExpr = TypeExpr::createForDecl(SourceLoc(), nominal,
                                                  funcDecl, /*Implicit*/ true);
  auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, nominalTypeExpr);

  // Get method protocol requirement.
  auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol);
  auto *methodReq = getProtocolRequirement(vectorProto, methodName);

  // Get references to `self` and parameter declarations.
  auto *selfDecl = funcDecl->getImplicitSelfDecl();
  auto *selfDRE =
      new (C) DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
  auto *paramDecl = funcDecl->getParameters()->get(0);
  auto *paramDRE =
      new (C) DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true);

  // Create call expression applying a member method to the parameter.
  // Format: `<member>.method(<parameter>)`.
  // Example: `x.scaled(by: scalar)`.
  auto createMemberMethodCallExpr = [&](VarDecl *member) -> Expr * {
    auto *module = nominal->getModuleContext();
    auto memberType =
        parentDC->mapTypeIntoContext(member->getValueInterfaceType());
    auto confRef = module->lookupConformance(memberType, vectorProto);
    assert(confRef && "Member does not conform to `VectorNumeric`");

    // Get member type's method, e.g. `Member.scaled(by:)`.
    // Use protocol requirement declaration for the method by default: this
    // will be dynamically dispatched.
    ValueDecl *memberMethodDecl = methodReq;
    // If conformance reference is concrete, then use concrete witness
    // declaration for the operator.
    if (confRef->isConcrete()) {
      if (auto *concreteMemberMethodDecl =
              confRef->getConcrete()->getWitnessDecl(methodReq))
        memberMethodDecl = concreteMemberMethodDecl;
      assert(memberMethodDecl);
    }
    assert(memberMethodDecl && "Member method declaration must exist");
    auto memberMethodDRE =
        new (C) DeclRefExpr(memberMethodDecl, DeclNameLoc(), /*Implicit*/ true);
    memberMethodDRE->setFunctionRefKind(FunctionRefKind::SingleApply);

    // Create reference to member method: `x.scaled(by:)`.
    auto memberExpr =
        new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
                              /*Implicit*/ true);
    auto memberMethodExpr =
        new (C) DotSyntaxCallExpr(memberMethodDRE, SourceLoc(), memberExpr);

    // Create expression: `x.scaled(by: scalar)`.
    return CallExpr::createImplicit(C, memberMethodExpr, {paramDRE},
                                    {methodParamLabel});
  };

  // Create array of member method call expressions.
  llvm::SmallVector<Expr *, 2> memberMethodCallExprs;
  llvm::SmallVector<Identifier, 2> memberNames;
  for (auto *member : nominal->getStoredProperties()) {
    memberMethodCallExprs.push_back(createMemberMethodCallExpr(member));
    memberNames.push_back(member->getName());
  }
  // Call memberwise initializer with member method call expressions.
  auto *callExpr =
      CallExpr::createImplicit(C, initExpr, memberMethodCallExprs, memberNames);
  ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true);
  auto *braceStmt =
      BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true);
  return std::pair<BraceStmt *, bool>(braceStmt, false);
}

// Synthesize function declaration for a `VectorProtocol` method requirement.
static ValueDecl *deriveVectorProtocol_method(
    DerivedConformance &derived, Identifier methodBaseName,
    Identifier argumentLabel, Identifier parameterName, Type parameterType,
    Type returnType, AbstractFunctionDecl::BodySynthesizer bodySynthesizer) {
  auto nominal = derived.Nominal;
  auto &TC = derived.TC;
  auto &C = derived.TC.Context;
  auto parentDC = derived.getConformanceContext();

  auto *param =
      new (C) ParamDecl(ParamDecl::Specifier::Default, SourceLoc(), SourceLoc(),
                        argumentLabel, SourceLoc(), parameterName, parentDC);
  param->setInterfaceType(parameterType);
  ParameterList *params = ParameterList::create(C, {param});

  DeclName declName(C, methodBaseName, params);
  auto funcDecl = FuncDecl::create(C, SourceLoc(), StaticSpellingKind::None,
                                   SourceLoc(), declName, SourceLoc(),
                                   /*Throws*/ false, SourceLoc(),
                                   /*GenericParams*/ nullptr, params,
                                   TypeLoc::withoutLoc(returnType), parentDC);
  funcDecl->setImplicit();
  funcDecl->setBodySynthesizer(bodySynthesizer.Fn, bodySynthesizer.Context);

  funcDecl->setGenericSignature(parentDC->getGenericSignatureOfContext());
  funcDecl->computeType();
  funcDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);

  derived.addMembersToConformanceContext({funcDecl});
  C.addSynthesizedDecl(funcDecl);

  // Returned nominal type must define a memberwise initializer.
  // Add memberwise initializer if necessary.
  if (!nominal->getEffectiveMemberwiseInitializer()) {
    // The implicit memberwise constructor must be explicitly created so that
    // it can called in `VectorProtocol` methods. Normally, the memberwise
    // constructor is synthesized during SILGen, which is too late.
    auto *initDecl = createMemberwiseImplicitConstructor(TC, nominal);
    nominal->addMember(initDecl);
    C.addSynthesizedDecl(initDecl);
  }

  return funcDecl;
}

/// Synthesize a method declaration that has the following signture:
///   func {methodBaseName}(
///     {argumentLabel} {parameterName}: VectorSpaceScalar
///   ) -> Self
static ValueDecl *deriveVectorProtocol_unaryMethodOnScalar(
    DerivedConformance &derived, Identifier methodBaseName,
    Identifier argumentLabel, Identifier parameterName) {
  auto &C = derived.TC.Context;
  auto *nominal = derived.Nominal;
  auto *parentDC = derived.getConformanceContext();

  auto selfInterfaceType = parentDC->getDeclaredInterfaceType();
  auto scalarType = deriveVectorProtocol_VectorSpaceScalar(nominal, parentDC)
      ->mapTypeOutOfContext();

  auto bodySynthesizer = [](AbstractFunctionDecl *funcDecl,
                            void *ctx) -> std::pair<BraceStmt *, bool> {
    auto methodNameAndLabel = reinterpret_cast<Identifier *>(ctx);
    return deriveBodyVectorProtocol_method(
        funcDecl, methodNameAndLabel[0], methodNameAndLabel[1]);
  };
  Identifier baseNameAndLabel[2] = {methodBaseName, argumentLabel};
  return deriveVectorProtocol_method(
      derived, methodBaseName, argumentLabel, parameterName, scalarType,
      selfInterfaceType,
      {bodySynthesizer, C.AllocateCopy(baseNameAndLabel).data()});
}

ValueDecl *DerivedConformance::deriveVectorProtocol(ValueDecl *requirement) {
  // Diagnose conformances in disallowed contexts.
  if (checkAndDiagnoseDisallowedContext(requirement))
    return nullptr;
  auto &C = requirement->getASTContext();
  if (requirement->getBaseName() == TC.Context.Id_scaled)
    return deriveVectorProtocol_unaryMethodOnScalar(
        *this, C.Id_scaled, C.Id_by, C.Id_scale);
  if (requirement->getBaseName() == TC.Context.Id_adding)
    return deriveVectorProtocol_unaryMethodOnScalar(
        *this, C.Id_adding, Identifier(), C.Id_x);
  if (requirement->getBaseName() == TC.Context.Id_subtracting)
    return deriveVectorProtocol_unaryMethodOnScalar(
        *this, C.Id_subtracting, Identifier(), C.Id_x);
  TC.diagnose(requirement->getLoc(), diag::broken_vector_protocol_requirement);
  return nullptr;
}

Type DerivedConformance::deriveVectorProtocol(AssociatedTypeDecl *requirement) {
  // Diagnose conformances in disallowed contexts.
  if (checkAndDiagnoseDisallowedContext(requirement))
    return nullptr;
  if (requirement->getBaseName() == TC.Context.Id_VectorSpaceScalar)
    return deriveVectorProtocol_VectorSpaceScalar(
        Nominal, getConformanceContext());
  TC.diagnose(requirement->getLoc(), diag::broken_vector_protocol_requirement);
  return nullptr;
}
