blob: 5d39842627d92edbce28c8cfda7f6154f4198a44 [file] [log] [blame]
//===--- Refactoring.cpp ---------------------------------------------------===//
// This source file is part of the open source project
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
// See for license information
// See for the list of Swift project authors
#include "swift/IDE/Refactoring.h"
#include "swift/IDE/IDERequests.h"
#include "swift/AST/ASTContext.h"
#include "swift/AST/ASTPrinter.h"
#include "swift/AST/Decl.h"
#include "swift/AST/DiagnosticsRefactoring.h"
#include "swift/AST/Expr.h"
#include "swift/AST/NameLookup.h"
#include "swift/AST/Pattern.h"
#include "swift/AST/ProtocolConformance.h"
#include "swift/AST/Stmt.h"
#include "swift/AST/Types.h"
#include "swift/AST/USRGeneration.h"
#include "swift/Basic/Edit.h"
#include "swift/Basic/StringExtras.h"
#include "swift/Frontend/Frontend.h"
#include "swift/Index/Index.h"
#include "swift/Parse/Lexer.h"
#include "swift/Sema/IDETypeChecking.h"
#include "swift/Subsystems.h"
#include "clang/Rewrite/Core/RewriteBuffer.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/SmallPtrSet.h"
using namespace swift;
using namespace swift::ide;
using namespace swift::index;
namespace {
class ContextFinder : public SourceEntityWalker {
SourceFile &SF;
ASTContext &Ctx;
SourceManager &SM;
SourceRange Target;
llvm::function_ref<bool(ASTNode)> IsContext;
SmallVector<ASTNode, 4> AllContexts;
bool contains(ASTNode Enclosing) {
auto Result = SM.rangeContains(Enclosing.getSourceRange(), Target);
if (Result && IsContext(Enclosing))
return Result;
ContextFinder(SourceFile &SF, ASTNode TargetNode,
llvm::function_ref<bool(ASTNode)> IsContext =
[](ASTNode N) { return true; }) :
SF(SF), Ctx(SF.getASTContext()), SM(Ctx.SourceMgr),
Target(TargetNode.getSourceRange()), IsContext(IsContext) {}
ContextFinder(SourceFile &SF, SourceLoc TargetLoc,
llvm::function_ref<bool(ASTNode)> IsContext =
[](ASTNode N) { return true; }) :
SF(SF), Ctx(SF.getASTContext()), SM(Ctx.SourceMgr),
Target(TargetLoc), IsContext(IsContext) {
assert(TargetLoc.isValid() && "Invalid loc to find");
bool walkToDeclPre(Decl *D, CharSourceRange Range) override { return contains(D); }
bool walkToStmtPre(Stmt *S) override { return contains(S); }
bool walkToExprPre(Expr *E) override { return contains(E); }
void resolve() { walk(SF); }
llvm::ArrayRef<ASTNode> getContexts() const {
return llvm::makeArrayRef(AllContexts);
class Renamer {
const SourceManager &SM;
Renamer(const SourceManager &SM, StringRef OldName) : SM(SM), Old(OldName) {}
// Implementor's interface.
virtual void doRenameLabel(CharSourceRange Label,
RefactoringRangeKind RangeKind,
unsigned NameIndex) = 0;
virtual void doRenameBase(CharSourceRange Range,
RefactoringRangeKind RangeKind) = 0;
const DeclNameViewer Old;
virtual ~Renamer() {}
/// Adds a replacement to rename the given base name range
/// \return true if the given range does not match the old name
bool renameBase(CharSourceRange Range, RefactoringRangeKind RangeKind) {
if (stripBackticks(Range).str() != Old.base())
return true;
doRenameBase(Range, RangeKind);
return false;
/// Adds replacements to rename the given label ranges
/// \return true if the label ranges do not match the old name
bool renameLabels(ArrayRef<CharSourceRange> LabelRanges,
Optional<unsigned> FirstTrailingLabel,
LabelRangeType RangeType, bool isCallSite) {
if (isCallSite)
return renameLabelsLenient(LabelRanges, FirstTrailingLabel, RangeType);
ArrayRef<StringRef> OldLabels = Old.args();
if (OldLabels.size() != LabelRanges.size())
return true;
size_t Index = 0;
for (const auto &LabelRange : LabelRanges) {
if (!labelRangeMatches(LabelRange, RangeType, OldLabels[Index]))
return true;
splitAndRenameLabel(LabelRange, RangeType, Index++);
return false;
bool isOperator() const { return Lexer::isOperator(Old.base()); }
/// Returns the range of the (possibly escaped) identifier at the start of
/// \p Range and updates \p IsEscaped to indicate whether it's escaped or not.
CharSourceRange getLeadingIdentifierRange(CharSourceRange Range, bool &IsEscaped) {
assert(Range.isValid() && Range.getByteLength());
IsEscaped = Range.str().front() == '`';
SourceLoc Start = Range.getStart();
if (IsEscaped)
Start = Start.getAdvancedLoc(1);
return Lexer::getCharSourceRangeFromSourceRange(SM, Start);
CharSourceRange stripBackticks(CharSourceRange Range) {
StringRef Content = Range.str();
if (Content.size() < 3 || Content.front() != '`' || Content.back() != '`') {
return Range;
return CharSourceRange(Range.getStart().getAdvancedLoc(1),
Range.getByteLength() - 2);
void splitAndRenameLabel(CharSourceRange Range, LabelRangeType RangeType,
size_t NameIndex) {
switch (RangeType) {
case LabelRangeType::CallArg:
return splitAndRenameCallArg(Range, NameIndex);
case LabelRangeType::Param:
return splitAndRenameParamLabel(Range, NameIndex, /*IsCollapsible=*/true);
case LabelRangeType::NoncollapsibleParam:
return splitAndRenameParamLabel(Range, NameIndex, /*IsCollapsible=*/false);
case LabelRangeType::Selector:
return doRenameLabel(
Range, RefactoringRangeKind::SelectorArgumentLabel, NameIndex);
case LabelRangeType::None:
llvm_unreachable("expected a label range");
void splitAndRenameParamLabel(CharSourceRange Range, size_t NameIndex, bool IsCollapsible) {
// Split parameter range foo([a b]: Int) into decl argument label [a] and
// parameter name [b] or noncollapsible parameter name [b] if IsCollapsible
// is false (as for subscript decls). If we have only foo([a]: Int), then we
// add an empty range for the local name, or for the decl argument label if
// IsCollapsible is false.
StringRef Content = Range.str();
size_t ExternalNameEnd = Content.find_first_of(" \t\n\v\f\r/");
if (ExternalNameEnd == StringRef::npos) { // foo([a]: Int)
if (IsCollapsible) {
doRenameLabel(Range, RefactoringRangeKind::DeclArgumentLabel, NameIndex);
doRenameLabel(CharSourceRange{Range.getEnd(), 0},
RefactoringRangeKind::ParameterName, NameIndex);
} else {
doRenameLabel(CharSourceRange{Range.getStart(), 0},
RefactoringRangeKind::DeclArgumentLabel, NameIndex);
doRenameLabel(Range, RefactoringRangeKind::NoncollapsibleParameterName,
} else { // foo([a b]: Int)
CharSourceRange Ext{Range.getStart(), unsigned(ExternalNameEnd)};
// Note: we consider the leading whitespace part of the parameter name
// if the parameter is collapsible, since if the parameter is collapsed
// into a matching argument label, we want to remove the whitespace too.
// FIXME: handle comments foo(a /*...*/b: Int).
size_t LocalNameStart = Content.find_last_of(" \t\n\v\f\r/");
assert(LocalNameStart != StringRef::npos);
if (!IsCollapsible)
auto LocalLoc = Range.getStart().getAdvancedLocOrInvalid(LocalNameStart);
CharSourceRange Local{LocalLoc, unsigned(Content.size() - LocalNameStart)};
doRenameLabel(Ext, RefactoringRangeKind::DeclArgumentLabel, NameIndex);
if (IsCollapsible) {
doRenameLabel(Local, RefactoringRangeKind::ParameterName, NameIndex);
} else {
doRenameLabel(Local, RefactoringRangeKind::NoncollapsibleParameterName, NameIndex);
void splitAndRenameCallArg(CharSourceRange Range, size_t NameIndex) {
// Split call argument foo([a: ]1) into argument name [a] and the remainder
// [: ].
StringRef Content = Range.str();
size_t Colon = Content.find(':'); // FIXME: leading whitespace?
if (Colon == StringRef::npos) {
doRenameLabel(Range, RefactoringRangeKind::CallArgumentCombined,
// Include any whitespace before the ':'.
assert(Colon == Content.substr(0, Colon).size());
Colon = Content.substr(0, Colon).rtrim().size();
CharSourceRange Arg{Range.getStart(), unsigned(Colon)};
doRenameLabel(Arg, RefactoringRangeKind::CallArgumentLabel, NameIndex);
auto ColonLoc = Range.getStart().getAdvancedLocOrInvalid(Colon);
CharSourceRange Rest{ColonLoc, unsigned(Content.size() - Colon)};
doRenameLabel(Rest, RefactoringRangeKind::CallArgumentColon, NameIndex);
bool labelRangeMatches(CharSourceRange Range, LabelRangeType RangeType, StringRef Expected) {
if (Range.getByteLength()) {
bool IsEscaped = false;
CharSourceRange ExistingLabelRange = getLeadingIdentifierRange(Range, IsEscaped);
StringRef ExistingLabel = ExistingLabelRange.str();
bool IsSingleName = Range == ExistingLabelRange ||
(IsEscaped && Range.getByteLength() == ExistingLabel.size() + 2);
switch (RangeType) {
case LabelRangeType::NoncollapsibleParam:
if (IsSingleName && Expected.empty()) // subscript([x]: Int)
return true;
case LabelRangeType::CallArg:
case LabelRangeType::Param:
case LabelRangeType::Selector:
return ExistingLabel == (Expected.empty() ? "_" : Expected);
case LabelRangeType::None:
llvm_unreachable("Unhandled label range type");
return Expected.empty();
bool renameLabelsLenient(ArrayRef<CharSourceRange> LabelRanges,
Optional<unsigned> FirstTrailingLabel,
LabelRangeType RangeType) {
ArrayRef<StringRef> OldNames = Old.args();
// First, match trailing closure arguments in reverse
if (FirstTrailingLabel) {
auto TrailingLabels = LabelRanges.drop_front(*FirstTrailingLabel);
LabelRanges = LabelRanges.take_front(*FirstTrailingLabel);
for (auto LabelIndex: llvm::reverse(indices(TrailingLabels))) {
CharSourceRange Label = TrailingLabels[LabelIndex];
if (Label.getByteLength()) {
if (OldNames.empty())
return true;
while (!labelRangeMatches(Label, LabelRangeType::Selector,
OldNames.back())) {
if ((OldNames = OldNames.drop_back()).empty())
return true;
splitAndRenameLabel(Label, LabelRangeType::Selector,
OldNames.size() - 1);
OldNames = OldNames.drop_back();
// empty labelled trailing closure label
if (LabelIndex) {
if (OldNames.empty())
return true;
while (!OldNames.back().empty()) {
if ((OldNames = OldNames.drop_back()).empty())
return true;
splitAndRenameLabel(Label, LabelRangeType::Selector,
OldNames.size() - 1);
OldNames = OldNames.drop_back();
// unlabelled trailing closure label
OldNames = OldNames.drop_back();
// Next, match the non-trailing arguments.
size_t NameIndex = 0;
for (CharSourceRange Label : LabelRanges) {
// empty label
if (!Label.getByteLength()) {
// first name pos
if (!NameIndex) {
while (!OldNames[NameIndex].empty()) {
if (++NameIndex >= OldNames.size())
return true;
splitAndRenameLabel(Label, RangeType, NameIndex++);
// other name pos
if (NameIndex >= OldNames.size() || !OldNames[NameIndex].empty()) {
// FIXME: only allow one variadic param
continue; // allow for variadic
splitAndRenameLabel(Label, RangeType, NameIndex++);
// non-empty label
if (NameIndex >= OldNames.size())
return true;
while (!labelRangeMatches(Label, RangeType, OldNames[NameIndex])) {
if (++NameIndex >= OldNames.size())
return true;
splitAndRenameLabel(Label, RangeType, NameIndex++);
return false;
static RegionType getSyntacticRenameRegionType(const ResolvedLoc &Resolved) {
if (Resolved.Node.isNull())
return RegionType::Comment;
if (Expr *E = Resolved.Node.getAsExpr()) {
if (isa<StringLiteralExpr>(E))
return RegionType::String;
if (Resolved.IsInSelector)
return RegionType::Selector;
if (Resolved.IsActive)
return RegionType::ActiveCode;
return RegionType::InactiveCode;
RegionType addSyntacticRenameRanges(const ResolvedLoc &Resolved,
const RenameLoc &Config) {
if (!Resolved.Range.isValid())
return RegionType::Unmatched;
auto RegionKind = getSyntacticRenameRegionType(Resolved);
// Don't include unknown references coming from active code; if we don't
// have a semantic NameUsage for them, then they're likely unrelated symbols
// that happen to have the same name.
if (RegionKind == RegionType::ActiveCode &&
Config.Usage == NameUsage::Unknown)
return RegionType::Unmatched;
assert(Config.Usage != NameUsage::Call || Config.IsFunctionLike);
// FIXME: handle escaped keyword names `init`
bool IsSubscript = Old.base() == "subscript" && Config.IsFunctionLike;
bool IsInit = Old.base() == "init" && Config.IsFunctionLike;
// FIXME: this should only be treated specially for instance methods.
bool IsCallAsFunction = Old.base() == "callAsFunction" &&
bool IsSpecialBase = IsInit || IsSubscript || IsCallAsFunction;
// Filter out non-semantic special basename locations with no labels.
// We've already filtered out those in active code, so these are
// any appearance of just 'init', 'subscript', or 'callAsFunction' in
// strings, comments, and inactive code.
if (IsSpecialBase && (Config.Usage == NameUsage::Unknown &&
Resolved.LabelType == LabelRangeType::None))
return RegionType::Unmatched;
if (!Config.IsFunctionLike || !IsSpecialBase) {
if (renameBase(Resolved.Range, RefactoringRangeKind::BaseName))
return RegionType::Mismatch;
} else if (IsInit || IsCallAsFunction) {
if (renameBase(Resolved.Range, RefactoringRangeKind::KeywordBaseName)) {
// The base name doesn't need to match (but may) for calls, but
// it should for definitions and references.
if (Config.Usage == NameUsage::Definition ||
Config.Usage == NameUsage::Reference) {
return RegionType::Mismatch;
} else if (IsSubscript && Config.Usage == NameUsage::Definition) {
if (renameBase(Resolved.Range, RefactoringRangeKind::KeywordBaseName))
return RegionType::Mismatch;
bool HandleLabels = false;
if (Config.IsFunctionLike) {
switch (Config.Usage) {
case NameUsage::Call:
HandleLabels = !isOperator();
case NameUsage::Definition:
HandleLabels = true;
case NameUsage::Reference:
HandleLabels = Resolved.LabelType == LabelRangeType::Selector || IsSubscript;
case NameUsage::Unknown:
HandleLabels = Resolved.LabelType != LabelRangeType::None;
} else if (Resolved.LabelType != LabelRangeType::None &&
!Config.IsNonProtocolType &&
// FIXME: Workaround for enum case labels until we support them
Config.Usage != NameUsage::Definition) {
return RegionType::Mismatch;
if (HandleLabels) {
bool isCallSite = Config.Usage != NameUsage::Definition &&
(Config.Usage != NameUsage::Reference || IsSubscript) &&
Resolved.LabelType == LabelRangeType::CallArg;
if (renameLabels(Resolved.LabelRanges, Resolved.FirstTrailingLabel,
Resolved.LabelType, isCallSite))
return Config.Usage == NameUsage::Unknown ?
RegionType::Unmatched : RegionType::Mismatch;
return RegionKind;
class RenameRangeDetailCollector : public Renamer {
void doRenameLabel(CharSourceRange Label, RefactoringRangeKind RangeKind,
unsigned NameIndex) override {
Ranges.push_back({Label, RangeKind, NameIndex});
void doRenameBase(CharSourceRange Range,
RefactoringRangeKind RangeKind) override {
Ranges.push_back({Range, RangeKind, None});
RenameRangeDetailCollector(const SourceManager &SM, StringRef OldName)
: Renamer(SM, OldName) {}
std::vector<RenameRangeDetail> Ranges;
class TextReplacementsRenamer : public Renamer {
llvm::StringSet<> &ReplaceTextContext;
std::vector<Replacement> Replacements;
const DeclNameViewer New;
StringRef registerText(StringRef Text) {
if (Text.empty())
return Text;
return ReplaceTextContext.insert(Text).first->getKey();
StringRef getCallArgLabelReplacement(StringRef OldLabelRange,
StringRef NewLabel) {
return NewLabel.empty() ? "" : NewLabel;
StringRef getCallArgColonReplacement(StringRef OldLabelRange,
StringRef NewLabel) {
// Expected OldLabelRange: foo( []3, a[: ]2, b[ : ]3 ...)
// FIXME: Preserve comments: foo([a/*:*/ : /*:*/ ]2, ...)
if (NewLabel.empty())
return "";
if (OldLabelRange.empty())
return ": ";
return registerText(OldLabelRange);
StringRef getCallArgCombinedReplacement(StringRef OldArgLabel,
StringRef NewArgLabel) {
// This case only happens when going from foo([]1) to foo([a: ]1).
if (NewArgLabel.empty())
return "";
return registerText((llvm::Twine(NewArgLabel) + ": ").str());
StringRef getParamNameReplacement(StringRef OldParam, StringRef OldArgLabel,
StringRef NewArgLabel) {
// We don't want to get foo(a a: Int), so drop the parameter name if the
// argument label will match the original name.
// Note: the leading whitespace is part of the parameter range.
if (!NewArgLabel.empty() && OldParam.ltrim() == NewArgLabel)
return "";
// If we're renaming foo(x: Int) to foo(_:), then use the original argument
// label as the parameter name so as to not break references in the body.
if (NewArgLabel.empty() && !OldArgLabel.empty() && OldParam.empty())
return registerText((llvm::Twine(" ") + OldArgLabel).str());
return registerText(OldParam);
StringRef getDeclArgumentLabelReplacement(StringRef OldLabelRange,
StringRef NewArgLabel) {
// OldLabelRange is subscript([]a: Int), foo([a]: Int) or foo([a] b: Int)
if (NewArgLabel.empty())
return OldLabelRange.empty() ? "" : "_";
if (OldLabelRange.empty())
return registerText((llvm::Twine(NewArgLabel) + " ").str());
return registerText(NewArgLabel);
StringRef getReplacementText(StringRef LabelRange,
RefactoringRangeKind RangeKind,
StringRef OldLabel, StringRef NewLabel) {
switch (RangeKind) {
case RefactoringRangeKind::CallArgumentLabel:
return getCallArgLabelReplacement(LabelRange, NewLabel);
case RefactoringRangeKind::CallArgumentColon:
return getCallArgColonReplacement(LabelRange, NewLabel);
case RefactoringRangeKind::CallArgumentCombined:
return getCallArgCombinedReplacement(LabelRange, NewLabel);
case RefactoringRangeKind::ParameterName:
return getParamNameReplacement(LabelRange, OldLabel, NewLabel);
case RefactoringRangeKind::NoncollapsibleParameterName:
return LabelRange;
case RefactoringRangeKind::DeclArgumentLabel:
return getDeclArgumentLabelReplacement(LabelRange, NewLabel);
case RefactoringRangeKind::SelectorArgumentLabel:
return NewLabel.empty() ? "_" : registerText(NewLabel);
llvm_unreachable("label range type is none but there are labels");
void addReplacement(CharSourceRange LabelRange,
RefactoringRangeKind RangeKind, StringRef OldLabel,
StringRef NewLabel) {
StringRef ExistingLabel = LabelRange.str();
StringRef Text =
getReplacementText(ExistingLabel, RangeKind, OldLabel, NewLabel);
if (Text != ExistingLabel)
Replacements.push_back({LabelRange, Text, {}});
void doRenameLabel(CharSourceRange Label, RefactoringRangeKind RangeKind,
unsigned NameIndex) override {
addReplacement(Label, RangeKind, Old.args()[NameIndex],
void doRenameBase(CharSourceRange Range, RefactoringRangeKind) override {
if (Old.base() != New.base())
Replacements.push_back({Range, registerText(New.base()), {}});
TextReplacementsRenamer(const SourceManager &SM, StringRef OldName,
StringRef NewName,
llvm::StringSet<> &ReplaceTextContext)
: Renamer(SM, OldName), ReplaceTextContext(ReplaceTextContext),
New(NewName) {
assert(Old.isValid() && New.isValid());
assert(Old.partsCount() == New.partsCount());
std::vector<Replacement> getReplacements() const {
return std::move(Replacements);
static const ValueDecl *getRelatedSystemDecl(const ValueDecl *VD) {
if (VD->getModuleContext()->isSystemModule())
return VD;
for (auto *Req : VD->getSatisfiedProtocolRequirements()) {
if (Req->getModuleContext()->isSystemModule())
return Req;
for (auto Over = VD->getOverriddenDecl(); Over;
Over = Over->getOverriddenDecl()) {
if (Over->getModuleContext()->isSystemModule())
return Over;
return nullptr;
static Optional<RefactoringKind>
getAvailableRenameForDecl(const ValueDecl *VD,
Optional<RenameRefInfo> RefInfo) {
std::vector<RenameAvailabiliyInfo> Scratch;
for (auto &Info : collectRenameAvailabilityInfo(VD, RefInfo, Scratch)) {
if (Info.AvailableKind == RenameAvailableKind::Available)
return Info.Kind;
return None;
class RenameRangeCollector : public IndexDataConsumer {
RenameRangeCollector(StringRef USR, StringRef newName)
: USR(USR.str()), newName(newName.str()) {}
RenameRangeCollector(const ValueDecl *D, StringRef newName)
: newName(newName.str()) {
llvm::raw_string_ostream OS(USR);
printValueDeclUSR(D, OS);
ArrayRef<RenameLoc> results() const { return locations; }
bool indexLocals() override { return true; }
void failed(StringRef error) override {}
bool startDependency(StringRef name, StringRef path, bool isClangModule, bool isSystem) override {
return true;
bool finishDependency(bool isClangModule) override { return true; }
Action startSourceEntity(const IndexSymbol &symbol) override {
if (symbol.USR == USR) {
if (auto loc = indexSymbolToRenameLoc(symbol, newName)) {
return IndexDataConsumer::Continue;
bool finishSourceEntity(SymbolInfo symInfo, SymbolRoleSet roles) override {
return true;
Optional<RenameLoc> indexSymbolToRenameLoc(const index::IndexSymbol &symbol,
StringRef NewName);
std::string USR;
std::string newName;
StringScratchSpace stringStorage;
std::vector<RenameLoc> locations;
RenameRangeCollector::indexSymbolToRenameLoc(const index::IndexSymbol &symbol,
StringRef newName) {
if (symbol.roles & (unsigned)index::SymbolRole::Implicit) {
return None;
NameUsage usage = NameUsage::Unknown;
if (symbol.roles & (unsigned)index::SymbolRole::Call) {
usage = NameUsage::Call;
} else if (symbol.roles & (unsigned)index::SymbolRole::Definition) {
usage = NameUsage::Definition;
} else if (symbol.roles & (unsigned)index::SymbolRole::Reference) {
usage = NameUsage::Reference;
} else {
llvm_unreachable("unexpected role");
bool isFunctionLike = false;
bool isNonProtocolType = false;
switch (symbol.symInfo.Kind) {
case index::SymbolKind::EnumConstant:
case index::SymbolKind::Function:
case index::SymbolKind::Constructor:
case index::SymbolKind::ConversionFunction:
case index::SymbolKind::InstanceMethod:
case index::SymbolKind::ClassMethod:
case index::SymbolKind::StaticMethod:
isFunctionLike = true;
case index::SymbolKind::Class:
case index::SymbolKind::Enum:
case index::SymbolKind::Struct:
isNonProtocolType = true;
StringRef oldName = stringStorage.copyString(;
return RenameLoc{symbol.line, symbol.column, usage, oldName, newName,
isFunctionLike, isNonProtocolType};
collectSourceFiles(ModuleDecl *MD, llvm::SmallVectorImpl<SourceFile*> &Scratch) {
for (auto Unit : MD->getFiles()) {
if (auto SF = dyn_cast<SourceFile>(Unit)) {
return llvm::makeArrayRef(Scratch);
/// Get the source file that contains the given range and belongs to the module.
SourceFile *getContainingFile(ModuleDecl *M, RangeConfig Range) {
llvm::SmallVector<SourceFile*, 4> Files;
for (auto File : collectSourceFiles(M, Files)) {
if (File->getBufferID()) {
if (File->getBufferID().getValue() == Range.BufferId) {
return File;
return nullptr;
class RefactoringAction {
ModuleDecl *MD;
SourceFile *TheFile;
SourceEditConsumer &EditConsumer;
ASTContext &Ctx;
SourceManager &SM;
DiagnosticEngine DiagEngine;
SourceLoc StartLoc;
StringRef PreferredName;
RefactoringAction(ModuleDecl *MD, RefactoringOptions &Opts,
SourceEditConsumer &EditConsumer,
DiagnosticConsumer &DiagConsumer);
virtual ~RefactoringAction() = default;
virtual bool performChange() = 0;
RefactoringAction(ModuleDecl *MD, RefactoringOptions &Opts,
SourceEditConsumer &EditConsumer,
DiagnosticConsumer &DiagConsumer): MD(MD),
TheFile(getContainingFile(MD, Opts.Range)),
EditConsumer(EditConsumer), Ctx(MD->getASTContext()),
SM(MD->getASTContext().SourceMgr), DiagEngine(SM),
StartLoc(Lexer::getLocForStartOfToken(SM, Opts.Range.getStart(SM))),
PreferredName(Opts.PreferredName) {
/// Different from RangeBasedRefactoringAction, TokenBasedRefactoringAction takes
/// the input of a given token, e.g., a name or an "if" key word. Contextual
/// refactoring kinds can suggest applicable refactorings on that token, e.g.
/// rename or reverse if statement.
class TokenBasedRefactoringAction : public RefactoringAction {
ResolvedCursorInfo CursorInfo;
TokenBasedRefactoringAction(ModuleDecl *MD, RefactoringOptions &Opts,
SourceEditConsumer &EditConsumer,
DiagnosticConsumer &DiagConsumer) :
RefactoringAction(MD, Opts, EditConsumer, DiagConsumer) {
// Resolve the sema token and save it for later use.
CursorInfo = evaluateOrDefault(TheFile->getASTContext().evaluator,
CursorInfoRequest{ CursorInfoOwner(TheFile, StartLoc)},
class RefactoringAction##KIND: public TokenBasedRefactoringAction { \
public: \
RefactoringAction##KIND(ModuleDecl *MD, RefactoringOptions &Opts, \
SourceEditConsumer &EditConsumer, \
DiagnosticConsumer &DiagConsumer) : \
TokenBasedRefactoringAction(MD, Opts, EditConsumer, DiagConsumer) {} \
bool performChange() override; \
static bool isApplicable(ResolvedCursorInfo Tok, DiagnosticEngine &Diag); \
bool isApplicable() { \
return RefactoringAction##KIND::isApplicable(CursorInfo, DiagEngine) ; \
} \
#include "swift/IDE/RefactoringKinds.def"
class RangeBasedRefactoringAction : public RefactoringAction {
ResolvedRangeInfo RangeInfo;
RangeBasedRefactoringAction(ModuleDecl *MD, RefactoringOptions &Opts,
SourceEditConsumer &EditConsumer,
DiagnosticConsumer &DiagConsumer) :
RefactoringAction(MD, Opts, EditConsumer, DiagConsumer),
RangeInfoRequest(RangeInfoOwner(TheFile, Opts.Range.getStart(SM), Opts.Range.getEnd(SM))),
ResolvedRangeInfo())) {}
class RefactoringAction##KIND: public RangeBasedRefactoringAction { \
public: \
RefactoringAction##KIND(ModuleDecl *MD, RefactoringOptions &Opts, \
SourceEditConsumer &EditConsumer, \
DiagnosticConsumer &DiagConsumer) : \
RangeBasedRefactoringAction(MD, Opts, EditConsumer, DiagConsumer) {} \
bool performChange() override; \
static bool isApplicable(ResolvedRangeInfo Info, DiagnosticEngine &Diag); \
bool isApplicable() { \
return RefactoringAction##KIND::isApplicable(RangeInfo, DiagEngine) ; \
} \
#include "swift/IDE/RefactoringKinds.def"
bool RefactoringActionLocalRename::
isApplicable(ResolvedCursorInfo CursorInfo, DiagnosticEngine &Diag) {
if (CursorInfo.Kind != CursorInfoKind::ValueRef)
return false;
Optional<RenameRefInfo> RefInfo;
if (CursorInfo.IsRef)
RefInfo = {CursorInfo.SF, CursorInfo.Loc, CursorInfo.IsKeywordArgument};
auto RenameOp = getAvailableRenameForDecl(CursorInfo.ValueD, RefInfo);
return RenameOp.hasValue() &&
RenameOp.getValue() == RefactoringKind::LocalRename;
static void analyzeRenameScope(ValueDecl *VD, Optional<RenameRefInfo> RefInfo,
DiagnosticEngine &Diags,
llvm::SmallVectorImpl<DeclContext *> &Scopes) {
if (!getAvailableRenameForDecl(VD, RefInfo).hasValue()) {
Diags.diagnose(SourceLoc(), diag::value_decl_no_loc, VD->getName());
auto *Scope = VD->getDeclContext();
// If the context is a top-level code decl, there may be other sibling
// decls that the renamed symbol is visible from
if (isa<TopLevelCodeDecl>(Scope))
Scope = Scope->getParent();
bool RefactoringActionLocalRename::performChange() {
if (StartLoc.isInvalid()) {
DiagEngine.diagnose(SourceLoc(), diag::invalid_location);
return true;
if (!DeclNameViewer(PreferredName).isValid()) {
DiagEngine.diagnose(SourceLoc(), diag::invalid_name, PreferredName);
return true;
if (!TheFile) {
DiagEngine.diagnose(StartLoc, diag::location_module_mismatch,
return true;
CursorInfo = evaluateOrDefault(TheFile->getASTContext().evaluator,
CursorInfoRequest{CursorInfoOwner(TheFile, StartLoc)},
if (CursorInfo.isValid() && CursorInfo.ValueD) {
ValueDecl *VD = CursorInfo.CtorTyRef ? CursorInfo.CtorTyRef : CursorInfo.ValueD;
llvm::SmallVector<DeclContext *, 8> Scopes;
Optional<RenameRefInfo> RefInfo;
if (CursorInfo.IsRef)
RefInfo = {CursorInfo.SF, CursorInfo.Loc, CursorInfo.IsKeywordArgument};
analyzeRenameScope(VD, RefInfo, DiagEngine, Scopes);
if (Scopes.empty())
return true;
RenameRangeCollector rangeCollector(VD, PreferredName);
for (DeclContext *DC : Scopes)
indexDeclContext(DC, rangeCollector);
auto consumers = DiagEngine.takeConsumers();
assert(consumers.size() == 1);
return syntacticRename(TheFile, rangeCollector.results(), EditConsumer,
} else {
DiagEngine.diagnose(StartLoc, diag::unresolved_location);
return true;
StringRef getDefaultPreferredName(RefactoringKind Kind) {
switch(Kind) {
case RefactoringKind::None:
llvm_unreachable("Should be a valid refactoring kind");
case RefactoringKind::GlobalRename:
case RefactoringKind::LocalRename:
return "newName";
case RefactoringKind::ExtractExpr:
case RefactoringKind::ExtractRepeatedExpr:
return "extractedExpr";
case RefactoringKind::ExtractFunction:
return "extractedFunc";
return "";
enum class CannotExtractReason {
class ExtractCheckResult {
bool KnownFailure;
llvm::SmallVector<CannotExtractReason, 2> AllReasons;
ExtractCheckResult(): KnownFailure(true) {}
ExtractCheckResult(ArrayRef<CannotExtractReason> AllReasons):
KnownFailure(false), AllReasons(AllReasons.begin(), AllReasons.end()) {}
bool success() { return success({}); }
bool success(llvm::ArrayRef<CannotExtractReason> ExpectedReasons) {
if (KnownFailure)
return false;
bool Result = true;
// Check if any reasons aren't covered by the list of expected reasons
// provided by the client.
for (auto R: AllReasons) {
Result &= llvm::is_contained(ExpectedReasons, R);
return Result;
/// Check whether a given range can be extracted.
/// Return true on successful condition checking,.
/// Return false on failed conditions.
ExtractCheckResult checkExtractConditions(ResolvedRangeInfo &RangeInfo,
DiagnosticEngine &DiagEngine) {
llvm::SmallVector<CannotExtractReason, 2> AllReasons;
// If any declared declaration is refered out of the given range, return false.
auto Declared = RangeInfo.DeclaredDecls;
auto It = std::find_if(Declared.begin(), Declared.end(),
[](DeclaredDecl DD) { return DD.ReferredAfterRange; });
if (It != Declared.end()) {
return ExtractCheckResult();
// We cannot extract a range with multi entry points.
if (!RangeInfo.HasSingleEntry) {
DiagEngine.diagnose(SourceLoc(), diag::multi_entry_range);
return ExtractCheckResult();
// We cannot extract code that is not sure to exit or not.
if (RangeInfo.exit() == ExitState::Unsure) {
return ExtractCheckResult();
// We cannot extract expressions of l-value type.
if (auto Ty = RangeInfo.getType()) {
if (Ty->hasLValueType() || Ty->is<InOutType>())
return ExtractCheckResult();
// Disallow extracting error type expressions/statements
// FIXME: diagnose what happened?
if (Ty->hasError())
return ExtractCheckResult();
if (Ty->isVoid()) {
// We cannot extract a range with orphaned loop keyword.
switch (RangeInfo.Orphan) {
case swift::ide::OrphanKind::Continue:
DiagEngine.diagnose(SourceLoc(), diag::orphan_loop_keyword, "continue");
return ExtractCheckResult();
case swift::ide::OrphanKind::Break:
DiagEngine.diagnose(SourceLoc(), diag::orphan_loop_keyword, "break");
return ExtractCheckResult();
case swift::ide::OrphanKind::None:
// Guard statement can not be extracted.
if (llvm::any_of(RangeInfo.ContainedNodes,
[](ASTNode N) { return N.isStmt(StmtKind::Guard); })) {
return ExtractCheckResult();
// Disallow extracting certain kinds of statements.
if (RangeInfo.Kind == RangeKind::SingleStatement) {
Stmt *S = RangeInfo.ContainedNodes[0].get<Stmt *>();
// These aren't independent statement.
if (isa<BraceStmt>(S) || isa<CaseStmt>(S))
return ExtractCheckResult();
// Disallow extracting literals.
if (RangeInfo.Kind == RangeKind::SingleExpression) {
Expr *E = RangeInfo.ContainedNodes[0].get<Expr*>();
// Until implementing the performChange() part of extracting trailing
// closures, we disable them for now.
if (isa<AbstractClosureExpr>(E))
return ExtractCheckResult();
if (isa<LiteralExpr>(E))
switch (RangeInfo.RangeContext->getContextKind()) {
case swift::DeclContextKind::Initializer:
case swift::DeclContextKind::SubscriptDecl:
case swift::DeclContextKind::EnumElementDecl:
case swift::DeclContextKind::AbstractFunctionDecl:
case swift::DeclContextKind::AbstractClosureExpr:
case swift::DeclContextKind::TopLevelCodeDecl:
case swift::DeclContextKind::SerializedLocal:
case swift::DeclContextKind::Module:
case swift::DeclContextKind::FileUnit:
case swift::DeclContextKind::GenericTypeDecl:
case swift::DeclContextKind::ExtensionDecl:
return ExtractCheckResult();
return ExtractCheckResult(AllReasons);
bool RefactoringActionExtractFunction::
isApplicable(ResolvedRangeInfo Info, DiagnosticEngine &Diag) {
switch (Info.Kind) {
case RangeKind::PartOfExpression:
case RangeKind::SingleDecl:
case RangeKind::MultiTypeMemberDecl:
case RangeKind::Invalid:
return false;
case RangeKind::SingleExpression:
case RangeKind::SingleStatement:
case RangeKind::MultiStatement: {
return checkExtractConditions(Info, Diag).
llvm_unreachable("unhandled kind");
static StringRef correctNameInternal(ASTContext &Ctx, StringRef Name,
ArrayRef<ValueDecl*> AllVisibles) {
// If we find the collision.
bool FoundCollision = false;
// The suffixes we cannot use by appending to the original given name.
llvm::StringSet<> UsedSuffixes;
for (auto VD : AllVisibles) {
StringRef S = VD->getBaseName().userFacingName();
if (!S.startswith(Name))
StringRef Suffix = S.substr(Name.size());
if (Suffix.empty())
FoundCollision = true;
if (!FoundCollision)
return Name;
// Find the first suffix we can use.
std::string SuffixToUse;
for (unsigned I = 1; ; I ++) {
SuffixToUse = std::to_string(I);
if (UsedSuffixes.count(SuffixToUse) == 0)
return Ctx.getIdentifier((llvm::Twine(Name) + SuffixToUse).str()).str();
static StringRef correctNewDeclName(DeclContext *DC, StringRef Name) {
// Collect all visible decls in the decl context.
llvm::SmallVector<ValueDecl*, 16> AllVisibles;
VectorDeclConsumer Consumer(AllVisibles);
ASTContext &Ctx = DC->getASTContext();
lookupVisibleDecls(Consumer, DC, true);
return correctNameInternal(Ctx, Name, AllVisibles);
static Type sanitizeType(Type Ty) {
// Transform lvalue type to inout type so that we can print it properly.
return Ty.transform([](Type Ty) {
if (Ty->is<LValueType>()) {
return Type(InOutType::get(Ty->getRValueType()->getCanonicalType()));
return Ty;
static SourceLoc
getNewFuncInsertLoc(DeclContext *DC, DeclContext*& InsertToContext) {
if (auto D = DC->getInnermostDeclarationDeclContext()) {
// If extracting from a getter/setter, we should skip both the immediate
// getter/setter function and the individual var decl. The pattern binding
// decl is the position before which we should insert the newly extracted
// function.
if (auto *FD = dyn_cast<AccessorDecl>(D)) {
ValueDecl *SD = FD->getStorage();
switch (SD->getKind()) {
case DeclKind::Var:
if (auto *PBD = cast<VarDecl>(SD)->getParentPatternBinding())
D = PBD;
case DeclKind::Subscript:
D = SD;
auto Result = D->getStartLoc();
// The insert loc should be before every decl attributes.
for (auto Attr : D->getAttrs()) {
auto Loc = Attr->getRangeWithAt().Start;
if (Loc.isValid() &&
Loc.getOpaquePointerValue() < Result.getOpaquePointerValue())
Result = Loc;
// The insert loc should be before the doc comments associated with this decl.
if (!D->getRawComment().Comments.empty()) {
auto Loc = D->getRawComment().Comments.front().Range.getStart();
if (Loc.isValid() &&
Loc.getOpaquePointerValue() < Result.getOpaquePointerValue()) {
Result = Loc;
InsertToContext = D->getDeclContext();
return Result;
return SourceLoc();
static std::vector<NoteRegion>
getNotableRegions(StringRef SourceText, unsigned NameOffset, StringRef Name,
bool IsFunctionLike = false, bool IsNonProtocolType = false) {
auto InputBuffer = llvm::MemoryBuffer::getMemBufferCopy(SourceText,"<extract>");
CompilerInvocation Invocation{};
InputFile("<extract>", true, InputBuffer.get(), file_types::TY_Swift));
Invocation.getFrontendOptions().ModuleName = "extract";
Invocation.getLangOptions().DisablePoundIfEvaluation = true;
auto Instance = std::make_unique<swift::CompilerInstance>();
if (Instance->setup(Invocation))
llvm_unreachable("Failed setup");
unsigned BufferId = Instance->getPrimarySourceFile()->getBufferID().getValue();
SourceManager &SM = Instance->getSourceMgr();
SourceLoc NameLoc = SM.getLocForOffset(BufferId, NameOffset);
auto LineAndCol = SM.getPresumedLineAndColumnForLoc(NameLoc);
UnresolvedLoc UnresoledName{NameLoc, true};
NameMatcher Matcher(*Instance->getPrimarySourceFile());
auto Resolved = Matcher.resolve(llvm::makeArrayRef(UnresoledName), None);
assert(!Resolved.empty() && "Failed to resolve generated func name loc");
RenameLoc RenameConfig = {
LineAndCol.first, LineAndCol.second,
NameUsage::Definition, /*OldName=*/Name, /*NewName=*/"",
IsFunctionLike, IsNonProtocolType
RenameRangeDetailCollector Renamer(SM, Name);
Renamer.addSyntacticRenameRanges(Resolved.back(), RenameConfig);
auto Ranges = Renamer.Ranges;
std::vector<NoteRegion> NoteRegions(Renamer.Ranges.size());
Ranges.begin(), Ranges.end(), NoteRegions.begin(),
[&SM](RenameRangeDetail &Detail) -> NoteRegion {
auto Start = SM.getPresumedLineAndColumnForLoc(Detail.Range.getStart());
auto End = SM.getPresumedLineAndColumnForLoc(Detail.Range.getEnd());
return {Detail.RangeKind, Start.first, Start.second,
End.first, End.second, Detail.Index};
return NoteRegions;
bool RefactoringActionExtractFunction::performChange() {
// Check if the new name is ok.
if (!Lexer::isIdentifier(PreferredName)) {
DiagEngine.diagnose(SourceLoc(), diag::invalid_name, PreferredName);
return true;
DeclContext *DC = RangeInfo.RangeContext;
DeclContext *InsertToDC = nullptr;
SourceLoc InsertLoc = getNewFuncInsertLoc(DC, InsertToDC);
// Complain about no inserting position.
if (InsertLoc.isInvalid()) {
DiagEngine.diagnose(SourceLoc(), diag::no_insert_position);
return true;
// Correct the given name if collision happens.
PreferredName = correctNewDeclName(InsertToDC, PreferredName);
// Collect the paramters to pass down to the new function.
std::vector<ReferencedDecl> Parameters;
for (auto &RD: RangeInfo.ReferencedDecls) {
// If the referenced decl is declared elsewhere, no need to pass as parameter
if (RD.VD->getDeclContext() != DC)
// We don't need to pass down implicitly declared variables, e.g. error in
// a catch block.
if (RD.VD->isImplicit()) {
SourceLoc Loc = RD.VD->getStartLoc();
if (Loc.isValid() &&
SM.isBeforeInBuffer(RangeInfo.ContentRange.getStart(), Loc) &&
SM.isBeforeInBuffer(Loc, RangeInfo.ContentRange.getEnd()))
// If the referenced decl is declared inside the range, no need to pass
// as parameter.
if (RangeInfo.DeclaredDecls.end() !=
std::find_if(RangeInfo.DeclaredDecls.begin(), RangeInfo.DeclaredDecls.end(),
[RD](DeclaredDecl DD) { return RD.VD == DD.VD; }))
// We don't need to pass down self.
if (auto PD = dyn_cast<ParamDecl>(RD.VD)) {
if (PD->isSelfParameter()) {
Parameters.emplace_back(RD.VD, sanitizeType(RD.Ty));
SmallString<64> Buffer;
unsigned FuncBegin = Buffer.size();
unsigned FuncNameOffset;
llvm::raw_svector_ostream OS(Buffer);
if (!InsertToDC->isLocalContext()) {
// Default to be file private.
OS << tok::kw_fileprivate << " ";
// Inherit static if the containing function is.
if (DC->getContextKind() == DeclContextKind::AbstractFunctionDecl) {
if (auto FD = dyn_cast<FuncDecl>(static_cast<AbstractFunctionDecl*>(DC))) {
if (FD->isStatic()) {
OS << tok::kw_static << " ";
OS << tok::kw_func << " ";
FuncNameOffset = Buffer.size() - FuncBegin;
OS << PreferredName;
OS << "(";
for (auto &RD : Parameters) {
OS << "_ " << RD.VD->getBaseName().userFacingName() << ": ";
if (&RD != &Parameters.back())
OS << ", ";
OS << ")";
if (RangeInfo.ThrowingUnhandledError)
OS << " " << tok::kw_throws;
bool InsertedReturnType = false;
if (auto Ty = RangeInfo.getType()) {
// If the type of the range is not void, specify the return type.
if (!Ty->isVoid()) {
OS << " " << tok::arrow << " ";
InsertedReturnType = true;
OS << " {\n";
// Add "return" if the extracted entity is an expression.
if (RangeInfo.Kind == RangeKind::SingleExpression && InsertedReturnType)
OS << tok::kw_return << " ";
OS << RangeInfo.ContentRange.str() << "\n}\n\n";
unsigned FuncEnd = Buffer.size();
unsigned ReplaceBegin = Buffer.size();
unsigned CallNameOffset;
llvm::raw_svector_ostream OS(Buffer);
if (RangeInfo.exit() == ExitState::Positive)
OS << tok::kw_return <<" ";
CallNameOffset = Buffer.size() - ReplaceBegin;
OS << PreferredName << "(";
for (auto &RD : Parameters) {
// Inout argument needs "&".
if (RD.Ty->is<InOutType>())
OS << "&";
OS << RD.VD->getBaseName().userFacingName();
if (&RD != &Parameters.back())
OS << ", ";
OS << ")";
unsigned ReplaceEnd = Buffer.size();
std::string ExtractedFuncName = PreferredName.str() + "(";
for (size_t i = 0; i < Parameters.size(); ++i) {
ExtractedFuncName += "_:";
ExtractedFuncName += ")";
StringRef DeclStr(Buffer.begin() + FuncBegin, FuncEnd - FuncBegin);
auto NotableFuncRegions = getNotableRegions(DeclStr, FuncNameOffset,
StringRef CallStr(Buffer.begin() + ReplaceBegin, ReplaceEnd - ReplaceBegin);
auto NotableCallRegions = getNotableRegions(CallStr, CallNameOffset,
// Insert the new function's declaration.
EditConsumer.accept(SM, InsertLoc, DeclStr, NotableFuncRegions);
// Replace the code to extract with the function call.
EditConsumer.accept(SM, RangeInfo.ContentRange, CallStr, NotableCallRegions);
return false;
class RefactoringActionExtractExprBase {
SourceFile *TheFile;
ResolvedRangeInfo RangeInfo;
DiagnosticEngine &DiagEngine;
const bool ExtractRepeated;
StringRef PreferredName;
SourceEditConsumer &EditConsumer;
ASTContext &Ctx;
SourceManager &SM;
RefactoringActionExtractExprBase(SourceFile *TheFile,
ResolvedRangeInfo RangeInfo,
DiagnosticEngine &DiagEngine,
bool ExtractRepeated,
StringRef PreferredName,
SourceEditConsumer &EditConsumer) :
TheFile(TheFile), RangeInfo(RangeInfo), DiagEngine(DiagEngine),
ExtractRepeated(ExtractRepeated), PreferredName(PreferredName),
EditConsumer(EditConsumer), Ctx(TheFile->getASTContext()),
bool performChange();
/// This is to ensure all decl references in two expressions are identical.
struct ReferenceCollector: public SourceEntityWalker {
llvm::SmallVector<ValueDecl*, 4> References;
ReferenceCollector(Expr *E) { walk(E); }
bool visitDeclReference(ValueDecl *D, CharSourceRange Range,
TypeDecl *CtorTyRef, ExtensionDecl *ExtTyRef,
Type T, ReferenceMetaData Data) override {
return true;
bool operator==(const ReferenceCollector &Other) const {
if (References.size() != Other.References.size())
return false;
return std::equal(References.begin(), References.end(),
struct SimilarExprCollector: public SourceEntityWalker {
SourceManager &SM;
/// The expression under selection.
Expr *SelectedExpr;
llvm::ArrayRef<Token> AllTokens;
llvm::SetVector<Expr*> &Bucket;
/// The tokens included in the expression under selection.
llvm::ArrayRef<Token> SelectedTokens;
/// The referenced decls in the expression under selection.
ReferenceCollector SelectedReferences;
bool compareTokenContent(ArrayRef<Token> Left, ArrayRef<Token> Right) {
if (Left.size() != Right.size())
return false;
return std::equal(Left.begin(), Left.end(), Right.begin(),
[](const Token &L, const Token& R) {
return L.getText() == R.getText();
/// Find all tokens included by an expression.
llvm::ArrayRef<Token> getExprSlice(Expr *E) {
return slice_token_array(AllTokens, E->getStartLoc(), E->getEndLoc());
SimilarExprCollector(SourceManager &SM, Expr* SelectedExpr,
llvm::ArrayRef<Token> AllTokens,
llvm::SetVector<Expr*> &Bucket): SM(SM), SelectedExpr(SelectedExpr),
AllTokens(AllTokens), Bucket(Bucket),
bool walkToExprPre(Expr *E) override {
// We don't extract implicit expressions.
if (E->isImplicit())
return true;
if (E->getKind() != SelectedExpr->getKind())
return true;
// First check the underlying token arrays have the same content.
if (compareTokenContent(getExprSlice(E), SelectedTokens)) {
ReferenceCollector CurrentReferences(E);
// Next, check the referenced decls are same.
if (CurrentReferences == SelectedReferences)
return true;
bool RefactoringActionExtractExprBase::performChange() {
// Check if the new name is ok.
if (!Lexer::isIdentifier(PreferredName)) {
DiagEngine.diagnose(SourceLoc(), diag::invalid_name, PreferredName);
return true;
// Find the enclosing brace statement;
ContextFinder Finder(*TheFile, RangeInfo.ContainedNodes.front(),
[](ASTNode N) { return N.isStmt(StmtKind::Brace); });
auto *SelectedExpr = RangeInfo.ContainedNodes[0].get<Expr*>();
SourceLoc InsertLoc;
llvm::SetVector<ValueDecl*> AllVisibleDecls;
struct DeclCollector: public SourceEntityWalker {
llvm::SetVector<ValueDecl*> &Bucket;
DeclCollector(llvm::SetVector<ValueDecl*> &Bucket): Bucket(Bucket) {}
bool walkToDeclPre(Decl *D, CharSourceRange Range) override {
if (auto *VD = dyn_cast<ValueDecl>(D))
return true;
} Collector(AllVisibleDecls);
llvm::SetVector<Expr*> AllExpressions;
if (!Finder.getContexts().empty()) {
// Get the innermost brace statement.
auto BS = static_cast<BraceStmt*>(Finder.getContexts().back().get<Stmt*>());
// Collect all value decls inside the brace statement.
if (ExtractRepeated) {
// Collect all expressions we are going to extract.
SimilarExprCollector(SM, SelectedExpr,
} else {
assert(!AllExpressions.empty() && "at least one expression is extracted.");
for (auto Ele : BS->getElements()) {
// Find the element that encloses the first expression under extraction.
if (SM.rangeContains(Ele.getSourceRange(),
(*AllExpressions.begin())->getSourceRange())) {
// Insert before the enclosing element.
InsertLoc = Ele.getStartLoc();
// Complain about no inserting position.
if (InsertLoc.isInvalid()) {
DiagEngine.diagnose(SourceLoc(), diag::no_insert_position);
return true;
// Correct name if collision happens.
PreferredName = correctNameInternal(TheFile->getASTContext(), PreferredName,
// Print the type name of this expression.
llvm::SmallString<16> TyBuffer;
// We are not sure about the type of repeated expressions.
if (!ExtractRepeated) {
if (auto Ty = RangeInfo.getType()) {
llvm::raw_svector_ostream OS(TyBuffer);
OS << ": ";
llvm::SmallString<64> DeclBuffer;
llvm::raw_svector_ostream OS(DeclBuffer);
unsigned StartOffset, EndOffset;
OS << tok::kw_let << " ";
StartOffset = DeclBuffer.size();
OS << PreferredName;
EndOffset = DeclBuffer.size();
OS << TyBuffer.str() << " = " << RangeInfo.ContentRange.str() << "\n";
NoteRegion DeclNameRegion{
/*StartLine=*/1, /*StartColumn=*/StartOffset + 1,
/*EndLine=*/1, /*EndColumn=*/EndOffset + 1,
// Perform code change.
EditConsumer.accept(SM, InsertLoc, DeclBuffer.str(), {DeclNameRegion});
// Replace all occurrences of the extracted expression.
for (auto *E : AllExpressions) {
Lexer::getCharSourceRangeFromSourceRange(SM, E->getSourceRange()),
/*StartLine=*/1, /*StartColumn-*/1, /*EndLine=*/1,
/*EndColumn=*/static_cast<unsigned int>(PreferredName.size() + 1),
return false;
bool RefactoringActionExtractExpr::
isApplicable(ResolvedRangeInfo Info, DiagnosticEngine &Diag) {
switch (Info.Kind) {
case RangeKind::SingleExpression:
// We disallow extract literal expression for two reasons:
// (1) since we print the type for extracted expression, the type of a
// literal may print as "int2048" where it is not typically users' choice;
// (2) Extracting one literal provides little value for users.
return checkExtractConditions(Info, Diag).success();
case RangeKind::PartOfExpression:
case RangeKind::SingleDecl:
case RangeKind::MultiTypeMemberDecl:
case RangeKind::SingleStatement:
case RangeKind::MultiStatement:
case RangeKind::Invalid:
return false;
llvm_unreachable("unhandled kind");
bool RefactoringActionExtractExpr::performChange() {
return RefactoringActionExtractExprBase(TheFile, RangeInfo,
DiagEngine, false, PreferredName,
bool RefactoringActionExtractRepeatedExpr::
isApplicable(ResolvedRangeInfo Info, DiagnosticEngine &Diag) {
switch (Info.Kind) {
case RangeKind::SingleExpression:
return checkExtractConditions(Info, Diag).
case RangeKind::PartOfExpression:
case RangeKind::SingleDecl:
case RangeKind::MultiTypeMemberDecl:
case RangeKind::SingleStatement:
case RangeKind::MultiStatement:
case RangeKind::Invalid:
return false;
llvm_unreachable("unhandled kind");
bool RefactoringActionExtractRepeatedExpr::performChange() {
return RefactoringActionExtractExprBase(TheFile, RangeInfo,
DiagEngine, true, PreferredName,
bool RefactoringActionMoveMembersToExtension::isApplicable(
ResolvedRangeInfo Info, DiagnosticEngine &Diag) {
switch (Info.Kind) {
case RangeKind::SingleDecl:
case RangeKind::MultiTypeMemberDecl: {
DeclContext *DC = Info.RangeContext;
// The the common decl context is not a nomial type, we cannot create an
// extension for it
if (!DC || !DC->getInnermostDeclarationDeclContext() ||
return false;
// Members of types not declared at top file level cannot be extracted
// to an extension at top file level
if (DC->getParent()->getContextKind() != DeclContextKind::FileUnit)
return false;
// Check if contained nodes are all allowed decls.
for (auto Node : Info.ContainedNodes) {
Decl *D = Node.dyn_cast<Decl*>();
if (!D)
return false;
if (isa<AccessorDecl>(D) || isa<DestructorDecl>(D) ||
isa<EnumCaseDecl>(D) || isa<EnumElementDecl>(D))
return false;
// We should not move instance variables with storage into the extension
// because they are not allowed to be declared there
for (auto DD : Info.DeclaredDecls) {
if (auto ASD = dyn_cast<AbstractStorageDecl>(DD.VD)) {
// Only disallow storages in the common decl context, allow them in
// any subtypes
if (ASD->hasStorage() && ASD->getDeclContext() == DC) {
return false;
return true;
case RangeKind::SingleExpression:
case RangeKind::PartOfExpression:
case RangeKind::SingleStatement:
case RangeKind::MultiStatement:
case RangeKind::Invalid:
return false;
llvm_unreachable("unhandled kind");
bool RefactoringActionMoveMembersToExtension::performChange() {
DeclContext *DC = RangeInfo.RangeContext;
auto CommonTypeDecl =
assert(CommonTypeDecl && "Not applicable if common parent is no nomial type");
SmallString<64> Buffer;
llvm::raw_svector_ostream OS(Buffer);
OS << "\n\n";
OS << "extension " << CommonTypeDecl->getName() << " {\n";
OS << RangeInfo.ContentRange.str().trim();
OS << "\n}";
// Insert extension after the type declaration
EditConsumer.insertAfter(SM, CommonTypeDecl->getEndLoc(), Buffer);
EditConsumer.remove(SM, RangeInfo.ContentRange);
return false;
namespace {
// A SingleDecl range may not include all decls actually declared in that range:
// a var decl has accessors that aren't included. This will find those missing
// decls.
class FindAllSubDecls : public SourceEntityWalker {
llvm::SmallPtrSetImpl<Decl *> &Found;
FindAllSubDecls(llvm::SmallPtrSetImpl<Decl *> &found)
: Found(found) {}
bool walkToDeclPre(Decl *D, CharSourceRange range) override {
// Record this Decl, and skip its contents if we've already touched it.
if (!Found.insert(D).second)
return false;
if (auto ASD = dyn_cast<AbstractStorageDecl>(D)) {
ASD->visitParsedAccessors([&](AccessorDecl *accessor) {
return true;
bool RefactoringActionReplaceBodiesWithFatalError::isApplicable(
ResolvedRangeInfo Info, DiagnosticEngine &Diag) {
switch (Info.Kind) {
case RangeKind::SingleDecl:
case RangeKind::MultiTypeMemberDecl: {
llvm::SmallPtrSet<Decl *, 16> Found;
for (auto decl : Info.DeclaredDecls) {
for (auto decl : Found) {
auto AFD = dyn_cast<AbstractFunctionDecl>(decl);
if (AFD && !AFD->isImplicit())
return true;
return false;
case RangeKind::SingleExpression:
case RangeKind::PartOfExpression:
case RangeKind::SingleStatement:
case RangeKind::MultiStatement:
case RangeKind::Invalid:
return false;
llvm_unreachable("unhandled kind");
bool RefactoringActionReplaceBodiesWithFatalError::performChange() {
const StringRef replacement = "{\nfatalError()\n}";
llvm::SmallPtrSet<Decl *, 16> Found;
for (auto decl : RangeInfo.DeclaredDecls) {
for (auto decl : Found) {
auto AFD = dyn_cast<AbstractFunctionDecl>(decl);
if (!AFD || AFD->isImplicit())
auto range = AFD->getBodySourceRange();
// If we're in replacement mode (i.e. have an edit consumer), we can
// rewrite the function body.
auto charRange = Lexer::getCharSourceRangeFromSourceRange(SM, range);
EditConsumer.accept(SM, charRange, replacement);
return false;
static std::pair<IfStmt *, IfStmt *>
findCollapseNestedIfTarget(ResolvedCursorInfo CursorInfo) {
if (CursorInfo.Kind != CursorInfoKind::StmtStart)
return {};
// Ensure the statement is 'if' statement. It must not have 'else' clause.
IfStmt *OuterIf = dyn_cast<IfStmt>(CursorInfo.TrailingStmt);
if (!OuterIf)
return {};
if (OuterIf->getElseStmt())
return {};
// The body must contain a sole inner 'if' statement.
auto Body = dyn_cast_or_null<BraceStmt>(OuterIf->getThenStmt());
if (!Body || Body->getNumElements() != 1)
return {};
IfStmt *InnerIf =
dyn_cast_or_null<IfStmt>(Body->getFirstElement().dyn_cast<Stmt *>());
if (!InnerIf)
return {};
// Inner 'if' statement also cannot have 'else' clause.
if (InnerIf->getElseStmt())
return {};
return {OuterIf, InnerIf};
bool RefactoringActionCollapseNestedIfStmt::
isApplicable(ResolvedCursorInfo CursorInfo, DiagnosticEngine &Diag) {
return findCollapseNestedIfTarget(CursorInfo).first;
bool RefactoringActionCollapseNestedIfStmt::performChange() {
auto Target = findCollapseNestedIfTarget(CursorInfo);
if (!Target.first)
return true;
auto OuterIf = Target.first;
auto InnerIf = Target.second;
EditorConsumerInsertStream OS(
EditConsumer, SM,
Lexer::getCharSourceRangeFromSourceRange(SM, OuterIf->getSourceRange()));
OS << tok::kw_if << " ";
// Emit conditions.
bool first = true;
for (auto &C : llvm::concat<StmtConditionElement>(OuterIf->getCond(),
InnerIf->getCond())) {
if (first)
first = false;
OS << ", ";
OS << Lexer::getCharSourceRangeFromSourceRange(SM, C.getSourceRange())
// Emit body.
OS << " ";
OS << Lexer::getCharSourceRangeFromSourceRange(
SM, InnerIf->getThenStmt()->getSourceRange())
return false;
static std::unique_ptr<llvm::SetVector<Expr*>>
findConcatenatedExpressions(ResolvedRangeInfo Info, ASTContext &Ctx) {
Expr *E = nullptr;
switch (Info.Kind) {
case RangeKind::SingleExpression:
E = Info.ContainedNodes[0].get<Expr*>();
case RangeKind::PartOfExpression:
E = Info.CommonExprParent;
return nullptr;
struct StringInterpolationExprFinder: public SourceEntityWalker {
std::unique_ptr<llvm::SetVector<Expr *>> Bucket =
std::make_unique<llvm::SetVector<Expr *>>();
ASTContext &Ctx;
bool IsValidInterpolation = true;
StringInterpolationExprFinder(ASTContext &Ctx): Ctx(Ctx) {}
bool isConcatenationExpr(DeclRefExpr* Expr) {
if (!Expr)
return false;
auto *FD = dyn_cast<FuncDecl>(Expr->getDecl());
if (FD == nullptr || (FD != Ctx.getPlusFunctionOnString() &&
FD != Ctx.getPlusFunctionOnRangeReplaceableCollection())) {
return false;
return true;
bool walkToExprPre(Expr *E) override {
if (E->isImplicit())
return true;
// FIXME: we should have ErrorType instead of null.
if (E->getType().isNull())
return true;
auto ExprType = E->getType()->getNominalOrBoundGenericNominal();
//Only binary concatenation operators should exist in expression
if (E->getKind() == ExprKind::Binary) {
auto *BE = dyn_cast<BinaryExpr>(E);
auto *OperatorDeclRef = BE->getSemanticFn()->getMemberOperatorRef();
if (!(isConcatenationExpr(OperatorDeclRef)
&& ExprType == Ctx.getStringDecl())) {
IsValidInterpolation = false;
return false;
return true;
// Everything that evaluates to string should be gathered.
if (ExprType == Ctx.getStringDecl()) {
return false;
if (auto *DR = dyn_cast<DeclRefExpr>(E)) {
// Checks whether all function references in expression are concatenations.
auto *FD = dyn_cast<FuncDecl>(DR->getDecl());
auto IsConcatenation = isConcatenationExpr(DR);
if (FD && IsConcatenation) {
return false;
// There was non-expected expression, it's not valid interpolation then.
IsValidInterpolation = false;
return false;
} Walker(Ctx);
// There should be two or more expressions to convert.
if (!Walker.IsValidInterpolation || Walker.Bucket->size() < 2)
return nullptr;
return std::move(Walker.Bucket);
static void interpolatedExpressionForm(Expr *E, SourceManager &SM,
llvm::raw_ostream &OS) {
if (auto *Literal = dyn_cast<StringLiteralExpr>(E)) {
OS << Literal->getValue();
auto ExpStr = Lexer::getCharSourceRangeFromSourceRange(SM,
if (isa<InterpolatedStringLiteralExpr>(E)) {
ExpStr.erase(0, 1);
OS << ExpStr;
OS << "\\(" << ExpStr << ")";
bool RefactoringActionConvertStringsConcatenationToInterpolation::
isApplicable(ResolvedRangeInfo Info, DiagnosticEngine &Diag) {
auto RangeContext = Info.RangeContext;
if (RangeContext) {
auto &Ctx = Info.RangeContext->getASTContext();
return findConcatenatedExpressions(Info, Ctx) != nullptr;
return false;
bool RefactoringActionConvertStringsConcatenationToInterpolation::performChange() {
auto Expressions = findConcatenatedExpressions(RangeInfo, Ctx);
if (!Expressions)
return true;
EditorConsumerInsertStream OS(EditConsumer, SM, RangeInfo.ContentRange);
OS << "\"";
for (auto It = Expressions->begin(); It != Expressions->end(); ++It) {
interpolatedExpressionForm(*It, SM, OS);
OS << "\"";
return false;
/// Abstract helper class containing info about an IfExpr
/// that can be expanded into an IfStmt.
class ExpandableTernaryExprInfo {
virtual ~ExpandableTernaryExprInfo() {}
virtual IfExpr *getIf() = 0;
virtual SourceRange getNameRange() = 0;
virtual Type getType() = 0;
virtual bool shouldDeclareNameAndType() {
return !getType().isNull();
virtual bool isValid() {
//Ensure all public properties are non-nil and valid
if (!getIf() || !getNameRange().isValid())
return false;
if (shouldDeclareNameAndType() && getType().isNull())
return false;
return true; //valid
CharSourceRange getNameCharRange(const SourceManager &SM) {
return Lexer::getCharSourceRangeFromSourceRange(SM, getNameRange());
/// Concrete subclass containing info about an AssignExpr
/// where the source is the expandable IfExpr.
class ExpandableAssignTernaryExprInfo: public ExpandableTernaryExprInfo {
ExpandableAssignTernaryExprInfo(AssignExpr *Assign): Assign(Assign) {}
IfExpr *getIf() override {
if (!Assign)
return nullptr;
return dyn_cast_or_null<IfExpr>(Assign->getSrc());
SourceRange getNameRange() override {
auto Invalid = SourceRange();
if (!Assign)
return Invalid;
if (auto dest = Assign->getDest())
return dest->getSourceRange();
return Invalid;
Type getType() override {
return nullptr;
AssignExpr *Assign = nullptr;
/// Concrete subclass containing info about a PatternBindingDecl
/// where the pattern initializer is the expandable IfExpr.
class ExpandableBindingTernaryExprInfo: public ExpandableTernaryExprInfo {
ExpandableBindingTernaryExprInfo(PatternBindingDecl *Binding):
Binding(Binding) {}
IfExpr *getIf() override {
if (Binding && Binding->getNumPatternEntries() == 1) {
if (auto *Init = Binding->getInit(0)) {
return dyn_cast<IfExpr>(Init);
return nullptr;
SourceRange getNameRange() override {
if (auto Pattern = getNamePattern())
return Pattern->getSourceRange();
return SourceRange();
Type getType() override {
if (auto Pattern = getNamePattern())
return Pattern->getType();
return nullptr;
Pattern *getNamePattern() {
if (!Binding || Binding->getNumPatternEntries() != 1)
return nullptr;
auto Pattern = Binding->getPattern(0);
if (!Pattern)
return nullptr;
if (auto TyPattern = dyn_cast<TypedPattern>(Pattern))
Pattern = TyPattern->getSubPattern();
return Pattern;
PatternBindingDecl *Binding = nullptr;
findExpandableTernaryExpression(ResolvedRangeInfo Info) {
if (Info.Kind != RangeKind::SingleDecl
&& Info.Kind != RangeKind:: SingleExpression)
return nullptr;
if (Info.ContainedNodes.size() != 1)
return nullptr;
if (auto D = Info.ContainedNodes[0].dyn_cast<Decl*>())
if (auto Binding = dyn_cast<PatternBindingDecl>(D))
return std::make_unique<ExpandableBindingTernaryExprInfo>(Binding);
if (auto E = Info.ContainedNodes[0].dyn_cast<Expr*>())
if (auto Assign = dyn_cast<AssignExpr>(E))
return std::make_unique<ExpandableAssignTernaryExprInfo>(Assign);
return nullptr;
bool RefactoringActionExpandTernaryExpr::
isApplicable(ResolvedRangeInfo Info, DiagnosticEngine &Diag) {
auto Target = findExpandableTernaryExpression(Info);
return Target && Target->isValid();
bool RefactoringActionExpandTernaryExpr::performChange() {
auto Target = findExpandableTernaryExpression(RangeInfo);
if (!Target || !Target->isValid())
return true; //abort
auto NameCharRange = Target->getNameCharRange(SM);
auto IfRange = Target->getIf()->getSourceRange();
auto IfCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, IfRange);
auto CondRange = Target->getIf()->getCondExpr()->getSourceRange();
auto CondCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, CondRange);
auto ThenRange = Target->getIf()->getThenExpr()->getSourceRange();
auto ThenCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, ThenRange);
auto ElseRange = Target->getIf()->getElseExpr()->getSourceRange();
auto ElseCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, ElseRange);
llvm::SmallString<64> DeclBuffer;
llvm::raw_svector_ostream OS(DeclBuffer);
llvm::StringRef Space = " ";
llvm::StringRef NewLine = "\n";
if (Target->shouldDeclareNameAndType()) {
//Specifier will not be replaced; append after specifier
OS << NameCharRange.str() << tok::colon << Space;
OS << Target->getType() << NewLine;
OS << tok::kw_if << Space;
OS << CondCharRange.str() << Space;
OS << tok::l_brace << NewLine;
OS << NameCharRange.str() << Space;
OS << tok::equal << Space;
OS << ThenCharRange.str() << NewLine;
OS << tok::r_brace << Space;
OS << tok::kw_else << Space;
OS << tok::l_brace << NewLine;
OS << NameCharRange.str() << Space;
OS << tok::equal << Space;
OS << ElseCharRange.str() << NewLine;
OS << tok::r_brace;
//Start replacement with name range, skip the specifier
auto ReplaceRange(NameCharRange);
EditConsumer.accept(SM, ReplaceRange, DeclBuffer.str());
return false; //don't abort
bool RefactoringActionConvertIfLetExprToGuardExpr::
isApplicable(ResolvedRangeInfo Info, DiagnosticEngine &Diag) {
if (Info.Kind != RangeKind::SingleStatement
&& Info.Kind != RangeKind::MultiStatement)
return false;
if (Info.ContainedNodes.empty())
return false;
IfStmt *If = nullptr;
if (Info.ContainedNodes.size() == 1) {
if (auto S = Info.ContainedNodes[0].dyn_cast<Stmt*>()) {
If = dyn_cast<IfStmt>(S);
if (!If)
return false;
auto CondList = If->getCond();
if (CondList.size() == 1) {
auto E = CondList[0];
auto P = E.getKind();
if (P == swift::StmtConditionElement::CK_PatternBinding) {
auto Body = dyn_cast_or_null<BraceStmt>(If->getThenStmt());
if (Body)
return true;
return false;
bool RefactoringActionConvertIfLetExprToGuardExpr::performChange() {
auto S = RangeInfo.ContainedNodes[0].dyn_cast<Stmt*>();
IfStmt *If = dyn_cast<IfStmt>(S);
auto CondList = If->getCond();
// Get if-let condition
SourceRange range = CondList[0].getSourceRange();
SourceManager &SM = RangeInfo.RangeContext->getASTContext().SourceMgr;
auto CondCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, range);
auto Body = dyn_cast_or_null<BraceStmt>(If->getThenStmt());
// Get if-let then body.
auto firstElement = Body->getFirstElement();
auto lastElement = Body->getLastElement();
SourceRange bodyRange = firstElement.getSourceRange();
auto BodyCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, bodyRange);
llvm::SmallString<64> DeclBuffer;
llvm::raw_svector_ostream OS(DeclBuffer);
llvm::StringRef Space = " ";
llvm::StringRef NewLine = "\n";
OS << tok::kw_guard << Space;
OS << CondCharRange.str().str() << Space;
OS << tok::kw_else << Space;
OS << tok::l_brace << NewLine;
// Get if-let else body.
if (auto *ElseBody = dyn_cast_or_null<BraceStmt>(If->getElseStmt())) {
auto firstElseElement = ElseBody->getFirstElement();
auto lastElseElement = ElseBody->getLastElement();
SourceRange elseBodyRange = firstElseElement.getSourceRange();
auto ElseBodyCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, elseBodyRange);
OS << ElseBodyCharRange.str().str() << NewLine;
OS << tok::kw_return << NewLine;
OS << tok::r_brace << NewLine;
OS << BodyCharRange.str().str();
// Replace if-let to guard
auto ReplaceRange = RangeInfo.ContentRange;
EditConsumer.accept(SM, ReplaceRange, DeclBuffer.str());
return false;
bool RefactoringActionConvertGuardExprToIfLetExpr::
isApplicable(ResolvedRangeInfo Info, DiagnosticEngine &Diag) {
if (Info.Kind != RangeKind::SingleStatement
&& Info.Kind != RangeKind::MultiStatement)
return false;
if (Info.ContainedNodes.empty())
return false;
GuardStmt *guardStmt = nullptr;
if (Info.ContainedNodes.size() > 0) {
if (auto S = Info.ContainedNodes[0].dyn_cast<Stmt*>()) {
guardStmt = dyn_cast<GuardStmt>(S);
if (!guardStmt)
return false;
auto CondList = guardStmt->getCond();
if (CondList.size() == 1) {
auto E = CondList[0];
auto P = E.getPatternOrNull();
if (P && E.getKind() == swift::StmtConditionElement::CK_PatternBinding)
return true;
return false;
bool RefactoringActionConvertGuardExprToIfLetExpr::performChange() {
// Get guard stmt
auto S = RangeInfo.ContainedNodes[0].dyn_cast<Stmt*>();
GuardStmt *Guard = dyn_cast<GuardStmt>(S);
// Get guard condition
auto CondList = Guard->getCond();
// Get guard condition source
SourceRange range = CondList[0].getSourceRange();
SourceManager &SM = RangeInfo.RangeContext->getASTContext().SourceMgr;
auto CondCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, range);
llvm::SmallString<64> DeclBuffer;
llvm::raw_svector_ostream OS(DeclBuffer);
llvm::StringRef Space = " ";
llvm::StringRef NewLine = "\n";
OS << tok::kw_if << Space;
OS << CondCharRange.str().str() << Space;
OS << tok::l_brace << NewLine;
// Get nodes after guard to place them at if-let body
if (RangeInfo.ContainedNodes.size() > 1) {
auto S = RangeInfo.ContainedNodes[1].getSourceRange();
auto BodyCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, S);
OS << BodyCharRange.str().str() << NewLine;
OS << tok::r_brace;
// Get guard body
auto Body = dyn_cast_or_null<BraceStmt>(Guard->getBody());
if (Body && Body->getNumElements() > 1) {
auto firstElement = Body->getFirstElement();
auto lastElement = Body->getLastElement();
SourceRange bodyRange = firstElement.getSourceRange();
auto BodyCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, bodyRange);
OS << Space << tok::kw_else << Space << tok::l_brace << NewLine;
OS << BodyCharRange.str().str() << NewLine;
OS << tok::r_brace;
// Replace guard to if-let
auto ReplaceRange = RangeInfo.ContentRange;
EditConsumer.accept(SM, ReplaceRange, DeclBuffer.str());
return false;
bool RefactoringActionConvertToSwitchStmt::
isApplicable(ResolvedRangeInfo Info, DiagnosticEngine &Diag) {
class ConditionalChecker : public ASTWalker {
bool ParamsUseSameVars = true;
bool ConditionUseOnlyAllowedFunctions = false;
StringRef ExpectName;
Expr *walkToExprPost(Expr *E) override {
if (E->getKind() != ExprKind::DeclRef)
return E;
auto D = dyn_cast<DeclRefExpr>(E)->getDecl();
if (D->getKind() == DeclKind::Var || D->getKind() == DeclKind::Param)
ParamsUseSameVars = checkName(dyn_cast<VarDecl>(D));
if (D->getKind() == DeclKind::Func)
ConditionUseOnlyAllowedFunctions = checkName(dyn_cast<FuncDecl>(D));
if (allCheckPassed())
return E;
return nullptr;
bool allCheckPassed() {
return ParamsUseSameVars && ConditionUseOnlyAllowedFunctions;
bool checkName(VarDecl *VD) {
auto Name = VD->getName().str();
if (ExpectName.empty())
ExpectName = Name;
return Name == ExpectName;
bool checkName(FuncDecl *FD) {
const auto Name = FD->getBaseIdentifier().str();
return Name == "~="
|| Name == "=="
|| Name == "__derived_enum_equals"
|| Name == "__derived_struct_equals"
|| Name == "||"
|| Name == "...";
class SwitchConvertable {
SwitchConvertable(ResolvedRangeInfo Info) {
this->Info = Info;
bool isApplicable() {
if (Info.Kind != RangeKind::SingleStatement)
return false;
if (!findIfStmt())
return false;
return checkEachCondition();
ResolvedRangeInfo Info;
IfStmt *If = nullptr;
ConditionalChecker checker;
bool findIfStmt() {
if (Info.ContainedNodes.size() != 1)
return false;
if (auto S = Info.ContainedNodes.front().dyn_cast<Stmt*>())
If = dyn_cast<IfStmt>(S);
return If != nullptr;
bool checkEachCondition() {
checker = ConditionalChecker();
do {
if (!checkEachElement())
return false;
} while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt())));
return true;
bool checkEachElement() {
bool result = true;
auto ConditionalList = If->getCond();
for (auto Element : ConditionalList) {
result &= check(Element);
return result;
bool check(StmtConditionElement ConditionElement) {
if (ConditionElement.getKind() == StmtConditionElement::CK_Availability)
return false;
if (ConditionElement.getKind() == StmtConditionElement::CK_PatternBinding)
checker.ConditionUseOnlyAllowedFunctions = true;
return checker.allCheckPassed();
return SwitchConvertable(Info).isApplicable();
bool RefactoringActionConvertToSwitchStmt::performChange() {
class VarNameFinder : public ASTWalker {
std::string VarName;
Expr *walkToExprPost(Expr *E) override {
if (E->getKind() != ExprKind::DeclRef)
return E;
auto D = dyn_cast<DeclRefExpr>(E)->getDecl();
if (D->getKind() != DeclKind::Var && D->getKind() != DeclKind::Param)
return E;
VarName = dyn_cast<VarDecl>(D)->getName().str().str();
return nullptr;
class ConditionalPatternFinder : public ASTWalker {
ConditionalPatternFinder(SourceManager &SM) : SM(SM) {}
SmallString<64> ConditionalPattern = SmallString<64>();
Expr *walkToExprPost(Expr *E) override {
if (E->getKind() != ExprKind::Binary)
return E;
auto BE = dyn_cast<BinaryExpr>(E);
if (isFunctionNameAllowed(BE))
return E;
std::pair<bool, Pattern*> walkToPatternPre(Pattern *P) override {
ConditionalPattern.append(Lexer::getCharSourceRangeFromSourceRange(SM, P->getSourceRange()).str());
if (P->getKind() == PatternKind::OptionalSome)
return { true, nullptr };
SourceManager &SM;
bool isFunctionNameAllowed(BinaryExpr *E) {
auto FunctionBody = dyn_cast<DotSyntaxCallExpr>(E->getFn())->getFn();
auto FunctionDeclaration = dyn_cast<DeclRefExpr>(FunctionBody)->getDecl();
const auto FunctionName = dyn_cast<FuncDecl>(FunctionDeclaration)
return FunctionName == "~="
|| FunctionName == "=="
|| FunctionName == "__derived_enum_equals"
|| FunctionName == "__derived_struct_equals";
void appendPattern(TupleExpr *Tuple) {
auto PatternArgument = Tuple->getElements().back();
if (PatternArgument->getKind() == ExprKind::DeclRef)
PatternArgument = Tuple->getElements().front();
if (ConditionalPattern.size() > 0)
ConditionalPattern.append(", ");
ConditionalPattern.append(Lexer::getCharSourceRangeFromSourceRange(SM, PatternArgument->getSourceRange()).str());
class ConverterToSwitch {
ConverterToSwitch(ResolvedRangeInfo Info, SourceManager &SM) : SM(SM) {
this->Info = Info;
void performConvert(SmallString<64> &Out) {
If = findIf();
OptionalLabel = If->getLabelInfo().Name.str().str();
ControlExpression = findControlExpression();
DefaultStatements = findDefaultStatements();
ResolvedRangeInfo Info;
SourceManager &SM;
IfStmt *If;
IfStmt *PreviousIf;
std::string OptionalLabel;
std::string ControlExpression;
SmallVector<std::pair<std::string, std::string>, 16> PatternsAndBodies;
std::string DefaultStatements;
IfStmt *findIf() {
auto S = Info.ContainedNodes[0].dyn_cast<Stmt*>();
return dyn_cast<IfStmt>(S);
std::string findControlExpression() {
auto ConditionElement = If->getCond().front();
auto Finder = VarNameFinder();
return Finder.VarName;
void findPatternsAndBodies(SmallVectorImpl<std::pair<std::string, std::string>> &Out) {
do {
auto pattern = findPattern();
auto body = findBodyStatements();
Out.push_back(std::make_pair(pattern, body));
PreviousIf = If;
} while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt())));
std::string findPattern() {
auto ConditionElement = If->getCond().front();
auto Finder = ConditionalPatternFinder(SM);
return Finder.ConditionalPattern.str().str();
std::string findBodyStatements() {
return findBodyWithoutBraces(If->getThenStmt());
std::string findDefaultStatements() {
auto ElseBody = dyn_cast_or_null<BraceStmt>(PreviousIf->getElseStmt());
if (!ElseBody)
return getTokenText(tok::kw_break).str();
return findBodyWithoutBraces(ElseBody);
std::string findBodyWithoutBraces(Stmt *body) {
auto BS = dyn_cast<BraceStmt>(body);
if (!BS)
return Lexer::getCharSourceRangeFromSourceRange(SM, body->getSourceRange()).str().str();
if (BS->getElements().empty())
return getTokenText(tok::kw_break).str();
SourceRange BodyRange = BS->getElements().front().getSourceRange();
return Lexer::getCharSourceRangeFromSourceRange(SM, BodyRange).str().str();
void makeSwitchStatement(SmallString<64> &Out) {
StringRef Space = " ";
StringRef NewLine = "\n";
llvm::raw_svector_ostream OS(Out);
if (OptionalLabel.size() > 0)
OS << OptionalLabel << ":" << Space;
OS << tok::kw_switch << Space << ControlExpression << Space << tok::l_brace << NewLine;
for (auto &pair : PatternsAndBodies) {
OS << tok::kw_case << Space << pair.first << tok::colon << NewLine;
OS << pair.second << NewLine;
OS << tok::kw_default << tok::colon << NewLine;
OS << DefaultStatements << NewLine;
OS << tok::r_brace;
SmallString<64> result;
ConverterToSwitch(RangeInfo, SM).performConvert(result);
EditConsumer.accept(SM, RangeInfo.ContentRange, result.str());
return false;
/// Struct containing info about an IfStmt that can be converted into an IfExpr.
struct ConvertToTernaryExprInfo {
ConvertToTernaryExprInfo() {}
Expr *AssignDest() {
if (!Then || !Then->getDest() || !Else || !Else->getDest())
return nullptr;
auto ThenDest = Then->getDest();
auto ElseDest = Else->getDest();
if (ThenDest->getKind() != ElseDest->getKind())
return nullptr;
switch (ThenDest->getKind()) {
case ExprKind::DeclRef: {
auto ThenRef = dyn_cast<DeclRefExpr>(Then->getDest());
auto ElseRef = dyn_cast<DeclRefExpr>(Else->getDest());
if (!ThenRef || !ThenRef->getDecl() || !ElseRef || !ElseRef->getDecl())
return nullptr;
const auto ThenName = ThenRef->getDecl()->getName();
const auto ElseName = ElseRef->getDecl()->getName();
if ( != 0)
return nullptr;
return Then->getDest();
case ExprKind::Tuple: {
auto ThenTuple = dyn_cast<TupleExpr>(Then->getDest());
auto ElseTuple = dyn_cast<TupleExpr>(Else->getDest());
if (!ThenTuple || !ElseTuple)
return nullptr;
auto ThenNames = ThenTuple->getElementNames();
auto ElseNames = ElseTuple->getElementNames();
if (!ThenNames.equals(ElseNames))
return nullptr;
return ThenTuple;
return nullptr;
Expr *ThenSrc() {
if (!Then)
return nullptr;
return Then->getSrc();
Expr *ElseSrc() {
if (!Else)
return nullptr;
return Else->getSrc();
bool isValid() {
if (!Cond || !AssignDest() || !ThenSrc() || !ElseSrc()
|| !IfRange.isValid())
return false;
return true;
PatternBindingDecl *Binding = nullptr; //optional
Expr *Cond = nullptr; //required
AssignExpr *Then = nullptr; //required
AssignExpr *Else = nullptr; //required
SourceRange IfRange;
findConvertToTernaryExpression(ResolvedRangeInfo Info) {
auto notFound = ConvertToTernaryExprInfo();
if (Info.Kind != RangeKind::SingleStatement
&& Info.Kind != RangeKind::MultiStatement)
return notFound;
if (Info.ContainedNodes.empty())
return notFound;
struct AssignExprFinder: public SourceEntityWalker {
AssignExpr *Assign = nullptr;
AssignExprFinder(Stmt* S) {
if (S)
virtual bool walkToExprPre(Expr *E) override {
Assign = dyn_cast<AssignExpr>(E);
return false;
ConvertToTernaryExprInfo Target;
IfStmt *If = nullptr;
if (Info.ContainedNodes.size() == 1) {
if (auto S = Info.ContainedNodes[0].dyn_cast<Stmt*>())
If = dyn_cast<IfStmt>(S);
if (Info.ContainedNodes.size() == 2) {
if (auto D = Info.ContainedNodes[0].dyn_cast<Decl*>())
Target.Binding = dyn_cast<PatternBindingDecl>(D);
if (auto S = Info.ContainedNodes[1].dyn_cast<Stmt*>())
If = dyn_cast<IfStmt>(S);
if (!If)
return notFound;
auto CondList = If->getCond();
if (CondList.size() != 1)
return notFound;
Target.Cond = CondList[0].getBooleanOrNull();
Target.IfRange = If->getSourceRange();
Target.Then = AssignExprFinder(If->getThenStmt()).Assign;
Target.Else = AssignExprFinder(If->getElseStmt()).Assign;
return Target;
bool RefactoringActionConvertToTernaryExpr::
isApplicable(ResolvedRangeInfo Info, DiagnosticEngine &Diag) {
return findConvertToTernaryExpression(Info).isValid();
bool RefactoringActionConvertToTernaryExpr::performChange() {
auto Target = findConvertToTernaryExpression(RangeInfo);
if (!Target.isValid())
return true; //abort
llvm::SmallString<64> DeclBuffer;
llvm::raw_svector_ostream OS(DeclBuffer);
llvm::StringRef Space = " ";
auto IfRange = Target.IfRange;
auto ReplaceRange = Lexer::getCharSourceRangeFromSourceRange(SM, IfRange);
auto CondRange = Target.Cond->getSourceRange();
auto CondCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, CondRange);
auto ThenRange = Target.ThenSrc()->getSourceRange();
auto ThenCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, ThenRange);
auto ElseRange = Target.ElseSrc()->getSourceRange();
auto ElseCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, ElseRange);
CharSourceRange DestCharRange;
if (Target.Binding) {
auto DestRange = Target.Binding->getSourceRange();
DestCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, DestRange);
} else {
auto DestRange = Target.AssignDest()->getSourceRange();
DestCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, DestRange);
OS << DestCharRange.str() << Space << tok::equal << Space;
OS << CondCharRange.str() << Space << tok::question_postfix << Space;
OS << ThenCharRange.str() << Space << tok::colon << Space;
OS << ElseCharRange.str();
EditConsumer.accept(SM, ReplaceRange, DeclBuffer.str());
return false; //don't abort
/// The helper class analyzes a given nominal decl or an extension decl to
/// decide whether stubs are required to filled in and the context in which
/// these stubs should be filled.
class FillProtocolStubContext {
getUnsatisfiedRequirements(const IterableDeclContext *IDC);
/// Context in which the content should be filled; this could be either a
/// nominal type declaraion or an extension declaration.
DeclContext *DC;
/// The type that adopts the required protocol stubs. For nominal type decl, this
/// should be the declared type itself; for extension decl, this should be the
/// extended type at hand.
Type Adopter;
/// The start location of the decl, either nominal type or extension, for the
/// printer to figure out the right indentation.
SourceLoc StartLoc;
/// The location of '{' for the decl, thus we know where to insert the filling
/// stubs.
SourceLoc BraceStartLoc;
/// The value decls that should be satisfied; this could be either function
/// decls, property decls, or required type alias.
std::vector<ValueDecl*> FillingContents;
FillProtocolStubContext(ExtensionDecl *ED) : DC(ED),
Adopter(ED->getExtendedType()), StartLoc(ED->getStartLoc()),
FillingContents(getUnsatisfiedRequirements(ED)) {};
FillProtocolStubContext(NominalTypeDecl *ND) : DC(ND),
Adopter(ND->getDeclaredType()), StartLoc(ND->getStartLoc()),
FillingContents(getUnsatisfiedRequirements(ND)) {};
FillProtocolStubContext() : DC(nullptr), Adopter(), FillingContents({}) {};
static FillProtocolStubContext getContextFromCursorInfo(ResolvedCursorInfo Tok);
ArrayRef<ValueDecl*> getFillingContents() const {
return llvm::makeArrayRef(FillingContents);
DeclContext *getFillingContext() const { return DC; }
bool canProceed() const {
return StartLoc.isValid() && BraceStartLoc.isValid() &&
Type getAdopter() const { return Adopter; }
SourceLoc getContextStartLoc() const { return StartLoc; }
SourceLoc getBraceStartLoc() const { return BraceStartLoc; }
FillProtocolStubContext FillProtocolStubContext::
getContextFromCursorInfo(ResolvedCursorInfo CursorInfo) {
return FillProtocolStubContext();
if (!CursorInfo.IsRef) {
// If the type name is on the declared nominal, e.g. "class A {}"
if (auto ND = dyn_cast<NominalTypeDecl>(CursorInfo.ValueD)) {
return FillProtocolStubContext(ND);
} else if (auto *ED = CursorInfo.ExtTyRef) {
// If the type ref is on a declared extension, e.g. "extension A {}"
return FillProtocolStubContext(ED);
return FillProtocolStubContext();
std::vector<ValueDecl*> FillProtocolStubContext::
getUnsatisfiedRequirements(const IterableDeclContext *IDC) {
// The results to return.
std::vector<ValueDecl*> NonWitnessedReqs;
// For each conformance of the extended nominal.
for(ProtocolConformance *Con : IDC->getLocalConformances()) {
// Collect non-witnessed requirements.
[&](ValueDecl *VD) { NonWitnessedReqs.push_back(VD); });
return NonWitnessedReqs;
bool RefactoringActionFillProtocolStub::
isApplicable(ResolvedCursorInfo Tok, DiagnosticEngine &Diag) {
return FillProtocolStubContext::getContextFromCursorInfo(Tok).canProceed();
bool RefactoringActionFillProtocolStub::performChange() {
// Get the filling protocol context from the input token.
FillProtocolStubContext Context = FillProtocolStubContext::
llvm::SmallString<128> Text;
llvm::raw_svector_ostream SS(Text);
Type Adopter = Context.getAdopter();
SourceLoc Loc = Context.getContextStartLoc();
auto Contents = Context.getFillingContents();
// For each unsatisfied requirement, print the stub to the buffer.
std::for_each(Contents.begin(), Contents.end(), [&](ValueDecl *VD) {
printRequirementStub(VD, Context.getFillingContext(), Adopter, Loc, SS);
// Insert all stubs after '{' in the extension/nominal type decl.
EditConsumer.insertAfter(SM, Context.getBraceStartLoc(), Text);
return false;
collectAvailableRefactoringsAtCursor(SourceFile *SF, unsigned Line,
unsigned Column,
std::vector<RefactoringKind> &Scratch,
llvm::ArrayRef<DiagnosticConsumer*> DiagConsumers) {
// Prepare the tool box.
ASTContext &Ctx = SF->getASTContext();
SourceManager &SM = Ctx.SourceMgr;
DiagnosticEngine DiagEngine(SM);
std::for_each(DiagConsumers.begin(), DiagConsumers.end(),
[&](DiagnosticConsumer *Con) { DiagEngine.addConsumer(*Con); });
SourceLoc Loc = SM.getLocForLineCol(SF->getBufferID().getValue(), Line, Column);
if (Loc.isInvalid())
return {};
ResolvedCursorInfo Tok = evaluateOrDefault(SF->getASTContext().evaluator,
CursorInfoRequest{CursorInfoOwner(SF, Lexer::getLocForStartOfToken(SM, Loc))},
return collectAvailableRefactorings(SF, Tok, Scratch, /*Exclude rename*/false);
static EnumDecl* getEnumDeclFromSwitchStmt(SwitchStmt *SwitchS) {
if (auto SubjectTy = SwitchS->getSubjectExpr()->getType()) {
// FIXME: Support more complex subject like '(Enum1, Enum2)'.
return dyn_cast_or_null<EnumDecl>(SubjectTy->getAnyNominal());
return nullptr;
static bool performCasesExpansionInSwitchStmt(SwitchStmt *SwitchS,
DiagnosticEngine &DiagEngine,
SourceLoc ExpandedStmtLoc,
EditorConsumerInsertStream &OS
) {
// Assume enum elements are not handled in the switch statement.
auto EnumDecl = getEnumDeclFromSwitchStmt(SwitchS);
llvm::DenseSet<EnumElementDecl*> UnhandledElements;
for (auto Current : SwitchS->getCases()) {
if (Current->isDefault()) {
// For each handled enum element, remove it from the bucket.
for (auto Item : Current->getCaseLabelItems()) {
if (auto *EEP = dyn_cast_or_null<EnumElementPattern>(Item.getPattern())) {
// If all enum elements are handled in the switch statement, issue error.
if (UnhandledElements.empty()) {
DiagEngine.diagnose(ExpandedStmtLoc, diag::no_remaining_cases);
return true;
printEnumElementsAsCases(UnhandledElements, OS);
return false;
// Finds SwitchStmt that contains given CaseStmt.
static SwitchStmt* findEnclosingSwitchStmt(CaseStmt *CS,
SourceFile *SF,
DiagnosticEngine &DiagEngine) {
auto IsSwitch = [](ASTNode Node) {
return<Stmt*>() &&
Node.get<Stmt*>()->getKind() == StmtKind::Switch;
ContextFinder Finder(*SF, CS, IsSwitch);
// If failed to find the switch statement, issue error.
if (Finder.getContexts().empty()) {
DiagEngine.diagnose(CS->getStartLoc(), diag::no_parent_switch);
return nullptr;
auto *SwitchS = static_cast<SwitchStmt*>(Finder.getContexts().back().
// Make sure that CaseStmt is included in switch that was found.
auto Cases = SwitchS->getCases();
auto Default = std::find(Cases.begin(), Cases.end(), CS);
if (Default == Cases.end()) {
DiagEngine.diagnose(CS->getStartLoc(), diag::no_parent_switch);
return nullptr;
return SwitchS;
bool RefactoringActionExpandDefault::
isApplicable(ResolvedCursorInfo CursorInfo, DiagnosticEngine &Diag) {
auto Exit = [&](bool Applicable) {
if (!Applicable)
Diag.diagnose(SourceLoc(), diag::invalid_default_location);
return Applicable;
if (CursorInfo.Kind != CursorInfoKind::StmtStart)
return Exit(false);
if (auto *CS = dyn_cast<CaseStmt>(CursorInfo.TrailingStmt)) {
auto EnclosingSwitchStmt = findEnclosingSwitchStmt(CS,
if (!EnclosingSwitchStmt)
return false;
auto EnumD = getEnumDeclFromSwitchStmt(EnclosingSwitchStmt);
auto IsApplicable = CS->isDefault() && EnumD != nullptr;
return IsApplicable;
return Exit(false);
bool RefactoringActionExpandDefault::performChange() {
// If we've not seen the default statement inside the switch statement, issue
// error.
auto *CS = static_cast<CaseStmt*>(CursorInfo.TrailingStmt);
auto *SwitchS = findEnclosingSwitchStmt(CS, TheFile, DiagEngine);
EditorConsumerInsertStream OS(EditConsumer, SM,
return performCasesExpansionInSwitchStmt(SwitchS,
bool RefactoringActionExpandSwitchCases::
isApplicable(ResolvedCursorInfo CursorInfo, DiagnosticEngine &DiagEngine) {
if (!CursorInfo.TrailingStmt)
return false;
if (auto *Switch = dyn_cast<SwitchStmt>(CursorInfo.TrailingStmt)) {
return getEnumDeclFromSwitchStmt(Switch);
return false;
bool RefactoringActionExpandSwitchCases::performChange() {
auto *SwitchS = dyn_cast<SwitchStmt>(CursorInfo.TrailingStmt);
auto InsertRange = CharSourceRange();
auto Cases = SwitchS->getCases();
auto Default = std::find_if(Cases.begin(), Cases.end(), [](CaseStmt *Stmt) {
return Stmt->isDefault();
if (Default != Cases.end()) {
auto DefaultRange = (*Default)->getLabelItemsRange();
InsertRange = Lexer::getCharSourceRangeFromSourceRange(SM, DefaultRange);
} else {
auto RBraceLoc = SwitchS->getRBraceLoc();
InsertRange = CharSourceRange(SM, RBraceLoc, RBraceLoc);
EditorConsumerInsertStream OS(EditConsumer, SM, InsertRange);
if (SM.getLineAndColumnInBuffer(SwitchS->getLBraceLoc()).first ==
SM.getLineAndColumnInBuffer(SwitchS->getRBraceLoc()).first) {
OS << "\n";
auto Result = performCasesExpansionInSwitchStmt(SwitchS,
return Result;
static Expr *findLocalizeTarget(ResolvedCursorInfo CursorInfo) {
if (CursorInfo.Kind != CursorInfoKind::ExprStart)
return nullptr;
struct StringLiteralFinder: public SourceEntityWalker {
SourceLoc StartLoc;
Expr *Target;
StringLiteralFinder(SourceLoc StartLoc): StartLoc(StartLoc), Target(nullptr) {}
bool walkToExprPre(Expr *E) override {
if (E->getStartLoc() != StartLoc)
return false;
if (E->getKind() == ExprKind::InterpolatedStringLiteral)
return false;
if (E->getKind() == ExprKind::StringLiteral) {
Target = E;
return false;
return true;
} Walker(CursorInfo.TrailingExpr->getStartLoc());
return Walker.Target;
bool RefactoringActionLocalizeString::
isApplicable(ResolvedCursorInfo Tok, DiagnosticEngine &Diag) {
return findLocalizeTarget(Tok);
bool RefactoringActionLocalizeString::performChange() {
Expr* Target = findLocalizeTarget(CursorInfo);
if (!Target)
return true;
EditConsumer.accept(SM, Target->getStartLoc(), "NSLocalizedString(");
EditConsumer.insertAfter(SM, Target->getEndLoc(), ", comment: \"\")");
return false;
struct MemberwiseParameter {
Identifier Name;
Type MemberType;
Expr *DefaultExpr;
MemberwiseParameter(Identifier name, Type type, Expr *initialExpr)
: Name(name), MemberType(type), DefaultExpr(initialExpr) {}
static void generateMemberwiseInit(SourceEditConsumer &EditConsumer,
SourceManager &SM,
ArrayRef<MemberwiseParameter> memberVector,
SourceLoc targetLocation) {
EditConsumer.accept(SM, targetLocation, "\ninternal init(");
auto insertMember = [&SM](const MemberwiseParameter &memberData,
llvm::raw_ostream &OS, bool wantsSeparator) {
OS << memberData.Name << ": ";
// Unconditionally print '@escaping' if we print out a function type -
// the assignments we generate below will escape this parameter.
if (isa<AnyFunctionType>(memberData.MemberType->getCanonicalType())) {
OS << "@" << TypeAttributes::getAttrName(TAK_escaping) << " ";
OS << memberData.MemberType.getString();
if (auto *expr = memberData.DefaultExpr) {
if (isa<NilLiteralExpr>(expr)) {
OS << " = nil";
} else if (expr->getSourceRange().isValid()) {
auto range =
SM, expr->getSourceRange());
OS << " = " << SM.extractText(range);
if (wantsSeparator) {
OS << ", ";
// Process the initial list of members, inserting commas as appropriate.
std::string Buffer;
llvm::raw_string_ostream OS(Buffer);
for (const auto &memberData : memberVector.drop_back()) {
insertMember(memberData, OS, /*wantsSeparator*/ true);
// Process the last (or perhaps, only) member.
insertMember(memberVector.back(), OS, /*wantsSeparator*/ false);
// Synthesize the body.
OS << ") {\n";
for (auto &member : memberVector) {
// self.<property> = <property>
OS << "self." << member.Name << " = " << member.Name << "\n";
OS << "}\n";
// Accept the entire edit.
EditConsumer.accept(SM, targetLocation, OS.str());
static SourceLoc
collectMembersForInit(ResolvedCursorInfo CursorInfo,
SmallVectorImpl<MemberwiseParameter> &memberVector) {
if (!CursorInfo.ValueD)
return SourceLoc();
NominalTypeDecl *nominalDecl = dyn_cast<NominalTypeDecl>(CursorInfo.ValueD);
if (!nominalDecl || nominalDecl->getStoredProperties().empty() ||
CursorInfo.IsRef) {
return SourceLoc();
SourceLoc bracesStart = nominalDecl->getBraces().Start;
if (!bracesStart.isValid())
return SourceLoc();
SourceLoc targetLocation = bracesStart.getAdvancedLoc(1);
if (!targetLocation.isValid())
return SourceLoc();
for (auto varDecl : nominalDecl->getStoredProperties()) {
auto patternBinding = varDecl->getParentPatternBinding();
if (!patternBinding)
if (!varDecl->isMemberwiseInitialized(/*preferDeclaredProperties=*/true)) {
const auto i = patternBinding->getPatternEntryIndexForVarDecl(varDecl);
Expr *defaultInit = nullptr;
if (patternBinding->isExplicitlyInitialized(i) ||
patternBinding->isDefaultInitializable()) {
defaultInit = varDecl->getParentInitializer();
varDecl->getType(), defaultInit);
if (memberVector.empty()) {
return SourceLoc();
return targetLocation;
bool RefactoringActionMemberwiseInitLocalRefactoring::
isApplicable(ResolvedCursorInfo Tok, DiagnosticEngine &Diag) {
SmallVector<MemberwiseParameter, 8> memberVector;
return collectMembersForInit(Tok, memberVector).isValid();
bool RefactoringActionMemberwiseInitLocalRefactoring::performChange() {
SmallVector<MemberwiseParameter, 8> memberVector;
SourceLoc targetLocation = collectMembersForInit(CursorInfo, memberVector);
if (targetLocation.isInvalid())
return true;
generateMemberwiseInit(EditConsumer, SM, memberVector, targetLocation);
return false;
class AddEquatableContext {
/// Declaration context
DeclContext *DC;
/// Adopter type
Type Adopter;
/// Start location of declaration context brace
SourceLoc StartLoc;
/// Array of all inherited protocols' locations
ArrayRef<TypeLoc> ProtocolsLocations;
/// Array of all conformed protocols
SmallVector<swift::ProtocolDecl *, 2> Protocols;
/// Start location of declaration,
/// a place to write protocol name
SourceLoc ProtInsertStartLoc;
/// Stored properties of extending adopter
ArrayRef<VarDecl *> StoredProperties;
/// Range of internal members in declaration
DeclRange Range;
bool conformsToEquatableProtocol() {
for (ProtocolDecl *Protocol : Protocols) {
if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Equatable) {
return true;
return false;
bool isRequirementValid() {
auto Reqs = getProtocolRequirements();
if (Reqs.empty()) {
return false;
auto Req = dyn_cast<FuncDecl>(Reqs[0]);
return Req && Req->getParameters()->size() == 2;
bool isPropertiesListValid() {
return !getUserAccessibleProperties().empty();
void printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent,
ParameterList *Params);
std::vector<ValueDecl *> getProtocolRequirements();
std::vector<VarDecl *> getUserAccessibleProperties();
AddEquatableContext(NominalTypeDecl *Decl) : DC(Decl),
Adopter(Decl->getDeclaredType()), StartLoc(Decl->getBraces().Start),
Protocols(Decl->getAllProtocols()), ProtInsertStartLoc(Decl->getNameLoc()),
StoredProperties(Decl->getStoredProperties()), Range(Decl->getMembers()) {};
AddEquatableContext(ExtensionDecl *Decl) : DC(Decl),
Adopter(Decl->getExtendedType()), StartLoc(Decl->getBraces().Start),
StoredProperties(Decl->getExtendedNominal()->getStoredProperties()), Range(Decl->getMembers()) {};
AddEquatableContext() : DC(nullptr), Adopter(), ProtocolsLocations(),
Protocols(), StoredProperties(), Range(nullptr, nullptr) {};
static AddEquatableContext getDeclarationContextFromInfo(ResolvedCursorInfo Info);
std::string getInsertionTextForProtocol();
std::string getInsertionTextForFunction(SourceManager &SM);
bool isValid() {
// FIXME: Allow to generate explicit == method for declarations which already have
// compiler-generated == method
return StartLoc.isValid() && ProtInsertStartLoc.isValid() &&
!conformsToEquatableProtocol() && isPropertiesListValid() &&
SourceLoc getStartLocForProtocolDecl() {
if (ProtocolsLocations.empty()) {
return ProtInsertStartLoc;
return ProtocolsLocations.back().getSourceRange().Start;
bool isMembersRangeEmpty() {
return Range.empty();
SourceLoc getInsertStartLoc();
SourceLoc AddEquatableContext::
getInsertStartLoc() {
SourceLoc MaxLoc = StartLoc;
for (auto Mem : Range) {
if (Mem->getEndLoc().getOpaquePointerValue() >
MaxLoc.getOpaquePointerValue()) {
MaxLoc = Mem->getEndLoc();
return MaxLoc;
std::string AddEquatableContext::
getInsertionTextForProtocol() {
StringRef ProtocolName = getProtocolName(KnownProtocolKind::Equatable);
std::string Buffer;
llvm::raw_string_ostream OS(Buffer);
if (ProtocolsLocations.empty()) {
OS << ": " << ProtocolName;
return Buffer;
OS << ", " << ProtocolName;
return Buffer;
std::string AddEquatableContext::
getInsertionTextForFunction(SourceManager &SM) {
auto Reqs = getProtocolRequirements();
auto Req = dyn_cast<FuncDecl>(Reqs[0]);
auto Params = Req->getParameters();
StringRef ExtraIndent;
StringRef CurrentIndent =
Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent);
std::string Indent;
if (isMembersRangeEmpty()) {
Indent = (CurrentIndent + ExtraIndent).str();
} else {
Indent = CurrentIndent.str();
PrintOptions Options = PrintOptions::printVerbose();
Options.PrintDocumentationComments = false;
Options.FunctionBody = [&](const ValueDecl *VD, ASTPrinter &Printer) {
Printer << " {";
printFunctionBody(Printer, ExtraIndent, Params);
Printer << "}";
std::string Buffer;
llvm::raw_string_ostream OS(Buffer);
ExtraIndentStreamPrinter Printer(OS, Indent);
if (!isMembersRangeEmpty()) {
Reqs[0]->print(Printer, Options);
return Buffer;
std::vector<VarDecl *> AddEquatableContext::
getUserAccessibleProperties() {
std::vector<VarDecl *> PublicProperties;
for (VarDecl *Decl : StoredProperties) {
if (Decl->Decl::isUserAccessible()) {
return PublicProperties;
std::vector<ValueDecl *> AddEquatableContext::
getProtocolRequirements() {
std::vector<ValueDecl *> Collection;
auto Proto = DC->getASTContext().getProtocol(KnownProtocolKind::Equatable);
for (auto Member : Proto->getMembers()) {
auto Req = dyn_cast<ValueDecl>(Member);
if (!Req || Req->isInvalid() || !Req->isProtocolRequirement()) {
return Collection;
AddEquatableContext AddEquatableContext::
getDeclarationContextFromInfo(ResolvedCursorInfo Info) {
if (Info.isInvalid()) {
return AddEquatableContext();
if (!Info.IsRef) {
if (auto *NomDecl = dyn_cast<NominalTypeDecl>(Info.ValueD)) {
return AddEquatableContext(NomDecl);
} else if (auto *ExtDecl = Info.ExtTyRef) {
if (ExtDecl->getExtendedNominal()) {
return AddEquatableContext(ExtDecl);
return AddEquatableContext();
void AddEquatableContext::
printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent, ParameterList *Params) {
llvm::SmallString<128> Return;
llvm::raw_svector_ostream SS(Return);
SS << tok::kw_return;
StringRef Space = " ";
StringRef AdditionalSpace = " ";
StringRef Point = ".";
StringRef Join = " == ";
StringRef And = " &&";
auto Props = getUserAccessibleProperties();
auto FParam = Params->get(0)->getName();
auto SParam = Params->get(1)->getName();
auto Prop = Props[0]->getName();
Printer << ExtraIndent << Return << Space
<< FParam << Point << Prop << Join << SParam << Point << Prop;
if (Props.size() > 1) {
std::for_each(Props.begin() + 1, Props.end(), [&](VarDecl *VD){
auto Name = VD->getName();
Printer << And;
Printer << ExtraIndent << AdditionalSpace << FParam << Point
<< Name << Join << SParam << Point << Name;
bool RefactoringActionAddEquatableConformance::
isApplicable(ResolvedCursorInfo Tok, DiagnosticEngine &Diag) {
return AddEquatableContext::getDeclarationContextFromInfo(Tok).isValid();
bool RefactoringActionAddEquatableConformance::
performChange() {
auto Context = AddEquatableContext::getDeclarationContextFromInfo(CursorInfo);
EditConsumer.insertAfter(SM, Context.getStartLocForProtocolDecl(),
EditConsumer.insertAfter(SM, Context.getInsertStartLoc(),
return false;
static CharSourceRange
findSourceRangeToWrapInCatch(ResolvedCursorInfo CursorInfo,
SourceFile *TheFile,
SourceManager &SM) {
Expr *E = CursorInfo.TrailingExpr;
if (!E)
return CharSourceRange();
auto Node = ASTNode(E);
auto NodeChecker = [](ASTNode N) { return N.isStmt(StmtKind::Brace); };
ContextFinder Finder(*TheFile, Node, NodeChecker);
auto Contexts = Finder.getContexts();
if (Contexts.empty())
return CharSourceRange();
auto TargetNode = Contexts.back();
BraceStmt *BStmt = dyn_cast<BraceStmt>(TargetNode.dyn_cast<Stmt*>());
auto ConvertToCharRange = [&SM](SourceRange SR) {
return Lexer::getCharSourceRangeFromSourceRange(SM, SR);
auto ExprRange = ConvertToCharRange(E->getSourceRange());
// Check elements of the deepest BraceStmt, pick one that covers expression.
for (auto Elem: BStmt->getElements()) {
auto ElemRange = ConvertToCharRange(Elem.getSourceRange());
if (ElemRange.contains(ExprRange))
TargetNode = Elem;
return ConvertToCharRange(TargetNode.getSourceRange());
bool RefactoringActionConvertToDoCatch::
isApplicable(ResolvedCursorInfo Tok, DiagnosticEngine &Diag) {
if (!Tok.TrailingExpr)
return false;
return isa<ForceTryExpr>(Tok.TrailingExpr);
bool RefactoringActionConvertToDoCatch::performChange() {
auto *TryExpr = dyn_cast<ForceTryExpr>(CursorInfo.TrailingExpr);
auto Range = findSourceRangeToWrapInCatch(CursorInfo, TheFile, SM);
if (!Range.isValid())
return true;
// Wrap given range in do catch block.
EditConsumer.accept(SM, Range.getStart(), "do {\n");
EditorConsumerInsertStream OS(EditConsumer, SM, Range.getEnd());
OS << "\n} catch {\n" << getCodePlaceholder() << "\n}";
// Delete ! from try! expression
auto ExclaimLen = getKeywordLen(tok::exclaim_postfix);
auto ExclaimRange = CharSourceRange(TryExpr->getExclaimLoc(), ExclaimLen);
EditConsumer.remove(SM, ExclaimRange);
return false;
/// Given a cursor position, this function tries to collect a number literal
/// expression immediately following the cursor.
static NumberLiteralExpr *getTrailingNumberLiteral(ResolvedCursorInfo Tok) {
// This cursor must point to the start of an expression.
if (Tok.Kind != CursorInfoKind::ExprStart)
return nullptr;
// For every sub-expression, try to find the literal expression that matches
// our criteria.
class FindLiteralNumber : public ASTWalker {
Expr * const parent;
NumberLiteralExpr *found = nullptr;
explicit FindLiteralNumber(Expr *parent) : parent(parent) { }
std::pair<bool, Expr *> walkToExprPre(Expr *expr) override {
if (auto *literal = dyn_cast<NumberLiteralExpr>(expr)) {
// The sub-expression must have the same start loc with the outermost
// expression, i.e. the cursor position.
if (!found &&
parent->getStartLoc().getOpaquePointerValue() ==
expr->getStartLoc().getOpaquePointerValue()) {
found = literal;
return { found == nullptr, expr };
auto parent = Tok.TrailingExpr;
FindLiteralNumber finder(parent);
return finder.found;
static std::string insertUnderscore(StringRef Text) {
llvm::SmallString<64> Buffer;
llvm::raw_svector_ostream OS(Buffer);
for (auto It = Text.begin(); It != Text.end(); ++It) {
unsigned Distance = It - Text.begin();
if (Distance && !(Distance % 3)) {
OS << '_';
OS << *It;
return OS.str().str();
static void insertUnderscoreInDigits(StringRef Digits,
llvm::raw_ostream &OS) {
StringRef BeforePointRef, AfterPointRef;
std::tie(BeforePointRef, AfterPointRef) = Digits.split('.');
std::string BeforePoint(BeforePointRef);
std::string AfterPoint(AfterPointRef);
// Insert '_' for the part before the decimal point.
std::reverse(BeforePoint.begin(), BeforePoint.end());
BeforePoint = insertUnderscore(BeforePoint);
std::reverse(BeforePoint.begin(), BeforePoint.end());
OS << BeforePoint;
// Insert '_' for the part after the decimal point, if necessary.
if (!AfterPoint.empty()) {
OS << '.';
OS << insertUnderscore(AfterPoint);
bool RefactoringActionSimplifyNumberLiteral::
isApplicable(ResolvedCursorInfo Tok, DiagnosticEngine &Diag) {
if (auto *Literal = getTrailingNumberLiteral(Tok)) {
llvm::SmallString<64> Buffer;
llvm::raw_svector_ostream OS(Buffer);
StringRef Digits = Literal->getDigitsText();
insertUnderscoreInDigits(Digits, OS);
// If inserting '_' results in a different digit sequence, this refactoring
// is applicable.
return OS.str() != Digits;
return false;
bool RefactoringActionSimplifyNumberLiteral::performChange() {
if (auto *Literal = getTrailingNumberLiteral(CursorInfo)) {
EditorConsumerInsertStream OS(EditConsumer, SM,
CharSourceRange(SM, Literal->getDigitsLoc(),
StringRef Digits = Literal->getDigitsText();
insertUnderscoreInDigits(Digits, OS);
return false;
return true;
static CallExpr *findTrailingClosureTarget(SourceManager &SM,
ResolvedCursorInfo CursorInfo) {
if (CursorInfo.Kind == CursorInfoKind::StmtStart)
// StmtStart postion can't be a part of CallExpr.
return nullptr;
// Find inner most CallExpr
Finder(*CursorInfo.SF, CursorInfo.Loc,
[](ASTNode N) {
return N.isStmt(StmtKind::Brace) || N.isExpr(ExprKind::Call);
auto contexts = Finder.getContexts();
if (contexts.empty())
return nullptr;
// If the innermost context is a statement (which will be a BraceStmt per
// the filtering condition above), drop it.
if (contexts.back().is<Stmt *>()) {
contexts = contexts.drop_back();
if (contexts.empty() || !contexts.back().is<Expr*>())
return nullptr;
CallExpr *CE = cast<CallExpr>(contexts.back().get<Expr*>());
if (CE->hasTrailingClosure())
// Call expression already has a trailing closure.
return nullptr;
// The last argument is a closure?
Expr *Args = CE->getArg();
if (!Args)
return nullptr;
Expr *LastArg;
if (auto *PE = dyn_cast<ParenExpr>(Args)) {
LastArg = PE->getSubExpr();
} else {
auto *TE = cast<TupleExpr>(Args);
if (TE->getNumElements() == 0)
return nullptr;
LastArg = TE->getElements().back();
if (auto *ICE = dyn_cast<ImplicitConversionExpr>(LastArg))
LastArg = ICE->getSyntacticSubExpr();
if (isa<ClosureExpr>(LastArg) || isa<CaptureListExpr>(LastArg))
return CE;
return nullptr;
bool RefactoringActionTrailingClosure::
isApplicable(ResolvedCursorInfo CursorInfo, DiagnosticEngine &Diag) {
SourceManager &SM = CursorInfo.SF->getASTContext().SourceMgr;
return findTrailingClosureTarget(SM, CursorInfo);
bool RefactoringActionTrailingClosure::performChange() {
auto *CE = findTrailingClosureTarget(SM, CursorInfo);
if (!CE)
return true;
Expr *Arg = CE->getArg();
Expr *ClosureArg = nullptr;
Expr *PrevArg = nullptr;
OriginalArgumentList ArgList = getOriginalArgumentList(Arg);
auto NumArgs = ArgList.args.size();
if (NumArgs == 0)
return true;
ClosureArg = ArgList.args[NumArgs - 1];
if (NumArgs > 1)
PrevArg = ArgList.args[NumArgs - 2];
if (auto *ICE = dyn_cast<ImplicitConversionExpr>(ClosureArg))
ClosureArg = ICE->getSyntacticSubExpr();
if (ArgList.lParenLoc.isInvalid() || ArgList.rParenLoc.isInvalid())
return true;
// Replace:
// * Open paren with ' ' if the closure is sole argument.
// * Comma with ') ' otherwise.
if (PrevArg) {
CharSourceRange PreRange(
Lexer::getLocForEndOfToken(SM, PrevArg->getEndLoc()),
EditConsumer.accept(SM, PreRange, ") ");
} else {
CharSourceRange PreRange(
SM, ArgList.lParenLoc, ClosureArg->getStartLoc());
EditConsumer.accept(SM, PreRange, " ");
// Remove original closing paren.
CharSourceRange PostRange(
Lexer::getLocForEndOfToken(SM, ClosureArg->getEndLoc()),
Lexer::getLocForEndOfToken(SM, ArgList.rParenLoc));
EditConsumer.remove(SM, PostRange);
return false;
static bool rangeStartMayNeedRename(ResolvedRangeInfo Info) {
switch(Info.Kind) {
case RangeKind::SingleExpression: {
Expr *E = Info.ContainedNodes[0].get<Expr*>();
// We should show rename for the selection of "foo()"
if (auto *CE = dyn_cast<CallExpr>(E)) {
if (CE->getFn()->getKind() == ExprKind::DeclRef)
return true;
// When callling an instance method inside another instance method,
// we have a dot syntax call whose dot and base are both implicit. We
// need to explicitly allow the specific case here.
if (auto *DSC = dyn_cast<DotSyntaxCallExpr>(CE->getFn())) {
if (DSC->getBase()->isImplicit() &&
DSC->getFn()->getStartLoc() == Info.TokensInRange.front().getLoc())
return true;
return false;
case RangeKind::PartOfExpression: {
if (auto *CE = dyn_cast<CallExpr>(Info.CommonExprParent)) {
if (auto *DSC = dyn_cast<DotSyntaxCallExpr>(CE->getFn())) {
if (DSC->getFn()->getStartLoc() == Info.TokensInRange.front().getLoc())
return true;
return false;
case RangeKind::SingleDecl:
case RangeKind::MultiTypeMemberDecl:
case RangeKind::SingleStatement:
case RangeKind::MultiStatement:
case RangeKind::Invalid:
return false;
llvm_unreachable("unhandled kind");
bool RefactoringActionConvertToComputedProperty::
isApplicable(ResolvedRangeInfo Info, DiagnosticEngine &Diag) {
if (Info.Kind != RangeKind::SingleDecl) {
return false;
if (Info.ContainedNodes.size() != 1) {
return false;
auto D = Info.ContainedNodes[0].dyn_cast<Decl*>();
if (!D) {
return false;
auto Binding = dyn_cast<PatternBindingDecl>(D);
if (!Binding) {
return false;
auto SV = Binding->getSingleVar();
if (!SV) {
return false;
// willSet, didSet cannot be provided together with a getter
for (auto AD : SV->getAllAccessors()) {
if (AD->isObservingAccessor()) {
return false;
// 'lazy' must not be used on a computed property
// NSCopying and IBOutlet attribute requires property to be mutable
auto Attributies = SV->getAttrs();
if (Attributies.hasAttribute<LazyAttr>() ||
Attributies.hasAttribute<NSCopyingAttr>() ||
Attributies.hasAttribute<IBOutletAttr>()) {
return false;
// Property wrapper cannot be applied to a computed property
if (SV->hasAttachedPropertyWrapper()) {
return false;
// has an initializer
return Binding->hasInitStringRepresentation(0);
bool RefactoringActionConvertToComputedProperty::performChange() {
// Get an initialization
auto D = RangeInfo.ContainedNodes[0].dyn_cast<Decl*>();
auto Binding = dyn_cast<PatternBindingDecl>(D);
SmallString<128> scratch;
auto Init = Binding->getInitStringRepresentation(0, scratch);
// Get type
auto SV = Binding->getSingleVar();
auto SVType = SV->getType();
auto TR = SV->getTypeReprOrParentPatternTypeRepr();
llvm::SmallString<64> DeclBuffer;
llvm::raw_svector_ostream OS(DeclBuffer);
llvm::StringRef Space = " ";
llvm::StringRef NewLine = "\n";
OS << tok::kw_var << Space;
// Add var name
OS << SV->getNameStr().str() << ":" << Space;
// For computed property must write a type of var
if (TR) {
OS << Lexer::getCharSourceRangeFromSourceRange(SM, TR->getSourceRange()).str();
} else {
OS << Space << tok::l_brace << NewLine;
// Add an initialization
OS << tok::kw_return << Space << Init.str() << NewLine;
OS << tok::r_brace;
// Replace initializer to computed property
auto ReplaceStartLoc = Binding->getLoc();
auto ReplaceEndLoc = Binding->getSourceRange().End;
auto ReplaceRange = SourceRange(ReplaceStartLoc, ReplaceEndLoc);
auto ReplaceCharSourceRange = Lexer::getCharSourceRangeFromSourceRange(SM, ReplaceRange);
EditConsumer.accept(SM, ReplaceCharSourceRange, DeclBuffer.str());
return false; // success
}// end of anonymous namespace
StringRef swift::ide::
getDescriptiveRefactoringKindName(RefactoringKind Kind) {
switch(Kind) {
case RefactoringKind::None:
llvm_unreachable("Should be a valid refactoring kind");
#define REFACTORING(KIND, NAME, ID) case RefactoringKind::KIND: return NAME;
#include "swift/IDE/RefactoringKinds.def"
llvm_unreachable("unhandled kind");
StringRef swift::ide::
getDescriptiveRenameUnavailableReason(RenameAvailableKind Kind) {
switch(Kind) {
case RenameAvailableKind::Available:
return "";
case RenameAvailableKind::Unavailable_system_symbol:
return "symbol from system module cannot be renamed";
case RenameAvailableKind::Unavailable_has_no_location:
return "symbol without a declaration location cannot be renamed";
case RenameAvailableKind::Unavailable_has_no_name:
return "cannot find the name of the symbol";
case RenameAvailableKind::Unavailable_has_no_accessibility:
return "cannot decide the accessibility of the symbol";
case RenameAvailableKind::Unavailable_decl_from_clang:
return "cannot rename a Clang symbol from its Swift reference";
llvm_unreachable("unhandled kind");
SourceLoc swift::ide::RangeConfig::getStart(SourceManager &SM) {
return SM.getLocForLineCol(BufferId, Line, Column);
SourceLoc swift::ide::RangeConfig::getEnd(SourceManager &SM) {
return getStart(SM).getAdvancedLoc(Length);
struct swift::ide::FindRenameRangesAnnotatingConsumer::Implementation {
std::unique_ptr<SourceEditConsumer> pRewriter;
Implementation(SourceManager &SM, unsigned BufferId, llvm::raw_ostream &OS)
: pRewriter(new SourceEditOutputConsumer(SM, BufferId, OS)) {}
static StringRef tag(RefactoringRangeKind Kind) {
switch (Kind) {
case RefactoringRangeKind::BaseName:
return "base";
case RefactoringRangeKind::KeywordBaseName:
return "keywordBase";
case RefactoringRangeKind::ParameterName:
return "param";
case RefactoringRangeKind::NoncollapsibleParameterName:
return "noncollapsibleparam";
case RefactoringRangeKind::DeclArgumentLabel:
return "arglabel";
case RefactoringRangeKind::CallArgumentLabel:
return "callarg";
case RefactoringRangeKind::CallArgumentColon:
return "callcolon";
case RefactoringRangeKind::CallArgumentCombined:
return "callcombo";
case RefactoringRangeKind::SelectorArgumentLabel:
return "sel";
llvm_unreachable("unhandled kind");
void accept(SourceManager &SM, const RenameRangeDetail &Range) {
std::string NewText;
llvm::raw_string_ostream OS(NewText);
StringRef Tag = tag(Range.RangeKind);
OS << "<" << Tag;
if (Range.Index.hasValue())
OS << " index=" << *Range.Index;
OS << ">" << Range.Range.str() << "</" << Tag << ">";
pRewriter->accept(SM, {Range.Range, OS.str(), {}});
FindRenameRangesAnnotatingConsumer(SourceManager &SM, unsigned BufferId,
llvm::raw_ostream &OS): Impl(*new Implementation(SM, BufferId, OS)) {}
swift::ide::FindRenameRangesAnnotatingConsumer::~FindRenameRangesAnnotatingConsumer() {
delete &Impl;
void swift::ide::FindRenameRangesAnnotatingConsumer::
accept(SourceManager &SM, RegionType RegionType,
ArrayRef<RenameRangeDetail> Ranges) {
if (RegionType == RegionType::Mismatch || RegionType == RegionType::Unmatched)
for (const auto &Range : Ranges) {
Impl.accept(SM, Range);
swift::ide::collectRenameAvailabilityInfo(const ValueDecl *VD,
Optional<RenameRefInfo> RefInfo,
std::vector<RenameAvailabiliyInfo> &Scratch) {
RenameAvailableKind AvailKind = RenameAvailableKind::Available;
if (getRelatedSystemDecl(VD)){
AvailKind = RenameAvailableKind::Unavailable_system_symbol;
} else if (VD->getClangDecl()) {
AvailKind = RenameAvailableKind::Unavailable_decl_from_clang;
} else if (VD->getStartLoc().isInvalid()) {
AvailKind = RenameAvailableKind::Unavailable_has_no_location;
} else if (!VD->hasName()) {
AvailKind = RenameAvailableKind::Unavailable_has_no_name;
if (isa<AbstractFunctionDecl>(VD)) {
// Disallow renaming accessors.
if (isa<AccessorDecl>(VD))
return Scratch;
// Disallow renaming deinit.
if (isa<DestructorDecl>(VD))
return Scratch;
// Disallow renaming init with no arguments.
if (auto CD = dyn_cast<ConstructorDecl>(VD)) {
if (!CD->getParameters()->size())
return Scratch;
if (RefInfo && !RefInfo->IsArgLabel) {
NameMatcher Matcher(*(RefInfo->SF));
auto Resolved = Matcher.resolve({RefInfo->Loc, /*ResolveArgs*/true});
if (Resolved.LabelRanges.empty())
return Scratch;
// Disallow renaming 'callAsFunction' method with no arguments.
if (auto FD = dyn_cast<FuncDecl>(VD)) {
// FIXME: syntactic rename can only decide by checking the spelling, not
// whether it's an instance method, so we do the same here for now.
if (FD->getBaseIdentifier() == FD->getASTContext().Id_callAsFunction) {
if (!FD->getParameters()->size())
return Scratch;
if (RefInfo && !RefInfo->IsArgLabel) {
NameMatcher Matcher(*(RefInfo->SF));
auto Resolved = Matcher.resolve({RefInfo->Loc, /*ResolveArgs*/true});
if (Resolved.LabelRanges.empty())
return Scratch;
// Always return local rename for parameters.
// FIXME: if the cursor is on the argument, we should return global rename.
if (isa<ParamDecl>(VD)) {
Scratch.emplace_back(RefactoringKind::LocalRename, AvailKind);
return Scratch;
// If the indexer considers VD a global symbol, then we apply global rename.
if (index::isLocalSymbol(VD))
Scratch.emplace_back(RefactoringKind::LocalRename, AvailKind);
Scratch.emplace_back(RefactoringKind::GlobalRename, AvailKind);
return llvm::makeArrayRef(Scratch);
ArrayRef<RefactoringKind> swift::ide::
collectAvailableRefactorings(SourceFile *SF,
ResolvedCursorInfo CursorInfo,
std::vector<RefactoringKind> &Scratch,
bool ExcludeRename) {
llvm::SmallVector<RefactoringKind, 2> AllKinds;
switch(CursorInfo.Kind) {
case CursorInfoKind::ModuleRef:
case CursorInfoKind::Invalid:
case CursorInfoKind::StmtStart:
case CursorInfoKind::ExprStart:
case CursorInfoKind::ValueRef: {
Optional<RenameRefInfo> RefInfo;
if (CursorInfo.IsRef)
RefInfo = {CursorInfo.SF, CursorInfo.Loc, CursorInfo.IsKeywordArgument};
auto RenameOp = getAvailableRenameForDecl(CursorInfo.ValueD, RefInfo);
if (RenameOp.hasValue() &&
RenameOp.getValue() == RefactoringKind::GlobalRename)
DiagnosticEngine DiagEngine(SF->getASTContext().SourceMgr);
if (RefactoringAction##KIND::isApplicable(CursorInfo, DiagEngine)) \
#include "swift/IDE/RefactoringKinds.def"
// Exclude renames.
for(auto Kind: AllKinds) {
switch (Kind) {
case RefactoringKind::LocalRename:
case RefactoringKind::GlobalRename:
if (ExcludeRename)
return llvm::makeArrayRef(Scratch);
ArrayRef<RefactoringKind> swift::ide::
collectAvailableRefactorings(SourceFile *SF, RangeConfig Range,
bool &RangeStartMayNeedRename,
std::vector<RefactoringKind> &Scratch,
llvm::ArrayRef<DiagnosticConsumer*> DiagConsumers) {
if (Range.Length == 0) {
return collectAvailableRefactoringsAtCursor(SF, Range.Line, Range.Column,
Scratch, DiagConsumers);
// Prepare the tool box.
ASTContext &Ctx = SF->getASTContext();
SourceManager &SM = Ctx.SourceMgr;
DiagnosticEngine DiagEngine(SM);
std::for_each(DiagConsumers.begin(), DiagConsumers.end(),
[&](DiagnosticConsumer *Con) { DiagEngine.addConsumer(*Con); });
ResolvedRangeInfo Result = evaluateOrDefault(SF->getASTContext().evaluator,
bool enableInternalRefactoring = getenv("SWIFT_ENABLE_INTERNAL_REFACTORING_ACTIONS");
if (RefactoringAction##KIND::isApplicable(Result, DiagEngine)) \
if (enableInternalRefactoring) \
#include "swift/IDE/RefactoringKinds.def"
RangeStartMayNeedRename = rangeStartMayNeedRename(Result);
return Scratch;
bool swift::ide::
refactorSwiftModule(ModuleDecl *M, RefactoringOptions Opts,
SourceEditConsumer &EditConsumer,
DiagnosticConsumer &DiagConsumer) {
assert(Opts.Kind != RefactoringKind::None && "should have a refactoring kind.");
// Use the default name if not specified.
if (Opts.PreferredName.empty()) {
Opts.PreferredName = getDefaultPreferredName(Opts.Kind).str();
switch (Opts.Kind) {
case RefactoringKind::KIND: { \
RefactoringAction##KIND Action(M, Opts, EditConsumer, DiagConsumer); \
if (RefactoringKind::KIND == RefactoringKind::LocalRename || \
Action.isApplicable()) \
return Action.performChange(); \
return true; \
#include "swift/IDE/RefactoringKinds.def"
case RefactoringKind::GlobalRename:
case RefactoringKind::FindGlobalRenameRanges:
case RefactoringKind::FindLocalRenameRanges:
llvm_unreachable("not a valid refactoring kind");
case RefactoringKind::None:
llvm_unreachable("should not enter here.");
llvm_unreachable("unhandled kind");
static std::vector<ResolvedLoc>
resolveRenameLocations(ArrayRef<RenameLoc> RenameLocs, SourceFile &SF,
DiagnosticEngine &Diags) {
SourceManager &SM = SF.getASTContext().SourceMgr;
unsigned BufferID = SF.getBufferID().getValue();
std::vector<UnresolvedLoc> UnresolvedLocs;
for (const RenameLoc &RenameLoc : RenameLocs) {
DeclNameViewer OldName(RenameLoc.OldName);
SourceLoc Location = SM.getLocForLineCol(BufferID, RenameLoc.Line,
if (!OldName.isValid()) {
Diags.diagnose(Location, diag::invalid_name, RenameLoc.OldName);
return {};
if (!RenameLoc.NewName.empty()) {
DeclNameViewer NewName(RenameLoc.NewName);
ArrayRef<StringRef> ParamNames = NewName.args();
bool newOperator = Lexer::isOperator(NewName.base());
bool NewNameIsValid = NewName.isValid() &&
(Lexer::isIdentifier(NewName.base()) || newOperator) &&
std::all_of(ParamNames.begin(), ParamNames.end(), [](StringRef Label) {
return Label.empty() || Lexer::isIdentifier(Label);
if (!NewNameIsValid) {
Diags.diagnose(Location, diag::invalid_name, RenameLoc.NewName);
return {};
if (NewName.partsCount() != OldName.partsCount()) {
Diags.diagnose(Location, diag::arity_mismatch, RenameLoc.NewName,
return {};
if (RenameLoc.Usage == NameUsage::Call && !RenameLoc.IsFunctionLike) {
Diags.diagnose(Location, diag::name_not_functionlike, RenameLoc.NewName);
return {};
bool isOperator = Lexer::isOperator(OldName.base());
(RenameLoc.Usage == NameUsage::Unknown ||
(RenameLoc.Usage == NameUsage::Call && !isOperator))
NameMatcher Resolver(SF);
return Resolver.resolve(UnresolvedLocs, SF.getAllTokens());
int swift::ide::syntacticRename(SourceFile *SF, ArrayRef<RenameLoc> RenameLocs,
SourceEditConsumer &EditConsumer,
DiagnosticConsumer &DiagConsumer) {
assert(SF && "null source file");
SourceManager &SM = SF->getASTContext().SourceMgr;
DiagnosticEngine DiagEngine(SM);
auto ResolvedLocs = resolveRenameLocations(RenameLocs, *SF, DiagEngine);
if (ResolvedLocs.size() != RenameLocs.size())
return true; // Already diagnosed.
size_t index = 0;
llvm::StringSet<> ReplaceTextContext;
for(const RenameLoc &Rename: RenameLocs) {
ResolvedLoc &Resolved = ResolvedLocs[index++];
TextReplacementsRenamer Renamer(SM, Rename.OldName, Rename.NewName,
RegionType Type = Renamer.addSyntacticRenameRanges(Resolved, Rename);
if (Type == RegionType::Mismatch) {
DiagEngine.diagnose(Resolved.Range.getStart(), diag::mismatched_rename,
EditConsumer.accept(SM, Type, None);
} else {
EditConsumer.accept(SM, Type, Renamer.getReplacements());
return false;
int swift::ide::findSyntacticRenameRanges(
SourceFile *SF, llvm::ArrayRef<RenameLoc> RenameLocs,
FindRenameRangesConsumer &RenameConsumer,
DiagnosticConsumer &DiagConsumer) {
assert(SF && "null source file");
SourceManager &SM = SF->getASTContext().SourceMgr;
DiagnosticEngine DiagEngine(SM);
auto ResolvedLocs = resolveRenameLocations(RenameLocs, *SF, DiagEngine);
if (ResolvedLocs.size() != RenameLocs.size())
return true; // Already diagnosed.
size_t index = 0;
for (const RenameLoc &Rename : RenameLocs) {
ResolvedLoc &Resolved = ResolvedLocs[index++];
RenameRangeDetailCollector Renamer(SM, Rename.OldName);
RegionType Type = Renamer.addSyntacticRenameRanges(Resolved, Rename);
if (Type == RegionType::Mismatch) {
DiagEngine.diagnose(Resolved.Range.getStart(), diag::mismatched_rename,
RenameConsumer.accept(SM, Type, None);
} else {
RenameConsumer.accept(SM, Type, Renamer.Ranges);
return false;
int swift::ide::findLocalRenameRanges(
SourceFile *SF, RangeConfig Range,
FindRenameRangesConsumer &RenameConsumer,
DiagnosticConsumer &DiagConsumer) {
assert(SF && "null source file");
SourceManager &SM = SF->getASTContext().SourceMgr;
DiagnosticEngine Diags(SM);
auto StartLoc = Lexer::getLocForStartOfToken(SM, Range.getStart(SM));
ResolvedCursorInfo CursorInfo =
CursorInfoRequest{CursorInfoOwner(SF, StartLoc)},
if (!CursorInfo.isValid() || !CursorInfo.ValueD) {
Diags.diagnose(StartLoc, diag::unresolved_location);
return true;
ValueDecl *VD = CursorInfo.CtorTyRef ? CursorInfo.CtorTyRef : CursorInfo.ValueD;
Optional<RenameRefInfo> RefInfo;
if (CursorInfo.IsRef)
RefInfo = {CursorInfo.SF, CursorInfo.Loc, CursorInfo.IsKeywordArgument};
llvm::SmallVector<DeclContext *, 8> Scopes;
analyzeRenameScope(VD, RefInfo, Diags, Scopes);
if (Scopes.empty())
return true;
RenameRangeCollector RangeCollector(VD, StringRef());
for (DeclContext *DC : Scopes)
indexDeclContext(DC, RangeCollector);
return findSyntacticRenameRanges(SF, RangeCollector.results(), RenameConsumer,