blob: 64843e508b380a994fad17938e8a9cb86bbceb43 [file] [log] [blame]
//===--- FillInMissingProtocolStubs.cpp - --------------------------------===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// Implements the "Add methods from protocol(s)" refactoring operation.
//
//===----------------------------------------------------------------------===//
#include "clang/AST/AST.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"
#include "clang/AST/NSAPI.h"
#include "clang/Edit/RefactoringFixits.h"
#include "clang/Lex/Lexer.h"
#include "llvm/ADT/DenseSet.h"
#include <algorithm>
using namespace clang;
using namespace edit;
using namespace fillInMissingProtocolStubs;
// FIXME: This is duplicated with the refactoring lib.
static bool areOnSameLine(SourceLocation Loc1, SourceLocation Loc2,
const SourceManager &SM) {
return !Loc1.isMacroID() && !Loc2.isMacroID() &&
SM.getSpellingLineNumber(Loc1) == SM.getSpellingLineNumber(Loc2);
}
static bool isSemicolonAtLocation(SourceLocation TokenLoc,
const SourceManager &SM,
const LangOptions &LangOpts) {
return Lexer::getSourceText(
CharSourceRange::getTokenRange(TokenLoc, TokenLoc), SM,
LangOpts) == ";";
}
static SourceLocation getLocationOfPrecedingToken(SourceLocation Loc,
const SourceManager &SM,
const LangOptions &LangOpts) {
SourceLocation Result = Loc;
if (Result.isMacroID())
Result = SM.getExpansionLoc(Result);
FileID FID = SM.getFileID(Result);
SourceLocation StartOfFile = SM.getLocForStartOfFile(FID);
if (Loc == StartOfFile)
return SourceLocation();
return Lexer::GetBeginningOfToken(Result.getLocWithOffset(-1), SM, LangOpts);
}
static SourceLocation
getLastLineLocationUnlessItHasOtherTokens(SourceLocation SpellingLoc,
const SourceManager &SM,
const LangOptions &LangOpts) {
assert(!SpellingLoc.isMacroID() && "Expecting a spelling location");
SourceLocation NextTokenLoc =
Lexer::findNextTokenLocationAfterTokenAt(SpellingLoc, SM, LangOpts);
if (NextTokenLoc.isValid()) {
bool IsSameLine = areOnSameLine(SpellingLoc, NextTokenLoc, SM);
if (IsSameLine) {
// Could be a ';' on the same line, so try looking after the ';'
if (isSemicolonAtLocation(NextTokenLoc, SM, LangOpts))
return getLastLineLocationUnlessItHasOtherTokens(NextTokenLoc, SM,
LangOpts);
} else {
SourceLocation LastLoc = SM.translateLineCol(
SM.getFileID(SpellingLoc), SM.getSpellingLineNumber(SpellingLoc),
std::numeric_limits<unsigned>::max());
if (LastLoc.isValid())
return LastLoc;
}
}
return Lexer::getLocForEndOfToken(SpellingLoc, 0, SM, LangOpts);
}
namespace {
struct ProtocolInfo {
/// The lower the priority, the more important this protocol is considered to
/// be. Typically protocols from the class have lower priority than protocols
/// from superclasses.
int Priority;
};
using ProtocolMapTy = llvm::DenseMap<const ObjCProtocolDecl *, ProtocolInfo>;
/// Contains the set of methods from all the protocols that the class conforms
/// to.
class MethodSet {
public:
struct MethodInfo {
const ObjCMethodDecl *M;
const ObjCProtocolDecl *P;
int ProtocolPriority;
enum MethodPresenceKind { IsDeclared = 0x1, IsImplemented = 0x2 };
unsigned PresenceKind = 0;
const ObjCMethodDecl *DeclaredOrImplementedMethod = nullptr;
MethodInfo(const ObjCMethodDecl *M, const ObjCProtocolDecl *P,
int ProtocolPriority)
: M(M), P(P), ProtocolPriority(ProtocolPriority) {}
bool isRequired() const {
return M->getImplementationControl() == ObjCMethodDecl::Required;
}
void markAs(MethodPresenceKind Kind) { PresenceKind |= Kind; }
bool is(MethodPresenceKind Kind) const {
return (PresenceKind & Kind) == Kind;
}
};
private:
llvm::DenseMap<Selector, MethodInfo> InstanceMethods;
llvm::DenseMap<Selector, MethodInfo> ClassMethods;
void markMethodsFrom(const ObjCContainerDecl *Container,
MethodInfo::MethodPresenceKind Kind) {
for (const ObjCMethodDecl *M : Container->methods()) {
auto &Map = M->isInstanceMethod() ? InstanceMethods : ClassMethods;
auto It = Map.find(M->getSelector());
if (It != Map.end()) {
It->second.markAs(Kind);
if (!It->second.DeclaredOrImplementedMethod)
It->second.DeclaredOrImplementedMethod = M;
}
}
}
public:
MethodSet() {}
MethodSet(MethodSet &&Other) = default;
MethodSet &operator=(MethodSet &&Other) = default;
void gatherMethodsFrom(const ObjCProtocolDecl *P, int Priority) {
for (const ObjCMethodDecl *M : P->methods()) {
if (M->isImplicit())
continue;
AvailabilityResult Availability = M->getAvailability();
// Methods that are unavailable or not yet introduced are not considered
// to be required.
if (Availability == AR_NotYetIntroduced || Availability == AR_Unavailable)
continue;
auto &Map = M->isInstanceMethod() ? InstanceMethods : ClassMethods;
Map.insert(std::make_pair(M->getSelector(), MethodInfo(M, P, Priority)));
}
}
void markImplementedMethods(const ObjCContainerDecl *Container) {
assert(isa<ObjCImplDecl>(Container) && "Not an implementation container");
markMethodsFrom(Container, MethodInfo::IsImplemented);
if (const auto *ID = dyn_cast<ObjCImplementationDecl>(Container)) {
const auto *I = ID->getClassInterface();
// Mark declarations from super-classes as implemented to prevent
// redundant implementations.
while ((I = I->getSuperClass()))
markMethodsFrom(I, MethodInfo::IsImplemented);
}
}
void markDeclaredMethods(const ObjCContainerDecl *Container) {
assert(!isa<ObjCImplDecl>(Container) && "Not an interface container");
markMethodsFrom(Container, MethodInfo::IsDeclared);
// Mark declarations from super-classes as declared to prevent redundant
// declarations.
if (const auto *I = dyn_cast<ObjCInterfaceDecl>(Container)) {
while ((I = I->getSuperClass()))
markMethodsFrom(I, MethodInfo::IsDeclared);
}
}
/// Returns true if the given container has missing @required method stubs.
///
/// For @interfaces, this method returns true when the interface is missing
/// a declaration for any @required method in all of the protocols.
/// For @implementations, this method returns true when the implementation is
/// missing an implementation of any @required method in all of the protocols.
bool hasMissingRequiredMethodStubs(const ObjCContainerDecl *Container) {
MethodInfo::MethodPresenceKind Kind = isa<ObjCImplDecl>(Container)
? MethodInfo::IsImplemented
: MethodInfo::IsDeclared;
for (const auto &I : InstanceMethods) {
if (!I.second.isRequired())
continue;
if (!I.second.is(Kind))
return true;
}
for (const auto &I : ClassMethods) {
if (!I.second.isRequired())
continue;
if (!I.second.is(Kind))
return true;
}
return false;
}
std::vector<MethodInfo>
getMissingRequiredMethods(const ObjCContainerDecl *Container) {
MethodInfo::MethodPresenceKind Kind = isa<ObjCImplDecl>(Container)
? MethodInfo::IsImplemented
: MethodInfo::IsDeclared;
std::vector<MethodInfo> Results;
for (const auto &I : InstanceMethods) {
if (!I.second.isRequired())
continue;
if (!I.second.is(Kind))
Results.push_back(I.second);
}
for (const auto &I : ClassMethods) {
if (!I.second.isRequired())
continue;
if (!I.second.is(Kind))
Results.push_back(I.second);
}
return Results;
}
SourceLocation findLocationForInsertionForMethodsFromProtocol(
const ObjCProtocolDecl *P, const ObjCContainerDecl *Container,
const SourceManager &SM, const LangOptions &LangOpts) {
MethodInfo::MethodPresenceKind Kind = isa<ObjCImplDecl>(Container)
? MethodInfo::IsImplemented
: MethodInfo::IsDeclared;
llvm::SmallVector<const ObjCMethodDecl *, 4> MethodsFromProtocolInContainer;
for (const ObjCMethodDecl *M : P->methods()) {
if (M->isImplicit())
continue;
const auto &Map = M->isInstanceMethod() ? InstanceMethods : ClassMethods;
auto It = Map.find(M->getSelector());
if (It == Map.end())
continue;
if (!It->second.is(Kind))
continue;
const ObjCMethodDecl *ContainerMethod =
It->second.DeclaredOrImplementedMethod;
// Ignore method declarations from superclasses.
if (ContainerMethod->getLexicalDeclContext() != Container)
continue;
// This is a method from the given protocol that either declared or
// implemented in the container.
MethodsFromProtocolInContainer.push_back(ContainerMethod);
}
// Find the appropriate source locations by looking
if (MethodsFromProtocolInContainer.empty())
return SourceLocation();
SourceLocation Loc = MethodsFromProtocolInContainer[0]->getLocEnd();
if (Loc.isMacroID())
Loc = SM.getExpansionRange(Loc).second;
for (const ObjCMethodDecl *M :
makeArrayRef(MethodsFromProtocolInContainer).drop_front()) {
SourceLocation EndLoc = M->getLocEnd();
if (EndLoc.isMacroID())
EndLoc = SM.getExpansionRange(EndLoc).second;
if (SM.isBeforeInTranslationUnit(Loc, EndLoc))
Loc = EndLoc;
}
return getLastLineLocationUnlessItHasOtherTokens(Loc, SM, LangOpts);
}
};
} // end anonymous namespace
namespace clang {
namespace edit {
namespace fillInMissingProtocolStubs {
class FillInMissingProtocolStubsImpl {
public:
const ObjCContainerDecl *Container;
MethodSet Methods;
};
} // end namespace fillInMissingProtocolStubsImpl
} // end namespace edit
} // end namespace clang
static void gatherProtocols(
llvm::iterator_range<ObjCList<ObjCProtocolDecl>::iterator> Protocols,
NSAPI &API, ProtocolMapTy &Result, int &Priority) {
for (const ObjCProtocolDecl *P : Protocols) {
// Ignore the 'NSObject' protocol.
if (API.getNSClassId(NSAPI::ClassId_NSObject) == P->getIdentifier())
continue;
gatherProtocols(P->protocols(), API, Result, Priority);
Result.insert(std::make_pair(P, ProtocolInfo{Priority++}));
}
}
static ProtocolMapTy
gatherSuitableClassProtocols(const ObjCInterfaceDecl *I,
const ObjCContainerDecl *Container, NSAPI &API) {
ProtocolMapTy Result;
// The class of interest should use the protocols from extensions when the
// operation is initiated from the @implementation / extension.
auto ClassProtocols =
Container == I ? I->protocols() : I->all_referenced_protocols();
int Priority = 0;
gatherProtocols(ClassProtocols, API, Result, Priority);
while ((I = I->getSuperClass()))
gatherProtocols(I->protocols(), API, Result, Priority);
return Result;
}
static const ObjCContainerDecl *
getInterfaceOrCategory(const ObjCContainerDecl *Container) {
if (const auto *Impl = dyn_cast<ObjCImplementationDecl>(Container))
return Impl->getClassInterface();
if (const auto *CategoryImpl = dyn_cast<ObjCCategoryImplDecl>(Container))
return CategoryImpl->getCategoryDecl();
return Container;
}
static bool initiate(FillInMissingProtocolStubsImpl &Dest, ASTContext &Context,
const ObjCContainerDecl *Container) {
const ObjCContainerDecl *ContainerProtocolSource =
getInterfaceOrCategory(Container);
if (!ContainerProtocolSource)
return false;
// The protocols that are specified in the @interface and/or in the
// superclasses.
ProtocolMapTy Protocols;
NSAPI API(Context);
if (const auto *I = dyn_cast<ObjCInterfaceDecl>(ContainerProtocolSource)) {
if (!I->hasDefinition())
return false;
Protocols = gatherSuitableClassProtocols(I, Container, API);
if (Protocols.empty())
return false;
} else if (const auto *I =
dyn_cast<ObjCCategoryDecl>(ContainerProtocolSource)) {
int Priority = 0;
gatherProtocols(I->protocols(), API, Protocols, Priority);
if (Protocols.empty())
return false;
}
// Check if there are missing @required methods.
for (const auto &P : Protocols)
Dest.Methods.gatherMethodsFrom(P.first, P.second.Priority);
if (isa<ObjCImplDecl>(Container))
Dest.Methods.markImplementedMethods(Container);
else
Dest.Methods.markDeclaredMethods(Container);
Dest.Container = Container;
return true;
}
FillInMissingProtocolStubs::FillInMissingProtocolStubs() {}
FillInMissingProtocolStubs::~FillInMissingProtocolStubs() {}
FillInMissingProtocolStubs::FillInMissingProtocolStubs(
FillInMissingProtocolStubs &&Other)
: Impl(std::move(Other.Impl)) {}
FillInMissingProtocolStubs &FillInMissingProtocolStubs::
operator=(FillInMissingProtocolStubs &&Other) {
Impl = std::move(Other.Impl);
return *this;
}
bool FillInMissingProtocolStubs::initiate(ASTContext &Context,
const ObjCContainerDecl *Container) {
Impl = llvm::make_unique<FillInMissingProtocolStubsImpl>();
if (!::initiate(*Impl, Context, Container))
return true;
return false;
}
bool FillInMissingProtocolStubs::hasMissingRequiredMethodStubs() {
return Impl->Methods.hasMissingRequiredMethodStubs(Impl->Container);
}
static void perform(MethodSet &Methods, const ObjCContainerDecl *Container,
ASTContext &Context,
llvm::function_ref<void(const FixItHint &)> Consumer) {
auto MissingMethods = Methods.getMissingRequiredMethods(Container);
// Sort the methods by grouping them into protocol clusters and then sorting
// them alphabetically within the same protocol.
std::sort(MissingMethods.begin(), MissingMethods.end(),
[](const MethodSet::MethodInfo &A, const MethodSet::MethodInfo &B) {
if (A.ProtocolPriority == B.ProtocolPriority)
return A.M->getSelector().getAsString() <
B.M->getSelector().getAsString();
assert(A.P != B.P && "Same protocols should have same priority");
return A.ProtocolPriority < B.ProtocolPriority;
});
SourceLocation InsertionLoc =
isa<ObjCImplDecl>(Container)
? Container->getLocEnd()
: getLocationOfPrecedingToken(Container->getLocEnd(),
Context.getSourceManager(),
Context.getLangOpts());
if (InsertionLoc.isInvalid())
InsertionLoc = Container->getLocEnd();
PrintingPolicy PP = Context.getPrintingPolicy();
PP.PolishForDeclaration = true;
PP.SuppressStrongLifetime = true;
PP.SuppressLifetimeQualifiers = true;
PP.SuppressUnwrittenScope = true;
std::string EndInsertionOSStr;
llvm::raw_string_ostream EndInsertionOS(EndInsertionOSStr);
std::string InsertionGroupStr;
llvm::raw_string_ostream InsertionGroupOS(InsertionGroupStr);
const ObjCProtocolDecl *CurrentProtocol = nullptr;
SourceLocation CurrentProtocolInsertionLoc;
bool IsImplementation = isa<ObjCImplDecl>(Container);
for (const auto &Method : MissingMethods) {
const ObjCProtocolDecl *P = Method.P;
if (CurrentProtocol != P) {
if (!InsertionGroupOS.str().empty()) {
assert(CurrentProtocolInsertionLoc.isValid());
Consumer(FixItHint::CreateInsertion(CurrentProtocolInsertionLoc,
InsertionGroupOS.str()));
}
InsertionGroupStr.clear();
CurrentProtocol = P;
CurrentProtocolInsertionLoc =
Methods.findLocationForInsertionForMethodsFromProtocol(
P, Container, Context.getSourceManager(), Context.getLangOpts());
}
bool IsInsertingAfterRelatedMethods = CurrentProtocolInsertionLoc.isValid();
raw_ostream &OS =
IsInsertingAfterRelatedMethods ? InsertionGroupOS : EndInsertionOS;
std::string MethodDeclStr;
llvm::raw_string_ostream MethodOS(MethodDeclStr);
Method.M->print(MethodOS, PP);
if (IsInsertingAfterRelatedMethods)
OS << "\n\n";
OS << StringRef(MethodOS.str()).drop_back(); // Drop the ';'
if (IsImplementation)
OS << " { \n <#code#>\n}\n";
else
OS << ";\n";
if (!IsInsertingAfterRelatedMethods)
OS << "\n";
}
if (!InsertionGroupOS.str().empty()) {
assert(CurrentProtocolInsertionLoc.isValid());
Consumer(FixItHint::CreateInsertion(CurrentProtocolInsertionLoc,
InsertionGroupOS.str()));
}
if (!EndInsertionOS.str().empty())
Consumer(FixItHint::CreateInsertion(InsertionLoc, EndInsertionOS.str()));
}
void FillInMissingProtocolStubs::perform(
ASTContext &Context, llvm::function_ref<void(const FixItHint &)> Consumer) {
::perform(Impl->Methods, Impl->Container, Context, Consumer);
}
void fillInMissingProtocolStubs::addMissingProtocolStubs(
ASTContext &Context, const ObjCContainerDecl *Container,
llvm::function_ref<void(const FixItHint &)> Consumer) {
FillInMissingProtocolStubsImpl Impl;
if (initiate(Impl, Context, Container))
perform(Impl.Methods, Impl.Container, Context, Consumer);
}