Merge pull request #7842 from slavapestov/sil-type-subst-works-so-use-it

SIL type substitution cleanup
diff --git a/include/swift/SIL/SILType.h b/include/swift/SIL/SILType.h
index 5b02ab9..cb46848 100644
--- a/include/swift/SIL/SILType.h
+++ b/include/swift/SIL/SILType.h
@@ -421,9 +421,14 @@
   SILType substGenericArgs(SILModule &M,
                            SubstitutionList Subs) const;
 
+  /// If the original type is generic, pass the signature as genericSig.
+  ///
+  /// If the replacement types are generic, you must push a generic context
+  /// first.
   SILType subst(SILModule &silModule,
                 TypeSubstitutionFn subs,
-                LookupConformanceFn conformances) const;
+                LookupConformanceFn conformances,
+                CanGenericSignature genericSig=CanGenericSignature()) const;
 
   SILType subst(SILModule &silModule, const SubstitutionMap &subs) const;
 
diff --git a/include/swift/SIL/TypeLowering.h b/include/swift/SIL/TypeLowering.h
index e97e6cf..dd42035 100644
--- a/include/swift/SIL/TypeLowering.h
+++ b/include/swift/SIL/TypeLowering.h
@@ -676,13 +676,6 @@
   SILConstantInfo getConstantOverrideInfo(SILDeclRef constant,
                                           SILDeclRef base);
 
-  /// Substitute the given function type so that it implements the
-  /// given substituted type.
-  CanSILFunctionType substFunctionType(CanSILFunctionType origFnType,
-                                 CanAnyFunctionType origLoweredType,
-                                 CanAnyFunctionType substLoweredInterfaceType,
-                         const Optional<ForeignErrorConvention> &foreignError);
-  
   /// Get the empty tuple type as a SILType.
   SILType getEmptyTupleType() {
     return getLoweredType(TupleType::getEmpty(Context));
diff --git a/lib/SIL/SILFunction.cpp b/lib/SIL/SILFunction.cpp
index d1b799d..e6573ea 100644
--- a/lib/SIL/SILFunction.cpp
+++ b/lib/SIL/SILFunction.cpp
@@ -163,109 +163,24 @@
       getGenericEnvironment(), type);
 }
 
