blob: 5ce39552876ab98fb8b2726d516177df9b0bff9b [file] [log] [blame]
//===--- IfSwitchConversion.cpp - ----------------------------------------===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// Implements the "convert to switch" refactoring operation.
//
//===----------------------------------------------------------------------===//
#include "RefactoringOperations.h"
#include "SourceLocationUtilities.h"
#include "clang/AST/AST.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"
#include "clang/AST/RecursiveASTVisitor.h"
using namespace clang;
using namespace clang::tooling;
namespace {
class IfSwitchConversionOperation : public RefactoringOperation {
public:
IfSwitchConversionOperation(const IfStmt *If) : If(If) {}
const Stmt *getTransformedStmt() const override { return If; }
llvm::Expected<RefactoringResult> perform(ASTContext &Context, const Preprocessor &ThePreprocessor,
const RefactoringOptionSet &Options,
unsigned SelectedCandidateIndex) override;
const IfStmt *If;
};
class ValidIfBodyVerifier : public RecursiveASTVisitor<ValidIfBodyVerifier> {
bool CheckBreaks = true;
public:
bool IsValid = true;
bool VisitBreakStmt(const BreakStmt *S) {
if (!CheckBreaks)
return true;
IsValid = false;
return false;
}
bool VisitDefaultStmt(const DefaultStmt *S) {
IsValid = false;
return false;
}
bool VisitCaseStmt(const CaseStmt *S) {
IsValid = false;
return false;
}
// Handle nested loops:
#define TRAVERSE_LOOP(STMT) \
bool Traverse##STMT(STMT *S) { \
bool Prev = CheckBreaks; \
CheckBreaks = false; \
RecursiveASTVisitor::Traverse##STMT(S); \
CheckBreaks = Prev; \
return true; \
}
TRAVERSE_LOOP(ForStmt)
TRAVERSE_LOOP(WhileStmt)
TRAVERSE_LOOP(DoStmt)
TRAVERSE_LOOP(CXXForRangeStmt)
TRAVERSE_LOOP(ObjCForCollectionStmt)
#undef TRAVERSE_LOOP
// Handle switches:
bool TraverseSwitchStmt(SwitchStmt *S) {
// Don't visit the body as 'break'/'case'/'default' are all allowed inside
// switches.
return true;
}
};
} // end anonymous namespace
/// Returns true if any of the if statements in the given if construct have
/// conditions that aren't allowed by the "convert to switch" operation.
static bool checkIfsHaveConditionExpression(const IfStmt *If) {
for (; If; If = dyn_cast_or_null<IfStmt>(If->getElse())) {
if (If->getConditionVariable() || If->getInit() || !If->getCond())
return true;
}
return false;
}
static Optional<std::pair<const Expr *, const Expr *>>
matchBinOp(const Expr *E, BinaryOperator::Opcode Kind) {
const auto *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
if (!BinOp || BinOp->getOpcode() != Kind)
return None;
return std::pair<const Expr *, const Expr *>(
BinOp->getLHS()->IgnoreParenImpCasts(), BinOp->getRHS()->IgnoreParens());
}
typedef llvm::SmallDenseSet<int64_t, 4> RHSValueSet;
/// Returns true if the conditional expression of an 'if' statement allows
/// the "convert to switch" refactoring action.
static bool isConditionValid(const Expr *E, ASTContext &Context,
Optional<llvm::FoldingSetNodeID> &MatchedLHSNodeID,
RHSValueSet &RHSValues) {
auto Equals = matchBinOp(E, BO_EQ);
if (!Equals.hasValue()) {
auto LogicalOr = matchBinOp(E, BO_LOr);
if (!LogicalOr.hasValue())
return false;
return isConditionValid(LogicalOr.getValue().first, Context,
MatchedLHSNodeID, RHSValues) &&
isConditionValid(LogicalOr.getValue().second, Context,
MatchedLHSNodeID, RHSValues);
}
const Expr *LHS = Equals.getValue().first;
const Expr *RHS = Equals.getValue().second;
if (!LHS->getType()->isIntegralOrEnumerationType() ||
!RHS->getType()->isIntegralOrEnumerationType())
return false;
// RHS must be a constant and unique.
llvm::APSInt Value;
if (!RHS->EvaluateAsInt(Value, Context))
return false;
// Only allow constant that fix into 64 bits.
if (Value.getMinSignedBits() > 64 ||
!RHSValues.insert(Value.getExtValue()).second)
return false;
// LHS must be identical to the other LHS expressions.
llvm::FoldingSetNodeID LHSNodeID;
LHS->Profile(LHSNodeID, Context, /*Canonical=*/false);
if (MatchedLHSNodeID.hasValue()) {
if (MatchedLHSNodeID.getValue() != LHSNodeID)
return false;
} else
MatchedLHSNodeID = std::move(LHSNodeID);
return true;
}
RefactoringOperationResult clang::tooling::initiateIfSwitchConversionOperation(
ASTSlice &Slice, ASTContext &Context, SourceLocation Location,
SourceRange SelectionRange, bool CreateOperation) {
// FIXME: Add support for selections.
const auto *If = cast_or_null<IfStmt>(Slice.nearestStmt(Stmt::IfStmtClass));
if (!If)
return None;
// Don't allow if statements without any 'else' or 'else if'.
if (!If->getElse())
return None;
// Don't allow ifs with variable declarations in conditions or C++17
// initializer statements.
if (checkIfsHaveConditionExpression(If))
return None;
// Find the ranges in which initiation can be performed and verify that the
// ifs don't have any initialization expressions or condition variables.
SmallVector<SourceRange, 4> Ranges;
SourceLocation RangeStart = If->getLocStart();
const IfStmt *CurrentIf = If;
const SourceManager &SM = Context.getSourceManager();
while (true) {
const Stmt *Then = CurrentIf->getThen();
Ranges.emplace_back(RangeStart,
findLastLocationOfSourceConstruct(
CurrentIf->getCond()->getLocEnd(), Then, SM));
const auto *Else = CurrentIf->getElse();
if (!Else)
break;
RangeStart =
findFirstLocationOfSourceConstruct(CurrentIf->getElseLoc(), Then, SM);
if (const auto *If = dyn_cast<IfStmt>(Else)) {
CurrentIf = If;
continue;
}
Ranges.emplace_back(RangeStart, findLastLocationOfSourceConstruct(
CurrentIf->getElseLoc(), Else, SM));
break;
}
if (!isLocationInAnyRange(Location, Ranges, SM))
return None;
// Verify that the bodies don't have any 'break'/'default'/'case' statements.
ValidIfBodyVerifier BodyVerifier;
BodyVerifier.TraverseStmt(const_cast<IfStmt *>(If));
if (!BodyVerifier.IsValid)
return RefactoringOperationResult(
"if's body contains a 'break'/'default'/'case' statement");
// FIXME: Use ASTMatchers if possible.
Optional<llvm::FoldingSetNodeID> MatchedLHSNodeID;
RHSValueSet RHSValues;
for (const IfStmt *CurrentIf = If; CurrentIf;
CurrentIf = dyn_cast_or_null<IfStmt>(CurrentIf->getElse())) {
if (!isConditionValid(CurrentIf->getCond(), Context, MatchedLHSNodeID,
RHSValues))
return RefactoringOperationResult("unsupported conditional expression");
}
RefactoringOperationResult Result;
Result.Initiated = true;
if (CreateOperation)
Result.RefactoringOp.reset(new IfSwitchConversionOperation(If));
return Result;
}
/// Returns the first LHS expression in the if's condition.
const Expr *getConditionFirstLHS(const Expr *E) {
auto Equals = matchBinOp(E, BO_EQ);
if (!Equals.hasValue()) {
auto LogicalOr = matchBinOp(E, BO_LOr);
if (!LogicalOr.hasValue())
return nullptr;
return getConditionFirstLHS(LogicalOr.getValue().first);
}
return Equals.getValue().first;
}
/// Gathers all of the RHS operands of the == expressions in the if's condition.
void gatherCaseValues(const Expr *E,
SmallVectorImpl<const Expr *> &CaseValues) {
auto Equals = matchBinOp(E, BO_EQ);
if (Equals.hasValue()) {
CaseValues.push_back(Equals.getValue().second);
return;
}
auto LogicalOr = matchBinOp(E, BO_LOr);
if (!LogicalOr.hasValue())
return;
gatherCaseValues(LogicalOr.getValue().first, CaseValues);
gatherCaseValues(LogicalOr.getValue().second, CaseValues);
}
/// Return true iff the given body should be terminated with a 'break' statement
/// when used inside of a switch.
static bool isBreakNeeded(const Stmt *Body) {
const auto *CS = dyn_cast<CompoundStmt>(Body);
if (!CS)
return !isa<ReturnStmt>(Body);
return CS->body_empty() ? true : isBreakNeeded(CS->body_back());
}
/// Returns true if the given statement declares a variable.
static bool isVarDeclaringStatement(const Stmt *S) {
const auto *DS = dyn_cast<DeclStmt>(S);
if (!DS)
return false;
for (const Decl *D : DS->decls()) {
if (isa<VarDecl>(D))
return true;
}
return false;
}
/// Return true if the body of an if/else if/else needs to be wrapped in braces
/// when put in a switch.
static bool areBracesNeeded(const Stmt *Body) {
const auto *CS = dyn_cast<CompoundStmt>(Body);
if (!CS)
return isVarDeclaringStatement(Body);
for (const Stmt *S : CS->body()) {
if (isVarDeclaringStatement(S))
return true;
}
return false;
}
namespace {
/// Information about the replacement that replaces 'if'/'else' with a 'case' or
/// a 'default'.
struct CasePlacement {
/// The location of the 'case' or 'default'.
SourceLocation CaseStartLoc;
/// True when this 'case' or 'default' statement needs a newline.
bool NeedsNewLine;
/// True if this the first 'if' in the source construct.
bool IsFirstIf;
/// True if we need to insert a 'break' to terminate the previous body
/// before the 'case' or 'default'.
bool IsBreakNeeded;
/// True if we need to insert a '}' before the case.
bool ArePreviousBracesNeeded;
CasePlacement(SourceLocation Loc)
: CaseStartLoc(Loc), NeedsNewLine(false), IsFirstIf(true),
IsBreakNeeded(false), ArePreviousBracesNeeded(false) {}
CasePlacement(const IfStmt *If, const SourceManager &SM,
bool AreBracesNeeded) {
CaseStartLoc = SM.getSpellingLoc(isa<CompoundStmt>(If->getThen())
? If->getThen()->getLocEnd()
: If->getElseLoc());
SourceLocation BodyEndLoc = findLastNonCompoundLocation(If->getThen());
NeedsNewLine = BodyEndLoc.isValid()
? areOnSameLine(CaseStartLoc, BodyEndLoc, SM)
: false;
IsFirstIf = false;
IsBreakNeeded = isBreakNeeded(If->getThen());
ArePreviousBracesNeeded = AreBracesNeeded;
}
std::string getCaseReplacementString(bool IsDefault = false,
bool AreNextBracesNeeded = false) const {
if (IsFirstIf)
return ") {\ncase ";
std::string Result;
llvm::raw_string_ostream OS(Result);
if (NeedsNewLine)
OS << '\n';
if (IsBreakNeeded)
OS << "break;\n";
if (ArePreviousBracesNeeded)
OS << "}\n";
OS << (IsDefault ? "default:" : "case ");
if (IsDefault && AreNextBracesNeeded)
OS << " {";
return std::move(OS.str());
}
};
} // end anonymous namespace
static llvm::Error
addCaseReplacements(const IfStmt *If, const CasePlacement &CaseInfo,
bool &AreBracesNeeded,
std::vector<RefactoringReplacement> &Replacements,
const SourceManager &SM, const LangOptions &LangOpts) {
SmallVector<const Expr *, 2> CaseValues;
gatherCaseValues(If->getCond(), CaseValues);
assert(!CaseValues.empty());
Replacements.emplace_back(
SourceRange(CaseInfo.CaseStartLoc,
SM.getSpellingLoc(CaseValues[0]->getLocStart())),
CaseInfo.getCaseReplacementString());
SourceLocation PrevCaseEnd = getPreciseTokenLocEnd(
SM.getSpellingLoc(CaseValues[0]->getLocEnd()), SM, LangOpts);
for (const Expr *CaseValue : llvm::makeArrayRef(CaseValues).drop_front()) {
Replacements.emplace_back(
SourceRange(PrevCaseEnd, SM.getSpellingLoc(CaseValue->getLocStart())),
StringRef(":\ncase "));
PrevCaseEnd = getPreciseTokenLocEnd(
SM.getSpellingLoc(CaseValue->getLocEnd()), SM, LangOpts);
}
AreBracesNeeded = areBracesNeeded(If->getThen());
StringRef ColonReplacement = AreBracesNeeded ? ": {" : ":";
if (isa<CompoundStmt>(If->getThen())) {
Replacements.emplace_back(
SourceRange(
PrevCaseEnd,
getPreciseTokenLocEnd(
SM.getSpellingLoc(If->getThen()->getLocStart()), SM, LangOpts)),
ColonReplacement);
} else {
// Find the location of the if's ')'
SourceLocation End = findClosingParenLocEnd(
SM.getSpellingLoc(If->getCond()->getLocEnd()), SM, LangOpts);
if (!End.isValid())
return llvm::make_error<RefactoringOperationError>(
"couldn't find the location of ')'");
Replacements.emplace_back(SourceRange(PrevCaseEnd, End), ColonReplacement);
}
return llvm::Error::success();
}
llvm::Expected<RefactoringResult>
IfSwitchConversionOperation::perform(ASTContext &Context,
const Preprocessor &ThePreprocessor,
const RefactoringOptionSet &Options,
unsigned SelectedCandidateIndex) {
std::vector<RefactoringReplacement> Replacements;
const SourceManager &SM = Context.getSourceManager();
const LangOptions &LangOpts = Context.getLangOpts();
// The first if should be replaced with a 'switch' and the text for first LHS
// should be preserved.
const Expr *LHS = getConditionFirstLHS(If->getCond());
assert(LHS && "Missing == expression");
Replacements.emplace_back(SourceRange(SM.getSpellingLoc(If->getLocStart()),
SM.getSpellingLoc(LHS->getLocStart())),
StringRef("switch ("));
bool AreBracesNeeded = false;
if (auto Error = addCaseReplacements(
If, CasePlacement(getPreciseTokenLocEnd(
SM.getSpellingLoc(LHS->getLocEnd()), SM, LangOpts)),
AreBracesNeeded, Replacements, SM, LangOpts))
return std::move(Error);
// Convert the remaining ifs to 'case' statements.
const IfStmt *CurrentIf = If;
while (true) {
const IfStmt *NextIf = dyn_cast_or_null<IfStmt>(CurrentIf->getElse());
if (!NextIf)
break;
if (auto Error = addCaseReplacements(
NextIf, CasePlacement(CurrentIf, SM, AreBracesNeeded),
AreBracesNeeded, Replacements, SM, LangOpts))
return std::move(Error);
CurrentIf = NextIf;
}
// Convert the 'else' to 'default'
if (const Stmt *Else = CurrentIf->getElse()) {
CasePlacement DefaultInfo(CurrentIf, SM, AreBracesNeeded);
AreBracesNeeded = areBracesNeeded(Else);
SourceLocation EndLoc = getPreciseTokenLocEnd(
SM.getSpellingLoc(isa<CompoundStmt>(Else) ? Else->getLocStart()
: CurrentIf->getElseLoc()),
SM, LangOpts);
Replacements.emplace_back(SourceRange(DefaultInfo.CaseStartLoc, EndLoc),
DefaultInfo.getCaseReplacementString(
/*IsDefault=*/true, AreBracesNeeded));
}
// Add the trailing break and one or two '}' if needed.
const Stmt *LastBody =
CurrentIf->getElse() ? CurrentIf->getElse() : CurrentIf->getThen();
bool IsLastBreakNeeded = isBreakNeeded(LastBody);
SourceLocation TerminatingReplacementLoc;
std::string TerminatingReplacement;
llvm::raw_string_ostream OS(TerminatingReplacement);
if (!isa<CompoundStmt>(LastBody)) {
TerminatingReplacementLoc = LastBody->getLocEnd();
// Try to adjust the location in order to preserve any trailing comments on
// the last line of the last body.
if (!TerminatingReplacementLoc.isMacroID())
TerminatingReplacementLoc = getLastLineLocationUnlessItHasOtherTokens(
TerminatingReplacementLoc, SM, LangOpts);
if (IsLastBreakNeeded)
OS << "\nbreak;";
OS << "\n}";
if (AreBracesNeeded)
OS << "\n}";
} else {
TerminatingReplacementLoc = LastBody->getLocEnd();
if (IsLastBreakNeeded)
OS << "break;\n";
if (AreBracesNeeded)
OS << "}\n";
}
if (!OS.str().empty()) {
TerminatingReplacementLoc = SM.getSpellingLoc(TerminatingReplacementLoc);
Replacements.emplace_back(
SourceRange(TerminatingReplacementLoc, TerminatingReplacementLoc),
std::move(OS.str()));
}
return std::move(Replacements);
}