Revision ac4b68fbf45853ba4b9e327cb42f93f42a8fa252 authored by Ellie Shin on 17 March 2023, 04:14:20 UTC, committed by Ellie Shin on 17 March 2023, 04:14:20 UTC
1 parent f2c68fb
Refactoring.cpp
//===--- Refactoring.cpp ---------------------------------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
#include "swift/Refactoring/Refactoring.h"
#include "swift/AST/ASTContext.h"
#include "swift/AST/ASTPrinter.h"
#include "swift/AST/Decl.h"
#include "swift/AST/DiagnosticsRefactoring.h"
#include "swift/AST/Expr.h"
#include "swift/AST/ForeignAsyncConvention.h"
#include "swift/AST/GenericParamList.h"
#include "swift/AST/NameLookup.h"
#include "swift/AST/Pattern.h"
#include "swift/AST/ProtocolConformance.h"
#include "swift/AST/Stmt.h"
#include "swift/AST/Types.h"
#include "swift/AST/USRGeneration.h"
#include "swift/Basic/Edit.h"
#include "swift/Basic/StringExtras.h"
#include "swift/ClangImporter/ClangImporter.h"
#include "swift/Frontend/Frontend.h"
#include "swift/IDE/IDERequests.h"
#include "swift/Index/Index.h"
#include "swift/Parse/Lexer.h"
#include "swift/Sema/IDETypeChecking.h"
#include "swift/Subsystems.h"
#include "clang/Rewrite/Core/RewriteBuffer.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringSet.h"
using namespace swift;
using namespace swift::ide;
using namespace swift::index;
namespace {
class ContextFinder : public SourceEntityWalker {
SourceFile &SF;
ASTContext &Ctx;
SourceManager &SM;
SourceRange Target;
std::function<bool(ASTNode)> IsContext;
SmallVector<ASTNode, 4> AllContexts;
bool contains(ASTNode Enclosing) {
auto Result = SM.rangeContainsRespectingReplacedRanges(
Enclosing.getSourceRange(), Target);
if (Result && IsContext(Enclosing)) {
AllContexts.push_back(Enclosing);
}
return Result;
}
public:
ContextFinder(SourceFile &SF, ASTNode TargetNode,
std::function<bool(ASTNode)> IsContext =
[](ASTNode N) { return true; }) :
SF(SF), Ctx(SF.getASTContext()), SM(Ctx.SourceMgr),
Target(TargetNode.getSourceRange()), IsContext(IsContext) {}
ContextFinder(SourceFile &SF, SourceLoc TargetLoc,
std::function<bool(ASTNode)> IsContext =
[](ASTNode N) { return true; }) :
SF(SF), Ctx(SF.getASTContext()), SM(Ctx.SourceMgr),
Target(TargetLoc), IsContext(IsContext) {
assert(TargetLoc.isValid() && "Invalid loc to find");
}
// Only need expansions for the expands refactoring, but we
// skip nodes that don't contain the passed location anyway.
virtual MacroWalking getMacroWalkingBehavior() const override {
return MacroWalking::ArgumentsAndExpansion;
}
bool walkToDeclPre(Decl *D, CharSourceRange Range) override { return contains(D); }
bool walkToStmtPre(Stmt *S) override { return contains(S); }
bool walkToExprPre(Expr *E) override { return contains(E); }
void resolve() { walk(SF); }
ArrayRef<ASTNode> getContexts() const {
return llvm::makeArrayRef(AllContexts);
}
};
class Renamer {
protected:
const SourceManager &SM;
protected:
Renamer(const SourceManager &SM, StringRef OldName) : SM(SM), Old(OldName) {}
// Implementor's interface.
virtual void doRenameLabel(CharSourceRange Label,
RefactoringRangeKind RangeKind,
unsigned NameIndex) = 0;
virtual void doRenameBase(CharSourceRange Range,
RefactoringRangeKind RangeKind) = 0;
public:
const DeclNameViewer Old;
public:
virtual ~Renamer() {}
/// Adds a replacement to rename the given base name range
/// \return true if the given range does not match the old name
bool renameBase(CharSourceRange Range, RefactoringRangeKind RangeKind) {
assert(Range.isValid());
if (stripBackticks(Range).str() != Old.base())
return true;
doRenameBase(Range, RangeKind);
return false;
}
/// Adds replacements to rename the given label ranges
/// \return true if the label ranges do not match the old name
bool renameLabels(ArrayRef<CharSourceRange> LabelRanges,
Optional<unsigned> FirstTrailingLabel,
LabelRangeType RangeType, bool isCallSite) {
if (isCallSite)
return renameLabelsLenient(LabelRanges, FirstTrailingLabel, RangeType);
assert(!FirstTrailingLabel);
ArrayRef<StringRef> OldLabels = Old.args();
if (OldLabels.size() != LabelRanges.size())
return true;
size_t Index = 0;
for (const auto &LabelRange : LabelRanges) {
assert(LabelRange.isValid());
if (!labelRangeMatches(LabelRange, RangeType, OldLabels[Index]))
return true;
splitAndRenameLabel(LabelRange, RangeType, Index++);
}
return false;
}
bool isOperator() const { return Lexer::isOperator(Old.base()); }
private:
/// Returns the range of the (possibly escaped) identifier at the start of
/// \p Range and updates \p IsEscaped to indicate whether it's escaped or not.
CharSourceRange getLeadingIdentifierRange(CharSourceRange Range, bool &IsEscaped) {
assert(Range.isValid() && Range.getByteLength());
IsEscaped = Range.str().front() == '`';
SourceLoc Start = Range.getStart();
if (IsEscaped)
Start = Start.getAdvancedLoc(1);
return Lexer::getCharSourceRangeFromSourceRange(SM, Start);
}
CharSourceRange stripBackticks(CharSourceRange Range) {
StringRef Content = Range.str();
if (Content.size() < 3 || Content.front() != '`' || Content.back() != '`') {
return Range;
}
return CharSourceRange(Range.getStart().getAdvancedLoc(1),
Range.getByteLength() - 2);
}
void splitAndRenameLabel(CharSourceRange Range, LabelRangeType RangeType,
size_t NameIndex) {
switch (RangeType) {
case LabelRangeType::CallArg:
return splitAndRenameCallArg(Range, NameIndex);
case LabelRangeType::Param:
return splitAndRenameParamLabel(Range, NameIndex, /*IsCollapsible=*/true);
case LabelRangeType::NoncollapsibleParam:
return splitAndRenameParamLabel(Range, NameIndex, /*IsCollapsible=*/false);
case LabelRangeType::Selector:
return doRenameLabel(
Range, RefactoringRangeKind::SelectorArgumentLabel, NameIndex);
case LabelRangeType::None:
llvm_unreachable("expected a label range");
}
}
void splitAndRenameParamLabel(CharSourceRange Range, size_t NameIndex, bool IsCollapsible) {
// Split parameter range foo([a b]: Int) into decl argument label [a] and
// parameter name [b] or noncollapsible parameter name [b] if IsCollapsible
// is false (as for subscript decls). If we have only foo([a]: Int), then we
// add an empty range for the local name, or for the decl argument label if
// IsCollapsible is false.
StringRef Content = Range.str();
size_t ExternalNameEnd = Content.find_first_of(" \t\n\v\f\r/");
if (ExternalNameEnd == StringRef::npos) { // foo([a]: Int)
if (IsCollapsible) {
doRenameLabel(Range, RefactoringRangeKind::DeclArgumentLabel, NameIndex);
doRenameLabel(CharSourceRange{Range.getEnd(), 0},
RefactoringRangeKind::ParameterName, NameIndex);
} else {
doRenameLabel(CharSourceRange{Range.getStart(), 0},
RefactoringRangeKind::DeclArgumentLabel, NameIndex);
doRenameLabel(Range, RefactoringRangeKind::NoncollapsibleParameterName,
NameIndex);
}
} else { // foo([a b]: Int)
CharSourceRange Ext{Range.getStart(), unsigned(ExternalNameEnd)};
// Note: we consider the leading whitespace part of the parameter name
// if the parameter is collapsible, since if the parameter is collapsed
// into a matching argument label, we want to remove the whitespace too.
// FIXME: handle comments foo(a /*...*/b: Int).
size_t LocalNameStart = Content.find_last_of(" \t\n\v\f\r/");
assert(LocalNameStart != StringRef::npos);
if (!IsCollapsible)
++LocalNameStart;
auto LocalLoc = Range.getStart().getAdvancedLocOrInvalid(LocalNameStart);
CharSourceRange Local{LocalLoc, unsigned(Content.size() - LocalNameStart)};
doRenameLabel(Ext, RefactoringRangeKind::DeclArgumentLabel, NameIndex);
if (IsCollapsible) {
doRenameLabel(Local, RefactoringRangeKind::ParameterName, NameIndex);
} else {
doRenameLabel(Local, RefactoringRangeKind::NoncollapsibleParameterName, NameIndex);
}
}
}
void splitAndRenameCallArg(CharSourceRange Range, size_t NameIndex) {
// Split call argument foo([a: ]1) into argument name [a] and the remainder
// [: ].
StringRef Content = Range.str();
size_t Colon = Content.find(':'); // FIXME: leading whitespace?
if (Colon == StringRef::npos) {
assert(Content.empty());
doRenameLabel(Range, RefactoringRangeKind::CallArgumentCombined,
NameIndex);
return;
}
// Include any whitespace before the ':'.
assert(Colon == Content.substr(0, Colon).size());
Colon = Content.substr(0, Colon).rtrim().size();
CharSourceRange Arg{Range.getStart(), unsigned(Colon)};
doRenameLabel(Arg, RefactoringRangeKind::CallArgumentLabel, NameIndex);
auto ColonLoc = Range.getStart().getAdvancedLocOrInvalid(Colon);
assert(ColonLoc.isValid());
CharSourceRange Rest{ColonLoc, unsigned(Content.size() - Colon)};
doRenameLabel(Rest, RefactoringRangeKind::CallArgumentColon, NameIndex);
}
bool labelRangeMatches(CharSourceRange Range, LabelRangeType RangeType, StringRef Expected) {
if (Range.getByteLength()) {
bool IsEscaped = false;
CharSourceRange ExistingLabelRange = getLeadingIdentifierRange(Range, IsEscaped);
StringRef ExistingLabel = ExistingLabelRange.str();
bool IsSingleName = Range == ExistingLabelRange ||
(IsEscaped && Range.getByteLength() == ExistingLabel.size() + 2);
switch (RangeType) {
case LabelRangeType::NoncollapsibleParam:
if (IsSingleName && Expected.empty()) // subscript([x]: Int)
return true;
LLVM_FALLTHROUGH;
case LabelRangeType::CallArg:
case LabelRangeType::Param:
case LabelRangeType::Selector:
return ExistingLabel == (Expected.empty() ? "_" : Expected);
case LabelRangeType::None:
llvm_unreachable("Unhandled label range type");
}
}
return Expected.empty();
}
bool renameLabelsLenient(ArrayRef<CharSourceRange> LabelRanges,
Optional<unsigned> FirstTrailingLabel,
LabelRangeType RangeType) {
ArrayRef<StringRef> OldNames = Old.args();
// First, match trailing closure arguments in reverse
if (FirstTrailingLabel) {
auto TrailingLabels = LabelRanges.drop_front(*FirstTrailingLabel);
LabelRanges = LabelRanges.take_front(*FirstTrailingLabel);
for (auto LabelIndex: llvm::reverse(indices(TrailingLabels))) {
CharSourceRange Label = TrailingLabels[LabelIndex];
if (Label.getByteLength()) {
if (OldNames.empty())
return true;
while (!labelRangeMatches(Label, LabelRangeType::Selector,
OldNames.back())) {
if ((OldNames = OldNames.drop_back()).empty())
return true;
}
splitAndRenameLabel(Label, LabelRangeType::Selector,
OldNames.size() - 1);
OldNames = OldNames.drop_back();
continue;
}
// empty labelled trailing closure label
if (LabelIndex) {
if (OldNames.empty())
return true;
while (!OldNames.back().empty()) {
if ((OldNames = OldNames.drop_back()).empty())
return true;
}
splitAndRenameLabel(Label, LabelRangeType::Selector,
OldNames.size() - 1);
OldNames = OldNames.drop_back();
continue;
}
// unlabelled trailing closure label
OldNames = OldNames.drop_back();
continue;
}
}
// Next, match the non-trailing arguments.
size_t NameIndex = 0;
for (CharSourceRange Label : LabelRanges) {
// empty label
if (!Label.getByteLength()) {
// first name pos
if (!NameIndex) {
while (!OldNames[NameIndex].empty()) {
if (++NameIndex >= OldNames.size())
return true;
}
splitAndRenameLabel(Label, RangeType, NameIndex++);
continue;
}
// other name pos
if (NameIndex >= OldNames.size() || !OldNames[NameIndex].empty()) {
// FIXME: only allow one variadic param
continue; // allow for variadic
}
splitAndRenameLabel(Label, RangeType, NameIndex++);
continue;
}
// non-empty label
if (NameIndex >= OldNames.size())
return true;
while (!labelRangeMatches(Label, RangeType, OldNames[NameIndex])) {
if (++NameIndex >= OldNames.size())
return true;
};
splitAndRenameLabel(Label, RangeType, NameIndex++);
}
return false;
}
static RegionType getSyntacticRenameRegionType(const ResolvedLoc &Resolved) {
if (Resolved.Node.isNull())
return RegionType::Comment;
if (Expr *E = Resolved.Node.getAsExpr()) {
if (isa<StringLiteralExpr>(E))
return RegionType::String;
}
if (Resolved.IsInSelector)
return RegionType::Selector;
if (Resolved.IsActive)
return RegionType::ActiveCode;
return RegionType::InactiveCode;
}
public:
RegionType addSyntacticRenameRanges(const ResolvedLoc &Resolved,
const RenameLoc &Config) {
if (!Resolved.Range.isValid())
return RegionType::Unmatched;
auto RegionKind = getSyntacticRenameRegionType(Resolved);
// Don't include unknown references coming from active code; if we don't
// have a semantic NameUsage for them, then they're likely unrelated symbols
// that happen to have the same name.
if (RegionKind == RegionType::ActiveCode &&
Config.Usage == NameUsage::Unknown)
return RegionType::Unmatched;
assert(Config.Usage != NameUsage::Call || Config.IsFunctionLike);
// FIXME: handle escaped keyword names `init`
bool IsSubscript = Old.base() == "subscript" && Config.IsFunctionLike;
bool IsInit = Old.base() == "init" && Config.IsFunctionLike;
// FIXME: this should only be treated specially for instance methods.
bool IsCallAsFunction = Old.base() == "callAsFunction" &&
Config.IsFunctionLike;
bool IsSpecialBase = IsInit || IsSubscript || IsCallAsFunction;
// Filter out non-semantic special basename locations with no labels.
// We've already filtered out those in active code, so these are
// any appearance of just 'init', 'subscript', or 'callAsFunction' in
// strings, comments, and inactive code.
if (IsSpecialBase && (Config.Usage == NameUsage::Unknown &&
Resolved.LabelType == LabelRangeType::None))
return RegionType::Unmatched;
if (!Config.IsFunctionLike || !IsSpecialBase) {
if (renameBase(Resolved.Range, RefactoringRangeKind::BaseName))
return RegionType::Mismatch;
} else if (IsInit || IsCallAsFunction) {
if (renameBase(Resolved.Range, RefactoringRangeKind::KeywordBaseName)) {
// The base name doesn't need to match (but may) for calls, but
// it should for definitions and references.
if (Config.Usage == NameUsage::Definition ||
Config.Usage == NameUsage::Reference) {
return RegionType::Mismatch;
}
}
} else if (IsSubscript && Config.Usage == NameUsage::Definition) {
if (renameBase(Resolved.Range, RefactoringRangeKind::KeywordBaseName))
return RegionType::Mismatch;
}
bool HandleLabels = false;
if (Config.IsFunctionLike) {
switch (Config.Usage) {
case NameUsage::Call:
HandleLabels = !isOperator();
break;
case NameUsage::Definition:
HandleLabels = true;
break;
case NameUsage::Reference:
HandleLabels = Resolved.LabelType == LabelRangeType::Selector || IsSubscript;
break;
case NameUsage::Unknown:
HandleLabels = Resolved.LabelType != LabelRangeType::None;
break;
}
}
if (HandleLabels) {
bool isCallSite = Config.Usage != NameUsage::Definition &&
(Config.Usage != NameUsage::Reference || IsSubscript) &&
Resolved.LabelType == LabelRangeType::CallArg;
if (renameLabels(Resolved.LabelRanges, Resolved.FirstTrailingLabel,
Resolved.LabelType, isCallSite))
return Config.Usage == NameUsage::Unknown ?
RegionType::Unmatched : RegionType::Mismatch;
}
return RegionKind;
}
};
class RenameRangeDetailCollector : public Renamer {
void doRenameLabel(CharSourceRange Label, RefactoringRangeKind RangeKind,
unsigned NameIndex) override {
Ranges.push_back({Label, RangeKind, NameIndex});
}
void doRenameBase(CharSourceRange Range,
RefactoringRangeKind RangeKind) override {
Ranges.push_back({Range, RangeKind, None});
}
public:
RenameRangeDetailCollector(const SourceManager &SM, StringRef OldName)
: Renamer(SM, OldName) {}
std::vector<RenameRangeDetail> Ranges;
};
class TextReplacementsRenamer : public Renamer {
llvm::StringSet<> &ReplaceTextContext;
SmallVector<Replacement> Replacements;
public:
const DeclNameViewer New;
private:
StringRef registerText(StringRef Text) {
if (Text.empty())
return Text;
return ReplaceTextContext.insert(Text).first->getKey();
}
StringRef getCallArgLabelReplacement(StringRef OldLabelRange,
StringRef NewLabel) {
return NewLabel.empty() ? "" : NewLabel;
}
StringRef getCallArgColonReplacement(StringRef OldLabelRange,
StringRef NewLabel) {
// Expected OldLabelRange: foo( []3, a[: ]2, b[ : ]3 ...)
// FIXME: Preserve comments: foo([a/*:*/ : /*:*/ ]2, ...)
if (NewLabel.empty())
return "";
if (OldLabelRange.empty())
return ": ";
return registerText(OldLabelRange);
}
StringRef getCallArgCombinedReplacement(StringRef OldArgLabel,
StringRef NewArgLabel) {
// This case only happens when going from foo([]1) to foo([a: ]1).
assert(OldArgLabel.empty());
if (NewArgLabel.empty())
return "";
return registerText((Twine(NewArgLabel) + ": ").str());
}
StringRef getParamNameReplacement(StringRef OldParam, StringRef OldArgLabel,
StringRef NewArgLabel) {
// We don't want to get foo(a a: Int), so drop the parameter name if the
// argument label will match the original name.
// Note: the leading whitespace is part of the parameter range.
if (!NewArgLabel.empty() && OldParam.ltrim() == NewArgLabel)
return "";
// If we're renaming foo(x: Int) to foo(_:), then use the original argument
// label as the parameter name so as to not break references in the body.
if (NewArgLabel.empty() && !OldArgLabel.empty() && OldParam.empty())
return registerText((Twine(" ") + OldArgLabel).str());
return registerText(OldParam);
}
StringRef getDeclArgumentLabelReplacement(StringRef OldLabelRange,
StringRef NewArgLabel) {
// OldLabelRange is subscript([]a: Int), foo([a]: Int) or foo([a] b: Int)
if (NewArgLabel.empty())
return OldLabelRange.empty() ? "" : "_";
if (OldLabelRange.empty())
return registerText((Twine(NewArgLabel) + " ").str());
return registerText(NewArgLabel);
}
StringRef getReplacementText(StringRef LabelRange,
RefactoringRangeKind RangeKind,
StringRef OldLabel, StringRef NewLabel) {
switch (RangeKind) {
case RefactoringRangeKind::CallArgumentLabel:
return getCallArgLabelReplacement(LabelRange, NewLabel);
case RefactoringRangeKind::CallArgumentColon:
return getCallArgColonReplacement(LabelRange, NewLabel);
case RefactoringRangeKind::CallArgumentCombined:
return getCallArgCombinedReplacement(LabelRange, NewLabel);
case RefactoringRangeKind::ParameterName:
return getParamNameReplacement(LabelRange, OldLabel, NewLabel);
case RefactoringRangeKind::NoncollapsibleParameterName:
return LabelRange;
case RefactoringRangeKind::DeclArgumentLabel:
return getDeclArgumentLabelReplacement(LabelRange, NewLabel);
case RefactoringRangeKind::SelectorArgumentLabel:
return NewLabel.empty() ? "_" : registerText(NewLabel);
default:
llvm_unreachable("label range type is none but there are labels");
}
}
void addReplacement(CharSourceRange LabelRange,
RefactoringRangeKind RangeKind, StringRef OldLabel,
StringRef NewLabel) {
StringRef ExistingLabel = LabelRange.str();
StringRef Text =
getReplacementText(ExistingLabel, RangeKind, OldLabel, NewLabel);
if (Text != ExistingLabel)
Replacements.push_back({/*Path=*/{}, LabelRange, /*BufferName=*/{}, Text,
/*RegionsWorthNote=*/{}});
}
void doRenameLabel(CharSourceRange Label, RefactoringRangeKind RangeKind,
unsigned NameIndex) override {
addReplacement(Label, RangeKind, Old.args()[NameIndex],
New.args()[NameIndex]);
}
void doRenameBase(CharSourceRange Range, RefactoringRangeKind) override {
if (Old.base() != New.base())
Replacements.push_back({/*Path=*/{}, Range, /*BufferName=*/{},
registerText(New.base()),
/*RegionsWorthNote=*/{}});
}
public:
TextReplacementsRenamer(const SourceManager &SM, StringRef OldName,
StringRef NewName,
llvm::StringSet<> &ReplaceTextContext)
: Renamer(SM, OldName), ReplaceTextContext(ReplaceTextContext),
New(NewName) {
assert(Old.isValid() && New.isValid());
assert(Old.partsCount() == New.partsCount());
}
ArrayRef<Replacement> getReplacements() const { return Replacements; }
};
static const ValueDecl *getRelatedSystemDecl(const ValueDecl *VD) {
if (VD->getModuleContext()->isSystemModule())
return VD;
for (auto *Req : VD->getSatisfiedProtocolRequirements()) {
if (Req->getModuleContext()->isSystemModule())
return Req;
}
for (auto Over = VD->getOverriddenDecl(); Over;
Over = Over->getOverriddenDecl()) {
if (Over->getModuleContext()->isSystemModule())
return Over;
}
return nullptr;
}
struct RenameInfo {
ValueDecl *VD;
RenameAvailabilityInfo Availability;
};
/// Given a cursor, return the decl and its rename availability. \c None if
/// the cursor did not resolve to a decl or it resolved to a decl that we do
/// not allow renaming on.
static Optional<RenameInfo> getRenameInfo(ResolvedCursorInfoPtr cursorInfo) {
auto valueCursor = dyn_cast<ResolvedValueRefCursorInfo>(cursorInfo);
if (!valueCursor)
return None;
ValueDecl *VD = valueCursor->typeOrValue();
if (!VD)
return None;
Optional<RenameRefInfo> refInfo;
if (!valueCursor->getShorthandShadowedDecls().empty()) {
// Find the outermost decl for a shorthand if let/closure capture
VD = valueCursor->getShorthandShadowedDecls().back();
} else if (valueCursor->isRef()) {
refInfo = {valueCursor->getSourceFile(), valueCursor->getLoc(),
valueCursor->isKeywordArgument()};
}
Optional<RenameAvailabilityInfo> info = renameAvailabilityInfo(VD, refInfo);
if (!info)
return None;
return RenameInfo{VD, *info};
}
class RenameRangeCollector : public IndexDataConsumer {
public:
RenameRangeCollector(StringRef USR, StringRef newName)
: USR(USR), newName(newName) {}
RenameRangeCollector(const ValueDecl *D, StringRef newName)
: newName(newName) {
SmallString<64> SS;
llvm::raw_svector_ostream OS(SS);
printValueDeclUSR(D, OS);
USR = stringStorage.copyString(SS.str());
}
RenameRangeCollector(RenameRangeCollector &&collector) = default;
ArrayRef<RenameLoc> results() const { return locations; }
private:
bool indexLocals() override { return true; }
void failed(StringRef error) override {}
bool startDependency(StringRef name, StringRef path, bool isClangModule, bool isSystem) override {
return true;
}
bool finishDependency(bool isClangModule) override { return true; }
Action startSourceEntity(const IndexSymbol &symbol) override {
if (symbol.USR == USR) {
if (auto loc = indexSymbolToRenameLoc(symbol, newName)) {
// Inside capture lists like `{ [test] in }`, 'test' refers to both the
// newly declared, captured variable and the referenced variable it is
// initialized from. Make sure to only rename it once.
auto existingLoc = llvm::find_if(locations, [&](RenameLoc searchLoc) {
return searchLoc.Line == loc->Line && searchLoc.Column == loc->Column;
});
if (existingLoc == locations.end()) {
locations.push_back(std::move(*loc));
} else {
assert(existingLoc->OldName == loc->OldName &&
existingLoc->NewName == loc->NewName &&
existingLoc->IsFunctionLike == loc->IsFunctionLike &&
existingLoc->IsNonProtocolType == loc->IsNonProtocolType &&
"Asked to do a different rename for the same location?");
}
}
}
return IndexDataConsumer::Continue;
}
bool finishSourceEntity(SymbolInfo symInfo, SymbolRoleSet roles) override {
return true;
}
Optional<RenameLoc> indexSymbolToRenameLoc(const index::IndexSymbol &symbol,
StringRef NewName);
private:
StringRef USR;
StringRef newName;
StringScratchSpace stringStorage;
std::vector<RenameLoc> locations;
};
Optional<RenameLoc>
RenameRangeCollector::indexSymbolToRenameLoc(const index::IndexSymbol &symbol,
StringRef newName) {
if (symbol.roles & (unsigned)index::SymbolRole::Implicit) {
return None;
}
NameUsage usage = NameUsage::Unknown;
if (symbol.roles & (unsigned)index::SymbolRole::Call) {
usage = NameUsage::Call;
} else if (symbol.roles & (unsigned)index::SymbolRole::Definition) {
usage = NameUsage::Definition;
} else if (symbol.roles & (unsigned)index::SymbolRole::Reference) {
usage = NameUsage::Reference;
} else {
llvm_unreachable("unexpected role");
}
bool isFunctionLike = false;
bool isNonProtocolType = false;
switch (symbol.symInfo.Kind) {
case index::SymbolKind::EnumConstant:
case index::SymbolKind::Function:
case index::SymbolKind::Constructor:
case index::SymbolKind::ConversionFunction:
case index::SymbolKind::InstanceMethod:
case index::SymbolKind::ClassMethod:
case index::SymbolKind::StaticMethod:
isFunctionLike = true;
break;
case index::SymbolKind::Class:
case index::SymbolKind::Enum:
case index::SymbolKind::Struct:
isNonProtocolType = true;
break;
default:
break;
}
StringRef oldName = stringStorage.copyString(symbol.name);
return RenameLoc{symbol.line, symbol.column, usage, oldName, newName,
isFunctionLike, isNonProtocolType};
}
/// Get the source file that corresponds to the given buffer.
SourceFile *getContainingFile(ModuleDecl *M, RangeConfig Range) {
auto &SM = M->getASTContext().SourceMgr;
// TODO: We should add an ID -> SourceFile mapping.
return M->getSourceFileContainingLocation(
SM.getRangeForBuffer(Range.BufferID).getStart());
}
class RefactoringAction {
protected:
ModuleDecl *MD;
SourceFile *TheFile;
SourceEditConsumer &EditConsumer;
ASTContext &Ctx;
SourceManager &SM;
DiagnosticEngine DiagEngine;
SourceLoc StartLoc;
StringRef PreferredName;
public:
RefactoringAction(ModuleDecl *MD, RefactoringOptions &Opts,
SourceEditConsumer &EditConsumer,
DiagnosticConsumer &DiagConsumer);
virtual ~RefactoringAction() = default;
virtual bool performChange() = 0;
};
RefactoringAction::
RefactoringAction(ModuleDecl *MD, RefactoringOptions &Opts,
SourceEditConsumer &EditConsumer,
DiagnosticConsumer &DiagConsumer): MD(MD),
TheFile(getContainingFile(MD, Opts.Range)),
EditConsumer(EditConsumer), Ctx(MD->getASTContext()),
SM(MD->getASTContext().SourceMgr), DiagEngine(SM),
StartLoc(Lexer::getLocForStartOfToken(SM, Opts.Range.getStart(SM))),
PreferredName(Opts.PreferredName) {
DiagEngine.addConsumer(DiagConsumer);
}
/// Different from RangeBasedRefactoringAction, TokenBasedRefactoringAction takes
/// the input of a given token, e.g., a name or an "if" key word. Contextual
/// refactoring kinds can suggest applicable refactorings on that token, e.g.
/// rename or reverse if statement.
class TokenBasedRefactoringAction : public RefactoringAction {
protected:
ResolvedCursorInfoPtr CursorInfo;
public:
TokenBasedRefactoringAction(ModuleDecl *MD, RefactoringOptions &Opts,
SourceEditConsumer &EditConsumer,
DiagnosticConsumer &DiagConsumer) :
RefactoringAction(MD, Opts, EditConsumer, DiagConsumer) {
// Resolve the sema token and save it for later use.
CursorInfo =
evaluateOrDefault(TheFile->getASTContext().evaluator,
CursorInfoRequest{CursorInfoOwner(TheFile, StartLoc)},
new ResolvedCursorInfo());
}
};
#define CURSOR_REFACTORING(KIND, NAME, ID) \
class RefactoringAction##KIND : public TokenBasedRefactoringAction { \
public: \
RefactoringAction##KIND(ModuleDecl *MD, RefactoringOptions &Opts, \
SourceEditConsumer &EditConsumer, \
DiagnosticConsumer &DiagConsumer) \
: TokenBasedRefactoringAction(MD, Opts, EditConsumer, DiagConsumer) {} \
bool performChange() override; \
static bool isApplicable(ResolvedCursorInfoPtr Info, \
DiagnosticEngine &Diag); \
bool isApplicable() { \
return RefactoringAction##KIND::isApplicable(CursorInfo, DiagEngine); \
} \
};
#include "swift/Refactoring/RefactoringKinds.def"
class RangeBasedRefactoringAction : public RefactoringAction {
protected:
ResolvedRangeInfo RangeInfo;
public:
RangeBasedRefactoringAction(ModuleDecl *MD, RefactoringOptions &Opts,
SourceEditConsumer &EditConsumer,
DiagnosticConsumer &DiagConsumer) :
RefactoringAction(MD, Opts, EditConsumer, DiagConsumer),
RangeInfo(evaluateOrDefault(MD->getASTContext().evaluator,
RangeInfoRequest(RangeInfoOwner(TheFile, Opts.Range.getStart(SM), Opts.Range.getEnd(SM))),
ResolvedRangeInfo())) {}
};
#define RANGE_REFACTORING(KIND, NAME, ID) \
class RefactoringAction##KIND: public RangeBasedRefactoringAction { \
public: \
RefactoringAction##KIND(ModuleDecl *MD, RefactoringOptions &Opts, \
SourceEditConsumer &EditConsumer, \
DiagnosticConsumer &DiagConsumer) : \
RangeBasedRefactoringAction(MD, Opts, EditConsumer, DiagConsumer) {} \
bool performChange() override; \
static bool isApplicable(const ResolvedRangeInfo &Info, \
DiagnosticEngine &Diag); \
bool isApplicable() { \
return RefactoringAction##KIND::isApplicable(RangeInfo, DiagEngine) ; \
} \
};
#include "swift/Refactoring/RefactoringKinds.def"
bool RefactoringActionLocalRename::isApplicable(
ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &Diag) {
Optional<RenameInfo> Info = getRenameInfo(CursorInfo);
return Info &&
Info->Availability.AvailableKind == RenameAvailableKind::Available &&
Info->Availability.Kind == RefactoringKind::LocalRename;
}
static void analyzeRenameScope(ValueDecl *VD,
SmallVectorImpl<DeclContext *> &Scopes) {
auto *Scope = VD->getDeclContext();
// There may be sibling decls that the renamed symbol is visible from.
switch (Scope->getContextKind()) {
case DeclContextKind::GenericTypeDecl:
case DeclContextKind::ExtensionDecl:
case DeclContextKind::TopLevelCodeDecl:
case DeclContextKind::SubscriptDecl:
case DeclContextKind::EnumElementDecl:
case DeclContextKind::AbstractFunctionDecl:
Scope = Scope->getParent();
break;
case DeclContextKind::AbstractClosureExpr:
case DeclContextKind::Initializer:
case DeclContextKind::SerializedLocal:
case DeclContextKind::Package:
case DeclContextKind::Module:
case DeclContextKind::FileUnit:
case DeclContextKind::MacroDecl:
break;
}
Scopes.push_back(Scope);
}
static Optional<RenameRangeCollector> localRenames(SourceFile *SF,
SourceLoc startLoc,
StringRef preferredName,
DiagnosticEngine &diags) {
auto cursorInfo =
evaluateOrDefault(SF->getASTContext().evaluator,
CursorInfoRequest{CursorInfoOwner(SF, startLoc)},
new ResolvedCursorInfo());
Optional<RenameInfo> info = getRenameInfo(cursorInfo);
if (!info) {
diags.diagnose(startLoc, diag::unresolved_location);
return None;
}
switch (info->Availability.AvailableKind) {
case RenameAvailableKind::Available:
break;
case RenameAvailableKind::Unavailable_system_symbol:
diags.diagnose(startLoc, diag::decl_is_system_symbol, info->VD->getName());
return None;
case RenameAvailableKind::Unavailable_has_no_location:
diags.diagnose(startLoc, diag::value_decl_no_loc, info->VD->getName());
return None;
case RenameAvailableKind::Unavailable_has_no_name:
diags.diagnose(startLoc, diag::decl_has_no_name);
return None;
case RenameAvailableKind::Unavailable_has_no_accessibility:
diags.diagnose(startLoc, diag::decl_no_accessibility);
return None;
case RenameAvailableKind::Unavailable_decl_from_clang:
diags.diagnose(startLoc, diag::decl_from_clang);
return None;
}
SmallVector<DeclContext *, 8> scopes;
analyzeRenameScope(info->VD, scopes);
if (scopes.empty())
return None;
RenameRangeCollector rangeCollector(info->VD, preferredName);
for (DeclContext *DC : scopes)
indexDeclContext(DC, rangeCollector);
return rangeCollector;
}
bool RefactoringActionLocalRename::performChange() {
if (StartLoc.isInvalid()) {
DiagEngine.diagnose(SourceLoc(), diag::invalid_location);
return true;
}
if (!DeclNameViewer(PreferredName).isValid()) {
DiagEngine.diagnose(SourceLoc(), diag::invalid_name, PreferredName);
return true;
}
if (!TheFile) {
DiagEngine.diagnose(StartLoc, diag::location_module_mismatch,
MD->getNameStr());
return true;
}
Optional<RenameRangeCollector> rangeCollector =
localRenames(TheFile, StartLoc, PreferredName, DiagEngine);
if (!rangeCollector)
return true;
auto consumers = DiagEngine.takeConsumers();
assert(consumers.size() == 1);
return syntacticRename(TheFile, rangeCollector->results(), EditConsumer,
*consumers[0]);
}
StringRef getDefaultPreferredName(RefactoringKind Kind) {
switch(Kind) {
case RefactoringKind::None:
llvm_unreachable("Should be a valid refactoring kind");
case RefactoringKind::GlobalRename:
case RefactoringKind::LocalRename:
return "newName";
case RefactoringKind::ExtractExpr:
case RefactoringKind::ExtractRepeatedExpr:
return "extractedExpr";
case RefactoringKind::ExtractFunction:
return "extractedFunc";
default:
return "";
}
}
enum class CannotExtractReason {
Literal,
VoidType,
};
class ExtractCheckResult {
bool KnownFailure;
SmallVector<CannotExtractReason, 2> AllReasons;
public:
ExtractCheckResult(): KnownFailure(true) {}
ExtractCheckResult(ArrayRef<CannotExtractReason> AllReasons):
KnownFailure(false), AllReasons(AllReasons.begin(), AllReasons.end()) {}
bool success() { return success({}); }
bool success(ArrayRef<CannotExtractReason> ExpectedReasons) {
if (KnownFailure)
return false;
bool Result = true;
// Check if any reasons aren't covered by the list of expected reasons
// provided by the client.
for (auto R: AllReasons) {
Result &= llvm::is_contained(ExpectedReasons, R);
}
return Result;
}
};
/// Check whether a given range can be extracted.
/// Return true on successful condition checking,.
/// Return false on failed conditions.
ExtractCheckResult checkExtractConditions(const ResolvedRangeInfo &RangeInfo,
DiagnosticEngine &DiagEngine) {
SmallVector<CannotExtractReason, 2> AllReasons;
// If any declared declaration is referred out of the given range, return false.
auto Declared = RangeInfo.DeclaredDecls;
auto It = std::find_if(Declared.begin(), Declared.end(),
[](DeclaredDecl DD) { return DD.ReferredAfterRange; });
if (It != Declared.end()) {
DiagEngine.diagnose(It->VD->getLoc(),
diag::value_decl_referenced_out_of_range,
It->VD->getName());
return ExtractCheckResult();
}
// We cannot extract a range with multi entry points.
if (!RangeInfo.HasSingleEntry) {
DiagEngine.diagnose(SourceLoc(), diag::multi_entry_range);
return ExtractCheckResult();
}
// We cannot extract code that is not sure to exit or not.
if (RangeInfo.exit() == ExitState::Unsure) {
return ExtractCheckResult();
}
// We cannot extract expressions of l-value type.
if (auto Ty = RangeInfo.getType()) {
if (Ty->hasLValueType() || Ty->is<InOutType>())
return ExtractCheckResult();
// Disallow extracting error type expressions/statements
// FIXME: diagnose what happened?
if (Ty->hasError())
return ExtractCheckResult();
if (Ty->isVoid()) {
AllReasons.emplace_back(CannotExtractReason::VoidType);
}
}
// We cannot extract a range with orphaned loop keyword.
switch (RangeInfo.Orphan) {
case swift::ide::OrphanKind::Continue:
DiagEngine.diagnose(SourceLoc(), diag::orphan_loop_keyword, "continue");
return ExtractCheckResult();
case swift::ide::OrphanKind::Break:
DiagEngine.diagnose(SourceLoc(), diag::orphan_loop_keyword, "break");
return ExtractCheckResult();
case swift::ide::OrphanKind::None:
break;
}
// Guard statement can not be extracted.
if (llvm::any_of(RangeInfo.ContainedNodes,
[](ASTNode N) { return N.isStmt(StmtKind::Guard); })) {
return ExtractCheckResult();
}
// Disallow extracting certain kinds of statements.
if (RangeInfo.Kind == RangeKind::SingleStatement) {
Stmt *S = RangeInfo.ContainedNodes[0].get<Stmt *>();
// These aren't independent statement.
if (isa<BraceStmt>(S) || isa<CaseStmt>(S))
return ExtractCheckResult();
}
// Disallow extracting literals.
if (RangeInfo.Kind == RangeKind::SingleExpression) {
Expr *E = RangeInfo.ContainedNodes[0].get<Expr*>();
// Until implementing the performChange() part of extracting trailing
// closures, we disable them for now.
if (isa<AbstractClosureExpr>(E))
return ExtractCheckResult();
if (isa<LiteralExpr>(E))
AllReasons.emplace_back(CannotExtractReason::Literal);
}
switch (RangeInfo.RangeContext->getContextKind()) {
case swift::DeclContextKind::Initializer:
case swift::DeclContextKind::SubscriptDecl:
case swift::DeclContextKind::EnumElementDecl:
case swift::DeclContextKind::AbstractFunctionDecl:
case swift::DeclContextKind::AbstractClosureExpr:
case swift::DeclContextKind::TopLevelCodeDecl:
break;
case swift::DeclContextKind::SerializedLocal:
case swift::DeclContextKind::Package:
case swift::DeclContextKind::Module:
case swift::DeclContextKind::FileUnit:
case swift::DeclContextKind::GenericTypeDecl:
case swift::DeclContextKind::ExtensionDecl:
case swift::DeclContextKind::MacroDecl:
return ExtractCheckResult();
}
return ExtractCheckResult(AllReasons);
}
bool RefactoringActionExtractFunction::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
switch (Info.Kind) {
case RangeKind::PartOfExpression:
case RangeKind::SingleDecl:
case RangeKind::MultiTypeMemberDecl:
case RangeKind::Invalid:
return false;
case RangeKind::SingleExpression:
case RangeKind::SingleStatement:
case RangeKind::MultiStatement: {
return checkExtractConditions(Info, Diag).
success({CannotExtractReason::VoidType});
}
}
llvm_unreachable("unhandled kind");
}
static StringRef correctNameInternal(ASTContext &Ctx, StringRef Name,
ArrayRef<ValueDecl*> AllVisibles) {
// If we find the collision.
bool FoundCollision = false;
// The suffixes we cannot use by appending to the original given name.
llvm::StringSet<> UsedSuffixes;
for (auto VD : AllVisibles) {
StringRef S = VD->getBaseName().userFacingName();
if (!S.startswith(Name))
continue;
StringRef Suffix = S.substr(Name.size());
if (Suffix.empty())
FoundCollision = true;
else
UsedSuffixes.insert(Suffix);
}
if (!FoundCollision)
return Name;
// Find the first suffix we can use.
std::string SuffixToUse;
for (unsigned I = 1; ; I ++) {
SuffixToUse = std::to_string(I);
if (UsedSuffixes.count(SuffixToUse) == 0)
break;
}
return Ctx.getIdentifier((llvm::Twine(Name) + SuffixToUse).str()).str();
}
static StringRef correctNewDeclName(DeclContext *DC, StringRef Name) {
// Collect all visible decls in the decl context.
llvm::SmallVector<ValueDecl*, 16> AllVisibles;
VectorDeclConsumer Consumer(AllVisibles);
ASTContext &Ctx = DC->getASTContext();
lookupVisibleDecls(Consumer, DC, true);
return correctNameInternal(Ctx, Name, AllVisibles);
}
static Type sanitizeType(Type Ty) {
// Transform lvalue type to inout type so that we can print it properly.
return Ty.transform([](Type Ty) {
if (Ty->is<LValueType>()) {
return Type(InOutType::get(Ty->getRValueType()->getCanonicalType()));
}
return Ty;
});
}
static SourceLoc
getNewFuncInsertLoc(DeclContext *DC, DeclContext*& InsertToContext) {
if (auto D = DC->getInnermostDeclarationDeclContext()) {
// If extracting from a getter/setter, we should skip both the immediate
// getter/setter function and the individual var decl. The pattern binding
// decl is the position before which we should insert the newly extracted
// function.
if (auto *FD = dyn_cast<AccessorDecl>(D)) {
ValueDecl *SD = FD->getStorage();
switch (SD->getKind()) {
case DeclKind::Var:
if (auto *PBD = cast<VarDecl>(SD)->getParentPatternBinding())
D = PBD;
break;
case DeclKind::Subscript:
D = SD;
break;
default:
break;
}
}
auto Result = D->getStartLoc();
assert(Result.isValid());
// The insert loc should be before every decl attributes.
for (auto Attr : D->getAttrs()) {
auto Loc = Attr->getRangeWithAt().Start;
if (Loc.isValid() &&
Loc.getOpaquePointerValue() < Result.getOpaquePointerValue())
Result = Loc;
}
// The insert loc should be before the doc comments associated with this decl.
if (!D->getRawComment().Comments.empty()) {
auto Loc = D->getRawComment().Comments.front().Range.getStart();
if (Loc.isValid() &&
Loc.getOpaquePointerValue() < Result.getOpaquePointerValue()) {
Result = Loc;
}
}
InsertToContext = D->getDeclContext();
return Result;
}
return SourceLoc();
}
static std::vector<NoteRegion>
getNotableRegions(StringRef SourceText, unsigned NameOffset, StringRef Name,
bool IsFunctionLike = false, bool IsNonProtocolType = false) {
auto InputBuffer = llvm::MemoryBuffer::getMemBufferCopy(SourceText,"<extract>");
CompilerInvocation Invocation{};
Invocation.getFrontendOptions().InputsAndOutputs.addInput(
InputFile("<extract>", true, InputBuffer.get(), file_types::TY_Swift));
Invocation.getFrontendOptions().ModuleName = "extract";
Invocation.getLangOptions().DisablePoundIfEvaluation = true;
auto Instance = std::make_unique<swift::CompilerInstance>();
std::string InstanceSetupError;
if (Instance->setup(Invocation, InstanceSetupError))
llvm_unreachable(InstanceSetupError.c_str());
unsigned BufferId = Instance->getPrimarySourceFile()->getBufferID().value();
SourceManager &SM = Instance->getSourceMgr();
SourceLoc NameLoc = SM.getLocForOffset(BufferId, NameOffset);
auto LineAndCol = SM.getLineAndColumnInBuffer(NameLoc);
UnresolvedLoc UnresoledName{NameLoc, true};
NameMatcher Matcher(*Instance->getPrimarySourceFile());
auto Resolved = Matcher.resolve(llvm::makeArrayRef(UnresoledName), None);
assert(!Resolved.empty() && "Failed to resolve generated func name loc");
RenameLoc RenameConfig = {
LineAndCol.first, LineAndCol.second,
NameUsage::Definition, /*OldName=*/Name, /*NewName=*/"",
IsFunctionLike, IsNonProtocolType
};
RenameRangeDetailCollector Renamer(SM, Name);
Renamer.addSyntacticRenameRanges(Resolved.back(), RenameConfig);
auto Ranges = Renamer.Ranges;
std::vector<NoteRegion> NoteRegions(Renamer.Ranges.size());
llvm::transform(
Ranges, NoteRegions.begin(),
[&SM](RenameRangeDetail &Detail) -> NoteRegion {
auto Start = SM.getLineAndColumnInBuffer(Detail.Range.getStart());
auto End = SM.getLineAndColumnInBuffer(Detail.Range.getEnd());
return {Detail.RangeKind, Start.first, Start.second,
End.first, End.second, Detail.Index};
});
return NoteRegions;
}
bool RefactoringActionExtractFunction::performChange() {
// Check if the new name is ok.
if (!Lexer::isIdentifier(PreferredName)) {
DiagEngine.diagnose(SourceLoc(), diag::invalid_name, PreferredName);
return true;
}
DeclContext *DC = RangeInfo.RangeContext;
DeclContext *InsertToDC = nullptr;
SourceLoc InsertLoc = getNewFuncInsertLoc(DC, InsertToDC);
// Complain about no inserting position.
if (InsertLoc.isInvalid()) {
DiagEngine.diagnose(SourceLoc(), diag::no_insert_position);
return true;
}
// Correct the given name if collision happens.
PreferredName = correctNewDeclName(InsertToDC, PreferredName);
// Collect the paramters to pass down to the new function.
std::vector<ReferencedDecl> Parameters;
for (auto &RD: RangeInfo.ReferencedDecls) {
// If the referenced decl is declared elsewhere, no need to pass as parameter
if (RD.VD->getDeclContext() != DC)
continue;
// We don't need to pass down implicitly declared variables, e.g. error in
// a catch block.
if (RD.VD->isImplicit()) {
SourceLoc Loc = RD.VD->getStartLoc();
if (Loc.isValid() &&
SM.isBeforeInBuffer(RangeInfo.ContentRange.getStart(), Loc) &&
SM.isBeforeInBuffer(Loc, RangeInfo.ContentRange.getEnd()))
continue;
}
// If the referenced decl is declared inside the range, no need to pass
// as parameter.
if (RangeInfo.DeclaredDecls.end() !=
std::find_if(RangeInfo.DeclaredDecls.begin(), RangeInfo.DeclaredDecls.end(),
[RD](DeclaredDecl DD) { return RD.VD == DD.VD; }))
continue;
// We don't need to pass down self.
if (auto PD = dyn_cast<ParamDecl>(RD.VD)) {
if (PD->isSelfParameter()) {
continue;
}
}
Parameters.emplace_back(RD.VD, sanitizeType(RD.Ty));
}
SmallString<64> Buffer;
unsigned FuncBegin = Buffer.size();
unsigned FuncNameOffset;
{
llvm::raw_svector_ostream OS(Buffer);
if (!InsertToDC->isLocalContext()) {
// Default to be file private.
OS << tok::kw_fileprivate << " ";
}
// Inherit static if the containing function is.
if (DC->getContextKind() == DeclContextKind::AbstractFunctionDecl) {
if (auto FD = dyn_cast<FuncDecl>(static_cast<AbstractFunctionDecl*>(DC))) {
if (FD->isStatic()) {
OS << tok::kw_static << " ";
}
}
}
OS << tok::kw_func << " ";
FuncNameOffset = Buffer.size() - FuncBegin;
OS << PreferredName;
OS << "(";
for (auto &RD : Parameters) {
OS << "_ " << RD.VD->getBaseName().userFacingName() << ": ";
RD.Ty->reconstituteSugar(/*Recursive*/true)->print(OS);
if (&RD != &Parameters.back())
OS << ", ";
}
OS << ")";
if (RangeInfo.UnhandledEffects.contains(EffectKind::Async))
OS << " async";
if (RangeInfo.UnhandledEffects.contains(EffectKind::Throws))
OS << " " << tok::kw_throws;
bool InsertedReturnType = false;
if (auto Ty = RangeInfo.getType()) {
// If the type of the range is not void, specify the return type.
if (!Ty->isVoid()) {
OS << " " << tok::arrow << " ";
sanitizeType(Ty)->reconstituteSugar(/*Recursive*/true)->print(OS);
InsertedReturnType = true;
}
}
OS << " {\n";
// Add "return" if the extracted entity is an expression.
if (RangeInfo.Kind == RangeKind::SingleExpression && InsertedReturnType)
OS << tok::kw_return << " ";
OS << RangeInfo.ContentRange.str() << "\n}\n\n";
}
unsigned FuncEnd = Buffer.size();
unsigned ReplaceBegin = Buffer.size();
unsigned CallNameOffset;
{
llvm::raw_svector_ostream OS(Buffer);
if (RangeInfo.exit() == ExitState::Positive)
OS << tok::kw_return <<" ";
if (RangeInfo.UnhandledEffects.contains(EffectKind::Throws))
OS << tok::kw_try << " ";
if (RangeInfo.UnhandledEffects.contains(EffectKind::Async))
OS << "await ";
CallNameOffset = Buffer.size() - ReplaceBegin;
OS << PreferredName << "(";
for (auto &RD : Parameters) {
// Inout argument needs "&".
if (RD.Ty->is<InOutType>())
OS << "&";
OS << RD.VD->getBaseName().userFacingName();
if (&RD != &Parameters.back())
OS << ", ";
}
OS << ")";
}
unsigned ReplaceEnd = Buffer.size();
std::string ExtractedFuncName = PreferredName.str() + "(";
for (size_t i = 0; i < Parameters.size(); ++i) {
ExtractedFuncName += "_:";
}
ExtractedFuncName += ")";
StringRef DeclStr(Buffer.begin() + FuncBegin, FuncEnd - FuncBegin);
auto NotableFuncRegions = getNotableRegions(DeclStr, FuncNameOffset,
ExtractedFuncName,
/*IsFunctionLike=*/true);
StringRef CallStr(Buffer.begin() + ReplaceBegin, ReplaceEnd - ReplaceBegin);
auto NotableCallRegions = getNotableRegions(CallStr, CallNameOffset,
ExtractedFuncName,
/*IsFunctionLike=*/true);
// Insert the new function's declaration.
EditConsumer.accept(SM, InsertLoc, DeclStr, NotableFuncRegions);
// Replace the code to extract with the function call.
EditConsumer.accept(SM, RangeInfo.ContentRange, CallStr, NotableCallRegions);
return false;
}
class RefactoringActionExtractExprBase {
SourceFile *TheFile;
ResolvedRangeInfo RangeInfo;
DiagnosticEngine &DiagEngine;
const bool ExtractRepeated;
StringRef PreferredName;
SourceEditConsumer &EditConsumer;
ASTContext &Ctx;
SourceManager &SM;
public:
RefactoringActionExtractExprBase(SourceFile *TheFile,
ResolvedRangeInfo RangeInfo,
DiagnosticEngine &DiagEngine,
bool ExtractRepeated,
StringRef PreferredName,
SourceEditConsumer &EditConsumer) :
TheFile(TheFile), RangeInfo(RangeInfo), DiagEngine(DiagEngine),
ExtractRepeated(ExtractRepeated), PreferredName(PreferredName),
EditConsumer(EditConsumer), Ctx(TheFile->getASTContext()),
SM(Ctx.SourceMgr){}
bool performChange();
};
/// This is to ensure all decl references in two expressions are identical.
struct ReferenceCollector: public SourceEntityWalker {
SmallVector<ValueDecl*, 4> References;
ReferenceCollector(Expr *E) { walk(E); }
bool visitDeclReference(ValueDecl *D, CharSourceRange Range,
TypeDecl *CtorTyRef, ExtensionDecl *ExtTyRef,
Type T, ReferenceMetaData Data) override {
References.emplace_back(D);
return true;
}
bool operator==(const ReferenceCollector &Other) const {
if (References.size() != Other.References.size())
return false;
return std::equal(References.begin(), References.end(),
Other.References.begin());
}
};
struct SimilarExprCollector: public SourceEntityWalker {
SourceManager &SM;
/// The expression under selection.
Expr *SelectedExpr;
ArrayRef<Token> AllTokens;
llvm::SetVector<Expr*> &Bucket;
/// The tokens included in the expression under selection.
ArrayRef<Token> SelectedTokens;
/// The referenced decls in the expression under selection.
ReferenceCollector SelectedReferences;
bool compareTokenContent(ArrayRef<Token> Left, ArrayRef<Token> Right) {
if (Left.size() != Right.size())
return false;
return std::equal(Left.begin(), Left.end(), Right.begin(),
[](const Token &L, const Token& R) {
return L.getText() == R.getText();
});
}
/// Find all tokens included by an expression.
ArrayRef<Token> getExprSlice(Expr *E) {
return slice_token_array(AllTokens, E->getStartLoc(), E->getEndLoc());
}
SimilarExprCollector(SourceManager &SM, Expr *SelectedExpr,
ArrayRef<Token> AllTokens,
llvm::SetVector<Expr*> &Bucket): SM(SM), SelectedExpr(SelectedExpr),
AllTokens(AllTokens), Bucket(Bucket),
SelectedTokens(getExprSlice(SelectedExpr)),
SelectedReferences(SelectedExpr){}
bool walkToExprPre(Expr *E) override {
// We don't extract implicit expressions.
if (E->isImplicit())
return true;
if (E->getKind() != SelectedExpr->getKind())
return true;
// First check the underlying token arrays have the same content.
if (compareTokenContent(getExprSlice(E), SelectedTokens)) {
ReferenceCollector CurrentReferences(E);
// Next, check the referenced decls are same.
if (CurrentReferences == SelectedReferences)
Bucket.insert(E);
}
return true;
}
};
bool RefactoringActionExtractExprBase::performChange() {
// Check if the new name is ok.
if (!Lexer::isIdentifier(PreferredName)) {
DiagEngine.diagnose(SourceLoc(), diag::invalid_name, PreferredName);
return true;
}
// Find the enclosing brace statement;
ContextFinder Finder(*TheFile, RangeInfo.ContainedNodes.front(),
[](ASTNode N) { return N.isStmt(StmtKind::Brace); });
auto *SelectedExpr = RangeInfo.ContainedNodes[0].get<Expr*>();
Finder.resolve();
SourceLoc InsertLoc;
llvm::SetVector<ValueDecl*> AllVisibleDecls;
struct DeclCollector: public SourceEntityWalker {
llvm::SetVector<ValueDecl*> &Bucket;
DeclCollector(llvm::SetVector<ValueDecl*> &Bucket): Bucket(Bucket) {}
bool walkToDeclPre(Decl *D, CharSourceRange Range) override {
if (auto *VD = dyn_cast<ValueDecl>(D))
Bucket.insert(VD);
return true;
}
} Collector(AllVisibleDecls);
llvm::SetVector<Expr*> AllExpressions;
if (!Finder.getContexts().empty()) {
// Get the innermost brace statement.
auto BS = static_cast<BraceStmt*>(Finder.getContexts().back().get<Stmt*>());
// Collect all value decls inside the brace statement.
Collector.walk(BS);
if (ExtractRepeated) {
// Collect all expressions we are going to extract.
SimilarExprCollector(SM, SelectedExpr,
slice_token_array(TheFile->getAllTokens(),
BS->getStartLoc(),
BS->getEndLoc()),
AllExpressions).walk(BS);
} else {
AllExpressions.insert(SelectedExpr);
}
assert(!AllExpressions.empty() && "at least one expression is extracted.");
for (auto Ele : BS->getElements()) {
// Find the element that encloses the first expression under extraction.
if (SM.rangeContains(Ele.getSourceRange(),
(*AllExpressions.begin())->getSourceRange())) {
// Insert before the enclosing element.
InsertLoc = Ele.getStartLoc();
}
}
}
// Complain about no inserting position.
if (InsertLoc.isInvalid()) {
DiagEngine.diagnose(SourceLoc(), diag::no_insert_position);
return true;
}
// Correct name if collision happens.
PreferredName = correctNameInternal(TheFile->getASTContext(), PreferredName,
AllVisibleDecls.getArrayRef());
// Print the type name of this expression.
SmallString<16> TyBuffer;
// We are not sure about the type of repeated expressions.
if (!ExtractRepeated) {
if (auto Ty = RangeInfo.getType()) {
llvm::raw_svector_ostream OS(TyBuffer);
OS << ": ";
Ty->getRValueType()->reconstituteSugar(true)->print(OS);
}
}
SmallString<64> DeclBuffer;
llvm::raw_svector_ostream OS(DeclBuffer);
unsigned StartOffset, EndOffset;
OS << tok::kw_let << " ";
StartOffset = DeclBuffer.size();
OS << PreferredName;
EndOffset = DeclBuffer.size();
OS << TyBuffer.str() << " = " << RangeInfo.ContentRange.str() << "\n";
NoteRegion DeclNameRegion{
RefactoringRangeKind::BaseName,
/*StartLine=*/1, /*StartColumn=*/StartOffset + 1,
/*EndLine=*/1, /*EndColumn=*/EndOffset + 1,
/*ArgIndex*/None
};
// Perform code change.
EditConsumer.accept(SM, InsertLoc, DeclBuffer.str(), {DeclNameRegion});
// Replace all occurrences of the extracted expression.
for (auto *E : AllExpressions) {
EditConsumer.accept(SM,
Lexer::getCharSourceRangeFromSourceRange(SM, E->getSourceRange()),
PreferredName,
{{
RefactoringRangeKind::BaseName,
/*StartLine=*/1, /*StartColumn-*/1, /*EndLine=*/1,
/*EndColumn=*/static_cast<unsigned int>(PreferredName.size() + 1),
/*ArgIndex*/None
}});
}
return false;
}
bool RefactoringActionExtractExpr::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
switch (Info.Kind) {
case RangeKind::SingleExpression:
// We disallow extract literal expression for two reasons:
// (1) since we print the type for extracted expression, the type of a
// literal may print as "int2048" where it is not typically users' choice;
// (2) Extracting one literal provides little value for users.
return checkExtractConditions(Info, Diag).success();
case RangeKind::PartOfExpression:
case RangeKind::SingleDecl:
case RangeKind::MultiTypeMemberDecl:
case RangeKind::SingleStatement:
case RangeKind::MultiStatement:
case RangeKind::Invalid:
return false;
}
llvm_unreachable("unhandled kind");
}
bool RefactoringActionExtractExpr::performChange() {
return RefactoringActionExtractExprBase(TheFile, RangeInfo,
DiagEngine, false, PreferredName,
EditConsumer).performChange();
}
bool RefactoringActionExtractRepeatedExpr::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
switch (Info.Kind) {
case RangeKind::SingleExpression:
return checkExtractConditions(Info, Diag).
success({CannotExtractReason::Literal});
case RangeKind::PartOfExpression:
case RangeKind::SingleDecl:
case RangeKind::MultiTypeMemberDecl:
case RangeKind::SingleStatement:
case RangeKind::MultiStatement:
case RangeKind::Invalid:
return false;
}
llvm_unreachable("unhandled kind");
}
bool RefactoringActionExtractRepeatedExpr::performChange() {
return RefactoringActionExtractExprBase(TheFile, RangeInfo,
DiagEngine, true, PreferredName,
EditConsumer).performChange();
}
bool RefactoringActionMoveMembersToExtension::isApplicable(
const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
switch (Info.Kind) {
case RangeKind::SingleDecl:
case RangeKind::MultiTypeMemberDecl: {
DeclContext *DC = Info.RangeContext;
// The the common decl context is not a nomial type, we cannot create an
// extension for it
if (!DC || !DC->getInnermostDeclarationDeclContext() ||
!isa<NominalTypeDecl>(DC->getInnermostDeclarationDeclContext()))
return false;
// Members of types not declared at top file level cannot be extracted
// to an extension at top file level
if (DC->getParent()->getContextKind() != DeclContextKind::FileUnit)
return false;
// Check if contained nodes are all allowed decls.
for (auto Node : Info.ContainedNodes) {
Decl *D = Node.dyn_cast<Decl*>();
if (!D)
return false;
if (isa<AccessorDecl>(D) || isa<DestructorDecl>(D) ||
isa<EnumCaseDecl>(D) || isa<EnumElementDecl>(D))
return false;
}
// We should not move instance variables with storage into the extension
// because they are not allowed to be declared there
for (auto DD : Info.DeclaredDecls) {
if (auto ASD = dyn_cast<AbstractStorageDecl>(DD.VD)) {
// Only disallow storages in the common decl context, allow them in
// any subtypes
if (ASD->hasStorage() && ASD->getDeclContext() == DC) {
return false;
}
}
}
return true;
}
case RangeKind::SingleExpression:
case RangeKind::PartOfExpression:
case RangeKind::SingleStatement:
case RangeKind::MultiStatement:
case RangeKind::Invalid:
return false;
}
llvm_unreachable("unhandled kind");
}
bool RefactoringActionMoveMembersToExtension::performChange() {
DeclContext *DC = RangeInfo.RangeContext;
auto CommonTypeDecl =
dyn_cast<NominalTypeDecl>(DC->getInnermostDeclarationDeclContext());
assert(CommonTypeDecl && "Not applicable if common parent is no nomial type");
SmallString<64> Buffer;
llvm::raw_svector_ostream OS(Buffer);
OS << "\n\n";
OS << "extension " << CommonTypeDecl->getName() << " {\n";
OS << RangeInfo.ContentRange.str().trim();
OS << "\n}";
// Insert extension after the type declaration
EditConsumer.insertAfter(SM, CommonTypeDecl->getEndLoc(), Buffer);
EditConsumer.remove(SM, RangeInfo.ContentRange);
return false;
}
namespace {
// A SingleDecl range may not include all decls actually declared in that range:
// a var decl has accessors that aren't included. This will find those missing
// decls.
class FindAllSubDecls : public SourceEntityWalker {
SmallPtrSetImpl<Decl *> &Found;
public:
FindAllSubDecls(SmallPtrSetImpl<Decl *> &found)
: Found(found) {}
bool walkToDeclPre(Decl *D, CharSourceRange range) override {
// Record this Decl, and skip its contents if we've already touched it.
if (!Found.insert(D).second)
return false;
if (auto ASD = dyn_cast<AbstractStorageDecl>(D)) {
ASD->visitParsedAccessors([&](AccessorDecl *accessor) {
Found.insert(accessor);
});
}
return true;
}
};
}
bool RefactoringActionReplaceBodiesWithFatalError::isApplicable(
const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
switch (Info.Kind) {
case RangeKind::SingleDecl:
case RangeKind::MultiTypeMemberDecl: {
SmallPtrSet<Decl *, 16> Found;
for (auto decl : Info.DeclaredDecls) {
FindAllSubDecls(Found).walk(decl.VD);
}
for (auto decl : Found) {
auto AFD = dyn_cast<AbstractFunctionDecl>(decl);
if (AFD && !AFD->isImplicit())
return true;
}
return false;
}
case RangeKind::SingleExpression:
case RangeKind::PartOfExpression:
case RangeKind::SingleStatement:
case RangeKind::MultiStatement:
case RangeKind::Invalid:
return false;
}
llvm_unreachable("unhandled kind");
}
bool RefactoringActionReplaceBodiesWithFatalError::performChange() {
const StringRef replacement = "{\nfatalError()\n}";
SmallPtrSet<Decl *, 16> Found;
for (auto decl : RangeInfo.DeclaredDecls) {
FindAllSubDecls(Found).walk(decl.VD);
}
for (auto decl : Found) {
auto AFD = dyn_cast<AbstractFunctionDecl>(decl);
if (!AFD || AFD->isImplicit())
continue;
auto range = AFD->getBodySourceRange();
// If we're in replacement mode (i.e. have an edit consumer), we can
// rewrite the function body.
auto charRange = Lexer::getCharSourceRangeFromSourceRange(SM, range);
EditConsumer.accept(SM, charRange, replacement);
}
return false;
}
static std::pair<IfStmt *, IfStmt *>
findCollapseNestedIfTarget(ResolvedCursorInfoPtr CursorInfo) {
auto StmtStartInfo = dyn_cast<ResolvedStmtStartCursorInfo>(CursorInfo);
if (!StmtStartInfo)
return {};
// Ensure the statement is 'if' statement. It must not have 'else' clause.
IfStmt *OuterIf = dyn_cast<IfStmt>(StmtStartInfo->getTrailingStmt());
if (!OuterIf)
return {};
if (OuterIf->getElseStmt())
return {};
// The body must contain a sole inner 'if' statement.
auto Body = dyn_cast_or_null<BraceStmt>(OuterIf->getThenStmt());
if (!Body || Body->getNumElements() != 1)
return {};
IfStmt *InnerIf =
dyn_cast_or_null<IfStmt>(Body->getFirstElement().dyn_cast<Stmt *>());
if (!InnerIf)
return {};
// Inner 'if' statement also cannot have 'else' clause.
if (InnerIf->getElseStmt())
return {};
return {OuterIf, InnerIf};
}
bool RefactoringActionCollapseNestedIfStmt::isApplicable(
ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &Diag) {
return findCollapseNestedIfTarget(CursorInfo).first;
}
bool RefactoringActionCollapseNestedIfStmt::performChange() {
auto Target = findCollapseNestedIfTarget(CursorInfo);
if (!Target.first)
return true;
auto OuterIf = Target.first;
auto InnerIf = Target.second;
EditorConsumerInsertStream OS(
EditConsumer, SM,
Lexer::getCharSourceRangeFromSourceRange(SM, OuterIf->getSourceRange()));
OS << tok::kw_if << " ";
// Emit conditions.
bool first = true;
for (auto &C : llvm::concat<StmtConditionElement>(OuterIf->getCond(),
InnerIf->getCond())) {
if (first)
first = false;
else
OS << ", ";
OS << Lexer::getCharSourceRangeFromSourceRange(SM, C.getSourceRange())
.str();
}
// Emit body.
OS << " ";
OS << Lexer::getCharSourceRangeFromSourceRange(
SM, InnerIf->getThenStmt()->getSourceRange())
.str();
return false;
}
static std::unique_ptr<llvm::SetVector<Expr*>>
findConcatenatedExpressions(const ResolvedRangeInfo &Info, ASTContext &Ctx) {
Expr *E = nullptr;
switch (Info.Kind) {
case RangeKind::SingleExpression:
E = Info.ContainedNodes[0].get<Expr*>();
break;
case RangeKind::PartOfExpression:
E = Info.CommonExprParent;
break;
default:
return nullptr;
}
assert(E);
struct StringInterpolationExprFinder: public SourceEntityWalker {
std::unique_ptr<llvm::SetVector<Expr *>> Bucket =
std::make_unique<llvm::SetVector<Expr *>>();
ASTContext &Ctx;
bool IsValidInterpolation = true;
StringInterpolationExprFinder(ASTContext &Ctx): Ctx(Ctx) {}
bool isConcatenationExpr(DeclRefExpr* Expr) {
if (!Expr)
return false;
auto *FD = dyn_cast<FuncDecl>(Expr->getDecl());
if (FD == nullptr || (FD != Ctx.getPlusFunctionOnString() &&
FD != Ctx.getPlusFunctionOnRangeReplaceableCollection())) {
return false;
}
return true;
}
bool walkToExprPre(Expr *E) override {
if (E->isImplicit())
return true;
// FIXME: we should have ErrorType instead of null.
if (E->getType().isNull())
return true;
//Only binary concatenation operators should exist in expression
if (E->getKind() == ExprKind::Binary) {
auto *BE = dyn_cast<BinaryExpr>(E);
auto *OperatorDeclRef = BE->getSemanticFn()->getMemberOperatorRef();
if (!(isConcatenationExpr(OperatorDeclRef) &&
E->getType()->isString())) {
IsValidInterpolation = false;
return false;
}
return true;
}
// Everything that evaluates to string should be gathered.
if (E->getType()->isString()) {
Bucket->insert(E);
return false;
}
if (auto *DR = dyn_cast<DeclRefExpr>(E)) {
// Checks whether all function references in expression are concatenations.
auto *FD = dyn_cast<FuncDecl>(DR->getDecl());
auto IsConcatenation = isConcatenationExpr(DR);
if (FD && IsConcatenation) {
return false;
}
}
// There was non-expected expression, it's not valid interpolation then.
IsValidInterpolation = false;
return false;
}
} Walker(Ctx);
Walker.walk(E);
// There should be two or more expressions to convert.
if (!Walker.IsValidInterpolation || Walker.Bucket->size() < 2)
return nullptr;
return std::move(Walker.Bucket);
}
static void interpolatedExpressionForm(Expr *E, SourceManager &SM,
llvm::raw_ostream &OS) {
if (auto *Literal = dyn_cast<StringLiteralExpr>(E)) {
OS << Literal->getValue();
return;
}
auto ExpStr = Lexer::getCharSourceRangeFromSourceRange(SM,
E->getSourceRange()).str().str();
if (isa<InterpolatedStringLiteralExpr>(E)) {
ExpStr.erase(0, 1);
ExpStr.pop_back();
OS << ExpStr;
return;
}
OS << "\\(" << ExpStr << ")";
}
bool RefactoringActionConvertStringsConcatenationToInterpolation::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
auto RangeContext = Info.RangeContext;
if (RangeContext) {
auto &Ctx = Info.RangeContext->getASTContext();
return findConcatenatedExpressions(Info, Ctx) != nullptr;
}
return false;
}
bool RefactoringActionConvertStringsConcatenationToInterpolation::performChange() {
auto Expressions = findConcatenatedExpressions(RangeInfo, Ctx);
if (!Expressions)
return true;
EditorConsumerInsertStream OS(EditConsumer, SM, RangeInfo.ContentRange);
OS << "\"";
for (auto It = Expressions->begin(); It != Expressions->end(); ++It) {
interpolatedExpressionForm(*It, SM, OS);
}
OS << "\"";
return false;
}
/// Abstract helper class containing info about a TernaryExpr
/// that can be expanded into an IfStmt.
class ExpandableTernaryExprInfo {
public:
virtual ~ExpandableTernaryExprInfo() {}
virtual TernaryExpr *getTernary() = 0;
virtual SourceRange getNameRange() = 0;
virtual Type getType() = 0;
virtual bool shouldDeclareNameAndType() {
return !getType().isNull();
}
virtual bool isValid() {
//Ensure all public properties are non-nil and valid
if (!getTernary() || !getNameRange().isValid())
return false;
if (shouldDeclareNameAndType() && getType().isNull())
return false;
return true; //valid
}
CharSourceRange getNameCharRange(const SourceManager &SM) {
return Lexer::getCharSourceRangeFromSourceRange(SM, getNameRange());
}
};
/// Concrete subclass containing info about an AssignExpr
/// where the source is the expandable TernaryExpr.
class ExpandableAssignTernaryExprInfo: public ExpandableTernaryExprInfo {
public:
ExpandableAssignTernaryExprInfo(AssignExpr *Assign): Assign(Assign) {}
TernaryExpr *getTernary() override {
if (!Assign)
return nullptr;
return dyn_cast_or_null<TernaryExpr>(Assign->getSrc());
}
SourceRange getNameRange() override {
auto Invalid = SourceRange();
if (!Assign)
return Invalid;
if (auto dest = Assign->getDest())
return dest->getSourceRange();
return Invalid;
}
Type getType() override {
return nullptr;
}
private:
AssignExpr *Assign = nullptr;
};
/// Concrete subclass containing info about a PatternBindingDecl
/// where the pattern initializer is the expandable TernaryExpr.
class ExpandableBindingTernaryExprInfo: public ExpandableTernaryExprInfo {
public:
ExpandableBindingTernaryExprInfo(PatternBindingDecl *Binding):
Binding(Binding) {}
TernaryExpr *getTernary() override {
if (Binding && Binding->getNumPatternEntries() == 1) {
if (auto *Init = Binding->getInit(0)) {
return dyn_cast<TernaryExpr>(Init);
}
}
return nullptr;
}
SourceRange getNameRange() override {
if (auto Pattern = getNamePattern())
return Pattern->getSourceRange();
return SourceRange();
}
Type getType() override {
if (auto Pattern = getNamePattern())
return Pattern->getType();
return nullptr;
}
private:
Pattern *getNamePattern() {
if (!Binding || Binding->getNumPatternEntries() != 1)
return nullptr;
auto Pattern = Binding->getPattern(0);
if (!Pattern)
return nullptr;
if (auto TyPattern = dyn_cast<TypedPattern>(Pattern))
Pattern = TyPattern->getSubPattern();
return Pattern;
}
PatternBindingDecl *Binding = nullptr;
};
std::unique_ptr<ExpandableTernaryExprInfo>
findExpandableTernaryExpression(const ResolvedRangeInfo &Info) {
if (Info.Kind != RangeKind::SingleDecl
&& Info.Kind != RangeKind:: SingleExpression)
return nullptr;
if (Info.ContainedNodes.size() != 1)
return nullptr;
if (auto D = Info.ContainedNodes[0].dyn_cast<Decl*>())
if (auto Binding = dyn_cast<PatternBindingDecl>(D))
return std::make_unique<ExpandableBindingTernaryExprInfo>(Binding);
if (auto E = Info.ContainedNodes[0].dyn_cast<Expr*>())
if (auto Assign = dyn_cast<AssignExpr>(E))
return std::make_unique<ExpandableAssignTernaryExprInfo>(Assign);
return nullptr;
}
bool RefactoringActionExpandTernaryExpr::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
auto Target = findExpandableTernaryExpression(Info);
return Target && Target->isValid();
}
bool RefactoringActionExpandTernaryExpr::performChange() {
auto Target = findExpandableTernaryExpression(RangeInfo);
if (!Target || !Target->isValid())
return true; //abort
auto NameCharRange = Target->getNameCharRange(SM);
auto IfRange = Target->getTernary()->getSourceRange();
auto IfCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, IfRange);
auto CondRange = Target->getTernary()->getCondExpr()->getSourceRange();
auto CondCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, CondRange);
auto ThenRange = Target->getTernary()->getThenExpr()->getSourceRange();
auto ThenCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, ThenRange);
auto ElseRange = Target->getTernary()->getElseExpr()->getSourceRange();
auto ElseCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, ElseRange);
SmallString<64> DeclBuffer;
llvm::raw_svector_ostream OS(DeclBuffer);
StringRef Space = " ";
StringRef NewLine = "\n";
if (Target->shouldDeclareNameAndType()) {
//Specifier will not be replaced; append after specifier
OS << NameCharRange.str() << tok::colon << Space;
OS << Target->getType() << NewLine;
}
OS << tok::kw_if << Space;
OS << CondCharRange.str() << Space;
OS << tok::l_brace << NewLine;
OS << NameCharRange.str() << Space;
OS << tok::equal << Space;
OS << ThenCharRange.str() << NewLine;
OS << tok::r_brace << Space;
OS << tok::kw_else << Space;
OS << tok::l_brace << NewLine;
OS << NameCharRange.str() << Space;
OS << tok::equal << Space;
OS << ElseCharRange.str() << NewLine;
OS << tok::r_brace;
//Start replacement with name range, skip the specifier
auto ReplaceRange(NameCharRange);
ReplaceRange.widen(IfCharRange);
EditConsumer.accept(SM, ReplaceRange, DeclBuffer.str());
return false; //don't abort
}
bool RefactoringActionConvertIfLetExprToGuardExpr::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
if (Info.Kind != RangeKind::SingleStatement
&& Info.Kind != RangeKind::MultiStatement)
return false;
if (Info.ContainedNodes.empty())
return false;
IfStmt *If = nullptr;
if (Info.ContainedNodes.size() == 1) {
if (auto S = Info.ContainedNodes[0].dyn_cast<Stmt*>()) {
If = dyn_cast<IfStmt>(S);
}
}
if (!If)
return false;
auto CondList = If->getCond();
if (CondList.size() == 1) {
auto E = CondList[0];
auto P = E.getKind();
if (P == swift::StmtConditionElement::CK_PatternBinding) {
auto Body = dyn_cast_or_null<BraceStmt>(If->getThenStmt());
if (Body)
return true;
}
}
return false;
}
bool RefactoringActionConvertIfLetExprToGuardExpr::performChange() {
auto S = RangeInfo.ContainedNodes[0].dyn_cast<Stmt*>();
IfStmt *If = dyn_cast<IfStmt>(S);
auto CondList = If->getCond();
// Get if-let condition
SourceRange range = CondList[0].getSourceRange();
SourceManager &SM = RangeInfo.RangeContext->getASTContext().SourceMgr;
auto CondCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, range);
auto Body = dyn_cast_or_null<BraceStmt>(If->getThenStmt());
// Get if-let then body.
auto firstElement = Body->getFirstElement();
auto lastElement = Body->getLastElement();
SourceRange bodyRange = firstElement.getSourceRange();
bodyRange.widen(lastElement.getSourceRange());
auto BodyCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, bodyRange);
SmallString<64> DeclBuffer;
llvm::raw_svector_ostream OS(DeclBuffer);
StringRef Space = " ";
StringRef NewLine = "\n";
OS << tok::kw_guard << Space;
OS << CondCharRange.str().str() << Space;
OS << tok::kw_else << Space;
OS << tok::l_brace << NewLine;
// Get if-let else body.
if (auto *ElseBody = dyn_cast_or_null<BraceStmt>(If->getElseStmt())) {
auto firstElseElement = ElseBody->getFirstElement();
auto lastElseElement = ElseBody->getLastElement();
SourceRange elseBodyRange = firstElseElement.getSourceRange();
elseBodyRange.widen(lastElseElement.getSourceRange());
auto ElseBodyCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, elseBodyRange);
OS << ElseBodyCharRange.str().str() << NewLine;
}
OS << tok::kw_return << NewLine;
OS << tok::r_brace << NewLine;
OS << BodyCharRange.str().str();
// Replace if-let to guard
auto ReplaceRange = RangeInfo.ContentRange;
EditConsumer.accept(SM, ReplaceRange, DeclBuffer.str());
return false;
}
bool RefactoringActionConvertGuardExprToIfLetExpr::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
if (Info.Kind != RangeKind::SingleStatement
&& Info.Kind != RangeKind::MultiStatement)
return false;
if (Info.ContainedNodes.empty())
return false;
GuardStmt *guardStmt = nullptr;
if (Info.ContainedNodes.size() > 0) {
if (auto S = Info.ContainedNodes[0].dyn_cast<Stmt*>()) {
guardStmt = dyn_cast<GuardStmt>(S);
}
}
if (!guardStmt)
return false;
auto CondList = guardStmt->getCond();
if (CondList.size() == 1) {
auto E = CondList[0];
auto P = E.getPatternOrNull();
if (P && E.getKind() == swift::StmtConditionElement::CK_PatternBinding)
return true;
}
return false;
}
bool RefactoringActionConvertGuardExprToIfLetExpr::performChange() {
// Get guard stmt
auto S = RangeInfo.ContainedNodes[0].dyn_cast<Stmt*>();
GuardStmt *Guard = dyn_cast<GuardStmt>(S);
// Get guard condition
auto CondList = Guard->getCond();
// Get guard condition source
SourceRange range = CondList[0].getSourceRange();
SourceManager &SM = RangeInfo.RangeContext->getASTContext().SourceMgr;
auto CondCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, range);
SmallString<64> DeclBuffer;
llvm::raw_svector_ostream OS(DeclBuffer);
StringRef Space = " ";
StringRef NewLine = "\n";
OS << tok::kw_if << Space;
OS << CondCharRange.str().str() << Space;
OS << tok::l_brace << NewLine;
// Get nodes after guard to place them at if-let body
if (RangeInfo.ContainedNodes.size() > 1) {
auto S = RangeInfo.ContainedNodes[1].getSourceRange();
S.widen(RangeInfo.ContainedNodes.back().getSourceRange());
auto BodyCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, S);
OS << BodyCharRange.str().str() << NewLine;
}
OS << tok::r_brace;
// Get guard body
auto Body = dyn_cast_or_null<BraceStmt>(Guard->getBody());
if (Body && Body->getNumElements() > 1) {
auto firstElement = Body->getFirstElement();
auto lastElement = Body->getLastElement();
SourceRange bodyRange = firstElement.getSourceRange();
bodyRange.widen(lastElement.getSourceRange());
auto BodyCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, bodyRange);
OS << Space << tok::kw_else << Space << tok::l_brace << NewLine;
OS << BodyCharRange.str().str() << NewLine;
OS << tok::r_brace;
}
// Replace guard to if-let
auto ReplaceRange = RangeInfo.ContentRange;
EditConsumer.accept(SM, ReplaceRange, DeclBuffer.str());
return false;
}
bool RefactoringActionConvertToSwitchStmt::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
class ConditionalChecker : public ASTWalker {
public:
bool ParamsUseSameVars = true;
bool ConditionUseOnlyAllowedFunctions = false;
StringRef ExpectName;
MacroWalking getMacroWalkingBehavior() const override {
return MacroWalking::Arguments;
}
PostWalkResult<Expr *> walkToExprPost(Expr *E) override {
if (E->getKind() != ExprKind::DeclRef)
return Action::Continue(E);
auto D = dyn_cast<DeclRefExpr>(E)->getDecl();
if (D->getKind() == DeclKind::Var || D->getKind() == DeclKind::Param)
ParamsUseSameVars = checkName(dyn_cast<VarDecl>(D));
if (D->getKind() == DeclKind::Func)
ConditionUseOnlyAllowedFunctions = checkName(dyn_cast<FuncDecl>(D));
if (allCheckPassed())
return Action::Continue(E);
return Action::Stop();
}
bool allCheckPassed() {
return ParamsUseSameVars && ConditionUseOnlyAllowedFunctions;
}
private:
bool checkName(VarDecl *VD) {
auto Name = VD->getName().str();
if (ExpectName.empty())
ExpectName = Name;
return Name == ExpectName;
}
bool checkName(FuncDecl *FD) {
const auto Name = FD->getBaseIdentifier().str();
return Name == "~="
|| Name == "=="
|| Name == "__derived_enum_equals"
|| Name == "__derived_struct_equals"
|| Name == "||"
|| Name == "...";
}
};
class SwitchConvertable {
public:
SwitchConvertable(const ResolvedRangeInfo &Info) : Info(Info) { }
bool isApplicable() {
if (Info.Kind != RangeKind::SingleStatement)
return false;
if (!findIfStmt())
return false;
return checkEachCondition();
}
private:
const ResolvedRangeInfo &Info;
IfStmt *If = nullptr;
ConditionalChecker checker;
bool findIfStmt() {
if (Info.ContainedNodes.size() != 1)
return false;
if (auto S = Info.ContainedNodes.front().dyn_cast<Stmt*>())
If = dyn_cast<IfStmt>(S);
return If != nullptr;
}
bool checkEachCondition() {
checker = ConditionalChecker();
do {
if (!checkEachElement())
return false;
} while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt())));
return true;
}
bool checkEachElement() {
bool result = true;
auto ConditionalList = If->getCond();
for (auto Element : ConditionalList) {
result &= check(Element);
}
return result;
}
bool check(StmtConditionElement ConditionElement) {
if (ConditionElement.getKind() == StmtConditionElement::CK_Availability)
return false;
if (ConditionElement.getKind() == StmtConditionElement::CK_PatternBinding)
checker.ConditionUseOnlyAllowedFunctions = true;
ConditionElement.walk(checker);
return checker.allCheckPassed();
}
};
return SwitchConvertable(Info).isApplicable();
}
bool RefactoringActionConvertToSwitchStmt::performChange() {
class VarNameFinder : public ASTWalker {
public:
std::string VarName;
MacroWalking getMacroWalkingBehavior() const override {
return MacroWalking::Arguments;
}
PostWalkResult<Expr *> walkToExprPost(Expr *E) override {
if (E->getKind() != ExprKind::DeclRef)
return Action::Continue(E);
auto D = dyn_cast<DeclRefExpr>(E)->getDecl();
if (D->getKind() != DeclKind::Var && D->getKind() != DeclKind::Param)
return Action::Continue(E);
VarName = dyn_cast<VarDecl>(D)->getName().str().str();
return Action::Stop();
}
};
class ConditionalPatternFinder : public ASTWalker {
public:
ConditionalPatternFinder(SourceManager &SM) : SM(SM) {}
SmallString<64> ConditionalPattern = SmallString<64>();
MacroWalking getMacroWalkingBehavior() const override {
return MacroWalking::Arguments;
}
PostWalkResult<Expr *> walkToExprPost(Expr *E) override {
auto *BE = dyn_cast<BinaryExpr>(E);
if (!BE)
return Action::Continue(E);
if (isFunctionNameAllowed(BE))
appendPattern(BE->getLHS(), BE->getRHS());
return Action::Continue(E);
}
PreWalkResult<Pattern *> walkToPatternPre(Pattern *P) override {
ConditionalPattern.append(Lexer::getCharSourceRangeFromSourceRange(SM, P->getSourceRange()).str());
if (P->getKind() == PatternKind::OptionalSome)
ConditionalPattern.append("?");
return Action::Stop();
}
private:
SourceManager &SM;
bool isFunctionNameAllowed(BinaryExpr *E) {
Expr *Fn = E->getFn();
if (auto DotSyntaxCall = dyn_cast_or_null<DotSyntaxCallExpr>(Fn)) {
Fn = DotSyntaxCall->getFn();
}
DeclRefExpr *DeclRef = dyn_cast_or_null<DeclRefExpr>(Fn);
if (!DeclRef) {
return false;
}
auto FunctionDeclaration = dyn_cast_or_null<FuncDecl>(DeclRef->getDecl());
if (!FunctionDeclaration) {
return false;
}
auto &ASTCtx = FunctionDeclaration->getASTContext();
const auto FunctionName = FunctionDeclaration->getBaseIdentifier();
return FunctionName == ASTCtx.Id_MatchOperator ||
FunctionName == ASTCtx.Id_EqualsOperator ||
FunctionName == ASTCtx.Id_derived_enum_equals ||
FunctionName == ASTCtx.Id_derived_struct_equals;
}
void appendPattern(Expr *LHS, Expr *RHS) {
auto *PatternArgument = RHS;
if (PatternArgument->getKind() == ExprKind::DeclRef)
PatternArgument = LHS;
if (ConditionalPattern.size() > 0)
ConditionalPattern.append(", ");
ConditionalPattern.append(Lexer::getCharSourceRangeFromSourceRange(SM, PatternArgument->getSourceRange()).str());
}
};
class ConverterToSwitch {
public:
ConverterToSwitch(const ResolvedRangeInfo &Info,
SourceManager &SM) : Info(Info), SM(SM) { }
void performConvert(SmallString<64> &Out) {
If = findIf();
OptionalLabel = If->getLabelInfo().Name.str().str();
ControlExpression = findControlExpression();
findPatternsAndBodies(PatternsAndBodies);
DefaultStatements = findDefaultStatements();
makeSwitchStatement(Out);
}
private:
const ResolvedRangeInfo &Info;
SourceManager &SM;
IfStmt *If;
IfStmt *PreviousIf;
std::string OptionalLabel;
std::string ControlExpression;
SmallVector<std::pair<std::string, std::string>, 16> PatternsAndBodies;
std::string DefaultStatements;
IfStmt *findIf() {
auto S = Info.ContainedNodes[0].dyn_cast<Stmt*>();
return dyn_cast<IfStmt>(S);
}
std::string findControlExpression() {
auto ConditionElement = If->getCond().front();
auto Finder = VarNameFinder();
ConditionElement.walk(Finder);
return Finder.VarName;
}
void findPatternsAndBodies(SmallVectorImpl<std::pair<std::string, std::string>> &Out) {
do {
auto pattern = findPattern();
auto body = findBodyStatements();
Out.push_back(std::make_pair(pattern, body));
PreviousIf = If;
} while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt())));
}
std::string findPattern() {
auto ConditionElement = If->getCond().front();
auto Finder = ConditionalPatternFinder(SM);
ConditionElement.walk(Finder);
return Finder.ConditionalPattern.str().str();
}
std::string findBodyStatements() {
return findBodyWithoutBraces(If->getThenStmt());
}
std::string findDefaultStatements() {
auto ElseBody = dyn_cast_or_null<BraceStmt>(PreviousIf->getElseStmt());
if (!ElseBody)
return getTokenText(tok::kw_break).str();
return findBodyWithoutBraces(ElseBody);
}
std::string findBodyWithoutBraces(Stmt *body) {
auto BS = dyn_cast<BraceStmt>(body);
if (!BS)
return Lexer::getCharSourceRangeFromSourceRange(SM, body->getSourceRange()).str().str();
if (BS->getElements().empty())
return getTokenText(tok::kw_break).str();
SourceRange BodyRange = BS->getElements().front().getSourceRange();
BodyRange.widen(BS->getElements().back().getSourceRange());
return Lexer::getCharSourceRangeFromSourceRange(SM, BodyRange).str().str();
}
void makeSwitchStatement(SmallString<64> &Out) {
StringRef Space = " ";
StringRef NewLine = "\n";
llvm::raw_svector_ostream OS(Out);
if (OptionalLabel.size() > 0)
OS << OptionalLabel << ":" << Space;
OS << tok::kw_switch << Space << ControlExpression << Space << tok::l_brace << NewLine;
for (auto &pair : PatternsAndBodies) {
OS << tok::kw_case << Space << pair.first << tok::colon << NewLine;
OS << pair.second << NewLine;
}
OS << tok::kw_default << tok::colon << NewLine;
OS << DefaultStatements << NewLine;
OS << tok::r_brace;
}
};
SmallString<64> result;
ConverterToSwitch(RangeInfo, SM).performConvert(result);
EditConsumer.accept(SM, RangeInfo.ContentRange, result.str());
return false;
}
/// Struct containing info about an IfStmt that can be converted into a
/// TernaryExpr.
struct ConvertToTernaryExprInfo {
ConvertToTernaryExprInfo() {}
Expr *AssignDest() {
if (!Then || !Then->getDest() || !Else || !Else->getDest())
return nullptr;
auto ThenDest = Then->getDest();
auto ElseDest = Else->getDest();
if (ThenDest->getKind() != ElseDest->getKind())
return nullptr;
switch (ThenDest->getKind()) {
case ExprKind::DeclRef: {
auto ThenRef = dyn_cast<DeclRefExpr>(Then->getDest());
auto ElseRef = dyn_cast<DeclRefExpr>(Else->getDest());
if (!ThenRef || !ThenRef->getDecl() || !ElseRef || !ElseRef->getDecl())
return nullptr;
const auto ThenName = ThenRef->getDecl()->getName();
const auto ElseName = ElseRef->getDecl()->getName();
if (ThenName.compare(ElseName) != 0)
return nullptr;
return Then->getDest();
}
case ExprKind::Tuple: {
auto ThenTuple = dyn_cast<TupleExpr>(Then->getDest());
auto ElseTuple = dyn_cast<TupleExpr>(Else->getDest());
if (!ThenTuple || !ElseTuple)
return nullptr;
auto ThenNames = ThenTuple->getElementNames();
auto ElseNames = ElseTuple->getElementNames();
if (!ThenNames.equals(ElseNames))
return nullptr;
return ThenTuple;
}
default:
return nullptr;
}
}
Expr *ThenSrc() {
if (!Then)
return nullptr;
return Then->getSrc();
}
Expr *ElseSrc() {
if (!Else)
return nullptr;
return Else->getSrc();
}
bool isValid() {
if (!Cond || !AssignDest() || !ThenSrc() || !ElseSrc()
|| !IfRange.isValid())
return false;
return true;
}
PatternBindingDecl *Binding = nullptr; //optional
Expr *Cond = nullptr; //required
AssignExpr *Then = nullptr; //required
AssignExpr *Else = nullptr; //required
SourceRange IfRange;
};
ConvertToTernaryExprInfo
findConvertToTernaryExpression(const ResolvedRangeInfo &Info) {
auto notFound = ConvertToTernaryExprInfo();
if (Info.Kind != RangeKind::SingleStatement
&& Info.Kind != RangeKind::MultiStatement)
return notFound;
if (Info.ContainedNodes.empty())
return notFound;
struct AssignExprFinder: public SourceEntityWalker {
AssignExpr *Assign = nullptr;
AssignExprFinder(Stmt* S) {
if (S)
walk(S);
}
virtual bool walkToExprPre(Expr *E) override {
Assign = dyn_cast<AssignExpr>(E);
return false;
}
};
ConvertToTernaryExprInfo Target;
IfStmt *If = nullptr;
if (Info.ContainedNodes.size() == 1) {
if (auto S = Info.ContainedNodes[0].dyn_cast<Stmt*>())
If = dyn_cast<IfStmt>(S);
}
if (Info.ContainedNodes.size() == 2) {
if (auto D = Info.ContainedNodes[0].dyn_cast<Decl*>())
Target.Binding = dyn_cast<PatternBindingDecl>(D);
if (auto S = Info.ContainedNodes[1].dyn_cast<Stmt*>())
If = dyn_cast<IfStmt>(S);
}
if (!If)
return notFound;
auto CondList = If->getCond();
if (CondList.size() != 1)
return notFound;
Target.Cond = CondList[0].getBooleanOrNull();
Target.IfRange = If->getSourceRange();
Target.Then = AssignExprFinder(If->getThenStmt()).Assign;
Target.Else = AssignExprFinder(If->getElseStmt()).Assign;
return Target;
}
bool RefactoringActionConvertToTernaryExpr::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
return findConvertToTernaryExpression(Info).isValid();
}
bool RefactoringActionConvertToTernaryExpr::performChange() {
auto Target = findConvertToTernaryExpression(RangeInfo);
if (!Target.isValid())
return true; //abort
SmallString<64> DeclBuffer;
llvm::raw_svector_ostream OS(DeclBuffer);
StringRef Space = " ";
auto IfRange = Target.IfRange;
auto ReplaceRange = Lexer::getCharSourceRangeFromSourceRange(SM, IfRange);
auto CondRange = Target.Cond->getSourceRange();
auto CondCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, CondRange);
auto ThenRange = Target.ThenSrc()->getSourceRange();
auto ThenCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, ThenRange);
auto ElseRange = Target.ElseSrc()->getSourceRange();
auto ElseCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, ElseRange);
CharSourceRange DestCharRange;
if (Target.Binding) {
auto DestRange = Target.Binding->getSourceRange();
DestCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, DestRange);
ReplaceRange.widen(DestCharRange);
} else {
auto DestRange = Target.AssignDest()->getSourceRange();
DestCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, DestRange);
}
OS << DestCharRange.str() << Space << tok::equal << Space;
OS << CondCharRange.str() << Space << tok::question_postfix << Space;
OS << ThenCharRange.str() << Space << tok::colon << Space;
OS << ElseCharRange.str();
EditConsumer.accept(SM, ReplaceRange, DeclBuffer.str());
return false; //don't abort
}
/// The helper class analyzes a given nominal decl or an extension decl to
/// decide whether stubs are required to filled in and the context in which
/// these stubs should be filled.
class FillProtocolStubContext {
std::vector<ValueDecl*>
getUnsatisfiedRequirements(const IterableDeclContext *IDC);
/// Context in which the content should be filled; this could be either a
/// nominal type declaraion or an extension declaration.
DeclContext *DC;
/// The type that adopts the required protocol stubs. For nominal type decl, this
/// should be the declared type itself; for extension decl, this should be the
/// extended type at hand.
Type Adopter;
/// The start location of the decl, either nominal type or extension, for the
/// printer to figure out the right indentation.
SourceLoc StartLoc;
/// The location of '{' for the decl, thus we know where to insert the filling
/// stubs.
SourceLoc BraceStartLoc;
/// The value decls that should be satisfied; this could be either function
/// decls, property decls, or required type alias.
std::vector<ValueDecl*> FillingContents;
public:
FillProtocolStubContext(ExtensionDecl *ED) : DC(ED),
Adopter(ED->getExtendedType()), StartLoc(ED->getStartLoc()),
BraceStartLoc(ED->getBraces().Start),
FillingContents(getUnsatisfiedRequirements(ED)) {};
FillProtocolStubContext(NominalTypeDecl *ND) : DC(ND),
Adopter(ND->getDeclaredType()), StartLoc(ND->getStartLoc()),
BraceStartLoc(ND->getBraces().Start),
FillingContents(getUnsatisfiedRequirements(ND)) {};
FillProtocolStubContext() : DC(nullptr), Adopter(), FillingContents({}) {};
static FillProtocolStubContext
getContextFromCursorInfo(ResolvedCursorInfoPtr Tok);
ArrayRef<ValueDecl*> getFillingContents() const {
return llvm::makeArrayRef(FillingContents);
}
DeclContext *getFillingContext() const { return DC; }
bool canProceed() const {
return StartLoc.isValid() && BraceStartLoc.isValid() &&
!getFillingContents().empty();
}
Type getAdopter() const { return Adopter; }
SourceLoc getContextStartLoc() const { return StartLoc; }
SourceLoc getBraceStartLoc() const { return BraceStartLoc; }
};
FillProtocolStubContext FillProtocolStubContext::getContextFromCursorInfo(
ResolvedCursorInfoPtr CursorInfo) {
if (!CursorInfo->isValid())
return FillProtocolStubContext();
auto ValueRefInfo = dyn_cast<ResolvedValueRefCursorInfo>(CursorInfo);
if (!ValueRefInfo) {
return FillProtocolStubContext();
}
if (!ValueRefInfo->isRef()) {
// If the type name is on the declared nominal, e.g. "class A {}"
if (auto ND = dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD())) {
return FillProtocolStubContext(ND);
}
} else if (auto *ED = ValueRefInfo->getExtTyRef()) {
// If the type ref is on a declared extension, e.g. "extension A {}"
return FillProtocolStubContext(ED);
}
return FillProtocolStubContext();
}
std::vector<ValueDecl*> FillProtocolStubContext::
getUnsatisfiedRequirements(const IterableDeclContext *IDC) {
// The results to return.
std::vector<ValueDecl*> NonWitnessedReqs;
// For each conformance of the extended nominal.
for(ProtocolConformance *Con : IDC->getLocalConformances()) {
// Collect non-witnessed requirements.
Con->forEachNonWitnessedRequirement(
[&](ValueDecl *VD) { NonWitnessedReqs.push_back(VD); });
}
return NonWitnessedReqs;
}
bool RefactoringActionFillProtocolStub::isApplicable(ResolvedCursorInfoPtr Tok,
DiagnosticEngine &Diag) {
return FillProtocolStubContext::getContextFromCursorInfo(Tok).canProceed();
}
bool RefactoringActionFillProtocolStub::performChange() {
// Get the filling protocol context from the input token.
FillProtocolStubContext Context = FillProtocolStubContext::
getContextFromCursorInfo(CursorInfo);
assert(Context.canProceed());
assert(!Context.getFillingContents().empty());
assert(Context.getFillingContext());
SmallString<128> Text;
{
llvm::raw_svector_ostream SS(Text);
Type Adopter = Context.getAdopter();
SourceLoc Loc = Context.getContextStartLoc();
auto Contents = Context.getFillingContents();
// For each unsatisfied requirement, print the stub to the buffer.
std::for_each(Contents.begin(), Contents.end(), [&](ValueDecl *VD) {
printRequirementStub(VD, Context.getFillingContext(), Adopter, Loc, SS);
});
}
// Insert all stubs after '{' in the extension/nominal type decl.
EditConsumer.insertAfter(SM, Context.getBraceStartLoc(), Text);
return false;
}
static void collectAvailableRefactoringsAtCursor(
SourceFile *SF, unsigned Line, unsigned Column,
SmallVectorImpl<RefactoringKind> &Kinds,
ArrayRef<DiagnosticConsumer *> DiagConsumers) {
// Prepare the tool box.
ASTContext &Ctx = SF->getASTContext();
SourceManager &SM = Ctx.SourceMgr;
DiagnosticEngine DiagEngine(SM);
std::for_each(DiagConsumers.begin(), DiagConsumers.end(),
[&](DiagnosticConsumer *Con) { DiagEngine.addConsumer(*Con); });
SourceLoc Loc = SM.getLocForLineCol(SF->getBufferID().value(), Line, Column);
if (Loc.isInvalid())
return;
ResolvedCursorInfoPtr Tok =
evaluateOrDefault(SF->getASTContext().evaluator,
CursorInfoRequest{CursorInfoOwner(
SF, Lexer::getLocForStartOfToken(SM, Loc))},
new ResolvedCursorInfo());
collectAvailableRefactorings(Tok, Kinds, /*Exclude rename*/ false);
}
static EnumDecl* getEnumDeclFromSwitchStmt(SwitchStmt *SwitchS) {
if (auto SubjectTy = SwitchS->getSubjectExpr()->getType()) {
// FIXME: Support more complex subject like '(Enum1, Enum2)'.
return dyn_cast_or_null<EnumDecl>(SubjectTy->getAnyNominal());
}
return nullptr;
}
static bool performCasesExpansionInSwitchStmt(SwitchStmt *SwitchS,
DiagnosticEngine &DiagEngine,
SourceLoc ExpandedStmtLoc,
EditorConsumerInsertStream &OS
) {
// Assume enum elements are not handled in the switch statement.
auto EnumDecl = getEnumDeclFromSwitchStmt(SwitchS);
assert(EnumDecl);
llvm::DenseSet<EnumElementDecl*> UnhandledElements;
EnumDecl->getAllElements(UnhandledElements);
for (auto Current : SwitchS->getCases()) {
if (Current->isDefault()) {
continue;
}
// For each handled enum element, remove it from the bucket.
for (auto Item : Current->getCaseLabelItems()) {
if (auto *EEP = dyn_cast_or_null<EnumElementPattern>(Item.getPattern())) {
UnhandledElements.erase(EEP->getElementDecl());
}
}
}
// If all enum elements are handled in the switch statement, issue error.
if (UnhandledElements.empty()) {
DiagEngine.diagnose(ExpandedStmtLoc, diag::no_remaining_cases);
return true;
}
printEnumElementsAsCases(UnhandledElements, OS);
return false;
}
// Finds SwitchStmt that contains given CaseStmt.
static SwitchStmt* findEnclosingSwitchStmt(CaseStmt *CS,
SourceFile *SF,
DiagnosticEngine &DiagEngine) {
auto IsSwitch = [](ASTNode Node) {
return Node.is<Stmt*>() &&
Node.get<Stmt*>()->getKind() == StmtKind::Switch;
};
ContextFinder Finder(*SF, CS, IsSwitch);
Finder.resolve();
// If failed to find the switch statement, issue error.
if (Finder.getContexts().empty()) {
DiagEngine.diagnose(CS->getStartLoc(), diag::no_parent_switch);
return nullptr;
}
auto *SwitchS = static_cast<SwitchStmt*>(Finder.getContexts().back().
get<Stmt*>());
// Make sure that CaseStmt is included in switch that was found.
auto Cases = SwitchS->getCases();
auto Default = std::find(Cases.begin(), Cases.end(), CS);
if (Default == Cases.end()) {
DiagEngine.diagnose(CS->getStartLoc(), diag::no_parent_switch);
return nullptr;
}
return SwitchS;
}
bool RefactoringActionExpandDefault::isApplicable(
ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &Diag) {
auto Exit = [&](bool Applicable) {
if (!Applicable)
Diag.diagnose(SourceLoc(), diag::invalid_default_location);
return Applicable;
};
auto StmtStartInfo = dyn_cast<ResolvedStmtStartCursorInfo>(CursorInfo);
if (!StmtStartInfo)
return Exit(false);
if (auto *CS = dyn_cast<CaseStmt>(StmtStartInfo->getTrailingStmt())) {
auto EnclosingSwitchStmt =
findEnclosingSwitchStmt(CS, CursorInfo->getSourceFile(), Diag);
if (!EnclosingSwitchStmt)
return false;
auto EnumD = getEnumDeclFromSwitchStmt(EnclosingSwitchStmt);
auto IsApplicable = CS->isDefault() && EnumD != nullptr;
return IsApplicable;
}
return Exit(false);
}
bool RefactoringActionExpandDefault::performChange() {
// If we've not seen the default statement inside the switch statement, issue
// error.
auto StmtStartInfo = cast<ResolvedStmtStartCursorInfo>(CursorInfo);
auto *CS = static_cast<CaseStmt *>(StmtStartInfo->getTrailingStmt());
auto *SwitchS = findEnclosingSwitchStmt(CS, TheFile, DiagEngine);
assert(SwitchS);
EditorConsumerInsertStream OS(EditConsumer, SM,
Lexer::getCharSourceRangeFromSourceRange(SM,
CS->getLabelItemsRange()));
return performCasesExpansionInSwitchStmt(SwitchS,
DiagEngine,
CS->getStartLoc(),
OS);
}
bool RefactoringActionExpandSwitchCases::isApplicable(
ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &DiagEngine) {
auto StmtStartInfo = dyn_cast<ResolvedStmtStartCursorInfo>(CursorInfo);
if (!StmtStartInfo || !StmtStartInfo->getTrailingStmt())
return false;
if (auto *Switch = dyn_cast<SwitchStmt>(StmtStartInfo->getTrailingStmt())) {
return getEnumDeclFromSwitchStmt(Switch);
}
return false;
}
bool RefactoringActionExpandSwitchCases::performChange() {
auto StmtStartInfo = cast<ResolvedStmtStartCursorInfo>(CursorInfo);
auto *SwitchS = dyn_cast<SwitchStmt>(StmtStartInfo->getTrailingStmt());
assert(SwitchS);
auto InsertRange = CharSourceRange();
auto Cases = SwitchS->getCases();
auto Default = std::find_if(Cases.begin(), Cases.end(), [](CaseStmt *Stmt) {
return Stmt->isDefault();
});
if (Default != Cases.end()) {
auto DefaultRange = (*Default)->getLabelItemsRange();
InsertRange = Lexer::getCharSourceRangeFromSourceRange(SM, DefaultRange);
} else {
auto RBraceLoc = SwitchS->getRBraceLoc();
InsertRange = CharSourceRange(SM, RBraceLoc, RBraceLoc);
}
EditorConsumerInsertStream OS(EditConsumer, SM, InsertRange);
if (SM.getLineAndColumnInBuffer(SwitchS->getLBraceLoc()).first ==
SM.getLineAndColumnInBuffer(SwitchS->getRBraceLoc()).first) {
OS << "\n";
}
auto Result = performCasesExpansionInSwitchStmt(SwitchS,
DiagEngine,
SwitchS->getStartLoc(),
OS);
return Result;
}
static Expr *findLocalizeTarget(ResolvedCursorInfoPtr CursorInfo) {
auto ExprStartInfo = dyn_cast<ResolvedExprStartCursorInfo>(CursorInfo);
if (!ExprStartInfo)
return nullptr;
struct StringLiteralFinder: public SourceEntityWalker {
SourceLoc StartLoc;
Expr *Target;
StringLiteralFinder(SourceLoc StartLoc): StartLoc(StartLoc), Target(nullptr) {}
bool walkToExprPre(Expr *E) override {
if (E->getStartLoc() != StartLoc)
return false;
if (E->getKind() == ExprKind::InterpolatedStringLiteral)
return false;
if (E->getKind() == ExprKind::StringLiteral) {
Target = E;
return false;
}
return true;
}
} Walker(ExprStartInfo->getTrailingExpr()->getStartLoc());
Walker.walk(ExprStartInfo->getTrailingExpr());
return Walker.Target;
}
bool RefactoringActionLocalizeString::isApplicable(ResolvedCursorInfoPtr Tok,
DiagnosticEngine &Diag) {
return findLocalizeTarget(Tok);
}
bool RefactoringActionLocalizeString::performChange() {
Expr* Target = findLocalizeTarget(CursorInfo);
if (!Target)
return true;
EditConsumer.accept(SM, Target->getStartLoc(), "NSLocalizedString(");
EditConsumer.insertAfter(SM, Target->getEndLoc(), ", comment: \"\")");
return false;
}
struct MemberwiseParameter {
CharSourceRange NameRange;
Type MemberType;
Expr *DefaultExpr;
MemberwiseParameter(CharSourceRange nameRange, Type type, Expr *initialExpr)
: NameRange(nameRange), MemberType(type), DefaultExpr(initialExpr) {}
};
static void generateMemberwiseInit(SourceEditConsumer &EditConsumer,
SourceManager &SM,
ArrayRef<MemberwiseParameter> memberVector,
SourceLoc targetLocation) {
EditConsumer.accept(SM, targetLocation, "\ninternal init(");
auto insertMember = [&SM](const MemberwiseParameter &memberData,
raw_ostream &OS, bool wantsSeparator) {
{
OS << SM.extractText(memberData.NameRange) << ": ";
// Unconditionally print '@escaping' if we print out a function type -
// the assignments we generate below will escape this parameter.
if (isa<AnyFunctionType>(memberData.MemberType->getCanonicalType())) {
OS << "@" << TypeAttributes::getAttrName(TAK_escaping) << " ";
}
OS << memberData.MemberType.getString();
}
bool HasAddedDefault = false;
if (auto *expr = memberData.DefaultExpr) {
if (expr->getSourceRange().isValid()) {
auto range =
Lexer::getCharSourceRangeFromSourceRange(
SM, expr->getSourceRange());
OS << " = " << SM.extractText(range);
HasAddedDefault = true;
}
}
if (!HasAddedDefault && memberData.MemberType->isOptional()) {
OS << " = nil";
}
if (wantsSeparator) {
OS << ", ";
}
};
// Process the initial list of members, inserting commas as appropriate.
std::string Buffer;
llvm::raw_string_ostream OS(Buffer);
for (const auto &memberData : llvm::enumerate(memberVector)) {
bool wantsSeparator = (memberData.index() != memberVector.size() - 1);
insertMember(memberData.value(), OS, wantsSeparator);
}
// Synthesize the body.
OS << ") {\n";
for (auto &member : memberVector) {
// self.<property> = <property>
auto name = SM.extractText(member.NameRange);
OS << "self." << name << " = " << name << "\n";
}
OS << "}\n";
// Accept the entire edit.
EditConsumer.accept(SM, targetLocation, OS.str());
}
static SourceLoc
collectMembersForInit(ResolvedCursorInfoPtr CursorInfo,
SmallVectorImpl<MemberwiseParameter> &memberVector) {
auto ValueRefInfo = dyn_cast<ResolvedValueRefCursorInfo>(CursorInfo);
if (!ValueRefInfo || !ValueRefInfo->getValueD())
return SourceLoc();
NominalTypeDecl *nominalDecl =
dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD());
if (!nominalDecl || nominalDecl->getStoredProperties().empty() ||
ValueRefInfo->isRef()) {
return SourceLoc();
}
SourceLoc bracesStart = nominalDecl->getBraces().Start;
if (!bracesStart.isValid())
return SourceLoc();
SourceLoc targetLocation = bracesStart.getAdvancedLoc(1);
if (!targetLocation.isValid())
return SourceLoc();
SourceManager &SM = nominalDecl->getASTContext().SourceMgr;
for (auto member : nominalDecl->getMembers()) {
auto varDecl = dyn_cast<VarDecl>(member);
if (!varDecl) {
continue;
}
if (varDecl->getAttrs().hasAttribute<LazyAttr>()) {
// Exclude lazy members from the memberwise initializer. This is
// inconsistent with the implicitly synthesized memberwise initializer but
// we think it makes more sense because otherwise the lazy variable's
// initializer gets evaluated eagerly.
continue;
}
if (!varDecl->isMemberwiseInitialized(/*preferDeclaredProperties=*/true)) {
continue;
}
auto patternBinding = varDecl->getParentPatternBinding();
if (!patternBinding)
continue;
const auto i = patternBinding->getPatternEntryIndexForVarDecl(varDecl);
Expr *defaultInit = nullptr;
if (patternBinding->isExplicitlyInitialized(i) ||
patternBinding->isDefaultInitializable()) {
defaultInit = patternBinding->getOriginalInit(i);
}
auto NameRange =
Lexer::getCharSourceRangeFromSourceRange(SM, varDecl->getNameLoc());
memberVector.emplace_back(NameRange, varDecl->getType(), defaultInit);
}
return targetLocation;
}
bool RefactoringActionMemberwiseInitLocalRefactoring::isApplicable(
ResolvedCursorInfoPtr Tok, DiagnosticEngine &Diag) {
SmallVector<MemberwiseParameter, 8> memberVector;
return collectMembersForInit(Tok, memberVector).isValid();
}
bool RefactoringActionMemberwiseInitLocalRefactoring::performChange() {
SmallVector<MemberwiseParameter, 8> memberVector;
SourceLoc targetLocation = collectMembersForInit(CursorInfo, memberVector);
if (targetLocation.isInvalid())
return true;
generateMemberwiseInit(EditConsumer, SM, memberVector, targetLocation);
return false;
}
/// If \p NTD is a protocol, return all the protocols it inherits from. If it's
/// a type, return all the protocols it conforms to.
static SmallVector<ProtocolDecl *, 2> getAllProtocols(NominalTypeDecl *NTD) {
if (auto Proto = dyn_cast<ProtocolDecl>(NTD)) {
return SmallVector<ProtocolDecl *, 2>(
Proto->getInheritedProtocols().begin(),
Proto->getInheritedProtocols().end());
} else {
return NTD->getAllProtocols();
}
}
class AddEquatableContext {
/// Declaration context
DeclContext *DC;
/// Adopter type
Type Adopter;
/// Start location of declaration context brace
SourceLoc StartLoc;
/// Array of all inherited protocols' locations
ArrayRef<InheritedEntry> ProtocolsLocations;
/// Array of all conformed protocols
SmallVector<swift::ProtocolDecl *, 2> Protocols;
/// Start location of declaration,
/// a place to write protocol name
SourceLoc ProtInsertStartLoc;
/// Stored properties of extending adopter
ArrayRef<VarDecl *> StoredProperties;
/// Range of internal members in declaration
DeclRange Range;
bool conformsToEquatableProtocol() {
for (ProtocolDecl *Protocol : Protocols) {
if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Equatable) {
return true;
}
}
return false;
}
bool isRequirementValid() {
auto Reqs = getProtocolRequirements();
if (Reqs.empty()) {
return false;
}
auto Req = dyn_cast<FuncDecl>(Reqs[0]);
return Req && Req->getParameters()->size() == 2;
}
bool isPropertiesListValid() {
return !getUserAccessibleProperties().empty();
}
void printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent,
ParameterList *Params);
std::vector<ValueDecl *> getProtocolRequirements();
std::vector<VarDecl *> getUserAccessibleProperties();
public:
AddEquatableContext(NominalTypeDecl *Decl)
: DC(Decl), Adopter(Decl->getDeclaredType()),
StartLoc(Decl->getBraces().Start),
ProtocolsLocations(Decl->getInherited()),
Protocols(getAllProtocols(Decl)),
ProtInsertStartLoc(Decl->getNameLoc()),
StoredProperties(Decl->getStoredProperties()),
Range(Decl->getMembers()){};
AddEquatableContext(ExtensionDecl *Decl)
: DC(Decl), Adopter(Decl->getExtendedType()),
StartLoc(Decl->getBraces().Start),
ProtocolsLocations(Decl->getInherited()),
Protocols(getAllProtocols(Decl->getExtendedNominal())),
ProtInsertStartLoc(Decl->getExtendedTypeRepr()->getEndLoc()),
StoredProperties(Decl->getExtendedNominal()->getStoredProperties()),
Range(Decl->getMembers()){};
AddEquatableContext() : DC(nullptr), Adopter(), ProtocolsLocations(),
Protocols(), StoredProperties(), Range(nullptr, nullptr) {};
static AddEquatableContext
getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info);
std::string getInsertionTextForProtocol();
std::string getInsertionTextForFunction(SourceManager &SM);
bool isValid() {
// FIXME: Allow to generate explicit == method for declarations which already have
// compiler-generated == method
return StartLoc.isValid() && ProtInsertStartLoc.isValid() &&
!conformsToEquatableProtocol() && isPropertiesListValid() &&
isRequirementValid();
}
SourceLoc getStartLocForProtocolDecl() {
if (ProtocolsLocations.empty()) {
return ProtInsertStartLoc;
}
return ProtocolsLocations.back().getSourceRange().Start;
}
bool isMembersRangeEmpty() {
return Range.empty();
}
SourceLoc getInsertStartLoc();
};
SourceLoc AddEquatableContext::
getInsertStartLoc() {
SourceLoc MaxLoc = StartLoc;
for (auto Mem : Range) {
if (Mem->getEndLoc().getOpaquePointerValue() >
MaxLoc.getOpaquePointerValue()) {
MaxLoc = Mem->getEndLoc();
}
}
return MaxLoc;
}
std::string AddEquatableContext::
getInsertionTextForProtocol() {
StringRef ProtocolName = getProtocolName(KnownProtocolKind::Equatable);
std::string Buffer;
llvm::raw_string_ostream OS(Buffer);
if (ProtocolsLocations.empty()) {
OS << ": " << ProtocolName;
return Buffer;
}
OS << ", " << ProtocolName;
return Buffer;
}
std::string AddEquatableContext::
getInsertionTextForFunction(SourceManager &SM) {
auto Reqs = getProtocolRequirements();
auto Req = dyn_cast<FuncDecl>(Reqs[0]);
auto Params = Req->getParameters();
StringRef ExtraIndent;
StringRef CurrentIndent =
Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent);
std::string Indent;
if (isMembersRangeEmpty()) {
Indent = (CurrentIndent + ExtraIndent).str();
} else {
Indent = CurrentIndent.str();
}
PrintOptions Options = PrintOptions::printVerbose();
Options.PrintDocumentationComments = false;
Options.setBaseType(Adopter);
Options.FunctionBody = [&](const ValueDecl *VD, ASTPrinter &Printer) {
Printer << " {";
Printer.printNewline();
printFunctionBody(Printer, ExtraIndent, Params);
Printer.printNewline();
Printer << "}";
};
std::string Buffer;
llvm::raw_string_ostream OS(Buffer);
ExtraIndentStreamPrinter Printer(OS, Indent);
Printer.printNewline();
if (!isMembersRangeEmpty()) {
Printer.printNewline();
}
Reqs[0]->print(Printer, Options);
return Buffer;
}
std::vector<VarDecl *> AddEquatableContext::
getUserAccessibleProperties() {
std::vector<VarDecl *> PublicProperties;
for (VarDecl *Decl : StoredProperties) {
if (Decl->Decl::isUserAccessible()) {
PublicProperties.push_back(Decl);
}
}
return PublicProperties;
}
std::vector<ValueDecl *> AddEquatableContext::
getProtocolRequirements() {
std::vector<ValueDecl *> Collection;
auto Proto = DC->getASTContext().getProtocol(KnownProtocolKind::Equatable);
for (auto Member : Proto->getMembers()) {
auto Req = dyn_cast<ValueDecl>(Member);
if (!Req || Req->isInvalid() || !Req->isProtocolRequirement()) {
continue;
}
Collection.push_back(Req);
}
return Collection;
}
AddEquatableContext
AddEquatableContext::getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info) {
auto ValueRefInfo = dyn_cast<ResolvedValueRefCursorInfo>(Info);
if (!ValueRefInfo) {
return AddEquatableContext();
}
if (!ValueRefInfo->isRef()) {
if (auto *NomDecl = dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD())) {
return AddEquatableContext(NomDecl);
}
} else if (auto *ExtDecl = ValueRefInfo->getExtTyRef()) {
if (ExtDecl->getExtendedNominal()) {
return AddEquatableContext(ExtDecl);
}
}
return AddEquatableContext();
}
void AddEquatableContext::
printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent, ParameterList *Params) {
SmallString<128> Return;
llvm::raw_svector_ostream SS(Return);
SS << tok::kw_return;
StringRef Space = " ";
StringRef AdditionalSpace = " ";
StringRef Point = ".";
StringRef Join = " == ";
StringRef And = " &&";
auto Props = getUserAccessibleProperties();
auto FParam = Params->get(0)->getName();
auto SParam = Params->get(1)->getName();
auto Prop = Props[0]->getName();
Printer << ExtraIndent << Return << Space
<< FParam << Point << Prop << Join << SParam << Point << Prop;
if (Props.size() > 1) {
std::for_each(Props.begin() + 1, Props.end(), [&](VarDecl *VD){
auto Name = VD->getName();
Printer << And;
Printer.printNewline();
Printer << ExtraIndent << AdditionalSpace << FParam << Point
<< Name << Join << SParam << Point << Name;
});
}
}
bool RefactoringActionAddEquatableConformance::isApplicable(
ResolvedCursorInfoPtr Tok, DiagnosticEngine &Diag) {
return AddEquatableContext::getDeclarationContextFromInfo(Tok).isValid();
}
bool RefactoringActionAddEquatableConformance::
performChange() {
auto Context = AddEquatableContext::getDeclarationContextFromInfo(CursorInfo);
EditConsumer.insertAfter(SM, Context.getStartLocForProtocolDecl(),
Context.getInsertionTextForProtocol());
EditConsumer.insertAfter(SM, Context.getInsertStartLoc(),
Context.getInsertionTextForFunction(SM));
return false;
}
class AddCodableContext {
/// Declaration context
DeclContext *DC;
/// Start location of declaration context brace
SourceLoc StartLoc;
/// Array of all conformed protocols
SmallVector<swift::ProtocolDecl *, 2> Protocols;
/// Range of internal members in declaration
DeclRange Range;
bool conformsToCodableProtocol() {
for (ProtocolDecl *Protocol : Protocols) {
if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Encodable ||
Protocol->getKnownProtocolKind() == KnownProtocolKind::Decodable) {
return true;
}
}
return false;
}
public:
AddCodableContext(NominalTypeDecl *Decl)
: DC(Decl), StartLoc(Decl->getBraces().Start),
Protocols(getAllProtocols(Decl)), Range(Decl->getMembers()){};
AddCodableContext(ExtensionDecl *Decl)
: DC(Decl), StartLoc(Decl->getBraces().Start),
Protocols(getAllProtocols(Decl->getExtendedNominal())),
Range(Decl->getMembers()){};
AddCodableContext() : DC(nullptr), Protocols(), Range(nullptr, nullptr){};
static AddCodableContext
getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info);
void printInsertionText(ResolvedCursorInfoPtr CursorInfo, SourceManager &SM,
llvm::raw_ostream &OS);
bool isValid() { return StartLoc.isValid() && conformsToCodableProtocol(); }
SourceLoc getInsertStartLoc();
};
SourceLoc AddCodableContext::getInsertStartLoc() {
SourceLoc MaxLoc = StartLoc;
for (auto Mem : Range) {
if (Mem->getEndLoc().getOpaquePointerValue() >
MaxLoc.getOpaquePointerValue()) {
MaxLoc = Mem->getEndLoc();
}
}
return MaxLoc;
}
/// Walks an AST and prints the synthesized Codable implementation.
class SynthesizedCodablePrinter : public ASTWalker {
private:
ASTPrinter &Printer;
public:
SynthesizedCodablePrinter(ASTPrinter &Printer) : Printer(Printer) {}
MacroWalking getMacroWalkingBehavior() const override {
return MacroWalking::Arguments;
}
PreWalkAction walkToDeclPre(Decl *D) override {
auto *VD = dyn_cast<ValueDecl>(D);
if (!VD)
return Action::SkipChildren();
if (!VD->isSynthesized()) {
return Action::Continue();
}
SmallString<32> Scratch;
auto name = VD->getName().getString(Scratch);
// Print all synthesized enums,
// since Codable can synthesize multiple enums (for associated values).
auto shouldPrint =
isa<EnumDecl>(VD) || name == "init(from:)" || name == "encode(to:)";
if (!shouldPrint) {
// Some other synthesized decl that we don't want to print.
return Action::SkipChildren();
}
Printer.printNewline();
if (auto enumDecl = dyn_cast<EnumDecl>(D)) {
// Manually print enum here, since we don't want to print synthesized
// functions.
Printer << "enum " << enumDecl->getNameStr();
PrintOptions Options;
Options.PrintSpaceBeforeInheritance = false;
enumDecl->printInherited(Printer, Options);
Printer << " {";
for (Decl *EC : enumDecl->getAllElements()) {
Printer.printNewline();
Printer << " ";
EC->print(Printer, Options);
}
Printer.printNewline();
Printer << "}";
return Action::SkipChildren();
}
PrintOptions Options;
Options.SynthesizeSugarOnTypes = true;
Options.FunctionDefinitions = true;
Options.VarInitializers = true;
Options.PrintExprs = true;
Options.TypeDefinitions = true;
Options.ExcludeAttrList.push_back(DAK_HasInitialValue);
Printer.printNewline();
D->print(Printer, Options);
return Action::SkipChildren();
}
};
void AddCodableContext::printInsertionText(ResolvedCursorInfoPtr CursorInfo,
SourceManager &SM,
llvm::raw_ostream &OS) {
StringRef ExtraIndent;
StringRef CurrentIndent =
Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent);
std::string Indent;
if (getInsertStartLoc() == StartLoc) {
Indent = (CurrentIndent + ExtraIndent).str();
} else {
Indent = CurrentIndent.str();
}
ExtraIndentStreamPrinter Printer(OS, Indent);
Printer.printNewline();
SynthesizedCodablePrinter Walker(Printer);
DC->getAsDecl()->walk(Walker);
}
AddCodableContext
AddCodableContext::getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info) {
auto ValueRefInfo = dyn_cast<ResolvedValueRefCursorInfo>(Info);
if (!ValueRefInfo) {
return AddCodableContext();
}
if (!ValueRefInfo->isRef()) {
if (auto *NomDecl = dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD())) {
return AddCodableContext(NomDecl);
}
}
// TODO: support extensions
// (would need to get synthesized nodes from the main decl,
// and only if it's in the same file?)
return AddCodableContext();
}
bool RefactoringActionAddExplicitCodableImplementation::isApplicable(
ResolvedCursorInfoPtr Tok, DiagnosticEngine &Diag) {
return AddCodableContext::getDeclarationContextFromInfo(Tok).isValid();
}
bool RefactoringActionAddExplicitCodableImplementation::performChange() {
auto Context = AddCodableContext::getDeclarationContextFromInfo(CursorInfo);
SmallString<64> Buffer;
llvm::raw_svector_ostream OS(Buffer);
Context.printInsertionText(CursorInfo, SM, OS);
EditConsumer.insertAfter(SM, Context.getInsertStartLoc(), OS.str());
return false;
}
static CharSourceRange
findSourceRangeToWrapInCatch(const ResolvedExprStartCursorInfo &CursorInfo,
SourceFile *TheFile, SourceManager &SM) {
Expr *E = CursorInfo.getTrailingExpr();
if (!E)
return CharSourceRange();
auto Node = ASTNode(E);
auto NodeChecker = [](ASTNode N) { return N.isStmt(StmtKind::Brace); };
ContextFinder Finder(*TheFile, Node, NodeChecker);
Finder.resolve();
auto Contexts = Finder.getContexts();
if (Contexts.empty())
return CharSourceRange();
auto TargetNode = Contexts.back();
BraceStmt *BStmt = dyn_cast<BraceStmt>(TargetNode.dyn_cast<Stmt*>());
auto ConvertToCharRange = [&SM](SourceRange SR) {
return Lexer::getCharSourceRangeFromSourceRange(SM, SR);
};
assert(BStmt);
auto ExprRange = ConvertToCharRange(E->getSourceRange());
// Check elements of the deepest BraceStmt, pick one that covers expression.
for (auto Elem: BStmt->getElements()) {
auto ElemRange = ConvertToCharRange(Elem.getSourceRange());
if (ElemRange.contains(ExprRange))
TargetNode = Elem;
}
return ConvertToCharRange(TargetNode.getSourceRange());
}
bool RefactoringActionConvertToDoCatch::isApplicable(ResolvedCursorInfoPtr Tok,
DiagnosticEngine &Diag) {
auto ExprStartInfo = dyn_cast<ResolvedExprStartCursorInfo>(Tok);
if (!ExprStartInfo || !ExprStartInfo->getTrailingExpr())
return false;
return isa<ForceTryExpr>(ExprStartInfo->getTrailingExpr());
}
bool RefactoringActionConvertToDoCatch::performChange() {
auto ExprStartInfo = cast<ResolvedExprStartCursorInfo>(CursorInfo);
auto *TryExpr = dyn_cast<ForceTryExpr>(ExprStartInfo->getTrailingExpr());
assert(TryExpr);
auto Range = findSourceRangeToWrapInCatch(*ExprStartInfo, TheFile, SM);
if (!Range.isValid())
return true;
// Wrap given range in do catch block.
EditConsumer.accept(SM, Range.getStart(), "do {\n");
EditorConsumerInsertStream OS(EditConsumer, SM, Range.getEnd());
OS << "\n} catch {\n" << getCodePlaceholder() << "\n}";
// Delete ! from try! expression
auto ExclaimLen = getKeywordLen(tok::exclaim_postfix);
auto ExclaimRange = CharSourceRange(TryExpr->getExclaimLoc(), ExclaimLen);
EditConsumer.remove(SM, ExclaimRange);
return false;
}
/// Given a cursor position, this function tries to collect a number literal
/// expression immediately following the cursor.
static NumberLiteralExpr *getTrailingNumberLiteral(ResolvedCursorInfoPtr Tok) {
// This cursor must point to the start of an expression.
auto ExprStartInfo = dyn_cast<ResolvedExprStartCursorInfo>(Tok);
if (!ExprStartInfo)
return nullptr;
// For every sub-expression, try to find the literal expression that matches
// our criteria.
class FindLiteralNumber : public ASTWalker {
Expr * const parent;
public:
NumberLiteralExpr *found = nullptr;
explicit FindLiteralNumber(Expr *parent) : parent(parent) { }
MacroWalking getMacroWalkingBehavior() const override {
return MacroWalking::Arguments;
}
PreWalkResult<Expr *> walkToExprPre(Expr *expr) override {
if (auto *literal = dyn_cast<NumberLiteralExpr>(expr)) {
// The sub-expression must have the same start loc with the outermost
// expression, i.e. the cursor position.
if (!found &&
parent->getStartLoc().getOpaquePointerValue() ==
expr->getStartLoc().getOpaquePointerValue()) {
found = literal;
}
}
return Action::SkipChildrenIf(found, expr);
}
};
auto parent = ExprStartInfo->getTrailingExpr();
FindLiteralNumber finder(parent);
parent->walk(finder);
return finder.found;
}
static std::string insertUnderscore(StringRef Text) {
SmallString<64> Buffer;
llvm::raw_svector_ostream OS(Buffer);
for (auto It = Text.begin(); It != Text.end(); ++It) {
unsigned Distance = It - Text.begin();
if (Distance && !(Distance % 3)) {
OS << '_';
}
OS << *It;
}
return OS.str().str();
}
void insertUnderscoreInDigits(StringRef Digits,
raw_ostream &OS) {
StringRef BeforePointRef, AfterPointRef;
std::tie(BeforePointRef, AfterPointRef) = Digits.split('.');
std::string BeforePoint(BeforePointRef);
std::string AfterPoint(AfterPointRef);
// Insert '_' for the part before the decimal point.
std::reverse(BeforePoint.begin(), BeforePoint.end());
BeforePoint = insertUnderscore(BeforePoint);
std::reverse(BeforePoint.begin(), BeforePoint.end());
OS << BeforePoint;
// Insert '_' for the part after the decimal point, if necessary.
if (!AfterPoint.empty()) {
OS << '.';
OS << insertUnderscore(AfterPoint);
}
}
bool RefactoringActionSimplifyNumberLiteral::isApplicable(
ResolvedCursorInfoPtr Tok, DiagnosticEngine &Diag) {
if (auto *Literal = getTrailingNumberLiteral(Tok)) {
SmallString<64> Buffer;
llvm::raw_svector_ostream OS(Buffer);
StringRef Digits = Literal->getDigitsText();
insertUnderscoreInDigits(Digits, OS);
// If inserting '_' results in a different digit sequence, this refactoring
// is applicable.
return OS.str() != Digits;
}
return false;
}
bool RefactoringActionSimplifyNumberLiteral::performChange() {
if (auto *Literal = getTrailingNumberLiteral(CursorInfo)) {
EditorConsumerInsertStream OS(EditConsumer, SM,
CharSourceRange(SM, Literal->getDigitsLoc(),
Lexer::getLocForEndOfToken(SM,
Literal->getEndLoc())));
StringRef Digits = Literal->getDigitsText();
insertUnderscoreInDigits(Digits, OS);
return false;
}
return true;
}
static CallExpr *findTrailingClosureTarget(SourceManager &SM,
ResolvedCursorInfoPtr CursorInfo) {
if (CursorInfo->getKind() == CursorInfoKind::StmtStart)
// StmtStart postion can't be a part of CallExpr.
return nullptr;
// Find inner most CallExpr
ContextFinder Finder(
*CursorInfo->getSourceFile(), CursorInfo->getLoc(), [](ASTNode N) {
return N.isStmt(StmtKind::Brace) || N.isExpr(ExprKind::Call);
});
Finder.resolve();
auto contexts = Finder.getContexts();
if (contexts.empty())
return nullptr;
// If the innermost context is a statement (which will be a BraceStmt per
// the filtering condition above), drop it.
if (contexts.back().is<Stmt *>()) {
contexts = contexts.drop_back();
}
if (contexts.empty() || !contexts.back().is<Expr*>())
return nullptr;
CallExpr *CE = cast<CallExpr>(contexts.back().get<Expr*>());
// The last argument is a non-trailing closure?
auto *Args = CE->getArgs();
if (Args->empty() || Args->hasAnyTrailingClosures())
return nullptr;
auto *LastArg = Args->back().getExpr();
if (auto *ICE = dyn_cast<ImplicitConversionExpr>(LastArg))
LastArg = ICE->getSyntacticSubExpr();
if (isa<ClosureExpr>(LastArg) || isa<CaptureListExpr>(LastArg))
return CE;
return nullptr;
}
bool RefactoringActionTrailingClosure::isApplicable(
ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &Diag) {
SourceManager &SM = CursorInfo->getSourceFile()->getASTContext().SourceMgr;
return findTrailingClosureTarget(SM, CursorInfo);
}
bool RefactoringActionTrailingClosure::performChange() {
auto *CE = findTrailingClosureTarget(SM, CursorInfo);
if (!CE)
return true;
auto *ArgList = CE->getArgs()->getOriginalArgs();
auto LParenLoc = ArgList->getLParenLoc();
auto RParenLoc = ArgList->getRParenLoc();
if (LParenLoc.isInvalid() || RParenLoc.isInvalid())
return true;
auto NumArgs = ArgList->size();
if (NumArgs == 0)
return true;
auto *ClosureArg = ArgList->getExpr(NumArgs - 1);
if (auto *ICE = dyn_cast<ImplicitConversionExpr>(ClosureArg))
ClosureArg = ICE->getSyntacticSubExpr();
// Replace:
// * Open paren with ' ' if the closure is sole argument.
// * Comma with ') ' otherwise.
if (NumArgs > 1) {
auto *PrevArg = ArgList->getExpr(NumArgs - 2);
CharSourceRange PreRange(
SM,
Lexer::getLocForEndOfToken(SM, PrevArg->getEndLoc()),
ClosureArg->getStartLoc());
EditConsumer.accept(SM, PreRange, ") ");
} else {
CharSourceRange PreRange(SM, LParenLoc, ClosureArg->getStartLoc());
EditConsumer.accept(SM, PreRange, " ");
}
// Remove original closing paren.
CharSourceRange PostRange(
SM,
Lexer::getLocForEndOfToken(SM, ClosureArg->getEndLoc()),
Lexer::getLocForEndOfToken(SM, RParenLoc));
EditConsumer.remove(SM, PostRange);
return false;
}
static bool collectRangeStartRefactorings(const ResolvedRangeInfo &Info) {
switch (Info.Kind) {
case RangeKind::SingleExpression:
case RangeKind::SingleStatement:
case RangeKind::SingleDecl:
case RangeKind::PartOfExpression:
return true;
case RangeKind::MultiStatement:
case RangeKind::MultiTypeMemberDecl:
case RangeKind::Invalid:
return false;
}
}
bool RefactoringActionConvertToComputedProperty::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
if (Info.Kind != RangeKind::SingleDecl) {
return false;
}
if (Info.ContainedNodes.size() != 1) {
return false;
}
auto D = Info.ContainedNodes[0].dyn_cast<Decl*>();
if (!D) {
return false;
}
auto Binding = dyn_cast<PatternBindingDecl>(D);
if (!Binding) {
return false;
}
auto SV = Binding->getSingleVar();
if (!SV) {
return false;
}
// willSet, didSet cannot be provided together with a getter
for (auto AD : SV->getAllAccessors()) {
if (AD->isObservingAccessor()) {
return false;
}
}
// 'lazy' must not be used on a computed property
// NSCopying and IBOutlet attribute requires property to be mutable
auto Attributies = SV->getAttrs();
if (Attributies.hasAttribute<LazyAttr>() ||
Attributies.hasAttribute<NSCopyingAttr>() ||
Attributies.hasAttribute<IBOutletAttr>()) {
return false;
}
// Property wrapper cannot be applied to a computed property
if (SV->hasAttachedPropertyWrapper()) {
return false;
}
// has an initializer
return Binding->hasInitStringRepresentation(0);
}
bool RefactoringActionConvertToComputedProperty::performChange() {
// Get an initialization
auto D = RangeInfo.ContainedNodes[0].dyn_cast<Decl*>();
auto Binding = dyn_cast<PatternBindingDecl>(D);
SmallString<128> scratch;
auto Init = Binding->getInitStringRepresentation(0, scratch);
// Get type
auto SV = Binding->getSingleVar();
auto SVType = SV->getType();
auto TR = SV->getTypeReprOrParentPatternTypeRepr();
SmallString<64> DeclBuffer;
llvm::raw_svector_ostream OS(DeclBuffer);
StringRef Space = " ";
StringRef NewLine = "\n";
OS << tok::kw_var << Space;
// Add var name
OS << SV->getNameStr().str() << ":" << Space;
// For computed property must write a type of var
if (TR) {
OS << Lexer::getCharSourceRangeFromSourceRange(SM, TR->getSourceRange()).str();
} else {
SVType.print(OS);
}
OS << Space << tok::l_brace << NewLine;
// Add an initialization
OS << tok::kw_return << Space << Init.str() << NewLine;
OS << tok::r_brace;
// Replace initializer to computed property
auto ReplaceStartLoc = Binding->getLoc();
auto ReplaceEndLoc = Binding->getSourceRange().End;
auto ReplaceRange = SourceRange(ReplaceStartLoc, ReplaceEndLoc);
auto ReplaceCharSourceRange = Lexer::getCharSourceRangeFromSourceRange(SM, ReplaceRange);
EditConsumer.accept(SM, ReplaceCharSourceRange, DeclBuffer.str());
return false; // success
}
namespace asyncrefactorings {
// TODO: Should probably split the refactorings into separate files
/// Whether the given type is (or conforms to) the stdlib Error type
bool isErrorType(Type Ty, ModuleDecl *MD) {
if (!Ty)
return false;
return !MD->conformsToProtocol(Ty, Ty->getASTContext().getErrorDecl())
.isInvalid();
}
// The single Decl* subject of a switch statement, or nullptr if none
Decl *singleSwitchSubject(const SwitchStmt *Switch) {
if (auto *DRE = dyn_cast<DeclRefExpr>(Switch->getSubjectExpr()))
return DRE->getDecl();
return nullptr;
}
/// A more aggressive variant of \c Expr::getReferencedDecl that also looks
/// through autoclosures created to pass the \c self parameter to a member funcs
ValueDecl *getReferencedDecl(const Expr *Fn) {
Fn = Fn->getSemanticsProvidingExpr();
if (auto *DRE = dyn_cast<DeclRefExpr>(Fn))
return DRE->getDecl();
if (auto ApplyE = dyn_cast<SelfApplyExpr>(Fn))
return getReferencedDecl(ApplyE->getFn());
if (auto *ACE = dyn_cast<AutoClosureExpr>(Fn)) {
if (auto *Unwrapped = ACE->getUnwrappedCurryThunkExpr())
return getReferencedDecl(Unwrapped);
}
return nullptr;
}
FuncDecl *getUnderlyingFunc(const Expr *Fn) {
return dyn_cast_or_null<FuncDecl>(getReferencedDecl(Fn));
}
/// Find the outermost call of the given location
CallExpr *findOuterCall(ResolvedCursorInfoPtr CursorInfo) {
auto IncludeInContext = [](ASTNode N) {
if (auto *E = N.dyn_cast<Expr *>())
return !E->isImplicit();
return false;
};
// TODO: Bit pointless using the "ContextFinder" here. Ideally we would have
// already generated a slice of the AST for anything that contains
// the cursor location
ContextFinder Finder(*CursorInfo->getSourceFile(), CursorInfo->getLoc(),
IncludeInContext);
Finder.resolve();
auto Contexts = Finder.getContexts();
if (Contexts.empty())
return nullptr;
CallExpr *CE = dyn_cast<CallExpr>(Contexts[0].get<Expr *>());
if (!CE)
return nullptr;
SourceManager &SM = CursorInfo->getSourceFile()->getASTContext().SourceMgr;
if (!SM.rangeContains(CE->getFn()->getSourceRange(), CursorInfo->getLoc()))
return nullptr;
return CE;
}
/// Find the function matching the given location if it is not an accessor and
/// either has a body or is a member of a protocol
FuncDecl *findFunction(ResolvedCursorInfoPtr CursorInfo) {
auto IncludeInContext = [](ASTNode N) {
if (auto *D = N.dyn_cast<Decl *>())
return !D->isImplicit();
return false;
};
ContextFinder Finder(*CursorInfo->getSourceFile(), CursorInfo->getLoc(),
IncludeInContext);
Finder.resolve();
auto Contexts = Finder.getContexts();
if (Contexts.empty())
return nullptr;
if (Contexts.back().isDecl(DeclKind::Param))
Contexts = Contexts.drop_back();
auto *FD = dyn_cast_or_null<FuncDecl>(Contexts.back().get<Decl *>());
if (!FD || isa<AccessorDecl>(FD))
return nullptr;
auto *Body = FD->getBody();
if (!Body && !isa<ProtocolDecl>(FD->getDeclContext()))
return nullptr;
SourceManager &SM = CursorInfo->getSourceFile()->getASTContext().SourceMgr;
SourceLoc DeclEnd = Body ? Body->getLBraceLoc() : FD->getEndLoc();
if (!SM.rangeContains(SourceRange(FD->getStartLoc(), DeclEnd),
CursorInfo->getLoc()))
return nullptr;
return FD;
}
FuncDecl *isOperator(const BinaryExpr *BE) {
auto *AE = dyn_cast<ApplyExpr>(BE->getFn());
if (AE) {
auto *Callee = AE->getCalledValue();
if (Callee && Callee->isOperator() && isa<FuncDecl>(Callee))
return cast<FuncDecl>(Callee);
}
return nullptr;
}
/// Describes the expressions to be kept from a call to the handler in a
/// function that has (or will have ) and async alternative. Eg.
/// ```
/// func toBeAsync(completion: (String?, Error?) -> Void) {
/// ...
/// completion("something", nil) // Result = ["something"], IsError = false
/// ...
/// completion(nil, MyError.Bad) // Result = [MyError.Bad], IsError = true
/// }
class HandlerResult {
SmallVector<Argument, 2> Args;
bool IsError = false;
public:
HandlerResult() {}
HandlerResult(ArrayRef<Argument> ArgsRef)
: Args(ArgsRef.begin(), ArgsRef.end()) {}
HandlerResult(Argument Arg, bool IsError) : IsError(IsError) {
Args.push_back(Arg);
}
bool isError() { return IsError; }
ArrayRef<Argument> args() { return Args; }
};
/// The type of the handler, ie. whether it takes regular parameters or a
/// single parameter of `Result` type.
enum class HandlerType { INVALID, PARAMS, RESULT };
/// A single return type of a refactored async function. If the async function
/// returns a tuple, each element of the tuple (represented by a \c
/// LabeledReturnType) might have a label, otherwise the \p Label is empty.
struct LabeledReturnType {
Identifier Label;
swift::Type Ty;
LabeledReturnType(Identifier Label, swift::Type Ty) : Label(Label), Ty(Ty) {}
};
/// Given a function with an async alternative (or one that *could* have an
/// async alternative), stores information about the completion handler.
/// The completion handler can be either a variable (which includes a parameter)
/// or a function
struct AsyncHandlerDesc {
PointerUnion<const VarDecl *, const AbstractFunctionDecl *> Handler = nullptr;
HandlerType Type = HandlerType::INVALID;
bool HasError = false;
static AsyncHandlerDesc get(const ValueDecl *Handler, bool RequireName) {
AsyncHandlerDesc HandlerDesc;
if (auto Var = dyn_cast<VarDecl>(Handler)) {
HandlerDesc.Handler = Var;
} else if (auto Func = dyn_cast<AbstractFunctionDecl>(Handler)) {
HandlerDesc.Handler = Func;
} else {
// The handler must be a variable or function
return AsyncHandlerDesc();
}
// Callback must have a completion-like name
if (RequireName && !isCompletionHandlerParamName(HandlerDesc.getNameStr()))
return AsyncHandlerDesc();
// Callback must be a function type and return void. Doesn't need to have
// any parameters - may just be a "I'm done" callback
auto *HandlerTy = HandlerDesc.getType()->getAs<AnyFunctionType>();
if (!HandlerTy || !HandlerTy->getResult()->isVoid())
return AsyncHandlerDesc();
// Find the type of result in the handler (eg. whether it's a Result<...>,
// just parameters, or nothing).
auto HandlerParams = HandlerTy->getParams();
if (HandlerParams.size() == 1) {
auto ParamTy =
HandlerParams.back().getPlainType()->getAs<BoundGenericType>();
if (ParamTy && ParamTy->isResult()) {
auto GenericArgs = ParamTy->getGenericArgs();
assert(GenericArgs.size() == 2 && "Result should have two params");
HandlerDesc.Type = HandlerType::RESULT;
HandlerDesc.HasError = !GenericArgs.back()->isUninhabited();
}
}
if (HandlerDesc.Type != HandlerType::RESULT) {
// Only handle non-result parameters
for (auto &Param : HandlerParams) {
if (Param.getPlainType() && Param.getPlainType()->isResult())
return AsyncHandlerDesc();
}
HandlerDesc.Type = HandlerType::PARAMS;
if (!HandlerParams.empty()) {
auto LastParamTy = HandlerParams.back().getParameterType();
HandlerDesc.HasError = isErrorType(LastParamTy->getOptionalObjectType(),
Handler->getModuleContext());
}
}
return HandlerDesc;
}
bool isValid() const { return Type != HandlerType::INVALID; }
/// Return the declaration of the completion handler as a \c ValueDecl.
/// In practice, the handler will always be a \c VarDecl or \c
/// AbstractFunctionDecl.
/// \c getNameStr and \c getType provide access functions that are available
/// for both variables and functions, but not on \c ValueDecls.
const ValueDecl *getHandler() const {
if (!Handler) {
return nullptr;
}
if (auto Var = Handler.dyn_cast<const VarDecl *>()) {
return Var;
} else if (auto Func = Handler.dyn_cast<const AbstractFunctionDecl *>()) {
return Func;
} else {
llvm_unreachable("Unknown handler type");
}
}
/// Return the name of the completion handler. If it is a variable, the
/// variable name, if it's a function, the function base name.
StringRef getNameStr() const {
if (auto Var = Handler.dyn_cast<const VarDecl *>()) {
return Var->getNameStr();
} else if (auto Func = Handler.dyn_cast<const AbstractFunctionDecl *>()) {
return Func->getNameStr();
} else {
llvm_unreachable("Unknown handler type");
}
}
HandlerType getHandlerType() const { return Type; }
/// Get the type of the completion handler.
swift::Type getType() const {
if (auto Var = Handler.dyn_cast<const VarDecl *>()) {
return Var->getType();
} else if (auto Func = Handler.dyn_cast<const AbstractFunctionDecl *>()) {
auto Type = Func->getInterfaceType();
// Undo the self curry thunk if we are referencing a member function.
if (Func->hasImplicitSelfDecl()) {
assert(Type->is<AnyFunctionType>());
Type = Type->getAs<AnyFunctionType>()->getResult();
}
return Type;
} else {
llvm_unreachable("Unknown handler type");
}
}
ArrayRef<AnyFunctionType::Param> params() const {
auto Ty = getType()->getAs<AnyFunctionType>();
assert(Ty && "Type must be a function type");
return Ty->getParams();
}
/// Retrieve the parameters relevant to a successful return from the
/// completion handler. This drops the Error parameter if present.
ArrayRef<AnyFunctionType::Param> getSuccessParams() const {
if (HasError && Type == HandlerType::PARAMS)
return params().drop_back();
return params();
}
/// If the completion handler has an Error parameter, return it.
Optional<AnyFunctionType::Param> getErrorParam() const {
if (HasError && Type == HandlerType::PARAMS)
return params().back();
return None;
}
/// Get the type of the error that will be thrown by the \c async method or \c
/// None if the completion handler doesn't accept an error parameter.
/// This may be more specialized than the generic 'Error' type if the
/// completion handler of the converted function takes a more specialized
/// error type.
Optional<swift::Type> getErrorType() const {
if (HasError) {
switch (Type) {
case HandlerType::INVALID:
return None;
case HandlerType::PARAMS:
// The last parameter of the completion handler is the error param
return params().back().getPlainType()->lookThroughSingleOptionalType();
case HandlerType::RESULT:
assert(
params().size() == 1 &&
"Result handler should have the Result type as the only parameter");
auto ResultType =
params().back().getPlainType()->getAs<BoundGenericType>();
auto GenericArgs = ResultType->getGenericArgs();
assert(GenericArgs.size() == 2 && "Result should have two params");
// The second (last) generic parameter of the Result type is the error
// type.
return GenericArgs.back();
}
} else {
return None;
}
}
/// The `CallExpr` if the given node is a call to the `Handler`
CallExpr *getAsHandlerCall(ASTNode Node) const {
if (!isValid())
return nullptr;
if (auto E = Node.dyn_cast<Expr *>()) {
if (auto *CE = dyn_cast<CallExpr>(E->getSemanticsProvidingExpr())) {
if (CE->getFn()->getReferencedDecl().getDecl() == getHandler()) {
return CE;
}
}
}
return nullptr;
}
/// Returns \c true if the call to the completion handler contains possibly
/// non-nil values for both the success and error parameters, e.g.
/// \code
/// completion(result, error)
/// \endcode
/// This can only happen if the completion handler is a params handler.
bool isAmbiguousCallToParamHandler(const CallExpr *CE) const {
if (!HasError || Type != HandlerType::PARAMS) {
// Only param handlers with an error can pass both an error AND a result.
return false;
}
auto Args = CE->getArgs()->getArgExprs();
if (!isa<NilLiteralExpr>(Args.back())) {
// We've got an error parameter. If any of the success params is not nil,
// the call is ambiguous.
for (auto &Arg : Args.drop_back()) {
if (!isa<NilLiteralExpr>(Arg)) {
return true;
}
}
}
return false;
}
/// Given a call to the `Handler`, extract the expressions to be returned or
/// thrown, taking care to remove the `.success`/`.failure` if it's a
/// `RESULT` handler type.
/// If the call is ambiguous (contains potentially non-nil arguments to both
/// the result and the error parameters), the \p ReturnErrorArgsIfAmbiguous
/// determines whether the success or error parameters are passed.
HandlerResult extractResultArgs(const CallExpr *CE,
bool ReturnErrorArgsIfAmbiguous) const {
auto *ArgList = CE->getArgs();
SmallVector<Argument, 2> Scratch(ArgList->begin(), ArgList->end());
auto Args = llvm::makeArrayRef(Scratch);
if (Type == HandlerType::PARAMS) {
bool IsErrorResult;
if (isAmbiguousCallToParamHandler(CE)) {
IsErrorResult = ReturnErrorArgsIfAmbiguous;
} else {
// If there's an error parameter and the user isn't passing nil to it,
// assume this is the error path.
IsErrorResult =
(HasError && !isa<NilLiteralExpr>(Args.back().getExpr()));
}
if (IsErrorResult)
return HandlerResult(Args.back(), true);
// We can drop the args altogether if they're just Void.
if (willAsyncReturnVoid())
return HandlerResult();
return HandlerResult(HasError ? Args.drop_back() : Args);
} else if (Type == HandlerType::RESULT) {
if (Args.size() != 1)
return HandlerResult(Args);
auto *ResultCE = dyn_cast<CallExpr>(Args[0].getExpr());
if (!ResultCE)
return HandlerResult(Args);
auto *DSC = dyn_cast<DotSyntaxCallExpr>(ResultCE->getFn());
if (!DSC)
return HandlerResult(Args);
auto *D = dyn_cast<EnumElementDecl>(
DSC->getFn()->getReferencedDecl().getDecl());
if (!D)
return HandlerResult(Args);
auto ResultArgList = ResultCE->getArgs();
auto isFailure = D->getNameStr() == StringRef("failure");
// We can drop the arg altogether if it's just Void.
if (!isFailure && willAsyncReturnVoid())
return HandlerResult();
// Otherwise the arg gets the .success() or .failure() call dropped.
return HandlerResult(ResultArgList->get(0), isFailure);
}
llvm_unreachable("Unhandled result type");
}
// Convert the type of a success parameter in the completion handler function
// to a return type suitable for an async function. If there is an error
// parameter present e.g (T?, Error?) -> Void, this unwraps a level of
// optionality from T?. If this is a Result<T, U> type, returns the success
// type T.
swift::Type getSuccessParamAsyncReturnType(swift::Type Ty) const {
switch (Type) {
case HandlerType::PARAMS: {
// If there's an Error parameter in the handler, the success branch can
// be unwrapped.
if (HasError)
Ty = Ty->lookThroughSingleOptionalType();
return Ty;
}
case HandlerType::RESULT: {
// Result<T, U> maps to T.
return Ty->castTo<BoundGenericType>()->getGenericArgs()[0];
}
case HandlerType::INVALID:
llvm_unreachable("Invalid handler type");
}
}
/// If the async function returns a tuple, the label of the \p Index -th
/// element in the returned tuple. If the function doesn't return a tuple or
/// the element is unlabeled, an empty identifier is returned.
Identifier getAsyncReturnTypeLabel(size_t Index) const {
assert(Index < getSuccessParams().size());
if (getSuccessParams().size() <= 1) {
// There can't be any labels if the async function doesn't return a tuple.
return Identifier();
} else {
return getSuccessParams()[Index].getInternalLabel();
}
}
/// Gets the return value types for the async equivalent of this handler.
ArrayRef<LabeledReturnType>
getAsyncReturnTypes(SmallVectorImpl<LabeledReturnType> &Scratch) const {
for (size_t I = 0; I < getSuccessParams().size(); ++I) {
auto Ty = getSuccessParams()[I].getParameterType();
Scratch.emplace_back(getAsyncReturnTypeLabel(I),
getSuccessParamAsyncReturnType(Ty));
}
return Scratch;
}
/// Whether the async equivalent of this handler returns Void.
bool willAsyncReturnVoid() const {
// If all of the success params will be converted to Void return types,
// this will be a Void async function.
return llvm::all_of(getSuccessParams(), [&](auto ¶m) {
auto Ty = param.getParameterType();
return getSuccessParamAsyncReturnType(Ty)->isVoid();
});
}
// TODO: If we have an async alternative we should check its result types
// for whether to unwrap or not
bool shouldUnwrap(swift::Type Ty) const {
return HasError && Ty->isOptional();
}
};
/// Given a completion handler that is part of a function signature, stores
/// information about that completion handler and its index within the function
/// declaration.
struct AsyncHandlerParamDesc : public AsyncHandlerDesc {
/// Enum to represent the position of the completion handler param within
/// the parameter list. Given `(A, B, C, D)`:
/// - A is `First`
/// - B and C are `Middle`
/// - D is `Last`
/// The position is `Only` if there's a single parameter that is the
/// completion handler and `None` if there is no handler.
enum class Position {
First, Middle, Last, Only, None
};
/// The function the completion handler is a parameter of.
const FuncDecl *Func = nullptr;
/// The index of the completion handler in the function that declares it.
unsigned Index = 0;
/// The async alternative, if one is found.
const AbstractFunctionDecl *Alternative = nullptr;
AsyncHandlerParamDesc() : AsyncHandlerDesc() {}
AsyncHandlerParamDesc(const AsyncHandlerDesc &Handler, const FuncDecl *Func,
unsigned Index,
const AbstractFunctionDecl *Alternative)
: AsyncHandlerDesc(Handler), Func(Func), Index(Index),
Alternative(Alternative) {}
static AsyncHandlerParamDesc find(const FuncDecl *FD,
bool RequireAttributeOrName) {
if (!FD || FD->hasAsync() || FD->hasThrows() ||
!FD->getResultInterfaceType()->isVoid())
return AsyncHandlerParamDesc();
const auto *Alternative = FD->getAsyncAlternative();
Optional<unsigned> Index =
FD->findPotentialCompletionHandlerParam(Alternative);
if (!Index)
return AsyncHandlerParamDesc();
bool RequireName = RequireAttributeOrName && !Alternative;
return AsyncHandlerParamDesc(
AsyncHandlerDesc::get(FD->getParameters()->get(*Index), RequireName),
FD, *Index, Alternative);
}
/// Build an @available attribute with the name of the async alternative as
/// the \c renamed argument, followed by a newline.
SmallString<128> buildRenamedAttribute() const {
SmallString<128> AvailabilityAttr;
llvm::raw_svector_ostream OS(AvailabilityAttr);
// If there's an alternative then there must already be an attribute,
// don't add another.
if (!isValid() || Alternative)
return AvailabilityAttr;
DeclName Name = Func->getName();
OS << "@available(*, renamed: \"" << Name.getBaseName() << "(";
ArrayRef<Identifier> ArgNames = Name.getArgumentNames();
for (size_t I = 0; I < ArgNames.size(); ++I) {
if (I != Index) {
OS << ArgNames[I] << tok::colon;
}
}
OS << ")\")\n";
return AvailabilityAttr;
}
/// Retrieves the parameter decl for the completion handler parameter, or
/// \c nullptr if no valid completion parameter is present.
const ParamDecl *getHandlerParam() const {
if (!isValid())
return nullptr;
return cast<ParamDecl>(getHandler());
}
/// See \c Position
Position handlerParamPosition() const {
if (!isValid())
return Position::None;
const auto *Params = Func->getParameters();
if (Params->size() == 1)
return Position::Only;
if (Index == 0)
return Position::First;
if (Index == Params->size() - 1)
return Position::Last;
return Position::Middle;
}
bool operator==(const AsyncHandlerParamDesc &Other) const {
return Handler == Other.Handler && Type == Other.Type &&
HasError == Other.HasError && Index == Other.Index;
}
bool alternativeIsAccessor() const {
return isa_and_nonnull<AccessorDecl>(Alternative);
}
};
/// The type of a condition in a conditional statement.
enum class ConditionType {
NIL, // == nil
NOT_NIL, // != nil
IS_TRUE, // if b
IS_FALSE, // if !b
SUCCESS_PATTERN, // case .success
FAILURE_PATTEN // case .failure
};
/// Indicates whether a condition describes a success or failure path. For
/// example, a check for whether an error parameter is present is a failure
/// path. A check for a nil error parameter is a success path. This is distinct
/// from ConditionType, as it relies on contextual information about what values
/// need to be checked for success or failure.
enum class ConditionPath { SUCCESS, FAILURE };
static ConditionPath flippedConditionPath(ConditionPath Path) {
switch (Path) {
case ConditionPath::SUCCESS:
return ConditionPath::FAILURE;
case ConditionPath::FAILURE:
return ConditionPath::SUCCESS;
}
llvm_unreachable("Unhandled case in switch!");
}
/// Finds the `Subject` being compared to in various conditions. Also finds any
/// pattern that may have a bound name.
struct CallbackCondition {
Optional<ConditionType> Type;
const Decl *Subject = nullptr;
const Pattern *BindPattern = nullptr;
/// Initializes a `CallbackCondition` with a `!=` or `==` comparison of
/// an `Optional` typed `Subject` to `nil`, or a `Bool` typed `Subject` to a
/// boolean literal, ie.
/// - `<Subject> != nil`
/// - `<Subject> == nil`
/// - `<Subject> != true`
/// - `<Subject> == false`
CallbackCondition(const BinaryExpr *BE, const FuncDecl *Operator) {
bool FoundNil = false;
BooleanLiteralExpr *FoundBool = nullptr;
bool DidUnwrapOptional = false;
for (auto *Operand : {BE->getLHS(), BE->getRHS()}) {
Operand = Operand->getSemanticsProvidingExpr();
if (auto *IIOE = dyn_cast<InjectIntoOptionalExpr>(Operand)) {
Operand = IIOE->getSubExpr()->getSemanticsProvidingExpr();
DidUnwrapOptional = true;
}
if (isa<NilLiteralExpr>(Operand)) {
FoundNil = true;
} else if (auto *BLE = dyn_cast<BooleanLiteralExpr>(Operand)) {
FoundBool = BLE;
} else if (auto *DRE = dyn_cast<DeclRefExpr>(Operand)) {
Subject = DRE->getDecl();
}
}
if (!Subject)
return;
if (FoundNil) {
if (Operator->getBaseName() == "==") {
Type = ConditionType::NIL;
} else if (Operator->getBaseName() == "!=") {
Type = ConditionType::NOT_NIL;
}
} else if (FoundBool) {
if (Operator->getBaseName() == "==") {
Type = FoundBool->getValue() ? ConditionType::IS_TRUE
: ConditionType::IS_FALSE;
} else if (Operator->getBaseName() == "!=" && !DidUnwrapOptional) {
// Note that we don't consider this case if we unwrapped an optional,
// as e.g optBool != false is a check for true *or* nil.
Type = FoundBool->getValue() ? ConditionType::IS_FALSE
: ConditionType::IS_TRUE;
}
}
}
/// A bool condition expression.
explicit CallbackCondition(const Expr *E) {
// FIXME: Sema should produce ErrorType.
if (!E->getType() || !E->getType()->isBool())
return;
auto CondType = ConditionType::IS_TRUE;
E = E->getSemanticsProvidingExpr();
// If we have a prefix negation operator, this is a check for false.
if (auto *PrefixOp = dyn_cast<PrefixUnaryExpr>(E)) {
auto *Callee = PrefixOp->getCalledValue();
if (Callee && Callee->isOperator() && Callee->getBaseName() == "!") {
CondType = ConditionType::IS_FALSE;
E = PrefixOp->getOperand()->getSemanticsProvidingExpr();
}
}
auto *DRE = dyn_cast<DeclRefExpr>(E);
if (!DRE)
return;
Subject = DRE->getDecl();
Type = CondType;
}
/// Initializes a `CallbackCondition` with binding of an `Optional` or
/// `Result` typed `Subject`, ie.
/// - `let bind = <Subject>`
/// - `case .success(let bind) = <Subject>`
/// - `case .failure(let bind) = <Subject>`
/// - `let bind = try? <Subject>.get()`
CallbackCondition(const Pattern *P, const Expr *Init) {
Init = Init->getSemanticsProvidingExpr();
P = P->getSemanticsProvidingPattern();
if (auto *DRE = dyn_cast<DeclRefExpr>(Init)) {
if (auto *OSP = dyn_cast<OptionalSomePattern>(P)) {
// `let bind = <Subject>`
Type = ConditionType::NOT_NIL;
Subject = DRE->getDecl();
BindPattern = OSP->getSubPattern();
} else if (auto *EEP = dyn_cast<EnumElementPattern>(P)) {
// `case .<func>(let <bind>) = <Subject>`
initFromEnumPattern(DRE->getDecl(), EEP);
}
} else if (auto *OTE = dyn_cast<OptionalTryExpr>(Init)) {
// `let bind = try? <Subject>.get()`
if (auto *OSP = dyn_cast<OptionalSomePattern>(P))
initFromOptionalTry(OSP->getSubPattern(), OTE);
}
}
/// Initializes a `CallbackCondtion` from a case statement inside a switch
/// on `Subject` with `Result` type, ie.
/// ```
/// switch <Subject> {
/// case .success(let bind):
/// case .failure(let bind):
/// }
/// ```
CallbackCondition(const Decl *Subject, const CaseLabelItem *CaseItem) {
if (auto *EEP = dyn_cast<EnumElementPattern>(
CaseItem->getPattern()->getSemanticsProvidingPattern())) {
// `case .<func>(let <bind>)`
initFromEnumPattern(Subject, EEP);
}
}
bool isValid() const { return Type.has_value(); }
private:
void initFromEnumPattern(const Decl *D, const EnumElementPattern *EEP) {
if (auto *EED = EEP->getElementDecl()) {
auto eedTy = EED->getParentEnum()->getDeclaredType();
if (!eedTy || !eedTy->isResult())
return;
if (EED->getNameStr() == StringRef("failure")) {
Type = ConditionType::FAILURE_PATTEN;
} else {
Type = ConditionType::SUCCESS_PATTERN;
}
Subject = D;
BindPattern = EEP->getSubPattern();
}
}
void initFromOptionalTry(const class Pattern *P, const OptionalTryExpr *OTE) {
auto *ICE = dyn_cast<ImplicitConversionExpr>(OTE->getSubExpr());
if (!ICE)
return;
auto *CE = dyn_cast<CallExpr>(ICE->getSyntacticSubExpr());
if (!CE)
return;
auto *DSC = dyn_cast<DotSyntaxCallExpr>(CE->getFn());
if (!DSC)
return;
auto *BaseDRE = dyn_cast<DeclRefExpr>(DSC->getBase());
if (!BaseDRE->getType() || !BaseDRE->getType()->isResult())
return;
auto *FnDRE = dyn_cast<DeclRefExpr>(DSC->getFn());
if (!FnDRE)
return;
auto *FD = dyn_cast<FuncDecl>(FnDRE->getDecl());
if (!FD || FD->getNameStr() != StringRef("get"))
return;
Type = ConditionType::NOT_NIL;
Subject = BaseDRE->getDecl();
BindPattern = P;
}
};
/// A CallbackCondition with additional semantic information about whether it
/// is for a success path or failure path.
struct ClassifiedCondition : public CallbackCondition {
ConditionPath Path;
/// Whether this represents an Obj-C style boolean flag check for success.
bool IsObjCStyleFlagCheck;
explicit ClassifiedCondition(CallbackCondition Cond, ConditionPath Path,
bool IsObjCStyleFlagCheck)
: CallbackCondition(Cond), Path(Path),
IsObjCStyleFlagCheck(IsObjCStyleFlagCheck) {}
};
/// A wrapper for a map of parameter decls to their classified conditions, or
/// \c None if they are not present in any conditions.
struct ClassifiedCallbackConditions final
: llvm::MapVector<const Decl *, ClassifiedCondition> {
Optional<ClassifiedCondition> lookup(const Decl *D) const {
auto Res = find(D);
if (Res == end())
return None;
return Res->second;
}
};
/// A list of nodes to print, along with a list of locations that may have
/// preceding comments attached, which also need printing. For example:
///
/// \code
/// if .random() {
/// // a
/// print("hello")
/// // b
/// }
/// \endcode
///
/// To print out the contents of the if statement body, we'll include the AST
/// node for the \c print call. This will also include the preceding comment
/// \c a, but won't include the comment \c b. To ensure the comment \c b gets
/// printed, the SourceLoc for the closing brace \c } is added as a possible
/// comment loc.
class NodesToPrint {
SmallVector<ASTNode, 0> Nodes;
SmallVector<SourceLoc, 2> PossibleCommentLocs;
public:
NodesToPrint() {}
NodesToPrint(ArrayRef<ASTNode> Nodes, ArrayRef<SourceLoc> PossibleCommentLocs)
: Nodes(Nodes.begin(), Nodes.end()),
PossibleCommentLocs(PossibleCommentLocs.begin(),
PossibleCommentLocs.end()) {}
ArrayRef<ASTNode> getNodes() const { return Nodes; }
ArrayRef<SourceLoc> getPossibleCommentLocs() const {
return PossibleCommentLocs;
}
/// Add an AST node to print.
void addNode(ASTNode Node) {
// Note we skip vars as they'll be printed as a part of their
// PatternBindingDecl.
if (!Node.isDecl(DeclKind::Var))
Nodes.push_back(Node);
}
/// Add a SourceLoc which may have a preceding comment attached. If so, the
/// comment will be printed out at the appropriate location.
void addPossibleCommentLoc(SourceLoc Loc) {
if (Loc.isValid())
PossibleCommentLocs.push_back(Loc);
}
/// Add all the nodes in the brace statement to the list of nodes to print.
/// This should be preferred over adding the nodes manually as it picks up the
/// end location of the brace statement as a possible comment loc, ensuring
/// that we print any trailing comments in the brace statement.
void addNodesInBraceStmt(BraceStmt *Brace) {
for (auto Node : Brace->getElements())
addNode(Node);
// Ignore the end locations of implicit braces, as they're likely bogus.
// e.g for a case statement, the r-brace loc points to the last token of the
// last node in the body.
if (!Brace->isImplicit())
addPossibleCommentLoc(Brace->getRBraceLoc());
}
/// Add the nodes and comment locs from another NodesToPrint.
void addNodes(NodesToPrint OtherNodes) {
Nodes.append(OtherNodes.Nodes.begin(), OtherNodes.Nodes.end());
PossibleCommentLocs.append(OtherNodes.PossibleCommentLocs.begin(),
OtherNodes.PossibleCommentLocs.end());
}
/// Whether the last recorded node is an explicit return or break statement.
bool hasTrailingReturnOrBreak() const {
if (Nodes.empty())
return false;
return (Nodes.back().isStmt(StmtKind::Return) ||
Nodes.back().isStmt(StmtKind::Break)) &&
!Nodes.back().isImplicit();
}
/// If the last recorded node is an explicit return or break statement that
/// can be safely dropped, drop it from the list.
void dropTrailingReturnOrBreakIfPossible() {
if (!hasTrailingReturnOrBreak())
return;
auto *Node = Nodes.back().get<Stmt *>();
// If this is a return statement with return expression, let's preserve it.
if (auto *RS = dyn_cast<ReturnStmt>(Node)) {
if (RS->hasResult())
return;
}
// Remove the node from the list, but make sure to add it as a possible
// comment loc to preserve any of its attached comments.
Nodes.pop_back();
addPossibleCommentLoc(Node->getStartLoc());
}
/// Returns a list of nodes to print in a brace statement. This picks up the
/// end location of the brace statement as a possible comment loc, ensuring
/// that we print any trailing comments in the brace statement.
static NodesToPrint inBraceStmt(BraceStmt *stmt) {
NodesToPrint Nodes;
Nodes.addNodesInBraceStmt(stmt);
return Nodes;
}
};
/// The statements within the closure of call to a function taking a callback
/// are split into a `SuccessBlock` and `ErrorBlock` (`ClassifiedBlocks`).
/// This class stores the nodes for each block, as well as a mapping of
/// decls to any patterns they are used in.
class ClassifiedBlock {
NodesToPrint Nodes;
// A mapping of closure params to a list of patterns that bind them.
using ParamPatternBindingsMap =
llvm::MapVector<const Decl *, TinyPtrVector<const Pattern *>>;
ParamPatternBindingsMap ParamPatternBindings;
public:
const NodesToPrint &nodesToPrint() const { return Nodes; }
/// Attempt to retrieve an existing bound name for a closure parameter, or
/// an empty string if there's no suitable existing binding.
StringRef boundName(const Decl *D) const {
// Adopt the same name as the representative single pattern, if it only
// binds a single var.
if (auto *P = getSinglePatternFor(D)) {
if (P->getSingleVar())
return P->getBoundName().str();
}
return StringRef();
}
/// Checks whether a closure parameter can be represented by a single pattern
/// that binds it. If the param is only bound by a single pattern, that will
/// be returned. If there's a pattern with a single var that binds it, that
/// will be returned, preferring a 'let' pattern to prefer out of line
/// printing of 'var' patterns.
const Pattern *getSinglePatternFor(const Decl *D) const {
auto Iter = ParamPatternBindings.find(D);
if (Iter == ParamPatternBindings.end())
return nullptr;
const auto &Patterns = Iter->second;
if (Patterns.empty())
return nullptr;
if (Patterns.size() == 1)
return Patterns[0];
// If we have multiple patterns, search for the best single var pattern to
// use, preferring a 'let' binding.
const Pattern *FirstSingleVar = nullptr;
for (auto *P : Patterns) {
if (!P->getSingleVar())
continue;
if (!P->hasAnyMutableBindings())
return P;
if (!FirstSingleVar)
FirstSingleVar = P;
}
return FirstSingleVar;
}
/// Retrieve any bound vars that are effectively aliases of a given closure
/// parameter.
llvm::SmallDenseSet<const Decl *> getAliasesFor(const Decl *D) const {
auto Iter = ParamPatternBindings.find(D);
if (Iter == ParamPatternBindings.end())
return {};
llvm::SmallDenseSet<const Decl *> Aliases;
// The single pattern that we replace the decl with is always an alias.
if (auto *P = getSinglePatternFor(D)) {
if (auto *SingleVar = P->getSingleVar())
Aliases.insert(SingleVar);
}
// Any other let bindings we have are also aliases.
for (auto *P : Iter->second) {
if (auto *SingleVar = P->getSingleVar()) {
if (!P->hasAnyMutableBindings())
Aliases.insert(SingleVar);
}
}
return Aliases;
}
const ParamPatternBindingsMap ¶mPatternBindings() const {
return ParamPatternBindings;
}
void addNodesInBraceStmt(BraceStmt *Brace) {
Nodes.addNodesInBraceStmt(Brace);
}
void addPossibleCommentLoc(SourceLoc Loc) {
Nodes.addPossibleCommentLoc(Loc);
}
void addAllNodes(NodesToPrint OtherNodes) {
Nodes.addNodes(std::move(OtherNodes));
}
void addNode(ASTNode Node) {
Nodes.addNode(Node);
}
void addBinding(const ClassifiedCondition &FromCondition) {
auto *P = FromCondition.BindPattern;
if (!P)
return;
// Patterns that don't bind anything aren't interesting.
SmallVector<VarDecl *, 2> Vars;
P->collectVariables(Vars);
if (Vars.empty())
return;
ParamPatternBindings[FromCondition.Subject].push_back(P);
}
void addAllBindings(const ClassifiedCallbackConditions &FromConditions) {
for (auto &Entry : FromConditions)
addBinding(Entry.second);
}
};
/// The type of block rewritten code may be placed in.
enum class BlockKind {
SUCCESS, ERROR, FALLBACK
};
/// A completion handler function parameter that is known to be a Bool flag
/// indicating success or failure.
struct KnownBoolFlagParam {
const ParamDecl *Param;
bool IsSuccessFlag;
};
/// A set of parameters for a completion callback closure.
class ClosureCallbackParams final {
const AsyncHandlerParamDesc &HandlerDesc;
ArrayRef<const ParamDecl *> AllParams;
llvm::SetVector<const ParamDecl *> SuccessParams;
const ParamDecl *ErrParam = nullptr;
Optional<KnownBoolFlagParam> BoolFlagParam;
public:
ClosureCallbackParams(const AsyncHandlerParamDesc &HandlerDesc,
const ClosureExpr *Closure)
: HandlerDesc(HandlerDesc),
AllParams(Closure->getParameters()->getArray()) {
assert(AllParams.size() == HandlerDesc.params().size());
assert(HandlerDesc.Type != HandlerType::RESULT || AllParams.size() == 1);
SuccessParams.insert(AllParams.begin(), AllParams.end());
if (HandlerDesc.HasError && HandlerDesc.Type == HandlerType::PARAMS)
ErrParam = SuccessParams.pop_back_val();
// Check to see if we have a known bool flag parameter.
if (auto *AsyncAlt = HandlerDesc.Func->getAsyncAlternative()) {
if (auto Conv = AsyncAlt->getForeignAsyncConvention()) {
auto FlagIdx = Conv->completionHandlerFlagParamIndex();
if (FlagIdx && *FlagIdx >= 0 && *FlagIdx < AllParams.size()) {
auto IsSuccessFlag = Conv->completionHandlerFlagIsErrorOnZero();
BoolFlagParam = {AllParams[*FlagIdx], IsSuccessFlag};
}
}
}
}
/// Whether the closure has a particular parameter.
bool hasParam(const ParamDecl *Param) const {
return Param == ErrParam || SuccessParams.contains(Param);
}
/// Whether \p Param is a success param.
bool isSuccessParam(const ParamDecl *Param) const {
return SuccessParams.contains(Param);
}
/// Whether \p Param is a closure parameter that may be unwrapped. This
/// includes optional parameters as well as \c Result parameters that may be
/// unwrapped through e.g 'try? res.get()'.
bool isUnwrappableParam(const ParamDecl *Param) const {
if (!hasParam(Param))
return false;
if (getResultParam() == Param)
return true;
return HandlerDesc.shouldUnwrap(Param->getType());
}
/// Whether \p Param is the known Bool parameter that indicates success or
/// failure.
bool isKnownBoolFlagParam(const ParamDecl *Param) const {
if (auto BoolFlag = getKnownBoolFlagParam())
return BoolFlag->Param == Param;
return false;
}
/// Whether \p Param is a closure parameter that has a binding available in
/// the async variant of the call for a particular \p Block.
bool hasBinding(const ParamDecl *Param, BlockKind Block) const {
switch (Block) {
case BlockKind::SUCCESS:
// Known bool flags get dropped from the imported async variant.
if (isKnownBoolFlagParam(Param))
return false;
return isSuccessParam(Param);
case BlockKind::ERROR:
return Param == ErrParam;
case BlockKind::FALLBACK:
// We generally want to bind everything in the fallback case.
return hasParam(Param);
}
llvm_unreachable("Unhandled case in switch");
}
/// Retrieve the parameters to bind in a given \p Block.
TinyPtrVector<const ParamDecl *> getParamsToBind(BlockKind Block) {
TinyPtrVector<const ParamDecl *> Result;
for (auto *Param : AllParams) {
if (hasBinding(Param, Block))
Result.push_back(Param);
}
return Result;
}
/// If there is a known Bool flag parameter indicating success or failure,
/// returns it, \c None otherwise.
Optional<KnownBoolFlagParam> getKnownBoolFlagParam() const {
return BoolFlagParam;
}
/// All the parameters of the closure passed as the completion handler.
ArrayRef<const ParamDecl *> getAllParams() const { return AllParams; }
/// The success parameters of the closure passed as the completion handler.
/// Note this includes a \c Result parameter.
ArrayRef<const ParamDecl *> getSuccessParams() const {
return SuccessParams.getArrayRef();
}
/// The error parameter of the closure passed as the completion handler, or
/// \c nullptr if there is no error parameter.
const ParamDecl *getErrParam() const { return ErrParam; }
/// If the closure has a single \c Result parameter, returns it, \c nullptr
/// otherwise.
const ParamDecl *getResultParam() const {
return HandlerDesc.Type == HandlerType::RESULT ? SuccessParams[0] : nullptr;
}
};
/// Whether or not the given statement starts a new scope. Note that most
/// statements are handled by the \c BraceStmt check. The others listed are
/// a somewhat special case since they can also declare variables in their
/// condition.
static bool startsNewScope(Stmt *S) {
switch (S->getKind()) {
case StmtKind::Brace:
case StmtKind::If:
case StmtKind::While:
case StmtKind::ForEach:
case StmtKind::Case:
return true;
default:
return false;
}
}
struct ClassifiedBlocks {
ClassifiedBlock SuccessBlock;
ClassifiedBlock ErrorBlock;
};
/// Classifer of callback closure statements that that have either multiple
/// non-Result parameters or a single Result parameter and return Void.
///
/// It performs a (possibly incorrect) best effort and may give up in certain
/// cases. Aims to cover the idiomatic cases of either having no error
/// parameter at all, or having success/error code wrapped in ifs/guards/switch
/// using either pattern binding or nil checks.
///
/// Code outside any clear conditions is assumed to be solely part of the
/// success block for now, though some heuristics could be added to classify
/// these better in the future.
struct CallbackClassifier {
/// Updates the success and error block of `Blocks` with nodes and bound
/// names from `Body`. Errors are added through `DiagEngine`, possibly
/// resulting in partially filled out blocks.
static void classifyInto(ClassifiedBlocks &Blocks,
const ClosureCallbackParams &Params,
llvm::DenseSet<SwitchStmt *> &HandledSwitches,
DiagnosticEngine &DiagEngine, BraceStmt *Body) {
assert(!Body->getElements().empty() && "Cannot classify empty body");
CallbackClassifier Classifier(Blocks, Params, HandledSwitches, DiagEngine);
Classifier.classifyNodes(Body->getElements(), Body->getRBraceLoc());
}
private:
ClassifiedBlocks &Blocks;
const ClosureCallbackParams &Params;
llvm::DenseSet<SwitchStmt *> &HandledSwitches;
DiagnosticEngine &DiagEngine;
ClassifiedBlock *CurrentBlock;
/// This is set to \c true if we're currently classifying on a known condition
/// path, where \c CurrentBlock is set to the appropriate block. This lets us
/// be more lenient with unhandled conditions as we already know the block
/// we're supposed to be in.
bool IsKnownConditionPath = false;
CallbackClassifier(ClassifiedBlocks &Blocks,
const ClosureCallbackParams &Params,
llvm::DenseSet<SwitchStmt *> &HandledSwitches,
DiagnosticEngine &DiagEngine)
: Blocks(Blocks), Params(Params), HandledSwitches(HandledSwitches),
DiagEngine(DiagEngine), CurrentBlock(&Blocks.SuccessBlock) {}
/// Attempt to apply custom classification logic to a given node, returning
/// \c true if the node was classified, otherwise \c false.
bool tryClassifyNode(ASTNode Node) {
auto *Statement = Node.dyn_cast<Stmt *>();
if (!Statement)
return false;
if (auto *IS = dyn_cast<IfStmt>(Statement)) {
NodesToPrint TempNodes;
if (auto *BS = dyn_cast<BraceStmt>(IS->getThenStmt())) {
TempNodes = NodesToPrint::inBraceStmt(BS);
} else {
TempNodes = NodesToPrint({IS->getThenStmt()}, /*commentLocs*/ {});
}
classifyConditional(IS, IS->getCond(), std::move(TempNodes),
IS->getElseStmt());
return true;
} else if (auto *GS = dyn_cast<GuardStmt>(Statement)) {
classifyConditional(GS, GS->getCond(), NodesToPrint(), GS->getBody());
return true;
} else if (auto *SS = dyn_cast<SwitchStmt>(Statement)) {
classifySwitch(SS);
return true;
} else if (auto *RS = dyn_cast<ReturnStmt>(Statement)) {
// We can look through an implicit Void return of a SingleValueStmtExpr,
// as that's semantically a statement.
if (RS->hasResult() && RS->isImplicit()) {
auto Ty = RS->getResult()->getType();
if (Ty && Ty->isVoid()) {
if (auto *SVE = dyn_cast<SingleValueStmtExpr>(RS->getResult()))
return tryClassifyNode(SVE->getStmt());
}
}
}
return false;
}
/// Classify a node, or add the node to the block if it cannot be classified.
/// Returns \c true if there was an error.
bool classifyNode(ASTNode Node) {
auto DidClassify = tryClassifyNode(Node);
if (!DidClassify)
CurrentBlock->addNode(Node);
return DiagEngine.hadAnyError();
}
void classifyNodes(ArrayRef<ASTNode> Nodes, SourceLoc EndCommentLoc) {
for (auto Node : Nodes) {
auto HadError = classifyNode(Node);
if (HadError)
return;
}
// Make sure to pick up any trailing comments.
CurrentBlock->addPossibleCommentLoc(EndCommentLoc);
}
/// Whether any of the provided ASTNodes have a child expression that force
/// unwraps the error parameter. Note that this doesn't walk into new scopes.
bool hasForceUnwrappedErrorParam(ArrayRef<ASTNode> Nodes) {
auto *ErrParam = Params.getErrParam();
if (!ErrParam)
return false;
class ErrUnwrapFinder : public ASTWalker {
const ParamDecl *ErrParam;
bool FoundUnwrap = false;
public:
explicit ErrUnwrapFinder(const ParamDecl *ErrParam)
: ErrParam(ErrParam) {}
bool foundUnwrap() const { return FoundUnwrap; }
MacroWalking getMacroWalkingBehavior() const override {
return MacroWalking::Arguments;
}
PreWalkResult<Expr *> walkToExprPre(Expr *E) override {
// Don't walk into ternary conditionals as they may have additional
// conditions such as err != nil that make a force unwrap now valid.
if (isa<TernaryExpr>(E))
return Action::SkipChildren(E);
auto *FVE = dyn_cast<ForceValueExpr>(E);
if (!FVE)
return Action::Continue(E);
auto *DRE = dyn_cast<DeclRefExpr>(FVE->getSubExpr());
if (!DRE)
return Action::Continue(E);
if (DRE->getDecl() != ErrParam)
return Action::Continue(E);
// If we find the node we're looking for, make a note of it, and abort
// the walk.
FoundUnwrap = true;
return Action::Stop();
}
PreWalkResult<Stmt *> walkToStmtPre(Stmt *S) override {
// Don't walk into new explicit scopes, we only want to consider force
// unwraps in the immediate conditional body.
if (!S->isImplicit() && startsNewScope(S))
return Action::SkipChildren(S);
return Action::Continue(S);
}
PreWalkAction walkToDeclPre(Decl *D) override {
// Don't walk into new explicit DeclContexts.
return Action::VisitChildrenIf(D->isImplicit() || !isa<DeclContext>(D));
}
};
for (auto Node : Nodes) {
ErrUnwrapFinder walker(ErrParam);
Node.walk(walker);
if (walker.foundUnwrap())
return true;
}
return false;
}
/// Given a callback condition, classify it as a success or failure path.
Optional<ClassifiedCondition>
classifyCallbackCondition(const CallbackCondition &Cond,
const NodesToPrint &SuccessNodes, Stmt *ElseStmt) {
if (!Cond.isValid())
return None;
// If the condition involves a refutable pattern, we can't currently handle
// it.
if (Cond.BindPattern && Cond.BindPattern->isRefutablePattern())
return None;
auto *SubjectParam = dyn_cast<ParamDecl>(Cond.Subject);
if (!SubjectParam)
return None;
// For certain types of condition, they need to be certain kinds of params.
auto CondType = *Cond.Type;
switch (CondType) {
case ConditionType::NOT_NIL:
case ConditionType::NIL:
if (!Params.isUnwrappableParam(SubjectParam))
return None;
break;
case ConditionType::IS_TRUE:
case ConditionType::IS_FALSE:
if (!Params.isSuccessParam(SubjectParam))
return None;
break;
case ConditionType::SUCCESS_PATTERN:
case ConditionType::FAILURE_PATTEN:
if (SubjectParam != Params.getResultParam())
return None;
break;
}
// Let's start with a success path, and flip any negative conditions.
auto Path = ConditionPath::SUCCESS;
// If it's an error param, that's a flip.
if (SubjectParam == Params.getErrParam())
Path = flippedConditionPath(Path);
// If we have a nil, false, or failure condition, that's a flip.
switch (CondType) {
case ConditionType::NIL:
case ConditionType::IS_FALSE:
case ConditionType::FAILURE_PATTEN:
Path = flippedConditionPath(Path);
break;
case ConditionType::IS_TRUE:
case ConditionType::NOT_NIL:
case ConditionType::SUCCESS_PATTERN:
break;
}
// If we have a bool condition, it could be an Obj-C style flag check, which
// we do some extra checking for. Otherwise, we're done.
if (CondType != ConditionType::IS_TRUE &&
CondType != ConditionType::IS_FALSE) {
return ClassifiedCondition(Cond, Path, /*ObjCFlagCheck*/ false);
}
// Check to see if we have a known bool flag parameter that indicates
// success or failure.
if (auto KnownBoolFlag = Params.getKnownBoolFlagParam()) {
if (KnownBoolFlag->Param != SubjectParam)
return None;
// The path may need to be flipped depending on whether the flag indicates
// success.
if (!KnownBoolFlag->IsSuccessFlag)
Path = flippedConditionPath(Path);
return ClassifiedCondition(Cond, Path, /*ObjCStyleFlagCheck*/ true);
}
// If we've reached here, we have a bool flag check that isn't specified in
// the async convention. We apply a heuristic to see if the error param is
// force unwrapped in the conditional body. In that case, the user is
// expecting it to be the error path, and it's more likely than not that the
// flag value conveys no more useful information in the error block.
// First check the success block.
auto FoundInSuccessBlock =
hasForceUnwrappedErrorParam(SuccessNodes.getNodes());
// Then check the else block if we have it.
if (ASTNode ElseNode = ElseStmt) {
// Unwrap the BraceStmt of the else clause if needed. This is needed as
// we won't walk into BraceStmts by default as they introduce new
// scopes.
ArrayRef<ASTNode> Nodes;
if (auto *BS = dyn_cast<BraceStmt>(ElseStmt)) {
Nodes = BS->getElements();
} else {
Nodes = llvm::makeArrayRef(ElseNode);
}
if (hasForceUnwrappedErrorParam(Nodes)) {
// If we also found an unwrap in the success block, we don't know what's
// happening here.
if (FoundInSuccessBlock)
return None;
// Otherwise we can determine this as a success condition. Note this is
// flipped as if the error is present in the else block, this condition
// is for success.
return ClassifiedCondition(Cond, ConditionPath::SUCCESS,
/*ObjCStyleFlagCheck*/ true);
}
}
if (FoundInSuccessBlock) {
// Note that the path is flipped as if the error is present in the success
// block, this condition is for failure.
return ClassifiedCondition(Cond, ConditionPath::FAILURE,
/*ObjCStyleFlagCheck*/ true);
}
// Otherwise we can't classify this.
return None;
}
/// Classifies all the conditions present in a given StmtCondition, taking
/// into account its success body and failure body. Returns \c true if there
/// were any conditions that couldn't be classified, \c false otherwise.
bool classifyConditionsOf(StmtCondition Cond,
const NodesToPrint &ThenNodesToPrint,
Stmt *ElseStmt,
ClassifiedCallbackConditions &Conditions) {
bool UnhandledConditions = false;
Optional<ClassifiedCondition> ObjCFlagCheck;
auto TryAddCond = [&](CallbackCondition CC) {
auto Classified =
classifyCallbackCondition(CC, ThenNodesToPrint, ElseStmt);
// If we couldn't classify this, or if there are multiple Obj-C style flag
// checks, this is unhandled.
if (!Classified || (ObjCFlagCheck && Classified->IsObjCStyleFlagCheck)) {
UnhandledConditions = true;
return;
}
// If we've seen multiple conditions for the same subject, don't handle
// this.
if (!Conditions.insert({CC.Subject, *Classified}).second) {
UnhandledConditions = true;
return;
}
if (Classified->IsObjCStyleFlagCheck)
ObjCFlagCheck = Classified;
};
for (auto &CondElement : Cond) {
if (auto *BoolExpr = CondElement.getBooleanOrNull()) {
SmallVector<Expr *, 1> Exprs;
Exprs.push_back(BoolExpr);
while (!Exprs.empty()) {
auto *Next = Exprs.pop_back_val()->getSemanticsProvidingExpr();
if (auto *ACE = dyn_cast<AutoClosureExpr>(Next))
Next = ACE->getSingleExpressionBody()->getSemanticsProvidingExpr();
if (auto *BE = dyn_cast_or_null<BinaryExpr>(Next)) {
auto *Operator = isOperator(BE);
if (Operator) {
// If we have an && operator, decompose its arguments.
if (Operator->getBaseName() == "&&") {
Exprs.push_back(BE->getLHS());
Exprs.push_back(BE->getRHS());
} else {
// Otherwise check to see if we have an == nil or != nil
// condition.
TryAddCond(CallbackCondition(BE, Operator));
}
continue;
}
}
// Check to see if we have a lone bool condition.
TryAddCond(CallbackCondition(Next));
}
} else if (auto *P = CondElement.getPatternOrNull()) {
TryAddCond(CallbackCondition(P, CondElement.getInitializer()));
}
}
return UnhandledConditions || Conditions.empty();
}
/// Classifies the conditions of a conditional statement, and adds the
/// necessary nodes to either the success or failure block.
void classifyConditional(Stmt *Statement, StmtCondition Condition,
NodesToPrint ThenNodesToPrint, Stmt *ElseStmt) {
ClassifiedCallbackConditions CallbackConditions;
bool UnhandledConditions = classifyConditionsOf(
Condition, ThenNodesToPrint, ElseStmt, CallbackConditions);
auto ErrCondition = CallbackConditions.lookup(Params.getErrParam());
if (UnhandledConditions) {
// Some unknown conditions. If there's an else, assume we can't handle
// and use the fallback case. Otherwise add to either the success or
// error block depending on some heuristics, known conditions will have
// placeholders added (ideally we'd remove them)
// TODO: Remove known conditions and split the `if` statement
if (IsKnownConditionPath) {
// If we're on a known condition path, we can be lenient as we already
// know what block we're in and can therefore just add the conditional
// straight to it.
CurrentBlock->addNode(Statement);
} else if (CallbackConditions.empty()) {
// Technically this has a similar problem, ie. the else could have
// conditions that should be in either success/error
CurrentBlock->addNode(Statement);
} else if (ElseStmt) {
DiagEngine.diagnose(Statement->getStartLoc(),
diag::unknown_callback_conditions);
} else if (ErrCondition && ErrCondition->Path == ConditionPath::FAILURE) {
Blocks.ErrorBlock.addNode(Statement);
} else {
for (auto &Entry : CallbackConditions) {
if (Entry.second.Path == ConditionPath::FAILURE) {
Blocks.ErrorBlock.addNode(Statement);
return;
}
}
Blocks.SuccessBlock.addNode(Statement);
}
return;
}
// If all the conditions were classified, make sure they're all consistently
// on the success or failure path.
Optional<ConditionPath> Path;
for (auto &Entry : CallbackConditions) {
auto &Cond = Entry.second;
if (!Path) {
Path = Cond.Path;
} else if (*Path != Cond.Path) {
// Similar to the unknown conditions case. Add the whole if unless
// there's an else, in which case use the fallback instead.
// TODO: Split the `if` statement
if (ElseStmt) {
DiagEngine.diagnose(Statement->getStartLoc(),
diag::mixed_callback_conditions);
} else {
CurrentBlock->addNode(Statement);
}
return;
}
}
assert(Path && "Didn't classify a path?");
auto *ThenBlock = &Blocks.SuccessBlock;
auto *ElseBlock = &Blocks.ErrorBlock;
// If the condition is for a failure path, the error block is ThenBlock, and
// the success block is ElseBlock.
if (*Path == ConditionPath::FAILURE)
std::swap(ThenBlock, ElseBlock);
// We'll be dropping the statement, but make sure to keep any attached
// comments.
CurrentBlock->addPossibleCommentLoc(Statement->getStartLoc());
ThenBlock->addAllBindings(CallbackConditions);
// TODO: Handle nested ifs
setNodes(ThenBlock, ElseBlock, std::move(ThenNodesToPrint));
if (ElseStmt) {
if (auto *BS = dyn_cast<BraceStmt>(ElseStmt)) {
// If this is a guard statement, we know that we'll always exit,
// allowing us to classify any additional nodes into the opposite block.
auto AlwaysExits = isa<GuardStmt>(Statement);
setNodes(ElseBlock, ThenBlock, NodesToPrint::inBraceStmt(BS),
AlwaysExits);
} else {
// If we reached here, we should have an else if statement. Given we
// know we're in the else of a known condition, temporarily flip the
// current block, and set that we know what path we're on.
llvm::SaveAndRestore<bool> CondScope(IsKnownConditionPath, true);
llvm::SaveAndRestore<ClassifiedBlock *> BlockScope(CurrentBlock,
ElseBlock);
classifyNodes(ArrayRef<ASTNode>(ElseStmt),
/*endCommentLoc*/ SourceLoc());
}
}
}
/// Adds \p Nodes to \p Block, potentially flipping the current block if we
/// can determine that the nodes being added will cause control flow to leave
/// the scope.
///
/// \param Block The block to add the nodes to.
/// \param OtherBlock The block for the opposing condition path.
/// \param Nodes The nodes to add.
/// \param AlwaysExitsScope Whether the nodes being added always exit the
/// scope, and therefore whether the current block should be flipped.
void setNodes(ClassifiedBlock *Block, ClassifiedBlock *OtherBlock,
NodesToPrint Nodes, bool AlwaysExitsScope = false) {
// Drop an explicit trailing 'return' or 'break' if we can.
bool HasTrailingReturnOrBreak = Nodes.hasTrailingReturnOrBreak();
if (HasTrailingReturnOrBreak)
Nodes.dropTrailingReturnOrBreakIfPossible();
// If we know we're exiting the scope, we can set IsKnownConditionPath, as
// we know any future nodes should be classified into the other block.
if (HasTrailingReturnOrBreak || AlwaysExitsScope) {
CurrentBlock = OtherBlock;
IsKnownConditionPath = true;
Block->addAllNodes(std::move(Nodes));
} else {
Block->addAllNodes(std::move(Nodes));
}
}
void classifySwitch(SwitchStmt *SS) {
auto *ResultParam = Params.getResultParam();
if (singleSwitchSubject(SS) != ResultParam) {
CurrentBlock->addNode(SS);
return;
}
// We'll be dropping the switch, but make sure to keep any attached
// comments.
CurrentBlock->addPossibleCommentLoc(SS->getStartLoc());
// Push the cases into a vector. This is only done to eagerly evaluate the
// AsCaseStmtRange sequence so we can know what the last case is.
SmallVector<CaseStmt *, 2> Cases;
Cases.append(SS->getCases().begin(), SS->getCases().end());
for (auto *CS : Cases) {
if (CS->hasFallthroughDest()) {
DiagEngine.diagnose(CS->getLoc(), diag::callback_with_fallthrough);
return;
}
if (CS->isDefault()) {
DiagEngine.diagnose(CS->getLoc(), diag::callback_with_default);
return;
}
auto Items = CS->getCaseLabelItems();
if (Items.size() > 1) {
DiagEngine.diagnose(CS->getLoc(), diag::callback_multiple_case_items);
return;
}
if (Items[0].getWhereLoc().isValid()) {
DiagEngine.diagnose(CS->getLoc(), diag::callback_where_case_item);
return;
}
auto *Block = &Blocks.SuccessBlock;
auto *OtherBlock = &Blocks.ErrorBlock;
auto SuccessNodes = NodesToPrint::inBraceStmt(CS->getBody());
// Classify the case pattern.
auto CC = classifyCallbackCondition(
CallbackCondition(ResultParam, &Items[0]), SuccessNodes,
/*elseStmt*/ nullptr);
if (!CC) {
DiagEngine.diagnose(CS->getLoc(), diag::unknown_callback_case_item);
return;
}
if (CC->Path == ConditionPath::FAILURE)
std::swap(Block, OtherBlock);
// We'll be dropping the case, but make sure to keep any attached
// comments. Because these comments will effectively be part of the
// previous case, add them to CurrentBlock.
CurrentBlock->addPossibleCommentLoc(CS->getStartLoc());
// Make sure to grab trailing comments in the last case stmt.
if (CS == Cases.back())
Block->addPossibleCommentLoc(SS->getRBraceLoc());
setNodes(Block, OtherBlock, std::move(SuccessNodes));
Block->addBinding(*CC);
}
// Mark this switch statement as having been transformed.
HandledSwitches.insert(SS);
}
};
/// Base name of a decl if it has one, an empty \c DeclBaseName otherwise.
static DeclBaseName getDeclName(const Decl *D) {
if (auto *VD = dyn_cast<ValueDecl>(D)) {
if (VD->hasName())
return VD->getBaseName();
}
return DeclBaseName();
}
class DeclCollector : private SourceEntityWalker {
llvm::DenseSet<const Decl *> &Decls;
public:
/// Collect all explicit declarations declared in \p Scope (or \p SF if
/// \p Scope is a nullptr) that are not within their own scope.
static void collect(BraceStmt *Scope, SourceFile &SF,
llvm::DenseSet<const Decl *> &Decls) {
DeclCollector Collector(Decls);
if (Scope) {
for (auto Node : Scope->getElements()) {
Collector.walk(Node);
}
} else {
Collector.walk(SF);
}
}
private:
DeclCollector(llvm::DenseSet<const Decl *> &Decls)
: Decls(Decls) {}
bool walkToDeclPre(Decl *D, CharSourceRange Range) override {
// Want to walk through top level code decls (which are implicitly added
// for top level non-decl code) and pattern binding decls (which contain
// the var decls that we care about).
if (isa<TopLevelCodeDecl>(D) || isa<PatternBindingDecl>(D))
return true;
if (!D->isImplicit())
Decls.insert(D);
return false;
}
bool walkToExprPre(Expr *E) override {
return !isa<ClosureExpr>(E);
}
bool walkToStmtPre(Stmt *S) override {
return S->isImplicit() || !startsNewScope(S);
}
};
class ReferenceCollector : private SourceEntityWalker {
SourceManager *SM;
llvm::DenseSet<const Decl *> DeclaredDecls;
llvm::DenseSet<const Decl *> &ReferencedDecls;
ASTNode Target;
bool AfterTarget;
public:
/// Collect all explicit references in \p Scope (or \p SF if \p Scope is
/// a nullptr) that are after \p Target and not first declared. That is,
/// references that we don't want to shadow with hoisted declarations.
///
/// Also collect all declarations that are \c DeclContexts, which is an
/// over-appoximation but let's us ignore them elsewhere.
static void collect(ASTNode Target, BraceStmt *Scope, SourceFile &SF,
llvm::DenseSet<const Decl *> &Decls) {
ReferenceCollector Collector(Target, &SF.getASTContext().SourceMgr,
Decls);
if (Scope)
Collector.walk(Scope);
else
Collector.walk(SF);
}
private:
ReferenceCollector(ASTNode Target, SourceManager *SM,
llvm::DenseSet<const Decl *> &Decls)
: SM(SM), DeclaredDecls(), ReferencedDecls(Decls), Target(Target),
AfterTarget(false) {}
bool walkToDeclPre(Decl *D, CharSourceRange Range) override {
// Bit of a hack, include all contexts so they're never renamed (seems worse
// to rename a class/function than it does a variable). Again, an
// over-approximation, but hopefully doesn't come up too often.
if (isa<DeclContext>(D) && !D->isImplicit()) {
ReferencedDecls.insert(D);
}
if (AfterTarget && !D->isImplicit()) {
DeclaredDecls.insert(D);
} else if (D == Target.dyn_cast<Decl *>()) {
AfterTarget = true;
}
return shouldWalkInto(D->getSourceRange());
}
bool walkToExprPre(Expr *E) override {
if (AfterTarget && !E->isImplicit()) {
if (auto *DRE = dyn_cast<DeclRefExpr>(E)) {
if (auto *D = DRE->getDecl()) {
// Only care about references that aren't declared, as seen decls will
// be renamed (if necessary) during the refactoring.
if (!D->isImplicit() && !DeclaredDecls.count(D)) {
ReferencedDecls.insert(D);
// Also add the async alternative of a function to prevent
// collisions if a call is replaced with the alternative.
if (auto *AFD = dyn_cast<AbstractFunctionDecl>(D)) {
if (auto *Alternative = AFD->getAsyncAlternative())
ReferencedDecls.insert(Alternative);
}
}
}
}
} else if (E == Target.dyn_cast<Expr *>()) {
AfterTarget = true;
}
return shouldWalkInto(E->getSourceRange());
}
bool walkToStmtPre(Stmt *S) override {
if (S == Target.dyn_cast<Stmt *>())
AfterTarget = true;
return shouldWalkInto(S->getSourceRange());
}
bool walkToPatternPre(Pattern *P) override {
if (P == Target.dyn_cast<Pattern *>())
AfterTarget = true;
return shouldWalkInto(P->getSourceRange());
}
bool shouldWalkInto(SourceRange Range) {
return AfterTarget || (SM &&
SM->rangeContainsTokenLoc(Range, Target.getStartLoc()));
}
};
/// Similar to the \c ReferenceCollector but collects references in all scopes
/// without any starting point in each scope. In addition, it tracks the number
/// of references to a decl in a given scope.
class ScopedDeclCollector : private SourceEntityWalker {
public:
using DeclsTy = llvm::DenseSet<const Decl *>;
using RefDeclsTy = llvm::DenseMap<const Decl *, /*numRefs*/ unsigned>;
private:
using ScopedDeclsTy = llvm::DenseMap<const Stmt *, RefDeclsTy>;
struct Scope {
DeclsTy DeclaredDecls;
RefDeclsTy *ReferencedDecls;
Scope(RefDeclsTy *ReferencedDecls) : DeclaredDecls(),
ReferencedDecls(ReferencedDecls) {}
};
ScopedDeclsTy ReferencedDecls;
llvm::SmallVector<Scope, 4> ScopeStack;
public:
/// Starting at \c Scope, collect all explicit references in every scope
/// within (including the initial) that are not first declared, ie. those that
/// could end up shadowed. Also include all \c DeclContext declarations as
/// we'd like to avoid renaming functions and types completely.
void collect(ASTNode Node) {
walk(Node);
}
const RefDeclsTy *getReferencedDecls(Stmt *Scope) const {
auto Res = ReferencedDecls.find(Scope);
if (Res == ReferencedDecls.end())
return nullptr;
return &Res->second;
}
private:
bool walkToDeclPre(Decl *D, CharSourceRange Range) override {
if (ScopeStack.empty() || D->isImplicit())
return true;
ScopeStack.back().DeclaredDecls.insert(D);
if (isa<DeclContext>(D))
(*ScopeStack.back().ReferencedDecls)[D] += 1;
return true;
}
bool walkToExprPre(Expr *E) override {
if (ScopeStack.empty())
return true;
if (!E->isImplicit()) {
if (auto *DRE = dyn_cast<DeclRefExpr>(E)) {
if (auto *D = DRE->getDecl()) {
// If we have a reference that isn't declared in the same scope,
// increment the number of references to that decl.
if (!D->isImplicit() && !ScopeStack.back().DeclaredDecls.count(D)) {
(*ScopeStack.back().ReferencedDecls)[D] += 1;
// Also add the async alternative of a function to prevent
// collisions if a call is replaced with the alternative.
if (auto *AFD = dyn_cast<AbstractFunctionDecl>(D)) {
if (auto *Alternative = AFD->getAsyncAlternative())
(*ScopeStack.back().ReferencedDecls)[Alternative] += 1;
}
}
}
}
}
return true;
}
bool walkToStmtPre(Stmt *S) override {
// Purposely check \c BraceStmt here rather than \c startsNewScope.
// References in the condition should be applied to the previous scope, not
// the scope of that statement.
if (isa<BraceStmt>(S))
ScopeStack.emplace_back(&ReferencedDecls[S]);
return true;
}
bool walkToStmtPost(Stmt *S) override {
if (isa<BraceStmt>(S)) {
size_t NumScopes = ScopeStack.size();
if (NumScopes >= 2) {
// Add any referenced decls to the parent scope that weren't declared
// there.
auto &ParentStack = ScopeStack[NumScopes - 2];
for (auto DeclAndNumRefs : *ScopeStack.back().ReferencedDecls) {
auto *D = DeclAndNumRefs.first;
if (!ParentStack.DeclaredDecls.count(D))
(*ParentStack.ReferencedDecls)[D] += DeclAndNumRefs.second;
}
}
ScopeStack.pop_back();
}
return true;
}
};
/// Checks whether an ASTNode contains a reference to a given declaration.
class DeclReferenceFinder : private SourceEntityWalker {
bool HasFoundReference = false;
const Decl *Search;
bool walkToExprPre(Expr *E) override {
if (auto DRE = dyn_cast<DeclRefExpr>(E)) {
if (DRE->getDecl() == Search) {
HasFoundReference = true;
return false;
}
}
return true;
}
DeclReferenceFinder(const Decl *Search) : Search(Search) {}
public:
/// Returns \c true if \p node contains a reference to \p Search, \c false
/// otherwise.
static bool containsReference(ASTNode Node, const ValueDecl *Search) {
DeclReferenceFinder Checker(Search);
Checker.walk(Node);
return Checker.HasFoundReference;
}
};
/// Builds up async-converted code for an AST node.
///
/// If it is a function, its declaration will have `async` added. If a
/// completion handler is present, it will be removed and the return type of
/// the function will reflect the parameters of the handler, including an
/// added `throws` if necessary.
///
/// Calls to the completion handler are replaced with either a `return` or
/// `throws` depending on the arguments.
///
/// Calls to functions with an async alternative will be replaced with a call
/// to the alternative, possibly wrapped in a do/catch. The do/catch is skipped
/// if the the closure either:
/// 1. Has no error
/// 2. Has an error but no error handling (eg. just ignores)
/// 3. Has error handling that only calls the containing function's handler
/// with an error matching the error argument
///
/// (2) is technically not the correct translation, but in practice it's likely
/// the code a user would actually want.
///
/// If the success vs error handling split inside the closure cannot be
/// determined and the closure takes regular parameters (ie. not a Result), a
/// fallback translation is used that keeps all the same variable names and
/// simply moves the code within the closure out.
///
/// The fallback is generally avoided, however, since it's quite unlikely to be
/// the code the user intended. In most cases the refactoring will continue,
/// with any unhandled decls wrapped in placeholders instead.
class AsyncConverter : private SourceEntityWalker {
struct Scope {
llvm::DenseSet<DeclBaseName> Names;
/// If this scope is wrapped in a \c withChecked(Throwing)Continuation, the
/// name of the continuation that must be resumed where there previously was
/// a call to the function's completion handler.
/// Otherwise an empty identifier.
Identifier ContinuationName;
Scope(Identifier ContinuationName)
: Names(), ContinuationName(ContinuationName) {}
/// Whether this scope is wrapped in a \c withChecked(Throwing)Continuation.
bool isWrappedInContination() const { return !ContinuationName.empty(); }
};
SourceFile *SF;
SourceManager &SM;
DiagnosticEngine &DiagEngine;
// Node to convert
ASTNode StartNode;
// Completion handler of `StartNode` (if it's a function with an async
// alternative)
AsyncHandlerParamDesc TopHandler;
SmallString<0> Buffer;
llvm::raw_svector_ostream OS;
// Decls where any force unwrap or optional chain of that decl should be
// elided, e.g for a previously optional closure parameter that has become a
// non-optional local.
llvm::DenseSet<const Decl *> Unwraps;
// Decls whose references should be replaced with, either because they no
// longer exist or are a different type. Any replaced code should ideally be
// handled by the refactoring properly, but that's not possible in all cases
llvm::DenseSet<const Decl *> Placeholders;
// Mapping from decl -> name, used as the name of possible new local
// declarations of old completion handler parametes, as well as the
// replacement for other hoisted declarations and their references
llvm::DenseMap<const Decl *, Identifier> Names;
/// The scopes (containing all name decls and whether the scope is wrapped in
/// a continuation) as the AST is being walked. The first element is the
/// initial scope and the last is the current scope.
llvm::SmallVector<Scope, 4> Scopes;
// Mapping of \c BraceStmt -> declarations referenced in that statement
// without first being declared. These are used to fill the \c ScopeNames
// map on entering that scope.
ScopedDeclCollector ScopedDecls;
/// The switch statements that have been re-written by this transform.
llvm::DenseSet<SwitchStmt *> HandledSwitches;
// The last source location that has been output. Used to output the source
// between handled nodes
SourceLoc LastAddedLoc;
// Number of expressions (or pattern binding decl) currently nested in, taking
// into account hoisting and the possible removal of ifs/switches
int NestedExprCount = 0;
// Whether a completion handler body is currently being hoisted out of its
// call
bool Hoisting = false;
/// Whether a pattern is currently being converted.
bool ConvertingPattern = false;
/// A mapping of inline patterns to print for closure parameters.
using InlinePatternsToPrint = llvm::DenseMap<const Decl *, const Pattern *>;
public:
/// Convert a function
AsyncConverter(SourceFile *SF, SourceManager &SM,
DiagnosticEngine &DiagEngine, AbstractFunctionDecl *FD,
const AsyncHandlerParamDesc &TopHandler)
: SF(SF), SM(SM), DiagEngine(DiagEngine), StartNode(FD),
TopHandler(TopHandler), OS(Buffer) {
Placeholders.insert(TopHandler.getHandler());
ScopedDecls.collect(FD);
// Shouldn't strictly be necessary, but prefer possible shadowing over
// crashes caused by a missing scope
addNewScope({});
}
/// Convert a call
AsyncConverter(SourceFile *SF, SourceManager &SM,
DiagnosticEngine &DiagEngine, CallExpr *CE, BraceStmt *Scope)
: SF(SF), SM(SM), DiagEngine(DiagEngine), StartNode(CE), OS(Buffer) {
ScopedDecls.collect(CE);
// Create the initial scope, can be more accurate than the general
// \c ScopedDeclCollector as there is a starting point.
llvm::DenseSet<const Decl *> UsedDecls;
DeclCollector::collect(Scope, *SF, UsedDecls);
ReferenceCollector::collect(StartNode, Scope, *SF, UsedDecls);
addNewScope(UsedDecls);
}
ASTContext &getASTContext() const { return SF->getASTContext(); }
bool convert() {
assert(Buffer.empty() && "AsyncConverter can only be used once");
if (auto *FD = dyn_cast_or_null<FuncDecl>(StartNode.dyn_cast<Decl *>())) {
addFuncDecl(FD);
if (FD->getBody()) {
convertNode(FD->getBody());
}
} else {
convertNode(StartNode, /*StartOverride=*/{}, /*ConvertCalls=*/true,
/*IncludeComments=*/false);
}
return !DiagEngine.hadAnyError();
}
/// When adding an async alternative method for the function declaration \c
/// FD, this function tries to create a function body for the legacy function
/// (the one with a completion handler), which calls the newly converted async
/// function. There are certain situations in which we fail to create such a
/// body, e.g. if the completion handler has the signature `(String, Error?)
/// -> Void` in which case we can't synthesize the result of type \c String in
/// the error case.
bool createLegacyBody() {
assert(Buffer.empty() &&
"AsyncConverter can only be used once");
if (!canCreateLegacyBody())
return false;
FuncDecl *FD = cast<FuncDecl>(StartNode.get<Decl *>());
OS << tok::l_brace << "\n"; // start function body
OS << "Task " << tok::l_brace << "\n";
addHoistedNamedCallback(FD, TopHandler, TopHandler.getNameStr(), [&]() {
if (TopHandler.HasError) {
OS << tok::kw_try << " ";
}
OS << "await ";
// Since we're *creating* the async alternative here, there shouldn't
// already be one. Thus, just assume that the call to the alternative is
// the same as the call to the old completion handler function, minus the
// completion handler arg.
addForwardingCallTo(FD, /*HandlerReplacement=*/"");
});
OS << "\n";
OS << tok::r_brace << "\n"; // end 'Task'
OS << tok::r_brace << "\n"; // end function body
return true;
}
/// Creates an async alternative function that forwards onto the completion
/// handler function through
/// withCheckedContinuation/withCheckedThrowingContinuation.
bool createAsyncWrapper() {
assert(Buffer.empty() && "AsyncConverter can only be used once");
auto *FD = cast<FuncDecl>(StartNode.get<Decl *>());
// First add the new async function declaration.
addFuncDecl(FD);
OS << tok::l_brace << "\n";
// Then add the body.
OS << tok::kw_return << " ";
if (TopHandler.HasError)
OS << tok::kw_try << " ";
OS << "await ";
// withChecked[Throwing]Continuation { continuation in
if (TopHandler.HasError) {
OS << "withCheckedThrowingContinuation";
} else {
OS << "withCheckedContinuation";
}
OS << " " << tok::l_brace << " continuation " << tok::kw_in << "\n";
// fnWithHandler(args...) { ... }
auto ClosureStr =
getAsyncWrapperCompletionClosure("continuation", TopHandler);
addForwardingCallTo(FD, /*HandlerReplacement=*/ClosureStr);
OS << "\n";
OS << tok::r_brace << "\n"; // end continuation closure
OS << tok::r_brace << "\n"; // end function body
return true;
}
void replace(ASTNode Node, SourceEditConsumer &EditConsumer,
SourceLoc StartOverride = SourceLoc()) {
SourceRange Range = Node.getSourceRange();
if (StartOverride.isValid()) {
Range = SourceRange(StartOverride, Range.End);
}
CharSourceRange CharRange =
Lexer::getCharSourceRangeFromSourceRange(SM, Range);
EditConsumer.accept(SM, CharRange, Buffer.str());
Buffer.clear();
}
void insertAfter(ASTNode Node, SourceEditConsumer &EditConsumer) {
EditConsumer.insertAfter(SM, Node.getEndLoc(), "\n\n");
EditConsumer.insertAfter(SM, Node.getEndLoc(), Buffer.str());
Buffer.clear();
}
private:
bool canCreateLegacyBody() {
FuncDecl *FD = dyn_cast<FuncDecl>(StartNode.dyn_cast<Decl *>());
if (!FD) {
return false;
}
if (FD == nullptr || FD->getBody() == nullptr) {
return false;
}
if (FD->hasThrows()) {
assert(!TopHandler.isValid() && "We shouldn't have found a handler desc "
"if the original function throws");
return false;
}
return TopHandler.isValid();
}
/// Prints a tuple of elements, or a lone single element if only one is
/// present, using the provided printing function.
template <typename Container, typename PrintFn>
void addTupleOf(const Container &Elements, llvm::raw_ostream &OS,
PrintFn PrintElt) {
if (Elements.size() == 1) {
PrintElt(Elements[0]);
return;
}
OS << tok::l_paren;
llvm::interleave(Elements, PrintElt, [&]() { OS << tok::comma << " "; });
OS << tok::r_paren;
}
/// Retrieve the completion handler closure argument for an async wrapper
/// function.
std::string
getAsyncWrapperCompletionClosure(StringRef ContName,
const AsyncHandlerParamDesc &HandlerDesc) {
std::string OutputStr;
llvm::raw_string_ostream OS(OutputStr);
OS << tok::l_brace; // start closure
// Prepare parameter names for the closure.
auto SuccessParams = HandlerDesc.getSuccessParams();
SmallVector<SmallString<4>, 2> SuccessParamNames;
for (auto idx : indices(SuccessParams)) {
SuccessParamNames.emplace_back("result");
// If we have multiple success params, number them e.g res1, res2...
if (SuccessParams.size() > 1)
SuccessParamNames.back().append(std::to_string(idx + 1));
}
Optional<SmallString<4>> ErrName;
if (HandlerDesc.getErrorParam())
ErrName.emplace("error");
auto HasAnyParams = !SuccessParamNames.empty() || ErrName;
if (HasAnyParams)
OS << " ";
// res1, res2
llvm::interleave(
SuccessParamNames, [&](auto Name) { OS << Name; },
[&]() { OS << tok::comma << " "; });
// , err
if (ErrName) {
if (!SuccessParamNames.empty())
OS << tok::comma << " ";
OS << *ErrName;
}
if (HasAnyParams)
OS << " " << tok::kw_in;
OS << "\n";
// The closure body.
switch (HandlerDesc.Type) {
case HandlerType::PARAMS: {
// For a (Success?, Error?) -> Void handler, we do an if let on the error.
if (ErrName) {
// if let err = err {
OS << tok::kw_if << " " << tok::kw_let << " ";
OS << *ErrName << " " << tok::equal << " " << *ErrName << " ";
OS << tok::l_brace << "\n";
for (auto Idx : indices(SuccessParamNames)) {
auto ParamTy = SuccessParams[Idx].getParameterType();
if (!HandlerDesc.shouldUnwrap(ParamTy))
continue;
}
// continuation.resume(throwing: err)
OS << ContName << tok::period << "resume" << tok::l_paren;
OS << "throwing" << tok::colon << " " << *ErrName;
OS << tok::r_paren << "\n";
// return }
OS << tok::kw_return << "\n";
OS << tok::r_brace << "\n";
}
// If we have any success params that we need to unwrap, insert a guard.
for (auto Idx : indices(SuccessParamNames)) {
auto &Name = SuccessParamNames[Idx];
auto ParamTy = SuccessParams[Idx].getParameterType();
if (!HandlerDesc.shouldUnwrap(ParamTy))
continue;
// guard let res = res else {
OS << tok::kw_guard << " " << tok::kw_let << " ";
OS << Name << " " << tok::equal << " " << Name << " " << tok::kw_else;
OS << " " << tok::l_brace << "\n";
// fatalError(...)
OS << "fatalError" << tok::l_paren;
OS << "\"Expected non-nil result '" << Name << "' for nil error\"";
OS << tok::r_paren << "\n";
// End guard.
OS << tok::r_brace << "\n";
}
// continuation.resume(returning: (res1, res2, ...))
OS << ContName << tok::period << "resume" << tok::l_paren;
OS << "returning" << tok::colon << " ";
addTupleOf(SuccessParamNames, OS, [&](auto Ref) { OS << Ref; });
OS << tok::r_paren << "\n";
break;
}
case HandlerType::RESULT: {
// continuation.resume(with: res)
assert(SuccessParamNames.size() == 1);
OS << ContName << tok::period << "resume" << tok::l_paren;
OS << "with" << tok::colon << " " << SuccessParamNames[0];
OS << tok::r_paren << "\n";
break;
}
case HandlerType::INVALID:
llvm_unreachable("Should not have an invalid handler here");
}
OS << tok::r_brace; // end closure
return OutputStr;
}
/// Retrieves the SourceRange of the preceding comment, or an invalid range if
/// there is no preceding comment.
CharSourceRange getPrecedingCommentRange(SourceLoc Loc) {
auto Tokens = SF->getAllTokens();
auto TokenIter = token_lower_bound(Tokens, Loc);
if (TokenIter == Tokens.end() || !TokenIter->hasComment())
return CharSourceRange();
return TokenIter->getCommentRange();
}
/// Retrieves the location for the start of a comment attached to the token
/// at the provided location, or the location itself if there is no comment.
SourceLoc getLocIncludingPrecedingComment(SourceLoc Loc) {
auto CommentRange = getPrecedingCommentRange(Loc);
if (CommentRange.isInvalid())
return Loc;
return CommentRange.getStart();
}
/// If the provided SourceLoc has a preceding comment, print it out.
void printCommentIfNeeded(SourceLoc Loc) {
auto CommentRange = getPrecedingCommentRange(Loc);
if (CommentRange.isValid())
OS << "\n" << CommentRange.str();
}
void convertNodes(const NodesToPrint &ToPrint) {
// Sort the possible comment locs in reverse order so we can pop them as we
// go.
SmallVector<SourceLoc, 2> CommentLocs;
CommentLocs.append(ToPrint.getPossibleCommentLocs().begin(),
ToPrint.getPossibleCommentLocs().end());
llvm::sort(CommentLocs.begin(), CommentLocs.end(), [](auto lhs, auto rhs) {
return lhs.getOpaquePointerValue() > rhs.getOpaquePointerValue();
});
// First print the nodes we've been asked to print.
for (auto Node : ToPrint.getNodes()) {
// If we need to print comments, do so now.
while (!CommentLocs.empty()) {
auto CommentLoc = CommentLocs.back().getOpaquePointerValue();
auto NodeLoc = Node.getStartLoc().getOpaquePointerValue();
assert(CommentLoc != NodeLoc &&
"Added node to both comment locs and nodes to print?");
// If the comment occurs after the node, don't print now. Wait until
// the right node comes along.
if (CommentLoc > NodeLoc)
break;
printCommentIfNeeded(CommentLocs.pop_back_val());
}
OS << "\n";
convertNode(Node);
}
// We're done printing nodes. Make sure to output the remaining comments.
while (!CommentLocs.empty())
printCommentIfNeeded(CommentLocs.pop_back_val());
}
void convertNode(ASTNode Node, SourceLoc StartOverride = {},
bool ConvertCalls = true,
bool IncludePrecedingComment = true) {
if (!StartOverride.isValid())
StartOverride = Node.getStartLoc();
// Make sure to include any preceding comments attached to the loc
if (IncludePrecedingComment)
StartOverride = getLocIncludingPrecedingComment(StartOverride);
llvm::SaveAndRestore<SourceLoc> RestoreLoc(LastAddedLoc, StartOverride);
llvm::SaveAndRestore<int> RestoreCount(NestedExprCount,
ConvertCalls ? 0 : 1);
walk(Node);
addRange(LastAddedLoc, Node.getEndLoc(), /*ToEndOfToken=*/true);
}
void convertPattern(const Pattern *P) {
// Only print semantic patterns. This cleans up the output of the transform
// and works around some bogus source locs that can appear with typed
// patterns in if let statements.
P = P->getSemanticsProvidingPattern();
// Set up the start of the pattern as the last loc printed to make sure we
// accurately fill in the gaps as we customize the printing of sub-patterns.
llvm::SaveAndRestore<SourceLoc> RestoreLoc(LastAddedLoc, P->getStartLoc());
llvm::SaveAndRestore<bool> RestoreFlag(ConvertingPattern, true);
walk(const_cast<Pattern *>(P));
addRange(LastAddedLoc, P->getEndLoc(), /*ToEndOfToken*/ true);
}
/// Check whether \p Node requires the remainder of this scope to be wrapped
/// in a \c withChecked(Throwing)Continuation. If it is necessary, add
/// a call to \c withChecked(Throwing)Continuation and modify the current
/// scope (\c Scopes.back() ) so that it knows it's wrapped in a continuation.
///
/// Wrapping a node in a continuation is necessary if the following conditions
/// are satisfied:
/// - It contains a reference to the \c TopHandler's completion hander,
/// because these completion handler calls need to be promoted to \c return
/// statements in the refactored method, but
/// - We cannot hoist the completion handler of \p Node, because it doesn't
/// have an async alternative by our heuristics (e.g. because of a
/// completion handler name mismatch or because it also returns a value
/// synchronously).
void wrapScopeInContinationIfNecessary(ASTNode Node) {
if (NestedExprCount != 0) {
// We can't start a continuation in the middle of an expression
return;
}
if (Scopes.back().isWrappedInContination()) {
// We are already in a continuation. No need to add another one.
return;
}
if (!DeclReferenceFinder::containsReference(Node,
TopHandler.getHandler())) {
// The node doesn't have a reference to the function's completion handler.
// It can stay a call with a completion handler, because we don't need to
// promote a completion handler call to a 'return'.
return;
}
// Wrap the current call in a continuation
Identifier contName = createUniqueName("continuation");
Scopes.back().Names.insert(contName);
Scopes.back().ContinuationName = contName;
insertCustom(Node.getStartLoc(), [&]() {
OS << tok::kw_return << ' ';
if (TopHandler.HasError) {
OS << tok::kw_try << ' ';
}
OS << "await ";
if (TopHandler.HasError) {
OS << "withCheckedThrowingContinuation ";
} else {
OS << "withCheckedContinuation ";
}
OS << tok::l_brace << ' ' << contName << ' ' << tok::kw_in << '\n';
});
}
bool walkToPatternPre(Pattern *P) override {
// If we're not converting a pattern, there's nothing extra to do.
if (!ConvertingPattern)
return true;
// When converting a pattern, don't print the 'let' or 'var' of binding
// subpatterns, as they're illegal when nested in PBDs, and we print a
// top-level one.
if (auto *BP = dyn_cast<BindingPattern>(P)) {
return addCustom(BP->getSourceRange(), [&]() {
convertPattern(BP->getSubPattern());
});
}
return true;
}
bool walkToDeclPre(Decl *D, CharSourceRange Range) override {
if (isa<PatternBindingDecl>(D)) {
// We can't hoist a closure inside a PatternBindingDecl. If it contains
// a call to the completion handler, wrap it in a continuation.
wrapScopeInContinationIfNecessary(D);
NestedExprCount++;
return true;
}
// Functions and types already have their names in \c Scopes.Names, only
// variables should need to be renamed.
if (isa<VarDecl>(D)) {
// If we don't already have a name for the var, assign it one. Note that
// vars in binding patterns may already have assigned names here.
if (Names.find(D) == Names.end()) {
auto Ident = assignUniqueName(D, StringRef());
Scopes.back().Names.insert(Ident);
}
addCustom(D->getSourceRange(), [&]() {
OS << newNameFor(D);
});
}
// Note we don't walk into any nested local function decls. If we start
// doing so in the future, be sure to update the logic that deals with
// converting unhandled returns into placeholders in walkToStmtPre.
return false;
}
bool walkToDeclPost(Decl *D) override {
NestedExprCount--;
return true;
}
#define PLACEHOLDER_START "<#"
#define PLACEHOLDER_END "#>"
bool walkToExprPre(Expr *E) override {
// TODO: Handle Result.get as well
if (auto *DRE = dyn_cast<DeclRefExpr>(E)) {
if (auto *D = DRE->getDecl()) {
// Look through to the parent var decl if we have one. This ensures we
// look at the var in a case stmt's pattern rather than the var that's
// implicitly declared in the body.
if (auto *VD = dyn_cast<VarDecl>(D)) {
if (auto *Parent = VD->getParentVarDecl())
D = Parent;
}
bool AddPlaceholder = Placeholders.count(D);
StringRef Name = newNameFor(D, false);
if (AddPlaceholder || !Name.empty())
return addCustom(DRE->getSourceRange(), [&]() {
if (AddPlaceholder)
OS << PLACEHOLDER_START;
if (!Name.empty())
OS << Name;
else
D->getName().print(OS);
if (AddPlaceholder)
OS << PLACEHOLDER_END;
});
}
} else if (isa<ForceValueExpr>(E) || isa<BindOptionalExpr>(E)) {
// Remove a force unwrap or optional chain of a returned success value,
// as it will no longer be optional. For force unwraps, this is always a
// valid transform. For optional chains, it is a locally valid transform
// within the optional chain e.g foo?.x -> foo.x, but may change the type
// of the overall chain, which could cause errors elsewhere in the code.
// However this is generally more useful to the user than just leaving
// 'foo' as a placeholder. Note this is only the case when no other
// optionals are involved in the chain, e.g foo?.x?.y -> foo.x?.y is
// completely valid.
if (auto *D = E->getReferencedDecl().getDecl()) {
if (Unwraps.count(D))
return addCustom(E->getSourceRange(),
[&]() { OS << newNameFor(D, true); });
}
} else if (CallExpr *CE = TopHandler.getAsHandlerCall(E)) {
if (Scopes.back().isWrappedInContination()) {
return addCustom(E->getSourceRange(),
[&]() { convertHandlerToContinuationResume(CE); });
} else if (NestedExprCount == 0) {
return addCustom(E->getSourceRange(),
[&]() { convertHandlerToReturnOrThrows(CE); });
}
} else if (auto *CE = dyn_cast<CallExpr>(E)) {
// Try and hoist a call's completion handler. Don't do so if
// - the current expression is nested (we can't start hoisting in the
// middle of an expression)
// - the current scope is wrapped in a continuation (we can't have await
// calls in the continuation block)
if (NestedExprCount == 0 && !Scopes.back().isWrappedInContination()) {
// If the refactoring is on the call itself, do not require the callee
// to have the @available attribute or a completion-like name.
auto HandlerDesc = AsyncHandlerParamDesc::find(
getUnderlyingFunc(CE->getFn()),
/*RequireAttributeOrName=*/StartNode.dyn_cast<Expr *>() != CE);
if (HandlerDesc.isValid()) {
return addCustom(CE->getSourceRange(),
[&]() { addHoistedCallback(CE, HandlerDesc); });
}
}
}
// A void SingleValueStmtExpr is semantically more like a statement than
// an expression, so recurse without bumping the expr depth or wrapping in
// continuation.
if (auto *SVE = dyn_cast<SingleValueStmtExpr>(E)) {
auto ty = SVE->getType();
if (!ty || ty->isVoid())
return true;
}
// We didn't do any special conversion for this expression. If needed, wrap
// it in a continuation.
wrapScopeInContinationIfNecessary(E);
NestedExprCount++;
return true;
}
bool replaceRangeWithPlaceholder(SourceRange range) {
return addCustom(range, [&]() {
OS << PLACEHOLDER_START;
addRange(range, /*toEndOfToken*/ true);
OS << PLACEHOLDER_END;
});
}
bool walkToExprPost(Expr *E) override {
if (auto *SVE = dyn_cast<SingleValueStmtExpr>(E)) {
auto ty = SVE->getType();
if (!ty || ty->isVoid())
return true;
}
NestedExprCount--;
return true;
}
#undef PLACEHOLDER_START
#undef PLACEHOLDER_END
bool walkToStmtPre(Stmt *S) override {
// CaseStmt has an implicit BraceStmt inside it, which *should* start a new
// scope, so don't check isImplicit here.
if (startsNewScope(S)) {
// Add all names of decls referenced within this statement that aren't
// also declared first, plus any contexts. Note that \c getReferencedDecl
// will only return a value for a \c BraceStmt. This means that \c IfStmt
// (and other statements with conditions) will have their own empty scope,
// which is fine for our purposes - their existing names are always valid.
// The body of those statements will include the decls if they've been
// referenced, so shadowing is still avoided there.
if (auto *ReferencedDecls = ScopedDecls.getReferencedDecls(S)) {
llvm::DenseSet<const Decl *> Decls;
for (auto DeclAndNumRefs : *ReferencedDecls)
Decls.insert(DeclAndNumRefs.first);
addNewScope(Decls);
} else {
addNewScope({});
}
} else if (Hoisting && !S->isImplicit()) {
// Some break and return statements need to be turned into placeholders,
// as they may no longer perform the control flow that the user is
// expecting.
if (auto *BS = dyn_cast<BreakStmt>(S)) {
// For a break, if it's jumping out of a switch statement that we've
// re-written as a part of the transform, turn it into a placeholder, as
// it would have been lifted out of the switch statement.
if (auto *SS = dyn_cast<SwitchStmt>(BS->getTarget())) {
if (HandledSwitches.contains(SS))
return replaceRangeWithPlaceholder(S->getSourceRange());
}
} else if (isa<ReturnStmt>(S) && NestedExprCount == 0) {
// For a return, if it's not nested inside another closure or function,
// turn it into a placeholder, as it will be lifted out of the callback.
// Note that we only turn the 'return' token into a placeholder as we
// still want to be able to apply transforms to the argument.
replaceRangeWithPlaceholder(S->getStartLoc());
}
}
return true;
}
bool walkToStmtPost(Stmt *S) override {
if (startsNewScope(S)) {
bool ClosedScopeWasWrappedInContinuation =
Scopes.back().isWrappedInContination();
Scopes.pop_back();
if (ClosedScopeWasWrappedInContinuation &&
!Scopes.back().isWrappedInContination()) {
// The nested scope was wrapped in a continuation but the current one
// isn't anymore. Add the '}' that corresponds to the the call to
// withChecked(Throwing)Continuation.
insertCustom(S->getEndLoc(), [&]() { OS << tok::r_brace << '\n'; });
}
}
return true;
}
bool addCustom(SourceRange Range, llvm::function_ref<void()> Custom = {}) {
addRange(LastAddedLoc, Range.Start);
Custom();
LastAddedLoc = Lexer::getLocForEndOfToken(SM, Range.End);
return false;
}
/// Insert custom text at the given \p Loc that shouldn't replace any existing
/// source code.
bool insertCustom(SourceLoc Loc, llvm::function_ref<void()> Custom = {}) {
addRange(LastAddedLoc, Loc);
Custom();
LastAddedLoc = Loc;
return false;
}
void addRange(SourceLoc Start, SourceLoc End, bool ToEndOfToken = false) {
if (ToEndOfToken) {
OS << Lexer::getCharSourceRangeFromSourceRange(SM,
SourceRange(Start, End))
.str();
} else {
OS << CharSourceRange(SM, Start, End).str();
}
}
void addRange(SourceRange Range, bool ToEndOfToken = false) {
addRange(Range.Start, Range.End, ToEndOfToken);
}
void addFuncDecl(const FuncDecl *FD) {
auto *Params = FD->getParameters();
auto *HandlerParam = TopHandler.getHandlerParam();
auto ParamPos = TopHandler.handlerParamPosition();
// If the completion handler parameter has a default argument, the async
// version is effectively @discardableResult, as not all the callers care
// about receiving the completion call.
if (HandlerParam && HandlerParam->isDefaultArgument())
OS << tok::at_sign << "discardableResult" << "\n";
// First chunk: start -> the parameter to remove (if any)
SourceLoc LeftEndLoc;
switch (ParamPos) {
case AsyncHandlerParamDesc::Position::None:
case AsyncHandlerParamDesc::Position::Only:
case AsyncHandlerParamDesc::Position::First:
// Handler is the first param (or there is none), so only include the (
LeftEndLoc = Params->getLParenLoc().getAdvancedLoc(1);
break;
case AsyncHandlerParamDesc::Position::Middle:
// Handler is somewhere in the middle of the params, so we need to
// include any comments and comma up until the handler
LeftEndLoc = Params->get(TopHandler.Index)->getStartLoc();
LeftEndLoc = getLocIncludingPrecedingComment(LeftEndLoc);
break;
case AsyncHandlerParamDesc::Position::Last:
// Handler is the last param, which means we don't want the comma. This
// is a little annoying since we *do* want the comments past for the
// last parameter
LeftEndLoc = Lexer::getLocForEndOfToken(
SM, Params->get(TopHandler.Index - 1)->getEndLoc());
// Skip to the end of any comments
Token Next = Lexer::getTokenAtLocation(SM, LeftEndLoc,
CommentRetentionMode::None);
if (Next.getKind() != tok::NUM_TOKENS)
LeftEndLoc = Next.getLoc();
break;
}
addRange(FD->getSourceRangeIncludingAttrs().Start, LeftEndLoc);
// Second chunk: end of the parameter to remove -> right parenthesis
SourceLoc MidStartLoc;
SourceLoc MidEndLoc = Params->getRParenLoc().getAdvancedLoc(1);
switch (ParamPos) {
case AsyncHandlerParamDesc::Position::None:
// No handler param, so make sure to include them all
MidStartLoc = LeftEndLoc;
break;
case AsyncHandlerParamDesc::Position::First:
case AsyncHandlerParamDesc::Position::Middle:
// Handler param is either the first or one of the middle params. Skip
// past it but make sure to include comments preceding the param after
// the handler
MidStartLoc = Params->get(TopHandler.Index + 1)->getStartLoc();
MidStartLoc = getLocIncludingPrecedingComment(MidStartLoc);
break;
case AsyncHandlerParamDesc::Position::Only:
case AsyncHandlerParamDesc::Position::Last:
// Handler param is last, this is easy since there's no other params
// to copy over
MidStartLoc = Params->getRParenLoc();
break;
}
addRange(MidStartLoc, MidEndLoc);
// Third chunk: add in async and throws if necessary
if (!FD->hasAsync())
OS << " async";
if (FD->hasThrows() || TopHandler.HasError)
// TODO: Add throws if converting a function and it has a converted call
// without a do/catch
OS << " " << tok::kw_throws;
// Fourth chunk: if no parent handler (ie. not adding an async
// alternative), the rest of the decl. Otherwise, add in the new return
// type
if (!TopHandler.isValid()) {
SourceLoc RightStartLoc = MidEndLoc;
if (FD->hasThrows()) {
RightStartLoc = Lexer::getLocForEndOfToken(SM, FD->getThrowsLoc());
}
SourceLoc RightEndLoc =
FD->getBody() ? FD->getBody()->getLBraceLoc() : RightStartLoc;
addRange(RightStartLoc, RightEndLoc);
return;
}
SmallVector<LabeledReturnType, 2> Scratch;
auto ReturnTypes = TopHandler.getAsyncReturnTypes(Scratch);
if (ReturnTypes.empty()) {
OS << " ";
return;
}
// Print the function result type, making sure to omit a '-> Void' return.
if (!TopHandler.willAsyncReturnVoid()) {
OS << " -> ";
addAsyncFuncReturnType(TopHandler);
}
if (FD->hasBody())
OS << " ";
// TODO: Should remove the generic param and where clause for the error
// param if it exists (and no other parameter uses that type)
TrailingWhereClause *TWC = FD->getTrailingWhereClause();
if (TWC && TWC->getWhereLoc().isValid()) {
auto Range = TWC->getSourceRange();
OS << Lexer::getCharSourceRangeFromSourceRange(SM, Range).str();
if (FD->hasBody())
OS << " ";
}
}
void addFallbackVars(ArrayRef<const ParamDecl *> FallbackParams,
const ClosureCallbackParams &AllParams) {
for (auto *Param : FallbackParams) {
auto Ty = Param->getType();
auto ParamName = newNameFor(Param);
// If this is the known bool success param, we can use 'let' and type it
// as non-optional, as it gets bound in both blocks.
if (AllParams.isKnownBoolFlagParam(Param)) {
OS << tok::kw_let << " " << ParamName << ": ";
Ty->print(OS);
OS << "\n";
continue;
}
OS << tok::kw_var << " " << ParamName << ": ";
Ty->print(OS);
if (!Ty->getOptionalObjectType())
OS << "?";
OS << " = " << tok::kw_nil << "\n";
}
}
void addDo() { OS << tok::kw_do << " " << tok::l_brace << "\n"; }
/// Assuming that \p Result represents an error result to completion handler,
/// returns \c true if the error has already been handled through a
/// 'try await'.
bool isErrorAlreadyHandled(HandlerResult Result) {
assert(Result.isError());
assert(Result.args().size() == 1 &&
"There should only be one error parameter");
// We assume that the error has already been handled if its variable
// declaration doesn't exist anymore, which is the case if it's in
// Placeholders but not in Unwraps (if it's in Placeholders and Unwraps
// an optional Error has simply been promoted to a non-optional Error).
if (auto *DRE = dyn_cast<DeclRefExpr>(Result.args().back().getExpr())) {
if (Placeholders.count(DRE->getDecl()) &&
!Unwraps.count(DRE->getDecl())) {
return true;
}
}
return false;
}
/// Returns \c true if the source representation of \p E can be interpreted
/// as an expression returning an Optional value.
bool isExpressionOptional(Expr *E) {
if (isa<InjectIntoOptionalExpr>(E)) {
// E is downgrading a non-Optional result to an Optional. Its source
// representation isn't Optional.
return false;
}
if (auto DRE = dyn_cast<DeclRefExpr>(E)) {
if (Unwraps.count(DRE->getDecl())) {
// E has been promoted to a non-Optional value. It can't be used as an
// Optional anymore.
return false;
}
}
if (!E->getType().isNull() && E->getType()->isOptional()) {
return true;
}
// We couldn't determine the type. Assume non-Optional.
return false;
}
/// Converts a call \p CE to a completion handler. Depending on the call it
/// will be interpreted as a call that's returning a success result, an error
/// or, if the call is completely ambiguous, adds an if-let that checks if the
/// error is \c nil at runtime and dispatches to the success or error case
/// depending on it.
/// \p AddConvertedHandlerCall needs to add the converted version of the
/// completion handler. Depending on the given \c HandlerResult, it must be
/// intepreted as a success or error call.
/// \p AddConvertedErrorCall must add the converted equivalent of returning an
/// error. The passed \c StringRef contains the name of a variable that is of
/// type 'Error'.
void convertHandlerCall(
const CallExpr *CE,
llvm::function_ref<void(HandlerResult)> AddConvertedHandlerCall,
llvm::function_ref<void(StringRef)> AddConvertedErrorCall) {
auto Result =
TopHandler.extractResultArgs(CE, /*ReturnErrorArgsIfAmbiguous=*/true);
if (!TopHandler.isAmbiguousCallToParamHandler(CE)) {
if (Result.isError()) {
if (!isErrorAlreadyHandled(Result)) {
// If the error has already been handled, we don't need to add another
// throwing call.
AddConvertedHandlerCall(Result);
}
} else {
AddConvertedHandlerCall(Result);
}
} else {
assert(Result.isError() && "If the call was ambiguous, we should have "
"retrieved its error representation");
assert(Result.args().size() == 1 &&
"There should only be one error parameter");
Expr *ErrorExpr = Result.args().back().getExpr();
if (isErrorAlreadyHandled(Result)) {
// The error has already been handled, interpret the call as a success
// call.
auto SuccessExprs = TopHandler.extractResultArgs(
CE, /*ReturnErrorArgsIfAmbiguous=*/false);
AddConvertedHandlerCall(SuccessExprs);
} else if (!isExpressionOptional(ErrorExpr)) {
// The error is never nil. No matter what the success param is, we
// interpret it as an error call.
AddConvertedHandlerCall(Result);
} else {
// The call was truly ambiguous. Add an
// if let error = <convert error arg> {
// throw error // or equivalent
// } else {
// <interpret call as success call>
// }
auto SuccessExprs = TopHandler.extractResultArgs(
CE, /*ReturnErrorArgsIfAmbiguous=*/false);
// The variable 'error' is only available in the 'if let' scope, so we
// don't need to create a new unique one.
StringRef ErrorName = "error";
OS << tok::kw_if << ' ' << tok::kw_let << ' ' << ErrorName << ' '
<< tok::equal << ' ';
convertNode(ErrorExpr, /*StartOverride=*/{}, /*ConvertCalls=*/false);
OS << ' ' << tok::l_brace << '\n';
AddConvertedErrorCall(ErrorName);
OS << tok::r_brace << ' ' << tok::kw_else << ' ' << tok::l_brace
<< '\n';
AddConvertedHandlerCall(SuccessExprs);
OS << '\n' << tok::r_brace;
}
}
}
/// Convert a call \p CE to a completion handler to its 'return' or 'throws'
/// equivalent.
void convertHandlerToReturnOrThrows(const CallExpr *CE) {
return convertHandlerCall(
CE,
[&](HandlerResult Exprs) {
convertHandlerToReturnOrThrowsImpl(CE, Exprs);
},
[&](StringRef ErrorName) {
OS << tok::kw_throw << ' ' << ErrorName << '\n';
});
}
/// Convert the call \p CE to a completion handler to its 'return' or 'throws'
/// equivalent, where \p Result determines whether the call should be
/// interpreted as an error or success call.
void convertHandlerToReturnOrThrowsImpl(const CallExpr *CE,
HandlerResult Result) {
bool AddedReturnOrThrow = true;
if (!Result.isError()) {
// It's possible the user has already written an explicit return statement
// for the completion handler call, e.g 'return completion(args...)'. In
// that case, be sure not to add another return.
auto *parent = getWalker().Parent.getAsStmt();
if (isa_and_nonnull<ReturnStmt>(parent) &&
!cast<ReturnStmt>(parent)->isImplicit()) {
// The statement already has a return keyword. Don't add another one.
AddedReturnOrThrow = false;
} else {
OS << tok::kw_return;
}
} else {
OS << tok::kw_throw;
}
auto Args = Result.args();
if (!Args.empty()) {
if (AddedReturnOrThrow)
OS << ' ';
addTupleOf(Args, OS, [&](Argument Arg) {
// Special case: If the completion handler is a params handler that
// takes an error, we could pass arguments to it without unwrapping
// them. E.g.
// simpleWithError { (res: String?, error: Error?) in
// completion(res, nil)
// }
// But after refactoring `simpleWithError` to an async function we have
// let res: String = await simple()
// and `res` is no longer an `Optional`. Thus it's in `Placeholders` and
// `Unwraps` and any reference to it will be replaced by a placeholder
// unless it is wrapped in an unwrapping expression. This would cause us
// to create `return <#res# >`.
// Under our assumption that either the error or the result parameter
// are non-nil, the above call to the completion handler is equivalent
// to
// completion(res!, nil)
// which correctly yields
// return res
// Synthesize the force unwrap so that we get the expected results.
auto *E = Arg.getExpr();
if (TopHandler.getHandlerType() == HandlerType::PARAMS &&
TopHandler.HasError) {
if (auto DRE =
dyn_cast<DeclRefExpr>(E->getSemanticsProvidingExpr())) {
auto D = DRE->getDecl();
if (Unwraps.count(D)) {
E = new (getASTContext()) ForceValueExpr(E, SourceLoc());
}
}
}
// Can't just add the range as we need to perform replacements
convertNode(E, /*StartOverride=*/Arg.getLabelLoc(),
/*ConvertCalls=*/false);
});
}
}
/// Convert a call \p CE to a completion handler to resumes of the
/// continuation that's currently on top of the stack.
void convertHandlerToContinuationResume(const CallExpr *CE) {
return convertHandlerCall(
CE,
[&](HandlerResult Exprs) {
convertHandlerToContinuationResumeImpl(CE, Exprs);
},
[&](StringRef ErrorName) {
Identifier ContinuationName = Scopes.back().ContinuationName;
OS << ContinuationName << tok::period << "resume" << tok::l_paren
<< "throwing" << tok::colon << ' ' << ErrorName;
OS << tok::r_paren << '\n';
});
}
/// Convert a call \p CE to a completion handler to resumes of the
/// continuation that's currently on top of the stack.
/// \p Result determines whether the call should be interpreted as a success
/// or error call.
void convertHandlerToContinuationResumeImpl(const CallExpr *CE,
HandlerResult Result) {
assert(Scopes.back().isWrappedInContination());
std::vector<Argument> Args;
StringRef ResumeArgumentLabel;
switch (TopHandler.getHandlerType()) {
case HandlerType::PARAMS: {
Args = Result.args();
if (!Result.isError()) {
ResumeArgumentLabel = "returning";
} else {
ResumeArgumentLabel = "throwing";
}
break;
}
case HandlerType::RESULT: {
Args = {CE->getArgs()->begin(), CE->getArgs()->end()};
ResumeArgumentLabel = "with";
break;
}
case HandlerType::INVALID:
llvm_unreachable("Invalid top handler");
}
// A vector in which each argument of Result has an entry. If the entry is
// not empty then that argument has been unwrapped using 'guard let' into
// a variable with that name.
SmallVector<Identifier, 4> ArgNames;
ArgNames.reserve(Args.size());
/// When unwrapping a result argument \p Arg into a variable using
/// 'guard let' return a suitable name for the unwrapped variable.
/// \p ArgIndex is the index of \p Arg in the results passed to the
/// completion handler.
auto GetSuitableNameForGuardUnwrap = [&](Expr *Arg,
unsigned ArgIndex) -> Identifier {
// If Arg is a DeclRef, use its name for the guard unwrap.
// guard let myVar1 = myVar.
if (auto DRE = dyn_cast<DeclRefExpr>(Arg)) {
return createUniqueName(DRE->getDecl()->getBaseIdentifier().str());
} else if (auto IIOE = dyn_cast<InjectIntoOptionalExpr>(Arg)) {
if (auto DRE = dyn_cast<DeclRefExpr>(IIOE->getSubExpr())) {
return createUniqueName(DRE->getDecl()->getBaseIdentifier().str());
}
}
if (Args.size() == 1) {
// We only have a single result. 'result' seems a resonable name.
return createUniqueName("result");
} else {
// We are returning a tuple. Name the result elements 'result' +
// index in tuple.
return createUniqueName("result" + std::to_string(ArgIndex));
}
};
unsigned ArgIndex = 0;
for (auto Arg : Args) {
auto *ArgExpr = Arg.getExpr();
Identifier ArgName;
if (isExpressionOptional(ArgExpr) && TopHandler.HasError) {
ArgName = GetSuitableNameForGuardUnwrap(ArgExpr, ArgIndex);
Scopes.back().Names.insert(ArgName);
OS << tok::kw_guard << ' ' << tok::kw_let << ' ' << ArgName << ' '
<< tok::equal << ' ';
// If the argument is a call with a trailing closure, the generated
// guard statement will not compile.
// e.g. 'guard let result1 = value.map { $0 + 1 } else { ... }'
// doesn't compile. Adding parentheses makes the code compile.
auto HasTrailingClosure = false;
if (auto *CE = dyn_cast<CallExpr>(ArgExpr)) {
if (CE->getArgs()->hasAnyTrailingClosures())
HasTrailingClosure = true;
}
if (HasTrailingClosure)
OS << tok::l_paren;
convertNode(ArgExpr, /*StartOverride=*/Arg.getLabelLoc(),
/*ConvertCalls=*/false);
if (HasTrailingClosure)
OS << tok::r_paren;
OS << ' ' << tok::kw_else << ' ' << tok::l_brace << '\n';
OS << "fatalError" << tok::l_paren;
OS << "\"Expected non-nil result ";
if (ArgName.str() != "result") {
OS << "'" << ArgName << "' ";
}
OS << "in the non-error case\"";
OS << tok::r_paren << '\n';
OS << tok::r_brace << '\n';
}
ArgNames.push_back(ArgName);
ArgIndex++;
}
Identifier ContName = Scopes.back().ContinuationName;
OS << ContName << tok::period << "resume" << tok::l_paren
<< ResumeArgumentLabel << tok::colon << ' ';
ArgIndex = 0;
addTupleOf(Args, OS, [&](Argument Arg) {
Identifier ArgName = ArgNames[ArgIndex];
if (!ArgName.empty()) {
OS << ArgName;
} else {
// Can't just add the range as we need to perform replacements
convertNode(Arg.getExpr(), /*StartOverride=*/Arg.getLabelLoc(),
/*ConvertCalls=*/false);
}
ArgIndex++;
});
OS << tok::r_paren;
}
/// From the given expression \p E, which is an argument to a function call,
/// extract the passed closure if there is one. Otherwise return \c nullptr.
ClosureExpr *extractCallback(Expr *E) {
E = lookThroughFunctionConversionExpr(E);
if (auto Closure = dyn_cast<ClosureExpr>(E)) {
return Closure;
} else if (auto CaptureList = dyn_cast<CaptureListExpr>(E)) {
return dyn_cast<ClosureExpr>(CaptureList->getClosureBody());
} else {
return nullptr;
}
}
/// Callback arguments marked as e.g. `@convention(block)` produce arguments
/// that are `FunctionConversionExpr`.
/// We don't care about the conversions and want to shave them off.
Expr *lookThroughFunctionConversionExpr(Expr *E) {
if (auto FunctionConversion = dyn_cast<FunctionConversionExpr>(E)) {
return lookThroughFunctionConversionExpr(
FunctionConversion->getSubExpr());
} else {
return E;
}
}
void addHoistedCallback(const CallExpr *CE,
const AsyncHandlerParamDesc &HandlerDesc) {
llvm::SaveAndRestore<bool> RestoreHoisting(Hoisting, true);
auto *ArgList = CE->getArgs();
if (HandlerDesc.Index >= ArgList->size()) {
DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg);
return;
}
Expr *CallbackArg =
lookThroughFunctionConversionExpr(ArgList->getExpr(HandlerDesc.Index));
if (ClosureExpr *Callback = extractCallback(CallbackArg)) {
// The user is using a closure for the completion handler
addHoistedClosureCallback(CE, HandlerDesc, Callback);
return;
}
if (auto CallbackDecl = getReferencedDecl(CallbackArg)) {
if (CallbackDecl == TopHandler.getHandler()) {
// We are refactoring the function that declared the completion handler
// that would be called here. We can't call the completion handler
// anymore because it will be removed. But since the function that
// declared it is being refactored to async, we can just return the
// values.
if (!HandlerDesc.willAsyncReturnVoid()) {
OS << tok::kw_return << " ";
}
InlinePatternsToPrint InlinePatterns;
addAwaitCall(CE, ClassifiedBlock(), {}, InlinePatterns, HandlerDesc,
/*AddDeclarations*/ false);
return;
}
// We are not removing the completion handler, so we can call it once the
// async function returns.
// The completion handler that is called as part of the \p CE call.
// This will be called once the async function returns.
auto CompletionHandler =
AsyncHandlerDesc::get(CallbackDecl, /*RequireAttributeOrName=*/false);
if (CompletionHandler.isValid()) {
if (auto CalledFunc = getUnderlyingFunc(CE->getFn())) {
StringRef HandlerName = Lexer::getCharSourceRangeFromSourceRange(
SM, CallbackArg->getSourceRange()).str();
addHoistedNamedCallback(
CalledFunc, CompletionHandler, HandlerName, [&] {
InlinePatternsToPrint InlinePatterns;
addAwaitCall(CE, ClassifiedBlock(), {}, InlinePatterns,
HandlerDesc, /*AddDeclarations*/ false);
});
return;
}
}
}
DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg);
}
/// Add a binding to a known bool flag that indicates success or failure.
void addBoolFlagParamBindingIfNeeded(Optional<KnownBoolFlagParam> Flag,
BlockKind Block) {
if (!Flag)
return;
// Figure out the polarity of the binding based on the block we're in and
// whether the flag indicates success.
auto Polarity = true;
switch (Block) {
case BlockKind::SUCCESS:
break;
case BlockKind::ERROR:
Polarity = !Polarity;
break;
case BlockKind::FALLBACK:
llvm_unreachable("Not a valid place to bind");
}
if (!Flag->IsSuccessFlag)
Polarity = !Polarity;
OS << newNameFor(Flag->Param) << " " << tok::equal << " ";
OS << (Polarity ? tok::kw_true : tok::kw_false) << "\n";
}
/// Add a call to the async alternative of \p CE and convert the \p Callback
/// to be executed after the async call. \p HandlerDesc describes the
/// completion handler in the function that's called by \p CE and \p ArgList
/// are the arguments being passed in \p CE.
void addHoistedClosureCallback(const CallExpr *CE,
const AsyncHandlerParamDesc &HandlerDesc,
const ClosureExpr *Callback) {
if (HandlerDesc.params().size() != Callback->getParameters()->size()) {
DiagEngine.diagnose(CE->getStartLoc(), diag::mismatched_callback_args);
return;
}
ClosureCallbackParams CallbackParams(HandlerDesc, Callback);
ClassifiedBlocks Blocks;
auto *CallbackBody = Callback->getBody();
if (!HandlerDesc.HasError) {
Blocks.SuccessBlock.addNodesInBraceStmt(CallbackBody);
} else if (!CallbackBody->getElements().empty()) {
CallbackClassifier::classifyInto(Blocks, CallbackParams, HandledSwitches,
DiagEngine, CallbackBody);
}
auto SuccessBindings = CallbackParams.getParamsToBind(BlockKind::SUCCESS);
auto *ErrParam = CallbackParams.getErrParam();
if (DiagEngine.hadAnyError()) {
// For now, only fallback when the results are params with an error param,
// in which case only the names are used (defaulted to the names of the
// params if none).
if (HandlerDesc.Type != HandlerType::PARAMS || !HandlerDesc.HasError)
return;
DiagEngine.resetHadAnyError();
// Note that we don't print any inline patterns here as we just want
// assignments to the names in the outer scope.
InlinePatternsToPrint InlinePatterns;
auto AllBindings = CallbackParams.getParamsToBind(BlockKind::FALLBACK);
prepareNames(ClassifiedBlock(), AllBindings, InlinePatterns);
preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams,
BlockKind::FALLBACK);
addFallbackVars(AllBindings, CallbackParams);
addDo();
addAwaitCall(CE, Blocks.SuccessBlock, SuccessBindings, InlinePatterns,
HandlerDesc, /*AddDeclarations*/ false);
OS << "\n";
// If we have a known Bool success param, we need to bind it.
addBoolFlagParamBindingIfNeeded(CallbackParams.getKnownBoolFlagParam(),
BlockKind::SUCCESS);
addFallbackCatch(CallbackParams);
OS << "\n";
convertNodes(NodesToPrint::inBraceStmt(CallbackBody));
clearNames(AllBindings);
return;
}
auto *ErrOrResultParam = ErrParam;
if (auto *ResultParam = CallbackParams.getResultParam())
ErrOrResultParam = ResultParam;
auto ErrorNodes = Blocks.ErrorBlock.nodesToPrint().getNodes();
bool RequireDo = !ErrorNodes.empty();
// Check if we *actually* need a do/catch (see class comment)
if (ErrorNodes.size() == 1) {
auto Node = ErrorNodes[0];
if (auto *HandlerCall = TopHandler.getAsHandlerCall(Node)) {
auto Res = TopHandler.extractResultArgs(
HandlerCall, /*ReturnErrorArgsIfAmbiguous=*/true);
if (Res.args().size() == 1) {
// Skip if we have the param itself or the name it's bound to
auto *ArgExpr = Res.args()[0].getExpr();
auto *SingleDecl = ArgExpr->getReferencedDecl().getDecl();
auto ErrName = Blocks.ErrorBlock.boundName(ErrOrResultParam);
RequireDo = SingleDecl != ErrOrResultParam &&
!(Res.isError() && SingleDecl &&
SingleDecl->getName().isSimpleName(ErrName));
}
}
}
// If we're not requiring a 'do', we'll be dropping the error block. But
// let's make sure we at least preserve the comments in the error block by
// transplanting them into the success block. This should make sure they
// maintain a sensible ordering.
if (!RequireDo) {
auto ErrorNodes = Blocks.ErrorBlock.nodesToPrint();
for (auto CommentLoc : ErrorNodes.getPossibleCommentLocs())
Blocks.SuccessBlock.addPossibleCommentLoc(CommentLoc);
}
if (RequireDo) {
addDo();
}
auto InlinePatterns = getInlinePatternsToPrint(Blocks.SuccessBlock,
SuccessBindings, Callback);
prepareNames(Blocks.SuccessBlock, SuccessBindings, InlinePatterns);
preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams,
BlockKind::SUCCESS);
addAwaitCall(CE, Blocks.SuccessBlock, SuccessBindings, InlinePatterns,
HandlerDesc, /*AddDeclarations=*/true);
printOutOfLineBindingPatterns(Blocks.SuccessBlock, InlinePatterns);
convertNodes(Blocks.SuccessBlock.nodesToPrint());
clearNames(SuccessBindings);
if (RequireDo) {
// We don't use inline patterns for the error path.
InlinePatternsToPrint ErrInlinePatterns;
// Always use the ErrParam name if none is bound.
prepareNames(Blocks.ErrorBlock, llvm::makeArrayRef(ErrOrResultParam),
ErrInlinePatterns,
/*AddIfMissing=*/HandlerDesc.Type != HandlerType::RESULT);
preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams,
BlockKind::ERROR);
addCatch(ErrOrResultParam);
convertNodes(Blocks.ErrorBlock.nodesToPrint());
OS << "\n" << tok::r_brace;
clearNames(llvm::makeArrayRef(ErrOrResultParam));
}
}
/// Add a call to the async alternative of \p FD. Afterwards, pass the results
/// of the async call to the completion handler, named \p HandlerName and
/// described by \p HandlerDesc.
/// \p AddAwaitCall adds the call to the refactored async method to the output
/// stream without storing the result to any variables.
/// This is used when the user didn't use a closure for the callback, but
/// passed in a variable or function name for the completion handler.
void addHoistedNamedCallback(const FuncDecl *FD,
const AsyncHandlerDesc &HandlerDesc,
StringRef HandlerName,
std::function<void(void)> AddAwaitCall) {
if (HandlerDesc.HasError) {
// "result" and "error" always okay to use here since they're added
// in their own scope, which only contains new code.
addDo();
if (!HandlerDesc.willAsyncReturnVoid()) {
OS << tok::kw_let << " result";
addResultTypeAnnotationIfNecessary(FD, HandlerDesc);
OS << " " << tok::equal << " ";
}
AddAwaitCall();
OS << "\n";
addCallToCompletionHandler("result", HandlerDesc, HandlerName);
OS << "\n";
OS << tok::r_brace << " " << tok::kw_catch << " " << tok::l_brace << "\n";
addCallToCompletionHandler(StringRef(), HandlerDesc, HandlerName);
OS << "\n" << tok::r_brace; // end catch
} else {
// This code may be placed into an existing scope, in that case create
// a unique "result" name so that it doesn't cause shadowing or redecls.
StringRef ResultName;
if (!HandlerDesc.willAsyncReturnVoid()) {
Identifier Unique = createUniqueName("result");
Scopes.back().Names.insert(Unique);
ResultName = Unique.str();
OS << tok::kw_let << " " << ResultName;
addResultTypeAnnotationIfNecessary(FD, HandlerDesc);
OS << " " << tok::equal << " ";
} else {
// The name won't end up being used, just give it a bogus one so that
// the result path is taken (versus the error path).
ResultName = "result";
}
AddAwaitCall();
OS << "\n";
addCallToCompletionHandler(ResultName, HandlerDesc, HandlerName);
}
}
/// Checks whether a binding pattern for a given decl can be printed inline in
/// an await call, e.g 'let ((x, y), z) = await foo()', where '(x, y)' is the
/// inline pattern.
const Pattern *
bindingPatternToPrintInline(const Decl *D, const ClassifiedBlock &Block,
const ClosureExpr *CallbackClosure) {
// Only currently done for callback closures.
if (!CallbackClosure)
return nullptr;
// If we can reduce the pattern bindings down to a single pattern, we may
// be able to print it inline.
auto *P = Block.getSinglePatternFor(D);
if (!P)
return nullptr;
// Patterns that bind a single var are always printed inline.
if (P->getSingleVar())
return P;
// If we have a multi-var binding, and the decl being bound is referenced
// elsewhere in the block, we cannot print the pattern immediately in the
// await call. Instead, we'll print it out of line.
auto *Decls = ScopedDecls.getReferencedDecls(CallbackClosure->getBody());
assert(Decls);
auto NumRefs = Decls->lookup(D);
return NumRefs == 1 ? P : nullptr;
}
/// Retrieve a map of patterns to print inline for an array of param decls.
InlinePatternsToPrint
getInlinePatternsToPrint(const ClassifiedBlock &Block,
ArrayRef<const ParamDecl *> Params,
const ClosureExpr *CallbackClosure) {
InlinePatternsToPrint Patterns;
for (auto *Param : Params) {
if (auto *P = bindingPatternToPrintInline(Param, Block, CallbackClosure))
Patterns[Param] = P;
}
return Patterns;
}
/// Print any out of line binding patterns that could not be printed as inline
/// patterns. These typically appear directly after an await call, e.g:
/// \code
/// let x = await foo()
/// let (y, z) = x
/// \endcode
void
printOutOfLineBindingPatterns(const ClassifiedBlock &Block,
const InlinePatternsToPrint &InlinePatterns) {
for (auto &Entry : Block.paramPatternBindings()) {
auto *D = Entry.first;
auto Aliases = Block.getAliasesFor(D);
for (auto *P : Entry.second) {
// If we already printed this as an inline pattern, there's nothing else
// to do.
if (InlinePatterns.lookup(D) == P)
continue;
// If this is an alias binding, it can be elided.
if (auto *SingleVar = P->getSingleVar()) {
if (Aliases.contains(SingleVar))
continue;
}
auto HasMutable = P->hasAnyMutableBindings();
OS << "\n" << (HasMutable ? tok::kw_var : tok::kw_let) << " ";
convertPattern(P);
OS << " = ";
OS << newNameFor(D);
}
}
}
/// Prints an \c await call to an \c async function, binding any return values
/// into variables.
///
/// \param CE The call expr to convert.
/// \param SuccessBlock The nodes present in the success block following the
/// call.
/// \param SuccessParams The success parameters, which will be printed as
/// return values.
/// \param InlinePatterns A map of patterns that can be printed inline for
/// a given param.
/// \param HandlerDesc A description of the completion handler.
/// \param AddDeclarations Whether or not to add \c let or \c var keywords to
/// the return value bindings.
void addAwaitCall(const CallExpr *CE, const ClassifiedBlock &SuccessBlock,
ArrayRef<const ParamDecl *> SuccessParams,
const InlinePatternsToPrint &InlinePatterns,
const AsyncHandlerParamDesc &HandlerDesc,
bool AddDeclarations) {
auto *Args = CE->getArgs();
// Print the bindings to match the completion handler success parameters,
// making sure to omit in the case of a Void return.
if (!SuccessParams.empty() && !HandlerDesc.willAsyncReturnVoid()) {
auto AllLet = true;
// Gather the items to print for the variable bindings. This can either be
// a param decl, or a pattern that binds it.
using DeclOrPattern = llvm::PointerUnion<const Decl *, const Pattern *>;
SmallVector<DeclOrPattern, 4> ToPrint;
for (auto *Param : SuccessParams) {
// Check if we have an inline pattern to print.
if (auto *P = InlinePatterns.lookup(Param)) {
if (P->hasAnyMutableBindings())
AllLet = false;
ToPrint.push_back(P);
continue;
}
ToPrint.push_back(Param);
}
if (AddDeclarations) {
if (AllLet) {
OS << tok::kw_let;
} else {
OS << tok::kw_var;
}
OS << " ";
}
// 'res =' or '(res1, res2, ...) ='
addTupleOf(ToPrint, OS, [&](DeclOrPattern Elt) {
if (auto *P = Elt.dyn_cast<const Pattern *>()) {
convertPattern(P);
return;
}
OS << newNameFor(Elt.get<const Decl *>());
});
OS << " " << tok::equal << " ";
}
if (HandlerDesc.HasError) {
OS << tok::kw_try << " ";
}
OS << "await ";
// Try to replace the name with that of the alternative. Use the existing
// name if for some reason that's not possible.
bool NameAdded = false;
if (HandlerDesc.Alternative) {
const ValueDecl *Named = HandlerDesc.Alternative;
if (auto *Accessor = dyn_cast<AccessorDecl>(HandlerDesc.Alternative))
Named = Accessor->getStorage();
if (!Named->getBaseName().isSpecial()) {
Names.try_emplace(HandlerDesc.Func,
Named->getBaseName().getIdentifier());
convertNode(CE->getFn(), /*StartOverride=*/{}, /*ConvertCalls=*/false,
/*IncludeComments=*/false);
NameAdded = true;
}
}
if (!NameAdded) {
addRange(CE->getStartLoc(), CE->getFn()->getEndLoc(),
/*ToEndOfToken=*/true);
}
if (!HandlerDesc.alternativeIsAccessor())
OS << tok::l_paren;
size_t ConvertedArgIndex = 0;
ArrayRef<ParamDecl *> AlternativeParams;
if (HandlerDesc.Alternative)
AlternativeParams = HandlerDesc.Alternative->getParameters()->getArray();
for (auto I : indices(*Args)) {
auto Arg = Args->get(I);
auto *ArgExpr = Arg.getExpr();
if (I == HandlerDesc.Index || isa<DefaultArgumentExpr>(ArgExpr))
continue;
if (ConvertedArgIndex > 0)
OS << tok::comma << " ";
if (HandlerDesc.Alternative) {
// Skip argument if it's defaulted and has a different name
while (ConvertedArgIndex < AlternativeParams.size() &&
AlternativeParams[ConvertedArgIndex]->isDefaultArgument() &&
AlternativeParams[ConvertedArgIndex]->getArgumentName() !=
Arg.getLabel()) {
ConvertedArgIndex++;
}
if (ConvertedArgIndex < AlternativeParams.size()) {
// Could have a different argument label (or none), so add it instead
auto Name = AlternativeParams[ConvertedArgIndex]->getArgumentName();
if (!Name.empty())
OS << Name << ": ";
convertNode(ArgExpr, /*StartOverride=*/{}, /*ConvertCalls=*/false);
ConvertedArgIndex++;
continue;
}
// Fallthrough if arguments don't match up for some reason
}
// Can't just add the range as we need to perform replacements. Also
// make sure to include the argument label (if any)
convertNode(ArgExpr, /*StartOverride=*/Arg.getLabelLoc(),
/*ConvertCalls=*/false);
ConvertedArgIndex++;
}
if (!HandlerDesc.alternativeIsAccessor())
OS << tok::r_paren;
}
void addFallbackCatch(const ClosureCallbackParams &Params) {
auto *ErrParam = Params.getErrParam();
assert(ErrParam);
auto ErrName = newNameFor(ErrParam);
OS << tok::r_brace << " " << tok::kw_catch << " " << tok::l_brace << "\n"
<< ErrName << " = error\n";
// If we have a known Bool success param, we need to bind it.
addBoolFlagParamBindingIfNeeded(Params.getKnownBoolFlagParam(),
BlockKind::ERROR);
OS << tok::r_brace;
}
void addCatch(const ParamDecl *ErrParam) {
OS << "\n" << tok::r_brace << " " << tok::kw_catch << " ";
auto ErrName = newNameFor(ErrParam, false);
if (!ErrName.empty() && ErrName != "_") {
OS << tok::kw_let << " " << ErrName << " ";
}
OS << tok::l_brace;
}
void preparePlaceholdersAndUnwraps(AsyncHandlerDesc HandlerDesc,
const ClosureCallbackParams &Params,
BlockKind Block) {
// Params that have been dropped always need placeholdering.
for (auto *Param : Params.getAllParams()) {
if (!Params.hasBinding(Param, Block))
Placeholders.insert(Param);
}
// For the fallback case, no other params need placeholdering, as they are
// all freely accessible in the fallback case.
if (Block == BlockKind::FALLBACK)
return;
switch (HandlerDesc.Type) {
case HandlerType::PARAMS: {
auto *ErrParam = Params.getErrParam();
auto SuccessParams = Params.getSuccessParams();
switch (Block) {
case BlockKind::FALLBACK:
llvm_unreachable("Already handled");
case BlockKind::ERROR:
if (ErrParam) {
if (HandlerDesc.shouldUnwrap(ErrParam->getType())) {
Placeholders.insert(ErrParam);
Unwraps.insert(ErrParam);
}
// Can't use success params in the error body
Placeholders.insert(SuccessParams.begin(), SuccessParams.end());
}
break;
case BlockKind::SUCCESS:
for (auto *SuccessParam : SuccessParams) {
auto Ty = SuccessParam->getType();
if (HandlerDesc.shouldUnwrap(Ty)) {
// Either unwrap or replace with a placeholder if there's some other
// reference
Unwraps.insert(SuccessParam);
Placeholders.insert(SuccessParam);
}
// Void parameters get omitted where possible, so turn any reference
// into a placeholder, as its usage is unlikely what the user wants.
if (HandlerDesc.getSuccessParamAsyncReturnType(Ty)->isVoid())
Placeholders.insert(SuccessParam);
}
// Can't use the error param in the success body
if (ErrParam)
Placeholders.insert(ErrParam);
break;
}
break;
}
case HandlerType::RESULT: {
// Any uses of the result parameter in the current body (that aren't
// replaced) are invalid, so replace them with a placeholder.
auto *ResultParam = Params.getResultParam();
assert(ResultParam);
Placeholders.insert(ResultParam);
break;
}
default:
llvm_unreachable("Unhandled handler type");
}
}
/// Add a mapping from each passed parameter to a new name, possibly
/// synthesizing a new one if hoisting it would cause a redeclaration or
/// shadowing. If there's no bound name and \c AddIfMissing is false, no
/// name will be added.
void prepareNames(const ClassifiedBlock &Block,
ArrayRef<const ParamDecl *> Params,
const InlinePatternsToPrint &InlinePatterns,
bool AddIfMissing = true) {
for (auto *PD : Params) {
// If this param is to be replaced by a pattern that binds multiple
// separate vars, it's not actually going to be added to the scope, and
// therefore doesn't need naming. This avoids needing to rename a var with
// the same name later on in the scope, as it's not actually clashing.
if (auto *P = InlinePatterns.lookup(PD)) {
if (!P->getSingleVar())
continue;
}
auto Name = Block.boundName(PD);
if (Name.empty() && !AddIfMissing)
continue;
auto Ident = assignUniqueName(PD, Name);
// Also propagate the name to any aliases.
for (auto *Alias : Block.getAliasesFor(PD))
Names[Alias] = Ident;
}
}
/// Returns a unique name using \c Name as base that doesn't clash with any
/// other names in the current scope.
Identifier createUniqueName(StringRef Name) {
Identifier Ident = getASTContext().getIdentifier(Name);
if (Name == "_")
return Ident;
auto &CurrentNames = Scopes.back().Names;
if (CurrentNames.count(Ident)) {
// Add a number to the end of the name until it's unique given the current
// names in scope.
llvm::SmallString<32> UniquedName;
unsigned UniqueId = 1;
do {
UniquedName = Name;
UniquedName.append(std::to_string(UniqueId));
Ident = getASTContext().getIdentifier(UniquedName);
UniqueId++;
} while (CurrentNames.count(Ident));
}
return Ident;
}
/// Create a unique name for the variable declared by \p D that doesn't
/// clash with any other names in scope, using \p BoundName as the base name
/// if not empty and the name of \p D otherwise. Adds this name to both
/// \c Names and the current scope's names (\c Scopes.Names).
Identifier assignUniqueName(const Decl *D, StringRef BoundName) {
Identifier Ident;
if (BoundName.empty()) {
BoundName = getDeclName(D).userFacingName();
if (BoundName.empty())
return Ident;
}
if (BoundName.startswith("$")) {
llvm::SmallString<8> NewName;
NewName.append("val");
NewName.append(BoundName.drop_front());
Ident = createUniqueName(NewName);
} else {
Ident = createUniqueName(BoundName);
}
Names.try_emplace(D, Ident);
Scopes.back().Names.insert(Ident);
return Ident;
}
StringRef newNameFor(const Decl *D, bool Required = true) {
auto Res = Names.find(D);
if (Res == Names.end()) {
assert(!Required && "Missing name for decl when one was required");
return StringRef();
}
return Res->second.str();
}
void addNewScope(const llvm::DenseSet<const Decl *> &Decls) {
if (Scopes.empty()) {
Scopes.emplace_back(/*ContinuationName=*/Identifier());
} else {
// If the parent scope is nested in a continuation, the new one is also.
// Carry over the continuation name.
Identifier PreviousContinuationName = Scopes.back().ContinuationName;
Scopes.emplace_back(PreviousContinuationName);
}
for (auto D : Decls) {
auto Name = getDeclName(D);
if (!Name.empty())
Scopes.back().Names.insert(Name);
}
}
void clearNames(ArrayRef<const ParamDecl *> Params) {
for (auto *Param : Params) {
Unwraps.erase(Param);
Placeholders.erase(Param);
Names.erase(Param);
}
}
/// Adds a forwarding call to the old completion handler function, with
/// \p HandlerReplacement that allows for a custom replacement or, if empty,
/// removal of the completion handler closure.
void addForwardingCallTo(const FuncDecl *FD, StringRef HandlerReplacement) {
OS << FD->getBaseName() << tok::l_paren;
auto *Params = FD->getParameters();
size_t ConvertedArgsIndex = 0;
for (size_t I = 0, E = Params->size(); I < E; ++I) {
if (I == TopHandler.Index) {
/// If we're not replacing the handler with anything, drop it.
if (HandlerReplacement.empty())
continue;
// Use a trailing closure if the handler is the last param
if (I == E - 1) {
OS << tok::r_paren << " ";
OS << HandlerReplacement;
return;
}
// Otherwise fall through to do the replacement.
}
if (ConvertedArgsIndex > 0)
OS << tok::comma << " ";
const auto *Param = Params->get(I);
if (!Param->getArgumentName().empty())
OS << Param->getArgumentName() << tok::colon << " ";
if (I == TopHandler.Index) {
OS << HandlerReplacement;
} else {
OS << Param->getParameterName();
}
ConvertedArgsIndex++;
}
OS << tok::r_paren;
}
/// Adds a forwarded error argument to a completion handler call. If the error
/// type of \p HandlerDesc is more specialized than \c Error, an
/// 'as! CustomError' cast to the more specialized error type will be added to
/// the output stream.
void addForwardedErrorArgument(StringRef ErrorName,
const AsyncHandlerDesc &HandlerDesc) {
// If the error type is already Error, we can pass it as-is.
auto ErrorType = *HandlerDesc.getErrorType();
if (ErrorType->getCanonicalType() ==
getASTContext().getErrorExistentialType()) {
OS << ErrorName;
return;
}
// Otherwise we need to add a force cast to the destination custom error
// type. If this is for an Error? parameter, we'll need to add parens around
// the cast to silence a compiler warning about force casting never
// producing nil.
auto RequiresParens = HandlerDesc.getErrorParam().has_value();
if (RequiresParens)
OS << tok::l_paren;
OS << ErrorName << " " << tok::kw_as << tok::exclaim_postfix << " ";
ErrorType->lookThroughSingleOptionalType()->print(OS);
if (RequiresParens)
OS << tok::r_paren;
}
/// If \p T has a natural default value like \c nil for \c Optional or \c ()
/// for \c Void, add that default value to the output. Otherwise, add a
/// placeholder that contains \p T's name as the hint.
void addDefaultValueOrPlaceholder(Type T) {
if (T->isOptional()) {
OS << tok::kw_nil;
} else if (T->isVoid()) {
OS << "()";
} else {
OS << "<#";
T.print(OS);
OS << "#>";
}
}
/// Adds the \c Index -th parameter to the completion handler described by \p
/// HanderDesc.
/// If \p ResultName is not empty, it is assumed that a variable with that
/// name contains the result returned from the async alternative. If the
/// callback also takes an error parameter, \c nil passed to the completion
/// handler for the error. If \p ResultName is empty, it is a assumed that a
/// variable named 'error' contains the error thrown from the async method and
/// 'nil' will be passed to the completion handler for all result parameters.
void addCompletionHandlerArgument(size_t Index, StringRef ResultName,
const AsyncHandlerDesc &HandlerDesc) {
if (HandlerDesc.HasError && Index == HandlerDesc.params().size() - 1) {
// The error parameter is the last argument of the completion handler.
if (ResultName.empty()) {
addForwardedErrorArgument("error", HandlerDesc);
} else {
addDefaultValueOrPlaceholder(HandlerDesc.params()[Index].getPlainType());
}
} else {
if (ResultName.empty()) {
addDefaultValueOrPlaceholder(HandlerDesc.params()[Index].getPlainType());
} else if (HandlerDesc
.getSuccessParamAsyncReturnType(
HandlerDesc.params()[Index].getPlainType())
->isVoid()) {
// Void return types are not returned by the async function, synthesize
// a Void instance.
OS << tok::l_paren << tok::r_paren;
} else if (HandlerDesc.getSuccessParams().size() > 1) {
// If the async method returns a tuple, we need to pass its elements to
// the completion handler separately. For example:
//
// func foo() async -> (String, Int) {}
//
// causes the following legacy body to be created:
//
// func foo(completion: (String, Int) -> Void) {
// Task {
// let result = await foo()
// completion(result.0, result.1)
// }
// }
OS << ResultName << tok::period;
auto Label = HandlerDesc.getAsyncReturnTypeLabel(Index);
if (!Label.empty()) {
OS << Label;
} else {
OS << Index;
}
} else {
OS << ResultName;
}
}
}
/// Add a call to the completion handler named \p HandlerName and described by
/// \p HandlerDesc, passing all the required arguments. See \c
/// getCompletionHandlerArgument for how the arguments are synthesized.
void addCallToCompletionHandler(StringRef ResultName,
const AsyncHandlerDesc &HandlerDesc,
StringRef HandlerName) {
OS << HandlerName << tok::l_paren;
// Construct arguments to pass to the completion handler
switch (HandlerDesc.Type) {
case HandlerType::INVALID:
llvm_unreachable("Cannot be rewritten");
break;
case HandlerType::PARAMS: {
for (size_t I = 0; I < HandlerDesc.params().size(); ++I) {
if (I > 0) {
OS << tok::comma << " ";
}
addCompletionHandlerArgument(I, ResultName, HandlerDesc);
}
break;
}
case HandlerType::RESULT: {
if (!ResultName.empty()) {
OS << tok::period_prefix << "success" << tok::l_paren;
if (!HandlerDesc.willAsyncReturnVoid()) {
OS << ResultName;
} else {
OS << tok::l_paren << tok::r_paren;
}
OS << tok::r_paren;
} else {
OS << tok::period_prefix << "failure" << tok::l_paren;
addForwardedErrorArgument("error", HandlerDesc);
OS << tok::r_paren;
}
break;
}
}
OS << tok::r_paren; // Close the call to the completion handler
}
/// Adds the result type of a refactored async function that previously
/// returned results via a completion handler described by \p HandlerDesc.
void addAsyncFuncReturnType(const AsyncHandlerDesc &HandlerDesc) {
// Type or (Type1, Type2, ...)
SmallVector<LabeledReturnType, 2> Scratch;
auto ReturnTypes = HandlerDesc.getAsyncReturnTypes(Scratch);
if (ReturnTypes.empty()) {
OS << "Void";
} else {
addTupleOf(ReturnTypes, OS, [&](LabeledReturnType LabelAndType) {
if (!LabelAndType.Label.empty()) {
OS << LabelAndType.Label << tok::colon << " ";
}
LabelAndType.Ty->print(OS);
});
}
}
/// If \p FD is generic, adds a type annotation with the return type of the
/// converted async function. This is used when creating a legacy function,
/// calling the converted 'async' function so that the generic parameters of
/// the legacy function are passed to the generic function. For example for
/// \code
/// func foo<GenericParam>() async -> GenericParam {}
/// \endcode
/// we generate
/// \code
/// func foo<GenericParam>(completion: (GenericParam) -> Void) {
/// Task {
/// let result: GenericParam = await foo()
/// <------------>
/// completion(result)
/// }
/// }
/// \endcode
/// This function adds the range marked by \c <----->
void addResultTypeAnnotationIfNecessary(const FuncDecl *FD,
const AsyncHandlerDesc &HandlerDesc) {
if (FD->isGeneric()) {
OS << tok::colon << " ";
addAsyncFuncReturnType(HandlerDesc);
}
}
};
} // namespace asyncrefactorings
bool RefactoringActionConvertCallToAsyncAlternative::isApplicable(
ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &Diag) {
using namespace asyncrefactorings;
// Currently doesn't check that the call is in an async context. This seems
// possibly useful in some situations, so we'll see what the feedback is.
// May need to change in the future
auto *CE = findOuterCall(CursorInfo);
if (!CE)
return false;
auto HandlerDesc = AsyncHandlerParamDesc::find(
getUnderlyingFunc(CE->getFn()), /*RequireAttributeOrName=*/false);
return HandlerDesc.isValid();
}
/// Converts a call of a function with a possible async alternative, to use it
/// instead. Currently this is any function that
/// 1. has a void return type,
/// 2. has a void returning closure as its last parameter, and
/// 3. is not already async
///
/// For now the call need not be in an async context, though this may change
/// depending on feedback.
bool RefactoringActionConvertCallToAsyncAlternative::performChange() {
using namespace asyncrefactorings;
auto *CE = findOuterCall(CursorInfo);
assert(CE &&
"Should not run performChange when refactoring is not applicable");
// Find the scope this call is in
ContextFinder Finder(
*CursorInfo->getSourceFile(), CursorInfo->getLoc(),
[](ASTNode N) { return N.isStmt(StmtKind::Brace) && !N.isImplicit(); });
Finder.resolve();
auto Scopes = Finder.getContexts();
BraceStmt *Scope = nullptr;
if (!Scopes.empty())
Scope = cast<BraceStmt>(Scopes.back().get<Stmt *>());
AsyncConverter Converter(TheFile, SM, DiagEngine, CE, Scope);
if (!Converter.convert())
return true;
Converter.replace(CE, EditConsumer);
return false;
}
bool RefactoringActionConvertToAsync::isApplicable(
ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &Diag) {
using namespace asyncrefactorings;
// As with the call refactoring, should possibly only apply if there's
// actually calls to async alternatives. At the moment this will just add
// `async` if there are no calls, which is probably fine.
return findFunction(CursorInfo);
}
/// Converts a whole function to async, converting any calls to functions with
/// async alternatives as above.
bool RefactoringActionConvertToAsync::performChange() {
using namespace asyncrefactorings;
auto *FD = findFunction(CursorInfo);
assert(FD &&
"Should not run performChange when refactoring is not applicable");
auto HandlerDesc =
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc);
if (!Converter.convert())
return true;
Converter.replace(FD, EditConsumer, FD->getSourceRangeIncludingAttrs().Start);
return false;
}
bool RefactoringActionAddAsyncAlternative::isApplicable(
ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &Diag) {
using namespace asyncrefactorings;
auto *FD = findFunction(CursorInfo);
if (!FD)
return false;
auto HandlerDesc =
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
return HandlerDesc.isValid();
}
/// Adds an async alternative and marks the current function as deprecated.
/// Equivalent to the conversion but
/// 1. only works on functions that themselves are a possible async
/// alternative, and
/// 2. has extra handling to convert the completion/handler/callback closure
/// parameter to either `return`/`throws`
bool RefactoringActionAddAsyncAlternative::performChange() {
using namespace asyncrefactorings;
auto *FD = findFunction(CursorInfo);
assert(FD &&
"Should not run performChange when refactoring is not applicable");
auto HandlerDesc =
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
assert(HandlerDesc.isValid() &&
"Should not run performChange when refactoring is not applicable");
AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc);
if (!Converter.convert())
return true;
// Add a reference to the async function so that warnings appear when the
// synchronous function is used in an async context
SmallString<128> AvailabilityAttr = HandlerDesc.buildRenamedAttribute();
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
AvailabilityAttr);
AsyncConverter LegacyBodyCreator(TheFile, SM, DiagEngine, FD, HandlerDesc);
if (LegacyBodyCreator.createLegacyBody()) {
LegacyBodyCreator.replace(FD->getBody(), EditConsumer);
}
// Add the async alternative
Converter.insertAfter(FD, EditConsumer);
return false;
}
bool RefactoringActionAddAsyncWrapper::isApplicable(
ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &Diag) {
using namespace asyncrefactorings;
auto *FD = findFunction(CursorInfo);
if (!FD)
return false;
auto HandlerDesc =
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
return HandlerDesc.isValid();
}
bool RefactoringActionAddAsyncWrapper::performChange() {
using namespace asyncrefactorings;
auto *FD = findFunction(CursorInfo);
assert(FD &&
"Should not run performChange when refactoring is not applicable");
auto HandlerDesc =
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
assert(HandlerDesc.isValid() &&
"Should not run performChange when refactoring is not applicable");
AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc);
if (!Converter.createAsyncWrapper())
return true;
// Add a reference to the async function so that warnings appear when the
// synchronous function is used in an async context
SmallString<128> AvailabilityAttr = HandlerDesc.buildRenamedAttribute();
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
AvailabilityAttr);
// Add the async wrapper.
Converter.insertAfter(FD, EditConsumer);
return false;
}
/// Retrieve the macro expansion buffer for the given macro expansion
/// expression.
static Optional<unsigned> getMacroExpansionBuffer(
SourceManager &sourceMgr, MacroExpansionExpr *expansion) {
if (auto rewritten = expansion->getRewritten()) {
return sourceMgr.findBufferContainingLoc(rewritten->getStartLoc());
}
return None;
}
/// Retrieve the macro expansion buffer for the given macro expansion
/// declaration.
static Optional<unsigned>
getMacroExpansionBuffer(SourceManager &sourceMgr,
MacroExpansionDecl *expansion) {
return evaluateOrDefault(expansion->getASTContext().evaluator,
ExpandMacroExpansionDeclRequest{expansion}, {});
}
/// Retrieve the macro expansion buffers for the given attached macro reference.
static llvm::SmallVector<unsigned, 2>
getMacroExpansionBuffers(MacroDecl *macro, const CustomAttr *attr, Decl *decl) {
auto roles = macro->getMacroRoles() & getAttachedMacroRoles();
if (!roles)
return { };
ASTContext &ctx = macro->getASTContext();
llvm::SmallVector<unsigned, 2> allBufferIDs;
if (roles.contains(MacroRole::Accessor)) {
if (auto storage = dyn_cast<AbstractStorageDecl>(decl)) {
auto bufferIDs = evaluateOrDefault(
ctx.evaluator, ExpandAccessorMacros{storage}, { });
allBufferIDs.append(bufferIDs.begin(), bufferIDs.end());
}
}
if (roles.contains(MacroRole::MemberAttribute)) {
if (auto idc = dyn_cast<IterableDeclContext>(decl)) {
for (auto memberDecl : idc->getAllMembers()) {
auto bufferIDs = evaluateOrDefault(
ctx.evaluator, ExpandMemberAttributeMacros{memberDecl}, { });
allBufferIDs.append(bufferIDs.begin(), bufferIDs.end());
}
}
}
if (roles.contains(MacroRole::Member)) {
auto bufferIDs = evaluateOrDefault(
ctx.evaluator, ExpandSynthesizedMemberMacroRequest{decl}, { });
allBufferIDs.append(bufferIDs.begin(), bufferIDs.end());
}
if (roles.contains(MacroRole::Peer)) {
auto bufferIDs = evaluateOrDefault(
ctx.evaluator, ExpandPeerMacroRequest{decl}, { });
allBufferIDs.append(bufferIDs.begin(), bufferIDs.end());
}
if (roles.contains(MacroRole::Conformance)) {
if (auto nominal = dyn_cast<NominalTypeDecl>(decl)) {
auto bufferIDs = evaluateOrDefault(
ctx.evaluator, ExpandConformanceMacros{nominal}, { });
allBufferIDs.append(bufferIDs.begin(), bufferIDs.end());
}
}
// Drop any buffers that come from other macros. We could eliminate this
// step by adding more fine-grained requests above, which only expand for a
// single custom attribute.
SourceManager &sourceMgr = ctx.SourceMgr;
auto removedAt = std::remove_if(
allBufferIDs.begin(), allBufferIDs.end(),
[&](unsigned bufferID) {
auto generatedInfo = sourceMgr.getGeneratedSourceInfo(bufferID);
if (!generatedInfo)
return true;
return generatedInfo->attachedMacroCustomAttr != attr;
});
allBufferIDs.erase(removedAt, allBufferIDs.end());
return allBufferIDs;
}
/// Given a resolved cursor, determine whether it is for a macro expansion and
/// return the list of macro expansion buffer IDs that are associated with the
/// macro reference here.
static llvm::SmallVector<unsigned, 2>
getMacroExpansionBuffers(SourceManager &sourceMgr, ResolvedCursorInfoPtr Info) {
auto *refInfo = dyn_cast<ResolvedValueRefCursorInfo>(Info);
if (!refInfo || !refInfo->isRef())
return {};
auto *macro = dyn_cast_or_null<MacroDecl>(refInfo->getValueD());
if (!macro)
return {};
// Attached macros
if (auto customAttrRef = refInfo->getCustomAttrRef()) {
auto macro = cast<MacroDecl>(refInfo->getValueD());
return getMacroExpansionBuffers(macro, customAttrRef->first,
customAttrRef->second);
}
// FIXME: A resolved cursor should contain a slice up to its reference.
// We shouldn't need to find it again.
ContextFinder Finder(*Info->getSourceFile(), Info->getLoc(), [&](ASTNode N) {
if (auto *expr =
dyn_cast_or_null<MacroExpansionExpr>(N.dyn_cast<Expr *>())) {
return expr->getStartLoc() == Info->getLoc() ||
expr->getMacroNameLoc().getBaseNameLoc() == Info->getLoc();
} else if (auto *decl =
dyn_cast_or_null<MacroExpansionDecl>(N.dyn_cast<Decl *>())) {
return decl->getStartLoc() == Info->getLoc() ||
decl->getMacroNameLoc().getBaseNameLoc() == Info->getLoc();
}
return false;
});
Finder.resolve();
if (!Finder.getContexts().empty()) {
Optional<unsigned> bufferID;
if (auto *target = dyn_cast_or_null<MacroExpansionExpr>(
Finder.getContexts()[0].dyn_cast<Expr *>())) {
bufferID = getMacroExpansionBuffer(sourceMgr, target);
} else if (auto *target = dyn_cast_or_null<MacroExpansionDecl>(
Finder.getContexts()[0].dyn_cast<Decl *>())) {
bufferID = getMacroExpansionBuffer(sourceMgr, target);
}
if (bufferID)
return {*bufferID};
}
return {};
}
bool RefactoringActionExpandMacro::isApplicable(ResolvedCursorInfoPtr Info,
DiagnosticEngine &Diag) {
return !getMacroExpansionBuffers(Diag.SourceMgr, Info).empty();
}
bool RefactoringActionExpandMacro::performChange() {
auto bufferIDs = getMacroExpansionBuffers(SM, CursorInfo);
if (bufferIDs.empty())
return true;
// Send all of the rewritten buffer snippets.
CustomAttr *attachedMacroAttr = nullptr;
for (auto bufferID: bufferIDs) {
auto generatedInfo = SM.getGeneratedSourceInfo(bufferID);
if (!generatedInfo || generatedInfo->originalSourceRange.isInvalid())
continue;
auto rewrittenBuffer = SM.extractText(generatedInfo->generatedSourceRange);
// If there's no change, drop the edit entirely.
if (generatedInfo->originalSourceRange.getStart() ==
generatedInfo->originalSourceRange.getEnd() &&
rewrittenBuffer.empty())
continue;
// `TheFile` is the file of the actual expansion site, where as
// `OriginalFile` is the possibly enclosing buffer. Concretely:
// ```
// // m.swift
// @AddMemberAttributes
// struct Foo {
// // --- expanded from @AddMemberAttributes eg. @_someBufferName ---
// @AddedAttribute
// // ---
// let someMember: Int
// }
// ```
//
// When expanding `AddedAttribute`, the expansion actually applies to the
// original source (`m.swift`) rather than the buffer of the expansion
// site (`@_someBufferName`). Thus, we need to include the path to the
// original source as well. Note that this path could itself be another
// expansion.
auto originalSourceRange = generatedInfo->originalSourceRange;
SourceFile *originalFile =
MD->getSourceFileContainingLocation(originalSourceRange.getStart());
StringRef originalPath;
if (originalFile->getBufferID().hasValue() &&
TheFile->getBufferID() != originalFile->getBufferID()) {
originalPath = SM.getIdentifierForBuffer(*originalFile->getBufferID());
}
EditConsumer.accept(SM, {originalPath,
originalSourceRange,
SM.getIdentifierForBuffer(bufferID),
rewrittenBuffer,
{}});
if (generatedInfo->attachedMacroCustomAttr && !attachedMacroAttr)
attachedMacroAttr = generatedInfo->attachedMacroCustomAttr;
}
// For an attached macro, remove the custom attribute; it's been fully
// subsumed by its expansions.
if (attachedMacroAttr) {
SourceRange range = attachedMacroAttr->getRangeWithAt();
auto charRange = Lexer::getCharSourceRangeFromSourceRange(SM, range);
EditConsumer.accept(SM, charRange, StringRef());
}
return false;
}
} // end of anonymous namespace
StringRef swift::ide::
getDescriptiveRefactoringKindName(RefactoringKind Kind) {
switch(Kind) {
case RefactoringKind::None:
llvm_unreachable("Should be a valid refactoring kind");
#define REFACTORING(KIND, NAME, ID) case RefactoringKind::KIND: return NAME;
#include "swift/Refactoring/RefactoringKinds.def"
}
llvm_unreachable("unhandled kind");
}
StringRef swift::ide::
getDescriptiveRenameUnavailableReason(RenameAvailableKind Kind) {
switch(Kind) {
case RenameAvailableKind::Available:
return "";
case RenameAvailableKind::Unavailable_system_symbol:
return "symbol from system module cannot be renamed";
case RenameAvailableKind::Unavailable_has_no_location:
return "symbol without a declaration location cannot be renamed";
case RenameAvailableKind::Unavailable_has_no_name:
return "cannot find the name of the symbol";
case RenameAvailableKind::Unavailable_has_no_accessibility:
return "cannot decide the accessibility of the symbol";
case RenameAvailableKind::Unavailable_decl_from_clang:
return "cannot rename a Clang symbol from its Swift reference";
}
llvm_unreachable("unhandled kind");
}
SourceLoc swift::ide::RangeConfig::getStart(SourceManager &SM) {
return SM.getLocForLineCol(BufferID, Line, Column);
}
SourceLoc swift::ide::RangeConfig::getEnd(SourceManager &SM) {
return getStart(SM).getAdvancedLoc(Length);
}
struct swift::ide::FindRenameRangesAnnotatingConsumer::Implementation {
std::unique_ptr<SourceEditConsumer> pRewriter;
Implementation(SourceManager &SM, unsigned BufferId, raw_ostream &OS)
: pRewriter(new SourceEditOutputConsumer(SM, BufferId, OS)) {}
static StringRef tag(RefactoringRangeKind Kind) {
switch (Kind) {
case RefactoringRangeKind::BaseName:
return "base";
case RefactoringRangeKind::KeywordBaseName:
return "keywordBase";
case RefactoringRangeKind::ParameterName:
return "param";
case RefactoringRangeKind::NoncollapsibleParameterName:
return "noncollapsibleparam";
case RefactoringRangeKind::DeclArgumentLabel:
return "arglabel";
case RefactoringRangeKind::CallArgumentLabel:
return "callarg";
case RefactoringRangeKind::CallArgumentColon:
return "callcolon";
case RefactoringRangeKind::CallArgumentCombined:
return "callcombo";
case RefactoringRangeKind::SelectorArgumentLabel:
return "sel";
}
llvm_unreachable("unhandled kind");
}
void accept(SourceManager &SM, const RenameRangeDetail &Range) {
std::string NewText;
llvm::raw_string_ostream OS(NewText);
StringRef Tag = tag(Range.RangeKind);
OS << "<" << Tag;
if (Range.Index.has_value())
OS << " index=" << *Range.Index;
OS << ">" << Range.Range.str() << "</" << Tag << ">";
pRewriter->accept(SM, {/*Path=*/{}, Range.Range, /*BufferName=*/{},
OS.str(), /*RegionsWorthNote=*/{}});
}
};
swift::ide::FindRenameRangesAnnotatingConsumer::
FindRenameRangesAnnotatingConsumer(SourceManager &SM, unsigned BufferId,
raw_ostream &OS) :
Impl(*new Implementation(SM, BufferId, OS)) {}
swift::ide::FindRenameRangesAnnotatingConsumer::~FindRenameRangesAnnotatingConsumer() {
delete &Impl;
}
void swift::ide::FindRenameRangesAnnotatingConsumer::
accept(SourceManager &SM, RegionType RegionType,
ArrayRef<RenameRangeDetail> Ranges) {
if (RegionType == RegionType::Mismatch || RegionType == RegionType::Unmatched)
return;
for (const auto &Range : Ranges) {
Impl.accept(SM, Range);
}
}
Optional<RenameAvailabilityInfo>
swift::ide::renameAvailabilityInfo(const ValueDecl *VD,
Optional<RenameRefInfo> RefInfo) {
RenameAvailableKind AvailKind = RenameAvailableKind::Available;
if (getRelatedSystemDecl(VD)){
AvailKind = RenameAvailableKind::Unavailable_system_symbol;
} else if (VD->getClangDecl()) {
AvailKind = RenameAvailableKind::Unavailable_decl_from_clang;
} else if (VD->getLoc().isInvalid()) {
AvailKind = RenameAvailableKind::Unavailable_has_no_location;
} else if (!VD->hasName()) {
AvailKind = RenameAvailableKind::Unavailable_has_no_name;
}
if (isa<AbstractFunctionDecl>(VD)) {
// Disallow renaming accessors.
if (isa<AccessorDecl>(VD))
return None;
// Disallow renaming deinit.
if (isa<DestructorDecl>(VD))
return None;
// Disallow renaming init with no arguments.
if (auto CD = dyn_cast<ConstructorDecl>(VD)) {
if (!CD->getParameters()->size())
return None;
if (RefInfo && !RefInfo->IsArgLabel) {
NameMatcher Matcher(*(RefInfo->SF));
auto Resolved = Matcher.resolve({RefInfo->Loc, /*ResolveArgs*/true});
if (Resolved.LabelRanges.empty())
return None;
}
}
// Disallow renaming 'callAsFunction' method with no arguments.
if (auto FD = dyn_cast<FuncDecl>(VD)) {
// FIXME: syntactic rename can only decide by checking the spelling, not
// whether it's an instance method, so we do the same here for now.
if (FD->getBaseIdentifier() == FD->getASTContext().Id_callAsFunction) {
if (!FD->getParameters()->size())
return None;
if (RefInfo && !RefInfo->IsArgLabel) {
NameMatcher Matcher(*(RefInfo->SF));
auto Resolved = Matcher.resolve({RefInfo->Loc, /*ResolveArgs*/true});
if (Resolved.LabelRanges.empty())
return None;
}
}
}
}
// Always return local rename for parameters.
// FIXME: if the cursor is on the argument, we should return global rename.
if (isa<ParamDecl>(VD))
return RenameAvailabilityInfo{RefactoringKind::LocalRename, AvailKind};
// If the indexer considers VD a global symbol, then we apply global rename.
if (index::isLocalSymbol(VD))
return RenameAvailabilityInfo{RefactoringKind::LocalRename, AvailKind};
return RenameAvailabilityInfo{RefactoringKind::GlobalRename, AvailKind};
}
void swift::ide::collectAvailableRefactorings(
ResolvedCursorInfoPtr CursorInfo, SmallVectorImpl<RefactoringKind> &Kinds,
bool ExcludeRename) {
DiagnosticEngine DiagEngine(
CursorInfo->getSourceFile()->getASTContext().SourceMgr);
if (!ExcludeRename) {
if (auto Info = getRenameInfo(CursorInfo)) {
if (Info->Availability.AvailableKind == RenameAvailableKind::Available) {
Kinds.push_back(Info->Availability.Kind);
}
}
}
#define CURSOR_REFACTORING(KIND, NAME, ID) \
if (RefactoringKind::KIND != RefactoringKind::LocalRename && \
RefactoringAction##KIND::isApplicable(CursorInfo, DiagEngine)) \
Kinds.push_back(RefactoringKind::KIND);
#include "swift/Refactoring/RefactoringKinds.def"
}
void swift::ide::collectAvailableRefactorings(
SourceFile *SF, RangeConfig Range, bool &CollectRangeStartRefactorings,
SmallVectorImpl<RefactoringKind> &Kinds,
ArrayRef<DiagnosticConsumer *> DiagConsumers) {
if (Range.Length == 0) {
return collectAvailableRefactoringsAtCursor(SF, Range.Line, Range.Column,
Kinds, DiagConsumers);
}
// Prepare the tool box.
ASTContext &Ctx = SF->getASTContext();
SourceManager &SM = Ctx.SourceMgr;
DiagnosticEngine DiagEngine(SM);
std::for_each(DiagConsumers.begin(), DiagConsumers.end(),
[&](DiagnosticConsumer *Con) { DiagEngine.addConsumer(*Con); });
ResolvedRangeInfo Result = evaluateOrDefault(SF->getASTContext().evaluator,
RangeInfoRequest(RangeInfoOwner({SF,
Range.getStart(SF->getASTContext().SourceMgr),
Range.getEnd(SF->getASTContext().SourceMgr)})),
ResolvedRangeInfo());
bool enableInternalRefactoring = getenv("SWIFT_ENABLE_INTERNAL_REFACTORING_ACTIONS");
#define RANGE_REFACTORING(KIND, NAME, ID) \
if (RefactoringAction##KIND::isApplicable(Result, DiagEngine)) \
Kinds.push_back(RefactoringKind::KIND);
#define INTERNAL_RANGE_REFACTORING(KIND, NAME, ID) \
if (enableInternalRefactoring) \
RANGE_REFACTORING(KIND, NAME, ID)
#include "swift/Refactoring/RefactoringKinds.def"
CollectRangeStartRefactorings = collectRangeStartRefactorings(Result);
}
bool swift::ide::
refactorSwiftModule(ModuleDecl *M, RefactoringOptions Opts,
SourceEditConsumer &EditConsumer,
DiagnosticConsumer &DiagConsumer) {
assert(Opts.Kind != RefactoringKind::None && "should have a refactoring kind.");
// Use the default name if not specified.
if (Opts.PreferredName.empty()) {
Opts.PreferredName = getDefaultPreferredName(Opts.Kind).str();
}
switch (Opts.Kind) {
#define SEMANTIC_REFACTORING(KIND, NAME, ID) \
case RefactoringKind::KIND: { \
RefactoringAction##KIND Action(M, Opts, EditConsumer, DiagConsumer); \
if (RefactoringKind::KIND == RefactoringKind::LocalRename || \
Action.isApplicable()) \
return Action.performChange(); \
return true; \
}
#include "swift/Refactoring/RefactoringKinds.def"
case RefactoringKind::GlobalRename:
case RefactoringKind::FindGlobalRenameRanges:
case RefactoringKind::FindLocalRenameRanges:
llvm_unreachable("not a valid refactoring kind");
case RefactoringKind::None:
llvm_unreachable("should not enter here.");
}
llvm_unreachable("unhandled kind");
}
static std::vector<ResolvedLoc>
resolveRenameLocations(ArrayRef<RenameLoc> RenameLocs, SourceFile &SF,
DiagnosticEngine &Diags) {
SourceManager &SM = SF.getASTContext().SourceMgr;
unsigned BufferID = SF.getBufferID().value();
std::vector<UnresolvedLoc> UnresolvedLocs;
for (const RenameLoc &RenameLoc : RenameLocs) {
DeclNameViewer OldName(RenameLoc.OldName);
SourceLoc Location = SM.getLocForLineCol(BufferID, RenameLoc.Line,
RenameLoc.Column);
if (!OldName.isValid()) {
Diags.diagnose(Location, diag::invalid_name, RenameLoc.OldName);
return {};
}
if (!RenameLoc.NewName.empty()) {
DeclNameViewer NewName(RenameLoc.NewName);
ArrayRef<StringRef> ParamNames = NewName.args();
bool newOperator = Lexer::isOperator(NewName.base());
bool NewNameIsValid = NewName.isValid() &&
(Lexer::isIdentifier(NewName.base()) || newOperator) &&
std::all_of(ParamNames.begin(), ParamNames.end(), [](StringRef Label) {
return Label.empty() || Lexer::isIdentifier(Label);
});
if (!NewNameIsValid) {
Diags.diagnose(Location, diag::invalid_name, RenameLoc.NewName);
return {};
}
if (NewName.partsCount() != OldName.partsCount()) {
Diags.diagnose(Location, diag::arity_mismatch, RenameLoc.NewName,
RenameLoc.OldName);
return {};
}
if (RenameLoc.Usage == NameUsage::Call && !RenameLoc.IsFunctionLike) {
Diags.diagnose(Location, diag::name_not_functionlike, RenameLoc.NewName);
return {};
}
}
bool isOperator = Lexer::isOperator(OldName.base());
UnresolvedLocs.push_back({
Location,
(RenameLoc.Usage == NameUsage::Unknown ||
(RenameLoc.Usage == NameUsage::Call && !isOperator))
});
}
NameMatcher Resolver(SF);
return Resolver.resolve(UnresolvedLocs, SF.getAllTokens());
}
int swift::ide::syntacticRename(SourceFile *SF, ArrayRef<RenameLoc> RenameLocs,
SourceEditConsumer &EditConsumer,
DiagnosticConsumer &DiagConsumer) {
assert(SF && "null source file");
SourceManager &SM = SF->getASTContext().SourceMgr;
DiagnosticEngine DiagEngine(SM);
DiagEngine.addConsumer(DiagConsumer);
auto ResolvedLocs = resolveRenameLocations(RenameLocs, *SF, DiagEngine);
if (ResolvedLocs.size() != RenameLocs.size())
return true; // Already diagnosed.
size_t index = 0;
llvm::StringSet<> ReplaceTextContext;
for(const RenameLoc &Rename: RenameLocs) {
ResolvedLoc &Resolved = ResolvedLocs[index++];
TextReplacementsRenamer Renamer(SM, Rename.OldName, Rename.NewName,
ReplaceTextContext);
RegionType Type = Renamer.addSyntacticRenameRanges(Resolved, Rename);
if (Type == RegionType::Mismatch) {
DiagEngine.diagnose(Resolved.Range.getStart(), diag::mismatched_rename,
Rename.NewName);
EditConsumer.accept(SM, Type, None);
} else {
EditConsumer.accept(SM, Type, Renamer.getReplacements());
}
}
return false;
}
int swift::ide::findSyntacticRenameRanges(
SourceFile *SF, ArrayRef<RenameLoc> RenameLocs,
FindRenameRangesConsumer &RenameConsumer,
DiagnosticConsumer &DiagConsumer) {
assert(SF && "null source file");
SourceManager &SM = SF->getASTContext().SourceMgr;
DiagnosticEngine DiagEngine(SM);
DiagEngine.addConsumer(DiagConsumer);
auto ResolvedLocs = resolveRenameLocations(RenameLocs, *SF, DiagEngine);
if (ResolvedLocs.size() != RenameLocs.size())
return true; // Already diagnosed.
size_t index = 0;
for (const RenameLoc &Rename : RenameLocs) {
ResolvedLoc &Resolved = ResolvedLocs[index++];
RenameRangeDetailCollector Renamer(SM, Rename.OldName);
RegionType Type = Renamer.addSyntacticRenameRanges(Resolved, Rename);
if (Type == RegionType::Mismatch) {
DiagEngine.diagnose(Resolved.Range.getStart(), diag::mismatched_rename,
Rename.NewName);
RenameConsumer.accept(SM, Type, None);
} else {
RenameConsumer.accept(SM, Type, Renamer.Ranges);
}
}
return false;
}
int swift::ide::findLocalRenameRanges(
SourceFile *SF, RangeConfig Range,
FindRenameRangesConsumer &RenameConsumer,
DiagnosticConsumer &DiagConsumer) {
assert(SF && "null source file");
SourceManager &SM = SF->getASTContext().SourceMgr;
DiagnosticEngine Diags(SM);
Diags.addConsumer(DiagConsumer);
auto StartLoc = Lexer::getLocForStartOfToken(SM, Range.getStart(SM));
Optional<RenameRangeCollector> RangeCollector =
localRenames(SF, StartLoc, StringRef(), Diags);
if (!RangeCollector)
return true;
return findSyntacticRenameRanges(SF, RangeCollector->results(),
RenameConsumer, DiagConsumer);
}
Computing file changes ...