-namespace {
-template<typename SubstFn>
-struct SubstDependentSILType
-  : CanTypeVisitor<SubstDependentSILType<SubstFn>, CanType>
-{
-  SILModule &M;
-  SubstFn Subst;
-  
-  SubstDependentSILType(SILModule &M, SubstFn Subst)
-    : M(M), Subst(std::move(Subst))
-  {}
-  
-  using super = CanTypeVisitor<SubstDependentSILType<SubstFn>, CanType>;
-  using super::visit;
-  
-  CanType visitDependentMemberType(CanDependentMemberType t) {
-    // If a dependent member type appears in lowered position, we need to lower
-    // its context substitution against the associated type's abstraction
-    // pattern.
-    CanType astTy = Subst(t);
-    auto origTy = AbstractionPattern::getOpaque();
-    
-    return M.Types.getLoweredType(origTy, astTy)
-      .getSwiftRValueType();
-  }
-  
-  CanType visitTupleType(CanTupleType t) {
-    // Dependent members can appear in lowered position inside tuples.
-    
-    SmallVector<TupleTypeElt, 4> elements;
-    
-    for (auto &elt : t->getElements())
-      elements.push_back(elt.getWithType(visit(CanType(elt.getType()))));
-    
-    return TupleType::get(elements, t->getASTContext())
-      ->getCanonicalType();
-  }
-  
-  CanType visitSILFunctionType(CanSILFunctionType t) {
-    // Dependent members can appear in lowered position inside SIL functions.
-    
-    SmallVector<SILParameterInfo, 4> params;
-    for (auto &param : t->getParameters())
-      params.push_back(param.map([&](CanType pt) -> CanType {
-        return visit(pt);
-      }));
-
-    SmallVector<SILResultInfo, 4> results;
-    for (auto &result : t->getResults())
-      results.push_back(result.map([&](CanType pt) -> CanType {
-        return visit(pt);
-      }));
-    
-    Optional<SILResultInfo> errorResult;
-    if (t->hasErrorResult()) {
-      errorResult = t->getErrorResult().map([&](CanType elt) -> CanType {
-          return visit(elt);
-      });
-    }
-    
-    return SILFunctionType::get(t->getGenericSignature(),
-                                t->getExtInfo(),
-                                t->getCalleeConvention(),
-                                params, results, errorResult,
-                                t->getASTContext());
-  }
-  
-  CanType visitType(CanType t) {
-    // Other types get substituted into context normally.
-    return Subst(t);
-  }
-};
-
-template<typename SubstFn>
-SILType doSubstDependentSILType(SILModule &M,
-                                SubstFn Subst,
-                                SILType t) {
-  CanType result = SubstDependentSILType<SubstFn>(M, std::move(Subst))
-    .visit(t.getSwiftRValueType());
-  return SILType::getPrimitiveType(result, t.getCategory());
-}
-  
-} // end anonymous namespace
-
 SILType SILFunction::mapTypeIntoContext(SILType type) const {
-  return doSubstDependentSILType(getModule(),
-    [&](CanType t) { return mapTypeIntoContext(t)->getCanonicalType(); },
-    type);
+  if (auto *genericEnv = getGenericEnvironment())
+    return genericEnv->mapTypeIntoContext(getModule(), type);
+  return type;
 }
 
 SILType GenericEnvironment::mapTypeIntoContext(SILModule &M,
                                                SILType type) const {
-  return doSubstDependentSILType(M,
-    [&](CanType t) {
-      return mapTypeIntoContext(t)->getCanonicalType();
-    },
-    type);
+  auto genericSig = getGenericSignature()->getCanonicalSignature();
+  return type.subst(M,
+                    QueryInterfaceTypeSubstitutions(this),
+                    LookUpConformanceInSignature(*genericSig),
+                    genericSig);
 }
 
 Type SILFunction::mapTypeOutOfContext(Type type) const {
   return GenericEnvironment::mapTypeOutOfContext(
-      getGenericEnvironment(),
-      type);
+      getGenericEnvironment(), type);
 }
 
 bool SILFunction::isNoReturnFunction() const {
diff --git a/lib/SIL/SILFunctionType.cpp b/lib/SIL/SILFunctionType.cpp
index 37b5a2d..3246a39 100644
--- a/lib/SIL/SILFunctionType.cpp
+++ b/lib/SIL/SILFunctionType.cpp
@@ -1760,268 +1760,6 @@
   return result;
 }
 
-namespace {
-  class SILFunctionTypeSubstituter {
-    TypeConverter &TC;
-    CanSILFunctionType OrigFnType;
-    ArrayRef<SILParameterInfo> OrigParams;
-    ArrayRef<SILResultInfo> OrigResults;
-    unsigned NextOrigParamIndex = 0;
-    unsigned NextOrigResultIndex = 0;
-    SmallVector<SILParameterInfo, 8> SubstParams;
-    SmallVector<SILResultInfo, 8> SubstResults;
-    const Optional<ForeignErrorConvention> &ForeignError;
-
-  public:
-    SILFunctionTypeSubstituter(
-        TypeConverter &TC, CanSILFunctionType origFnType,
-        const Optional<ForeignErrorConvention> &foreignError)
-        : TC(TC), OrigFnType(origFnType),
-          OrigParams(origFnType->getParameters()),
-          OrigResults(origFnType->getResults()), ForeignError(foreignError) {}
-
-    ArrayRef<SILResultInfo> getSubstResults() const {
-      assert(NextOrigResultIndex == OrigResults.size() &&
-             "didn't claim all results?!");
-      return SubstResults;
-    }
-
-    void substResults(AbstractionPattern origType, CanType substType);
-
-    ArrayRef<SILParameterInfo> getSubstParams() const {
-      assert(NextOrigParamIndex == OrigParams.size() &&
-             "didn't claim all parameters?!");
-      return SubstParams;
-    }
-
-    void substInputs(AbstractionPattern origType, CanType substType) {
-      maybeSkipForeignErrorParameter();
-
-      // Decompose tuples.
-      if (origType.isTuple()) {
-        auto substTuple = cast<TupleType>(substType);
-        assert(origType.getNumTupleElements() == substTuple->getNumElements());
-        for (auto i : indices(substTuple.getElementTypes())) {
-          substInputs(origType.getTupleElementType(i),
-                      substTuple.getElementType(i));
-        }
-        return;
-      }
-
-      // Every other type corresponds to a single parameter in the
-      // original signature, since we're dealing with like-uncurried
-      // types and thus don't have to worry about expanding archetypes
-      // to unmaterializable parameter clauses in result function types.
-      auto origParam = claimNextOrigParam();
-
-      // If the type hasn't changed and doesn't rely on context, just use the
-      // original parameter.
-      if (origType.isExactType(substType) &&
-          !origParam.getType()->hasTypeParameter()) {
-        SubstParams.push_back(origParam);
-        return;
-      }
-
-      // Otherwise, lower the substituted type using the abstraction
-      // patterns of the original.
-      auto &substTL = TC.getTypeLowering(origType, substType);
-      auto substConvention = getSubstConvention(origParam.getConvention(),
-                                                substTL.isTrivial());
-      assert(isIndirectFormalParameter(substConvention)
-             || !substTL.isAddressOnly());
-      addSubstParam(substTL.getLoweredType().getSwiftRValueType(),
-                    substConvention);
-    }
-
-  private:
-    void decomposeResult(AbstractionPattern origType, CanType substType);
-
-    SILParameterInfo claimNextOrigParam() {
-      maybeSkipForeignErrorParameter();
-      return OrigParams[NextOrigParamIndex++];
-    }
-
-    SILResultInfo claimNextOrigResult() {
-      return OrigResults[NextOrigResultIndex++];
-    }
-
-    void maybeSkipForeignErrorParameter() {
-      if (!ForeignError ||
-          NextOrigParamIndex != ForeignError->getErrorParameterIndex())
-        return;
-      SubstParams.push_back(OrigParams[NextOrigParamIndex++]);
-    }
-
-    void addSubstParam(CanType type, ParameterConvention conv) {
-      SubstParams.push_back(SILParameterInfo(type, conv));
-    }
-
-    ParameterConvention getSubstConvention(ParameterConvention orig,
-                                           bool isTrivial) {
-      // We use the original convention, except that we have an
-      // invariant that direct trivial parameters are always unowned.
-      switch (orig) {
-      case ParameterConvention::Direct_Owned:
-      case ParameterConvention::Direct_Guaranteed:
-        if (isTrivial) return ParameterConvention::Direct_Unowned;
-        LLVM_FALLTHROUGH;
-      case ParameterConvention::Direct_Unowned:
-      case ParameterConvention::Indirect_Inout:
-      case ParameterConvention::Indirect_InoutAliasable:
-      case ParameterConvention::Indirect_In:
-      case ParameterConvention::Indirect_In_Guaranteed:
-        return orig;
-      }
-      llvm_unreachable("bad parameter convention");
-    }
-
-    ResultConvention getSubstConvention(ResultConvention orig,
-                                        bool isTrivial) {
-      // We use the original convention, except that we have an
-      // invariant that direct trivial results are always unowned.
-      switch (orig) {
-      case ResultConvention::Owned:
-      case ResultConvention::Autoreleased:
-        if (isTrivial) return ResultConvention::Unowned;
-        LLVM_FALLTHROUGH;
-      case ResultConvention::Indirect:
-      case ResultConvention::Unowned:
-      case ResultConvention::UnownedInnerPointer:
-        return orig;
-      }
-      llvm_unreachable("bad parameter convention");
-    }
-  };
-} // end anonymous namespace
-
-void SILFunctionTypeSubstituter::substResults(AbstractionPattern origResultType,
-                                              CanType substResultType) {
-  // Fast path: if the results of the original type are not type-dependent,
-  // we can just copy them over.
-  auto allResults = OrigFnType->getResults();
-  if (std::find_if(allResults.begin(), allResults.end(),
-                   [&](SILResultInfo result) { 
-                     return result.getType()->hasTypeParameter();
-                   }) == allResults.end()) {
-    SubstResults.append(allResults.begin(), allResults.end());
-    return;
-  }
-
-  // Okay, we need to walk the types and re-lower.
-
-  // If we have a foreign-error convention that strips result
-  // optionality, we need to wrap both the original and
-  // substituted types in a level of optionality.
-  if (ForeignError && ForeignError->stripsResultOptionality()) {
-    origResultType =
-      AbstractionPattern::getOptional(origResultType, OTK_Optional);
-    substResultType =
-      OptionalType::get(substResultType)->getCanonicalType();
-  }
-
-  decomposeResult(origResultType, substResultType);
-}
-
-void
-SILFunctionTypeSubstituter::decomposeResult(AbstractionPattern origResultType,
-                                            CanType substResultType) {
-  // If the result is a tuple, we need to expand it.
-  if (origResultType.isTuple()) {
-    auto substResultTupleType = cast<TupleType>(substResultType);
-    for (auto eltIndex : indices(substResultTupleType.getElementTypes())) {
-      auto origEltType = origResultType.getTupleElementType(eltIndex);
-      auto substEltType = substResultTupleType.getElementType(eltIndex);
-      decomposeResult(origEltType, substEltType);
-    }
-    return;
-  }
-
-  // Okay, the result is a single value, which will either be an
-  // indirect result or not.
-
-  // Grab the next result.
-  SILResultInfo origResult = claimNextOrigResult();
-
-  // If substitution is trivial, fast path.
-  if (!origResult.getType()->hasTypeParameter()) {
-    SubstResults.push_back(origResult);
-    return;
-  }
-
-  // Lower the substituted result using the abstraction patterns
-  // of the original result.
-  auto &substResultTL = TC.getTypeLowering(origResultType, substResultType);
-  auto loweredResultTy = substResultTL.getLoweredType().getSwiftRValueType();
-
-  // Return the new type with the old convention.
-  SILResultInfo substResult(loweredResultTy,
-                            getSubstConvention(origResult.getConvention(),
-                                               substResultTL.isTrivial()));
-  SubstResults.push_back(substResult);
-}
-
-/// Apply a substitution to the given SILFunctionType so that it has
-/// the form of the normal SILFunctionType for the substituted type,
-/// except using the original conventions.
-///
-/// This is equivalent to
-///    getLoweredType(origLoweredType,
-///                   substLoweredType).castTo<SILFunctionType>()
-/// except that origFnType's conventions may not correspond to the
-/// standard conventions of the lowered type.
-CanSILFunctionType
-TypeConverter::substFunctionType(CanSILFunctionType origFnType,
-                                 CanAnyFunctionType origLoweredType,
-                                 CanAnyFunctionType substLoweredInterfaceType,
-                         const Optional<ForeignErrorConvention> &foreignError) {
-  // FIXME: is this inefficient now?
-  if (origLoweredType == substLoweredInterfaceType)
-    return origFnType;
-
-  // Use the generic parameters from the substituted type.
-  CanGenericSignature genericSig;
-  if (auto genSubstFn = dyn_cast<GenericFunctionType>(substLoweredInterfaceType))
-    genericSig = genSubstFn.getGenericSignature();
-
-  GenericContextScope scope(*this, genericSig);
-  SILFunctionTypeSubstituter substituter(*this, origFnType, foreignError);
-
-  AbstractionPattern origLoweredPattern(origLoweredType);
-
-  // Map the results.
-  substituter.substResults(origLoweredPattern.getFunctionResultType(),
-                           substLoweredInterfaceType.getResult());
-
-  // Map the error result.  Currently this is never dependent.
-  Optional<SILResultInfo> substErrorResult
-    = origFnType->getOptionalErrorResult();
-  assert(!substErrorResult ||
-         (!substErrorResult->getType()->hasTypeParameter() &&
-          !substErrorResult->getType()->hasArchetype()));
-
-  // Map the inputs.
-  substituter.substInputs(origLoweredPattern.getFunctionInputType(),
-                          substLoweredInterfaceType.getInput());
-
-  // Allow the substituted type to add thick-ness, but not remove it.
-  assert(!origFnType->getExtInfo().hasContext()
-           || substLoweredInterfaceType->getExtInfo().hasContext());
-  assert(substLoweredInterfaceType->getExtInfo().getSILRepresentation()
-           == substLoweredInterfaceType->getExtInfo().getSILRepresentation());
-
-  auto rep = substLoweredInterfaceType->getExtInfo().getSILRepresentation();
-  auto extInfo = origFnType->getExtInfo().withRepresentation(rep);
-
-  // FIXME: Map into archetype context.
-  return SILFunctionType::get(genericSig,
-                              extInfo,
-                              origFnType->getCalleeConvention(),
-                              substituter.getSubstParams(),
-                              substituter.getSubstResults(),
-                              substErrorResult,
-                              Context);
-}
-
 /// Returns the SILParameterInfo for the given declaration's `self` parameter.
 /// `constant` must refer to a method.
 SILParameterInfo TypeConverter::getConstantSelfParameter(SILDeclRef constant) {
@@ -2129,27 +1867,38 @@
     SILModule &TheSILModule;
     TypeSubstitutionFn Subst;
     LookupConformanceFn Conformances;
+    // The signature for the original type.
+    //
+    // Replacement types are lowered with respect to the current
+    // context signature.
+    CanGenericSignature Sig;
 
     ASTContext &getASTContext() { return TheSILModule.getASTContext(); }
 
   public:
     SILTypeSubstituter(SILModule &silModule,
                        TypeSubstitutionFn Subst,
-                       LookupConformanceFn Conformances)
+                       LookupConformanceFn Conformances,
+                       CanGenericSignature Sig)
       : TheSILModule(silModule),
         Subst(Subst),
-        Conformances(Conformances)
+        Conformances(Conformances),
+        Sig(Sig)
     {}
 
     // SIL type lowering only does special things to tuples and functions.
 
-    /// Functions need to preserve their abstraction structure.
-    CanSILFunctionType visitSILFunctionType(CanSILFunctionType origType,
-                                            bool dropGenerics = false)
-    {
-      GenericContextScope scope(TheSILModule.Types,
-                                origType->getGenericSignature());
+    // When a function appears inside of another type, we only perform
+    // substitutions if it does not have a generic signature.
+    CanSILFunctionType visitSILFunctionType(CanSILFunctionType origType) {
+      if (origType->getGenericSignature())
+        return origType;
 
+      return substSILFunctionType(origType);
+    }
+
+    // Entry point for use by SILType::substGenericArgs().
+    CanSILFunctionType substSILFunctionType(CanSILFunctionType origType) {
       SmallVector<SILResultInfo, 8> substResults;
       substResults.reserve(origType->getNumResults());
       for (auto origResult : origType->getResults()) {
@@ -2167,10 +1916,7 @@
         substParams.push_back(subst(origParam));
       }
 
-      auto genericSig
-        = (dropGenerics ? nullptr : origType->getGenericSignature());
-
-      return SILFunctionType::get(genericSig,
+      return SILFunctionType::get(nullptr,
                                   origType->getExtInfo(),
                                   origType->getCalleeConvention(),
                                   substParams, substResults,
@@ -2230,16 +1976,7 @@
     CanType visitType(CanType origType) {
       assert(!isa<AnyFunctionType>(origType));
       assert(!isa<LValueType>(origType) && !isa<InOutType>(origType));
-
-      CanGenericSignature genericSig =
-          TheSILModule.Types.getCurGenericContext();
-      AbstractionPattern abstraction(genericSig, origType);
-
-      assert(TheSILModule.Types.getLoweredType(abstraction, origType)
-               .getSwiftRValueType() == origType);
-
-      CanType substType =
-        origType.subst(Subst, Conformances, None)->getCanonicalType();
+      auto substType = origType.subst(Subst, Conformances)->getCanonicalType();
 
       // If the substitution didn't change anything, we know that the
       // original type was a lowered type, so we're good.
@@ -2247,6 +1984,7 @@
         return origType;
       }
 
+      AbstractionPattern abstraction(Sig, origType);
       return TheSILModule.Types.getLoweredType(abstraction, substType)
                .getSwiftRValueType();
     }
@@ -2255,8 +1993,15 @@
 
 SILType SILType::subst(SILModule &silModule,
                        TypeSubstitutionFn subs,
-                       LookupConformanceFn conformances) const {
-  SILTypeSubstituter STST(silModule, subs, conformances);
+                       LookupConformanceFn conformances,
+                       CanGenericSignature genericSig) const {
+  if (!hasArchetype() && !hasTypeParameter())
+    return *this;
+
+  if (!genericSig)
+    genericSig = silModule.Types.getCurGenericContext();
+  SILTypeSubstituter STST(silModule, subs, conformances,
+                          genericSig);
   return STST.subst(*this);
 }
 
@@ -2303,9 +2048,9 @@
                                   TypeSubstitutionFn subs,
                                   LookupConformanceFn conformances) {
   if (!isPolymorphic()) return CanSILFunctionType(this);
-  SILTypeSubstituter substituter(silModule, subs, conformances);
-  return substituter.visitSILFunctionType(CanSILFunctionType(this),
-                                          /*dropGenerics*/ true);
+  SILTypeSubstituter substituter(silModule, subs, conformances,
+                                 getGenericSignature());
+  return substituter.substSILFunctionType(CanSILFunctionType(this));
 }
 
 /// Fast path for bridging types in a function type without uncurrying.
diff --git a/lib/SIL/SILType.cpp b/lib/SIL/SILType.cpp
index 7fe9cc6..1a154c7 100644
--- a/lib/SIL/SILType.cpp
+++ b/lib/SIL/SILType.cpp
@@ -569,21 +569,13 @@
   
   // Apply generic arguments if the layout is generic.
   if (!getGenericArgs().empty()) {
-    // FIXME: Map the field type into the layout's generic context because
-    // SIL TypeLowering currently expects to lower abstract generic parameters
-    // with a generic context pushed, but nested generic contexts are not
-    // supported by TypeLowering. If TypeLowering were properly
-    // de-contextualized and plumbed through the generic signature, this could
-    // be avoided.
-    auto *env = getLayout()->getGenericSignature()
-      .getGenericEnvironment(*M.getSwiftModule());
-    auto substMap =
-      env->getSubstitutionMap(getGenericArgs());
-    fieldTy = env->mapTypeIntoContext(fieldTy)
-      ->getCanonicalType();
-    
-    fieldTy = SILType::getPrimitiveObjectType(fieldTy)
-      .subst(M, substMap)
+    auto sig = getLayout()->getGenericSignature();
+    auto subs = sig->getSubstitutionMap(getGenericArgs());
+    return SILType::getPrimitiveObjectType(fieldTy)
+      .subst(M,
+             QuerySubstitutionMap{subs},
+             LookUpConformanceInSubstitutionMap(subs),
+             sig)
       .getSwiftRValueType();
   }
   return fieldTy;