Revision ac4b68fbf45853ba4b9e327cb42f93f42a8fa252 authored by Ellie Shin on 17 March 2023, 04:14:20 UTC, committed by Ellie Shin on 17 March 2023, 04:14:20 UTC
1 parent f2c68fb
Raw File
Refactoring.cpp
//===--- Refactoring.cpp ---------------------------------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

#include "swift/Refactoring/Refactoring.h"
#include "swift/AST/ASTContext.h"
#include "swift/AST/ASTPrinter.h"
#include "swift/AST/Decl.h"
#include "swift/AST/DiagnosticsRefactoring.h"
#include "swift/AST/Expr.h"
#include "swift/AST/ForeignAsyncConvention.h"
#include "swift/AST/GenericParamList.h"
#include "swift/AST/NameLookup.h"
#include "swift/AST/Pattern.h"
#include "swift/AST/ProtocolConformance.h"
#include "swift/AST/Stmt.h"
#include "swift/AST/Types.h"
#include "swift/AST/USRGeneration.h"
#include "swift/Basic/Edit.h"
#include "swift/Basic/StringExtras.h"
#include "swift/ClangImporter/ClangImporter.h"
#include "swift/Frontend/Frontend.h"
#include "swift/IDE/IDERequests.h"
#include "swift/Index/Index.h"
#include "swift/Parse/Lexer.h"
#include "swift/Sema/IDETypeChecking.h"
#include "swift/Subsystems.h"
#include "clang/Rewrite/Core/RewriteBuffer.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringSet.h"

using namespace swift;
using namespace swift::ide;
using namespace swift::index;

namespace {

class ContextFinder : public SourceEntityWalker {
  SourceFile &SF;
  ASTContext &Ctx;
  SourceManager &SM;
  SourceRange Target;
  std::function<bool(ASTNode)> IsContext;
  SmallVector<ASTNode, 4> AllContexts;
  bool contains(ASTNode Enclosing) {
    auto Result = SM.rangeContainsRespectingReplacedRanges(
        Enclosing.getSourceRange(), Target);
    if (Result && IsContext(Enclosing)) {
      AllContexts.push_back(Enclosing);
    }
    return Result;
  }

public:
  ContextFinder(SourceFile &SF, ASTNode TargetNode,
                std::function<bool(ASTNode)> IsContext =
                  [](ASTNode N) { return true; }) :
                  SF(SF), Ctx(SF.getASTContext()), SM(Ctx.SourceMgr),
                  Target(TargetNode.getSourceRange()), IsContext(IsContext) {}

  ContextFinder(SourceFile &SF, SourceLoc TargetLoc,
                std::function<bool(ASTNode)> IsContext =
                  [](ASTNode N) { return true; }) :
                  SF(SF), Ctx(SF.getASTContext()), SM(Ctx.SourceMgr),
                  Target(TargetLoc), IsContext(IsContext) {
                    assert(TargetLoc.isValid() && "Invalid loc to find");
  }

  // Only need expansions for the expands refactoring, but we
  // skip nodes that don't contain the passed location anyway.
  virtual MacroWalking getMacroWalkingBehavior() const override {
    return MacroWalking::ArgumentsAndExpansion;
  }

  bool walkToDeclPre(Decl *D, CharSourceRange Range) override { return contains(D); }
  bool walkToStmtPre(Stmt *S) override { return contains(S); }
  bool walkToExprPre(Expr *E) override { return contains(E); }
  void resolve() { walk(SF); }
  ArrayRef<ASTNode> getContexts() const {
    return llvm::makeArrayRef(AllContexts);
  }
};

class Renamer {
protected:
  const SourceManager &SM;

protected:
  Renamer(const SourceManager &SM, StringRef OldName) : SM(SM), Old(OldName) {}

  // Implementor's interface.
  virtual void doRenameLabel(CharSourceRange Label,
                             RefactoringRangeKind RangeKind,
                             unsigned NameIndex) = 0;
  virtual void doRenameBase(CharSourceRange Range,
                            RefactoringRangeKind RangeKind) = 0;

public:
  const DeclNameViewer Old;

public:
  virtual ~Renamer() {}

  /// Adds a replacement to rename the given base name range
  /// \return true if the given range does not match the old name
  bool renameBase(CharSourceRange Range, RefactoringRangeKind RangeKind) {
    assert(Range.isValid());

    if (stripBackticks(Range).str() != Old.base())
      return true;
    doRenameBase(Range, RangeKind);
    return false;
  }

  /// Adds replacements to rename the given label ranges
  /// \return true if the label ranges do not match the old name
  bool renameLabels(ArrayRef<CharSourceRange> LabelRanges,
                    Optional<unsigned> FirstTrailingLabel,
                    LabelRangeType RangeType, bool isCallSite) {
    if (isCallSite)
      return renameLabelsLenient(LabelRanges, FirstTrailingLabel, RangeType);

    assert(!FirstTrailingLabel);
    ArrayRef<StringRef> OldLabels = Old.args();

    if (OldLabels.size() != LabelRanges.size())
      return true;

    size_t Index = 0;
    for (const auto &LabelRange : LabelRanges) {
      assert(LabelRange.isValid());
      if (!labelRangeMatches(LabelRange, RangeType, OldLabels[Index]))
        return true;
      splitAndRenameLabel(LabelRange, RangeType, Index++);
    }
    return false;
  }

  bool isOperator() const { return Lexer::isOperator(Old.base()); }

private:

  /// Returns the range of the  (possibly escaped) identifier at the start of
  /// \p Range and updates \p IsEscaped to indicate whether it's escaped or not.
  CharSourceRange getLeadingIdentifierRange(CharSourceRange Range, bool &IsEscaped) {
    assert(Range.isValid() && Range.getByteLength());
    IsEscaped = Range.str().front() == '`';
    SourceLoc Start = Range.getStart();
    if (IsEscaped)
      Start = Start.getAdvancedLoc(1);
    return Lexer::getCharSourceRangeFromSourceRange(SM, Start);
  }

  CharSourceRange stripBackticks(CharSourceRange Range) {
    StringRef Content = Range.str();
    if (Content.size() < 3 || Content.front() != '`' || Content.back() != '`') {
      return Range;
    }
    return CharSourceRange(Range.getStart().getAdvancedLoc(1),
                           Range.getByteLength() - 2);
  }

  void splitAndRenameLabel(CharSourceRange Range, LabelRangeType RangeType,
                           size_t NameIndex) {
    switch (RangeType) {
    case LabelRangeType::CallArg:
      return splitAndRenameCallArg(Range, NameIndex);
    case LabelRangeType::Param:
      return splitAndRenameParamLabel(Range, NameIndex, /*IsCollapsible=*/true);
    case LabelRangeType::NoncollapsibleParam:
      return splitAndRenameParamLabel(Range, NameIndex, /*IsCollapsible=*/false);
    case LabelRangeType::Selector:
      return doRenameLabel(
          Range, RefactoringRangeKind::SelectorArgumentLabel, NameIndex);
    case LabelRangeType::None:
      llvm_unreachable("expected a label range");
    }
  }

  void splitAndRenameParamLabel(CharSourceRange Range, size_t NameIndex, bool IsCollapsible) {
    // Split parameter range foo([a b]: Int) into decl argument label [a] and
    // parameter name [b] or noncollapsible parameter name [b] if IsCollapsible
    // is false (as for subscript decls). If we have only foo([a]: Int), then we
    // add an empty range for the local name, or for the decl argument label if
    // IsCollapsible is false.
    StringRef Content = Range.str();
    size_t ExternalNameEnd = Content.find_first_of(" \t\n\v\f\r/");

    if (ExternalNameEnd == StringRef::npos) { // foo([a]: Int)
      if (IsCollapsible) {
        doRenameLabel(Range, RefactoringRangeKind::DeclArgumentLabel, NameIndex);
        doRenameLabel(CharSourceRange{Range.getEnd(), 0},
                      RefactoringRangeKind::ParameterName, NameIndex);
      } else {
        doRenameLabel(CharSourceRange{Range.getStart(), 0},
                      RefactoringRangeKind::DeclArgumentLabel, NameIndex);
        doRenameLabel(Range, RefactoringRangeKind::NoncollapsibleParameterName,
                      NameIndex);
      }
    } else { // foo([a b]: Int)
      CharSourceRange Ext{Range.getStart(), unsigned(ExternalNameEnd)};

      // Note: we consider the leading whitespace part of the parameter name
      // if the parameter is collapsible, since if the parameter is collapsed
      // into a matching argument label, we want to remove the whitespace too.
      // FIXME: handle comments foo(a /*...*/b: Int).
      size_t LocalNameStart = Content.find_last_of(" \t\n\v\f\r/");
      assert(LocalNameStart != StringRef::npos);
      if (!IsCollapsible)
        ++LocalNameStart;
      auto LocalLoc = Range.getStart().getAdvancedLocOrInvalid(LocalNameStart);
      CharSourceRange Local{LocalLoc, unsigned(Content.size() - LocalNameStart)};

      doRenameLabel(Ext, RefactoringRangeKind::DeclArgumentLabel, NameIndex);
      if (IsCollapsible) {
        doRenameLabel(Local, RefactoringRangeKind::ParameterName, NameIndex);
      } else {
        doRenameLabel(Local, RefactoringRangeKind::NoncollapsibleParameterName, NameIndex);
      }
    }
  }

  void splitAndRenameCallArg(CharSourceRange Range, size_t NameIndex) {
    // Split call argument foo([a: ]1) into argument name [a] and the remainder
    // [: ].
    StringRef Content = Range.str();
    size_t Colon = Content.find(':'); // FIXME: leading whitespace?
    if (Colon == StringRef::npos) {
      assert(Content.empty());
      doRenameLabel(Range, RefactoringRangeKind::CallArgumentCombined,
                    NameIndex);
      return;
    }

    // Include any whitespace before the ':'.
    assert(Colon == Content.substr(0, Colon).size());
    Colon = Content.substr(0, Colon).rtrim().size();

    CharSourceRange Arg{Range.getStart(), unsigned(Colon)};
    doRenameLabel(Arg, RefactoringRangeKind::CallArgumentLabel, NameIndex);

    auto ColonLoc = Range.getStart().getAdvancedLocOrInvalid(Colon);
    assert(ColonLoc.isValid());
    CharSourceRange Rest{ColonLoc, unsigned(Content.size() - Colon)};
    doRenameLabel(Rest, RefactoringRangeKind::CallArgumentColon, NameIndex);
  }

  bool labelRangeMatches(CharSourceRange Range, LabelRangeType RangeType, StringRef Expected) {
    if (Range.getByteLength()) {
      bool IsEscaped = false;
      CharSourceRange ExistingLabelRange = getLeadingIdentifierRange(Range, IsEscaped);
      StringRef ExistingLabel = ExistingLabelRange.str();
      bool IsSingleName = Range == ExistingLabelRange ||
        (IsEscaped && Range.getByteLength() == ExistingLabel.size() + 2);

      switch (RangeType) {
      case LabelRangeType::NoncollapsibleParam:
        if (IsSingleName && Expected.empty()) // subscript([x]: Int)
          return true;
        LLVM_FALLTHROUGH;
      case LabelRangeType::CallArg:
      case LabelRangeType::Param:
      case LabelRangeType::Selector:
        return ExistingLabel == (Expected.empty() ? "_" : Expected);
      case LabelRangeType::None:
        llvm_unreachable("Unhandled label range type");
      }
    }
    return Expected.empty();
  }

  bool renameLabelsLenient(ArrayRef<CharSourceRange> LabelRanges,
                           Optional<unsigned> FirstTrailingLabel,
                           LabelRangeType RangeType) {

    ArrayRef<StringRef> OldNames = Old.args();

    // First, match trailing closure arguments in reverse
    if (FirstTrailingLabel) {
      auto TrailingLabels = LabelRanges.drop_front(*FirstTrailingLabel);
      LabelRanges = LabelRanges.take_front(*FirstTrailingLabel);

      for (auto LabelIndex: llvm::reverse(indices(TrailingLabels))) {
        CharSourceRange Label = TrailingLabels[LabelIndex];

        if (Label.getByteLength()) {
          if (OldNames.empty())
            return true;

          while (!labelRangeMatches(Label, LabelRangeType::Selector,
                                    OldNames.back())) {
            if ((OldNames = OldNames.drop_back()).empty())
              return true;
          }
          splitAndRenameLabel(Label, LabelRangeType::Selector,
                              OldNames.size() - 1);
          OldNames = OldNames.drop_back();
          continue;
        }

        // empty labelled trailing closure label
        if (LabelIndex) {
          if (OldNames.empty())
            return true;

          while (!OldNames.back().empty()) {
            if ((OldNames = OldNames.drop_back()).empty())
              return true;
          }
          splitAndRenameLabel(Label, LabelRangeType::Selector,
                              OldNames.size() - 1);
          OldNames = OldNames.drop_back();
          continue;
        }

        // unlabelled trailing closure label
        OldNames = OldNames.drop_back();
        continue;
      }
    }

    // Next, match the non-trailing arguments.
    size_t NameIndex = 0;

    for (CharSourceRange Label : LabelRanges) {
      // empty label
      if (!Label.getByteLength()) {

        // first name pos
        if (!NameIndex) {
          while (!OldNames[NameIndex].empty()) {
            if (++NameIndex >= OldNames.size())
              return true;
          }
          splitAndRenameLabel(Label, RangeType, NameIndex++);
          continue;
        }

        // other name pos
        if (NameIndex >= OldNames.size() || !OldNames[NameIndex].empty()) {
          // FIXME: only allow one variadic param
          continue; // allow for variadic
        }
        splitAndRenameLabel(Label, RangeType, NameIndex++);
        continue;
      }

      // non-empty label
      if (NameIndex >= OldNames.size())
        return true;

      while (!labelRangeMatches(Label, RangeType, OldNames[NameIndex])) {
        if (++NameIndex >= OldNames.size())
          return true;
      };
      splitAndRenameLabel(Label, RangeType, NameIndex++);
    }
    return false;
  }

  static RegionType getSyntacticRenameRegionType(const ResolvedLoc &Resolved) {
    if (Resolved.Node.isNull())
      return RegionType::Comment;

    if (Expr *E = Resolved.Node.getAsExpr()) {
      if (isa<StringLiteralExpr>(E))
        return RegionType::String;
    }
    if (Resolved.IsInSelector)
      return RegionType::Selector;
    if (Resolved.IsActive)
      return RegionType::ActiveCode;
    return RegionType::InactiveCode;
  }

public:
  RegionType addSyntacticRenameRanges(const ResolvedLoc &Resolved,
                                      const RenameLoc &Config) {

    if (!Resolved.Range.isValid())
      return RegionType::Unmatched;

    auto RegionKind = getSyntacticRenameRegionType(Resolved);
    // Don't include unknown references coming from active code; if we don't
    // have a semantic NameUsage for them, then they're likely unrelated symbols
    // that happen to have the same name.
    if (RegionKind == RegionType::ActiveCode &&
        Config.Usage == NameUsage::Unknown)
      return RegionType::Unmatched;

    assert(Config.Usage != NameUsage::Call || Config.IsFunctionLike);

    // FIXME: handle escaped keyword names `init`
    bool IsSubscript = Old.base() == "subscript" && Config.IsFunctionLike;
    bool IsInit = Old.base() == "init" && Config.IsFunctionLike;

    // FIXME: this should only be treated specially for instance methods.
    bool IsCallAsFunction = Old.base() == "callAsFunction" &&
        Config.IsFunctionLike;

    bool IsSpecialBase = IsInit || IsSubscript || IsCallAsFunction;
    
    // Filter out non-semantic special basename locations with no labels.
    // We've already filtered out those in active code, so these are
    // any appearance of just 'init', 'subscript', or 'callAsFunction' in
    // strings, comments, and inactive code.
    if (IsSpecialBase && (Config.Usage == NameUsage::Unknown &&
                          Resolved.LabelType == LabelRangeType::None))
      return RegionType::Unmatched;

    if (!Config.IsFunctionLike || !IsSpecialBase) {
      if (renameBase(Resolved.Range, RefactoringRangeKind::BaseName))
        return RegionType::Mismatch;

    } else if (IsInit || IsCallAsFunction) {
      if (renameBase(Resolved.Range, RefactoringRangeKind::KeywordBaseName)) {
        // The base name doesn't need to match (but may) for calls, but
        // it should for definitions and references.
        if (Config.Usage == NameUsage::Definition ||
            Config.Usage == NameUsage::Reference) {
          return RegionType::Mismatch;
        }
      }
    } else if (IsSubscript && Config.Usage == NameUsage::Definition) {
      if (renameBase(Resolved.Range, RefactoringRangeKind::KeywordBaseName))
        return RegionType::Mismatch;
    }

    bool HandleLabels = false;
    if (Config.IsFunctionLike) {
      switch (Config.Usage) {
      case NameUsage::Call:
        HandleLabels = !isOperator();
        break;
      case NameUsage::Definition:
        HandleLabels = true;
        break;
      case NameUsage::Reference:
        HandleLabels = Resolved.LabelType == LabelRangeType::Selector || IsSubscript;
        break;
      case NameUsage::Unknown:
        HandleLabels = Resolved.LabelType != LabelRangeType::None;
        break;
      }
    }

    if (HandleLabels) {
      bool isCallSite = Config.Usage != NameUsage::Definition &&
                        (Config.Usage != NameUsage::Reference || IsSubscript) &&
                        Resolved.LabelType == LabelRangeType::CallArg;

      if (renameLabels(Resolved.LabelRanges, Resolved.FirstTrailingLabel,
                       Resolved.LabelType, isCallSite))
        return Config.Usage == NameUsage::Unknown ?
            RegionType::Unmatched : RegionType::Mismatch;
    }

    return RegionKind;
  }
};

class RenameRangeDetailCollector : public Renamer {
  void doRenameLabel(CharSourceRange Label, RefactoringRangeKind RangeKind,
                     unsigned NameIndex) override {
    Ranges.push_back({Label, RangeKind, NameIndex});
  }
  void doRenameBase(CharSourceRange Range,
                    RefactoringRangeKind RangeKind) override {
    Ranges.push_back({Range, RangeKind, None});
  }

public:
  RenameRangeDetailCollector(const SourceManager &SM, StringRef OldName)
      : Renamer(SM, OldName) {}
  std::vector<RenameRangeDetail> Ranges;
};

class TextReplacementsRenamer : public Renamer {
  llvm::StringSet<> &ReplaceTextContext;
  SmallVector<Replacement> Replacements;

public:
  const DeclNameViewer New;

private:
  StringRef registerText(StringRef Text) {
    if (Text.empty())
      return Text;
    return ReplaceTextContext.insert(Text).first->getKey();
  }

  StringRef getCallArgLabelReplacement(StringRef OldLabelRange,
                                       StringRef NewLabel) {
    return NewLabel.empty() ? "" : NewLabel;
  }

  StringRef getCallArgColonReplacement(StringRef OldLabelRange,
                                       StringRef NewLabel) {
    // Expected OldLabelRange: foo( []3, a[: ]2,  b[ : ]3 ...)
    // FIXME: Preserve comments: foo([a/*:*/ : /*:*/ ]2, ...)
    if (NewLabel.empty())
      return "";
    if (OldLabelRange.empty())
      return ": ";
    return registerText(OldLabelRange);
  }

  StringRef getCallArgCombinedReplacement(StringRef OldArgLabel,
                                          StringRef NewArgLabel) {
    // This case only happens when going from foo([]1) to foo([a: ]1).
    assert(OldArgLabel.empty());
    if (NewArgLabel.empty())
      return "";
    return registerText((Twine(NewArgLabel) + ": ").str());
  }

  StringRef getParamNameReplacement(StringRef OldParam, StringRef OldArgLabel,
                                    StringRef NewArgLabel) {
    // We don't want to get foo(a a: Int), so drop the parameter name if the
    // argument label will match the original name.
    // Note: the leading whitespace is part of the parameter range.
    if (!NewArgLabel.empty() && OldParam.ltrim() == NewArgLabel)
      return "";

    // If we're renaming foo(x: Int) to foo(_:), then use the original argument
    // label as the parameter name so as to not break references in the body.
    if (NewArgLabel.empty() && !OldArgLabel.empty() && OldParam.empty())
      return registerText((Twine(" ") + OldArgLabel).str());

    return registerText(OldParam);
  }

  StringRef getDeclArgumentLabelReplacement(StringRef OldLabelRange,
                                            StringRef NewArgLabel) {
      // OldLabelRange is subscript([]a: Int), foo([a]: Int) or foo([a] b: Int)
      if (NewArgLabel.empty())
        return OldLabelRange.empty() ? "" : "_";

      if (OldLabelRange.empty())
        return registerText((Twine(NewArgLabel) + " ").str());
      return registerText(NewArgLabel);
  }

  StringRef getReplacementText(StringRef LabelRange,
                               RefactoringRangeKind RangeKind,
                               StringRef OldLabel, StringRef NewLabel) {
    switch (RangeKind) {
    case RefactoringRangeKind::CallArgumentLabel:
      return getCallArgLabelReplacement(LabelRange, NewLabel);
    case RefactoringRangeKind::CallArgumentColon:
      return getCallArgColonReplacement(LabelRange, NewLabel);
    case RefactoringRangeKind::CallArgumentCombined:
      return getCallArgCombinedReplacement(LabelRange, NewLabel);
    case RefactoringRangeKind::ParameterName:
      return getParamNameReplacement(LabelRange, OldLabel, NewLabel);
    case RefactoringRangeKind::NoncollapsibleParameterName:
      return LabelRange;
    case RefactoringRangeKind::DeclArgumentLabel:
      return getDeclArgumentLabelReplacement(LabelRange, NewLabel);
    case RefactoringRangeKind::SelectorArgumentLabel:
      return NewLabel.empty() ? "_" : registerText(NewLabel);
    default:
      llvm_unreachable("label range type is none but there are labels");
    }
  }

  void addReplacement(CharSourceRange LabelRange,
                      RefactoringRangeKind RangeKind, StringRef OldLabel,
                      StringRef NewLabel) {
    StringRef ExistingLabel = LabelRange.str();
    StringRef Text =
        getReplacementText(ExistingLabel, RangeKind, OldLabel, NewLabel);
    if (Text != ExistingLabel)
      Replacements.push_back({/*Path=*/{}, LabelRange, /*BufferName=*/{}, Text,
                              /*RegionsWorthNote=*/{}});
  }

  void doRenameLabel(CharSourceRange Label, RefactoringRangeKind RangeKind,
                     unsigned NameIndex) override {
    addReplacement(Label, RangeKind, Old.args()[NameIndex],
                   New.args()[NameIndex]);
  }

  void doRenameBase(CharSourceRange Range, RefactoringRangeKind) override {
    if (Old.base() != New.base())
      Replacements.push_back({/*Path=*/{}, Range, /*BufferName=*/{},
                              registerText(New.base()),
                              /*RegionsWorthNote=*/{}});
  }

public:
  TextReplacementsRenamer(const SourceManager &SM, StringRef OldName,
                          StringRef NewName,
                          llvm::StringSet<> &ReplaceTextContext)
      : Renamer(SM, OldName), ReplaceTextContext(ReplaceTextContext),
        New(NewName) {
    assert(Old.isValid() && New.isValid());
    assert(Old.partsCount() == New.partsCount());
  }

  ArrayRef<Replacement> getReplacements() const { return Replacements; }
};

static const ValueDecl *getRelatedSystemDecl(const ValueDecl *VD) {
  if (VD->getModuleContext()->isSystemModule())
    return VD;
  for (auto *Req : VD->getSatisfiedProtocolRequirements()) {
    if (Req->getModuleContext()->isSystemModule())
      return Req;
  }
  for (auto Over = VD->getOverriddenDecl(); Over;
       Over = Over->getOverriddenDecl()) {
    if (Over->getModuleContext()->isSystemModule())
      return Over;
  }
  return nullptr;
}

struct RenameInfo {
  ValueDecl *VD;
  RenameAvailabilityInfo Availability;
};

/// Given a cursor, return the decl and its rename availability. \c None if
/// the cursor did not resolve to a decl or it resolved to a decl that we do
/// not allow renaming on.
static Optional<RenameInfo> getRenameInfo(ResolvedCursorInfoPtr cursorInfo) {
  auto valueCursor = dyn_cast<ResolvedValueRefCursorInfo>(cursorInfo);
  if (!valueCursor)
    return None;

  ValueDecl *VD = valueCursor->typeOrValue();
  if (!VD)
    return None;

  Optional<RenameRefInfo> refInfo;
  if (!valueCursor->getShorthandShadowedDecls().empty()) {
    // Find the outermost decl for a shorthand if let/closure capture
    VD = valueCursor->getShorthandShadowedDecls().back();
  } else if (valueCursor->isRef()) {
    refInfo = {valueCursor->getSourceFile(), valueCursor->getLoc(),
               valueCursor->isKeywordArgument()};
  }

  Optional<RenameAvailabilityInfo> info = renameAvailabilityInfo(VD, refInfo);
  if (!info)
    return None;

  return RenameInfo{VD, *info};
}

class RenameRangeCollector : public IndexDataConsumer {
public:
  RenameRangeCollector(StringRef USR, StringRef newName)
      : USR(USR), newName(newName) {}

  RenameRangeCollector(const ValueDecl *D, StringRef newName)
      : newName(newName) {
    SmallString<64> SS;
    llvm::raw_svector_ostream OS(SS);
    printValueDeclUSR(D, OS);
    USR = stringStorage.copyString(SS.str());
  }

  RenameRangeCollector(RenameRangeCollector &&collector) = default;

  ArrayRef<RenameLoc> results() const { return locations; }

private:
  bool indexLocals() override { return true; }
  void failed(StringRef error) override {}
  bool startDependency(StringRef name, StringRef path, bool isClangModule, bool isSystem) override {
    return true;
  }
  bool finishDependency(bool isClangModule) override { return true; }

  Action startSourceEntity(const IndexSymbol &symbol) override {
    if (symbol.USR == USR) {
      if (auto loc = indexSymbolToRenameLoc(symbol, newName)) {
        // Inside capture lists like `{ [test] in }`, 'test' refers to both the
        // newly declared, captured variable and the referenced variable it is
        // initialized from. Make sure to only rename it once.
        auto existingLoc = llvm::find_if(locations, [&](RenameLoc searchLoc) {
          return searchLoc.Line == loc->Line && searchLoc.Column == loc->Column;
        });
        if (existingLoc == locations.end()) {
          locations.push_back(std::move(*loc));
        } else {
          assert(existingLoc->OldName == loc->OldName &&
                 existingLoc->NewName == loc->NewName &&
                 existingLoc->IsFunctionLike == loc->IsFunctionLike &&
                 existingLoc->IsNonProtocolType == loc->IsNonProtocolType &&
                 "Asked to do a different rename for the same location?");
        }
      }
    }
    return IndexDataConsumer::Continue;
  }

  bool finishSourceEntity(SymbolInfo symInfo, SymbolRoleSet roles) override {
    return true;
  }

  Optional<RenameLoc> indexSymbolToRenameLoc(const index::IndexSymbol &symbol,
                                             StringRef NewName);

private:
  StringRef USR;
  StringRef newName;
  StringScratchSpace stringStorage;
  std::vector<RenameLoc> locations;
};

Optional<RenameLoc>
RenameRangeCollector::indexSymbolToRenameLoc(const index::IndexSymbol &symbol,
                                             StringRef newName) {
  if (symbol.roles & (unsigned)index::SymbolRole::Implicit) {
    return None;
  }

  NameUsage usage = NameUsage::Unknown;
  if (symbol.roles & (unsigned)index::SymbolRole::Call) {
    usage = NameUsage::Call;
  } else if (symbol.roles & (unsigned)index::SymbolRole::Definition) {
    usage = NameUsage::Definition;
  } else if (symbol.roles & (unsigned)index::SymbolRole::Reference) {
    usage = NameUsage::Reference;
  } else {
    llvm_unreachable("unexpected role");
  }

  bool isFunctionLike = false;
  bool isNonProtocolType = false;

  switch (symbol.symInfo.Kind) {
  case index::SymbolKind::EnumConstant:
  case index::SymbolKind::Function:
  case index::SymbolKind::Constructor:
  case index::SymbolKind::ConversionFunction:
  case index::SymbolKind::InstanceMethod:
  case index::SymbolKind::ClassMethod:
  case index::SymbolKind::StaticMethod:
    isFunctionLike = true;
    break;
  case index::SymbolKind::Class:
  case index::SymbolKind::Enum:
  case index::SymbolKind::Struct:
    isNonProtocolType = true;
    break;
  default:
    break;
  }
  StringRef oldName = stringStorage.copyString(symbol.name);
  return RenameLoc{symbol.line,    symbol.column,    usage, oldName, newName,
                   isFunctionLike, isNonProtocolType};
}

/// Get the source file that corresponds to the given buffer.
SourceFile *getContainingFile(ModuleDecl *M, RangeConfig Range) {
  auto &SM = M->getASTContext().SourceMgr;
  // TODO: We should add an ID -> SourceFile mapping.
  return M->getSourceFileContainingLocation(
      SM.getRangeForBuffer(Range.BufferID).getStart());
}

class RefactoringAction {
protected:
  ModuleDecl *MD;
  SourceFile *TheFile;
  SourceEditConsumer &EditConsumer;
  ASTContext &Ctx;
  SourceManager &SM;
  DiagnosticEngine DiagEngine;
  SourceLoc StartLoc;
  StringRef PreferredName;
public:
  RefactoringAction(ModuleDecl *MD, RefactoringOptions &Opts,
                    SourceEditConsumer &EditConsumer,
                    DiagnosticConsumer &DiagConsumer);
  virtual ~RefactoringAction() = default;
  virtual bool performChange() = 0;
};

RefactoringAction::
RefactoringAction(ModuleDecl *MD, RefactoringOptions &Opts,
                  SourceEditConsumer &EditConsumer,
                  DiagnosticConsumer &DiagConsumer): MD(MD),
    TheFile(getContainingFile(MD, Opts.Range)),
    EditConsumer(EditConsumer), Ctx(MD->getASTContext()),
    SM(MD->getASTContext().SourceMgr), DiagEngine(SM),
    StartLoc(Lexer::getLocForStartOfToken(SM, Opts.Range.getStart(SM))),
    PreferredName(Opts.PreferredName) {
  DiagEngine.addConsumer(DiagConsumer);
}

/// Different from RangeBasedRefactoringAction, TokenBasedRefactoringAction takes
/// the input of a given token, e.g., a name or an "if" key word. Contextual
/// refactoring kinds can suggest applicable refactorings on that token, e.g.
/// rename or reverse if statement.
class TokenBasedRefactoringAction : public RefactoringAction {
protected:
  ResolvedCursorInfoPtr CursorInfo;

public:
  TokenBasedRefactoringAction(ModuleDecl *MD, RefactoringOptions &Opts,
                              SourceEditConsumer &EditConsumer,
                              DiagnosticConsumer &DiagConsumer) :
  RefactoringAction(MD, Opts, EditConsumer, DiagConsumer) {
  // Resolve the sema token and save it for later use.
  CursorInfo =
      evaluateOrDefault(TheFile->getASTContext().evaluator,
                        CursorInfoRequest{CursorInfoOwner(TheFile, StartLoc)},
                        new ResolvedCursorInfo());
  }
};

#define CURSOR_REFACTORING(KIND, NAME, ID)                                     \
  class RefactoringAction##KIND : public TokenBasedRefactoringAction {         \
  public:                                                                      \
    RefactoringAction##KIND(ModuleDecl *MD, RefactoringOptions &Opts,          \
                            SourceEditConsumer &EditConsumer,                  \
                            DiagnosticConsumer &DiagConsumer)                  \
        : TokenBasedRefactoringAction(MD, Opts, EditConsumer, DiagConsumer) {} \
    bool performChange() override;                                             \
    static bool isApplicable(ResolvedCursorInfoPtr Info,                       \
                             DiagnosticEngine &Diag);                          \
    bool isApplicable() {                                                      \
      return RefactoringAction##KIND::isApplicable(CursorInfo, DiagEngine);    \
    }                                                                          \
  };
#include "swift/Refactoring/RefactoringKinds.def"

class RangeBasedRefactoringAction : public RefactoringAction {
protected:
  ResolvedRangeInfo RangeInfo;
public:
  RangeBasedRefactoringAction(ModuleDecl *MD, RefactoringOptions &Opts,
                              SourceEditConsumer &EditConsumer,
                              DiagnosticConsumer &DiagConsumer) :
  RefactoringAction(MD, Opts, EditConsumer, DiagConsumer),
  RangeInfo(evaluateOrDefault(MD->getASTContext().evaluator,
    RangeInfoRequest(RangeInfoOwner(TheFile, Opts.Range.getStart(SM), Opts.Range.getEnd(SM))),
                              ResolvedRangeInfo())) {}
};

#define RANGE_REFACTORING(KIND, NAME, ID)                                     \
class RefactoringAction##KIND: public RangeBasedRefactoringAction {           \
  public:                                                                     \
  RefactoringAction##KIND(ModuleDecl *MD, RefactoringOptions &Opts,           \
                          SourceEditConsumer &EditConsumer,                   \
                          DiagnosticConsumer &DiagConsumer) :                 \
    RangeBasedRefactoringAction(MD, Opts, EditConsumer, DiagConsumer) {}      \
  bool performChange() override;                                              \
  static bool isApplicable(const ResolvedRangeInfo &Info,                     \
                           DiagnosticEngine &Diag);                           \
  bool isApplicable() {                                                       \
    return RefactoringAction##KIND::isApplicable(RangeInfo, DiagEngine) ;     \
  }                                                                           \
};
#include "swift/Refactoring/RefactoringKinds.def"

bool RefactoringActionLocalRename::isApplicable(
    ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &Diag) {
  Optional<RenameInfo> Info = getRenameInfo(CursorInfo);
  return Info &&
         Info->Availability.AvailableKind == RenameAvailableKind::Available &&
         Info->Availability.Kind == RefactoringKind::LocalRename;
}

static void analyzeRenameScope(ValueDecl *VD,
                               SmallVectorImpl<DeclContext *> &Scopes) {
  auto *Scope = VD->getDeclContext();
  // There may be sibling decls that the renamed symbol is visible from.
  switch (Scope->getContextKind()) {
  case DeclContextKind::GenericTypeDecl:
  case DeclContextKind::ExtensionDecl:
  case DeclContextKind::TopLevelCodeDecl:
  case DeclContextKind::SubscriptDecl:
  case DeclContextKind::EnumElementDecl:
  case DeclContextKind::AbstractFunctionDecl:
    Scope = Scope->getParent();
    break;
  case DeclContextKind::AbstractClosureExpr:
  case DeclContextKind::Initializer:
  case DeclContextKind::SerializedLocal:
  case DeclContextKind::Package:
  case DeclContextKind::Module:
  case DeclContextKind::FileUnit:
  case DeclContextKind::MacroDecl:
    break;
  }

  Scopes.push_back(Scope);
}

static Optional<RenameRangeCollector> localRenames(SourceFile *SF,
                                                   SourceLoc startLoc,
                                                   StringRef preferredName,
                                                   DiagnosticEngine &diags) {
  auto cursorInfo =
      evaluateOrDefault(SF->getASTContext().evaluator,
                        CursorInfoRequest{CursorInfoOwner(SF, startLoc)},
                        new ResolvedCursorInfo());

  Optional<RenameInfo> info = getRenameInfo(cursorInfo);
  if (!info) {
    diags.diagnose(startLoc, diag::unresolved_location);
    return None;
  }

  switch (info->Availability.AvailableKind) {
  case RenameAvailableKind::Available:
    break;
  case RenameAvailableKind::Unavailable_system_symbol:
    diags.diagnose(startLoc, diag::decl_is_system_symbol, info->VD->getName());
    return None;
  case RenameAvailableKind::Unavailable_has_no_location:
    diags.diagnose(startLoc, diag::value_decl_no_loc, info->VD->getName());
    return None;
  case RenameAvailableKind::Unavailable_has_no_name:
    diags.diagnose(startLoc, diag::decl_has_no_name);
    return None;
  case RenameAvailableKind::Unavailable_has_no_accessibility:
    diags.diagnose(startLoc, diag::decl_no_accessibility);
    return None;
  case RenameAvailableKind::Unavailable_decl_from_clang:
    diags.diagnose(startLoc, diag::decl_from_clang);
    return None;
  }

  SmallVector<DeclContext *, 8> scopes;
  analyzeRenameScope(info->VD, scopes);
  if (scopes.empty())
    return None;

  RenameRangeCollector rangeCollector(info->VD, preferredName);
  for (DeclContext *DC : scopes)
    indexDeclContext(DC, rangeCollector);

  return rangeCollector;
}

bool RefactoringActionLocalRename::performChange() {
  if (StartLoc.isInvalid()) {
    DiagEngine.diagnose(SourceLoc(), diag::invalid_location);
    return true;
  }
  if (!DeclNameViewer(PreferredName).isValid()) {
    DiagEngine.diagnose(SourceLoc(), diag::invalid_name, PreferredName);
    return true;
  }
  if (!TheFile) {
    DiagEngine.diagnose(StartLoc, diag::location_module_mismatch,
                        MD->getNameStr());
    return true;
  }

  Optional<RenameRangeCollector> rangeCollector =
      localRenames(TheFile, StartLoc, PreferredName, DiagEngine);
  if (!rangeCollector)
    return true;

  auto consumers = DiagEngine.takeConsumers();
  assert(consumers.size() == 1);
  return syntacticRename(TheFile, rangeCollector->results(), EditConsumer,
                         *consumers[0]);
}

StringRef getDefaultPreferredName(RefactoringKind Kind) {
  switch(Kind) {
    case RefactoringKind::None:
      llvm_unreachable("Should be a valid refactoring kind");
    case RefactoringKind::GlobalRename:
    case RefactoringKind::LocalRename:
      return "newName";
    case RefactoringKind::ExtractExpr:
    case RefactoringKind::ExtractRepeatedExpr:
      return "extractedExpr";
    case RefactoringKind::ExtractFunction:
      return "extractedFunc";
    default:
      return "";
  }
}

enum class CannotExtractReason {
  Literal,
  VoidType,
};

class ExtractCheckResult {
  bool KnownFailure;
  SmallVector<CannotExtractReason, 2> AllReasons;

public:
  ExtractCheckResult(): KnownFailure(true) {}
  ExtractCheckResult(ArrayRef<CannotExtractReason> AllReasons):
    KnownFailure(false), AllReasons(AllReasons.begin(), AllReasons.end()) {}
  bool success() { return success({}); }
  bool success(ArrayRef<CannotExtractReason> ExpectedReasons) {
    if (KnownFailure)
      return false;
    bool Result = true;

    // Check if any reasons aren't covered by the list of expected reasons
    // provided by the client.
    for (auto R: AllReasons) {
      Result &= llvm::is_contained(ExpectedReasons, R);
    }
    return Result;
  }
};

/// Check whether a given range can be extracted.
/// Return true on successful condition checking,.
/// Return false on failed conditions.
ExtractCheckResult checkExtractConditions(const ResolvedRangeInfo &RangeInfo,
                                          DiagnosticEngine &DiagEngine) {
  SmallVector<CannotExtractReason, 2> AllReasons;
  // If any declared declaration is referred out of the given range, return false.
  auto Declared = RangeInfo.DeclaredDecls;
  auto It = std::find_if(Declared.begin(), Declared.end(),
                         [](DeclaredDecl DD) { return DD.ReferredAfterRange; });
  if (It != Declared.end()) {
    DiagEngine.diagnose(It->VD->getLoc(),
                        diag::value_decl_referenced_out_of_range,
                        It->VD->getName());
    return ExtractCheckResult();
  }

  // We cannot extract a range with multi entry points.
  if (!RangeInfo.HasSingleEntry) {
    DiagEngine.diagnose(SourceLoc(), diag::multi_entry_range);
    return ExtractCheckResult();
  }

  // We cannot extract code that is not sure to exit or not.
  if (RangeInfo.exit() == ExitState::Unsure) {
    return ExtractCheckResult();
  }

  // We cannot extract expressions of l-value type.
  if (auto Ty = RangeInfo.getType()) {
    if (Ty->hasLValueType() || Ty->is<InOutType>())
      return ExtractCheckResult();

    // Disallow extracting error type expressions/statements
    // FIXME: diagnose what happened?
    if (Ty->hasError())
      return ExtractCheckResult();

    if (Ty->isVoid()) {
      AllReasons.emplace_back(CannotExtractReason::VoidType);
    }
  }

  // We cannot extract a range with orphaned loop keyword.
  switch (RangeInfo.Orphan) {
  case swift::ide::OrphanKind::Continue:
    DiagEngine.diagnose(SourceLoc(), diag::orphan_loop_keyword, "continue");
    return ExtractCheckResult();
  case swift::ide::OrphanKind::Break:
    DiagEngine.diagnose(SourceLoc(), diag::orphan_loop_keyword, "break");
    return ExtractCheckResult();
  case swift::ide::OrphanKind::None:
    break;
  }

  // Guard statement can not be extracted.
  if (llvm::any_of(RangeInfo.ContainedNodes,
                  [](ASTNode N) { return N.isStmt(StmtKind::Guard); })) {
    return ExtractCheckResult();
  }

  // Disallow extracting certain kinds of statements.
  if (RangeInfo.Kind == RangeKind::SingleStatement) {
    Stmt *S = RangeInfo.ContainedNodes[0].get<Stmt *>();

    // These aren't independent statement.
    if (isa<BraceStmt>(S) || isa<CaseStmt>(S))
      return ExtractCheckResult();
  }

  // Disallow extracting literals.
  if (RangeInfo.Kind == RangeKind::SingleExpression) {
    Expr *E = RangeInfo.ContainedNodes[0].get<Expr*>();

    // Until implementing the performChange() part of extracting trailing
    // closures, we disable them for now.
    if (isa<AbstractClosureExpr>(E))
      return ExtractCheckResult();

    if (isa<LiteralExpr>(E))
      AllReasons.emplace_back(CannotExtractReason::Literal);
  }

  switch (RangeInfo.RangeContext->getContextKind()) {
  case swift::DeclContextKind::Initializer:
  case swift::DeclContextKind::SubscriptDecl:
  case swift::DeclContextKind::EnumElementDecl:
  case swift::DeclContextKind::AbstractFunctionDecl:
  case swift::DeclContextKind::AbstractClosureExpr:
  case swift::DeclContextKind::TopLevelCodeDecl:
    break;

  case swift::DeclContextKind::SerializedLocal:
  case swift::DeclContextKind::Package:
  case swift::DeclContextKind::Module:
  case swift::DeclContextKind::FileUnit:
  case swift::DeclContextKind::GenericTypeDecl:
  case swift::DeclContextKind::ExtensionDecl:
  case swift::DeclContextKind::MacroDecl:
    return ExtractCheckResult();
  }
  return ExtractCheckResult(AllReasons);
}

bool RefactoringActionExtractFunction::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
  switch (Info.Kind) {
  case RangeKind::PartOfExpression:
  case RangeKind::SingleDecl:
  case RangeKind::MultiTypeMemberDecl:
  case RangeKind::Invalid:
    return false;
  case RangeKind::SingleExpression:
  case RangeKind::SingleStatement:
  case RangeKind::MultiStatement: {
    return checkExtractConditions(Info, Diag).
      success({CannotExtractReason::VoidType});
  }
  }
  llvm_unreachable("unhandled kind");
}

static StringRef correctNameInternal(ASTContext &Ctx, StringRef Name,
                                     ArrayRef<ValueDecl*> AllVisibles) {
  // If we find the collision.
  bool FoundCollision = false;

  // The suffixes we cannot use by appending to the original given name.
  llvm::StringSet<> UsedSuffixes;
  for (auto VD : AllVisibles) {
    StringRef S = VD->getBaseName().userFacingName();
    if (!S.startswith(Name))
      continue;
    StringRef Suffix = S.substr(Name.size());
    if (Suffix.empty())
      FoundCollision = true;
    else
      UsedSuffixes.insert(Suffix);
  }
  if (!FoundCollision)
    return Name;

  // Find the first suffix we can use.
  std::string SuffixToUse;
  for (unsigned I = 1; ; I ++) {
    SuffixToUse = std::to_string(I);
    if (UsedSuffixes.count(SuffixToUse) == 0)
      break;
  }
  return Ctx.getIdentifier((llvm::Twine(Name) + SuffixToUse).str()).str();
}

static StringRef correctNewDeclName(DeclContext *DC, StringRef Name) {

  // Collect all visible decls in the decl context.
  llvm::SmallVector<ValueDecl*, 16> AllVisibles;
  VectorDeclConsumer Consumer(AllVisibles);
  ASTContext &Ctx = DC->getASTContext();
  lookupVisibleDecls(Consumer, DC, true);
  return correctNameInternal(Ctx, Name, AllVisibles);
}

static Type sanitizeType(Type Ty) {
  // Transform lvalue type to inout type so that we can print it properly.
  return Ty.transform([](Type Ty) {
    if (Ty->is<LValueType>()) {
      return Type(InOutType::get(Ty->getRValueType()->getCanonicalType()));
    }
    return Ty;
  });
}

static SourceLoc
getNewFuncInsertLoc(DeclContext *DC, DeclContext*& InsertToContext) {
  if (auto D = DC->getInnermostDeclarationDeclContext()) {

    // If extracting from a getter/setter, we should skip both the immediate
    // getter/setter function and the individual var decl. The pattern binding
    // decl is the position before which we should insert the newly extracted
    // function.
    if (auto *FD = dyn_cast<AccessorDecl>(D)) {
      ValueDecl *SD = FD->getStorage();
      switch (SD->getKind()) {
      case DeclKind::Var:
        if (auto *PBD = cast<VarDecl>(SD)->getParentPatternBinding())
          D = PBD;
        break;
      case DeclKind::Subscript:
        D = SD;
        break;
      default:
        break;
      }
    }

    auto Result = D->getStartLoc();
    assert(Result.isValid());

    // The insert loc should be before every decl attributes.
    for (auto Attr : D->getAttrs()) {
      auto Loc = Attr->getRangeWithAt().Start;
      if (Loc.isValid() &&
          Loc.getOpaquePointerValue() < Result.getOpaquePointerValue())
        Result = Loc;
    }

    // The insert loc should be before the doc comments associated with this decl.
    if (!D->getRawComment().Comments.empty()) {
      auto Loc = D->getRawComment().Comments.front().Range.getStart();
      if (Loc.isValid() &&
          Loc.getOpaquePointerValue() < Result.getOpaquePointerValue()) {
        Result = Loc;
      }
    }
    InsertToContext = D->getDeclContext();
    return Result;
  }
  return SourceLoc();
}

static std::vector<NoteRegion>
getNotableRegions(StringRef SourceText, unsigned NameOffset, StringRef Name,
                    bool IsFunctionLike = false, bool IsNonProtocolType = false) {
  auto InputBuffer = llvm::MemoryBuffer::getMemBufferCopy(SourceText,"<extract>");

  CompilerInvocation Invocation{};

  Invocation.getFrontendOptions().InputsAndOutputs.addInput(
      InputFile("<extract>", true, InputBuffer.get(), file_types::TY_Swift));
  Invocation.getFrontendOptions().ModuleName = "extract";
  Invocation.getLangOptions().DisablePoundIfEvaluation = true;

  auto Instance = std::make_unique<swift::CompilerInstance>();
  std::string InstanceSetupError;
  if (Instance->setup(Invocation, InstanceSetupError))
    llvm_unreachable(InstanceSetupError.c_str());

  unsigned BufferId = Instance->getPrimarySourceFile()->getBufferID().value();
  SourceManager &SM = Instance->getSourceMgr();
  SourceLoc NameLoc = SM.getLocForOffset(BufferId, NameOffset);
  auto LineAndCol = SM.getLineAndColumnInBuffer(NameLoc);

  UnresolvedLoc UnresoledName{NameLoc, true};

  NameMatcher Matcher(*Instance->getPrimarySourceFile());
  auto Resolved = Matcher.resolve(llvm::makeArrayRef(UnresoledName), None);
  assert(!Resolved.empty() && "Failed to resolve generated func name loc");

  RenameLoc RenameConfig = {
    LineAndCol.first, LineAndCol.second,
    NameUsage::Definition, /*OldName=*/Name, /*NewName=*/"",
    IsFunctionLike, IsNonProtocolType
  };
  RenameRangeDetailCollector Renamer(SM, Name);
  Renamer.addSyntacticRenameRanges(Resolved.back(), RenameConfig);
  auto Ranges = Renamer.Ranges;

  std::vector<NoteRegion> NoteRegions(Renamer.Ranges.size());
  llvm::transform(
      Ranges, NoteRegions.begin(),
      [&SM](RenameRangeDetail &Detail) -> NoteRegion {
        auto Start = SM.getLineAndColumnInBuffer(Detail.Range.getStart());
        auto End = SM.getLineAndColumnInBuffer(Detail.Range.getEnd());
        return {Detail.RangeKind, Start.first, Start.second,
                End.first,        End.second,  Detail.Index};
      });

  return NoteRegions;
}

bool RefactoringActionExtractFunction::performChange() {
  // Check if the new name is ok.
  if (!Lexer::isIdentifier(PreferredName)) {
    DiagEngine.diagnose(SourceLoc(), diag::invalid_name, PreferredName);
    return true;
  }
  DeclContext *DC = RangeInfo.RangeContext;
  DeclContext *InsertToDC = nullptr;
  SourceLoc InsertLoc = getNewFuncInsertLoc(DC, InsertToDC);

  // Complain about no inserting position.
  if (InsertLoc.isInvalid()) {
    DiagEngine.diagnose(SourceLoc(), diag::no_insert_position);
    return true;
  }

  // Correct the given name if collision happens.
  PreferredName = correctNewDeclName(InsertToDC, PreferredName);

  // Collect the paramters to pass down to the new function.
  std::vector<ReferencedDecl> Parameters;
  for (auto &RD: RangeInfo.ReferencedDecls) {
    // If the referenced decl is declared elsewhere, no need to pass as parameter
    if (RD.VD->getDeclContext() != DC)
      continue;

    // We don't need to pass down implicitly declared variables, e.g. error in
    // a catch block.
    if (RD.VD->isImplicit()) {
      SourceLoc Loc = RD.VD->getStartLoc();
      if (Loc.isValid() &&
          SM.isBeforeInBuffer(RangeInfo.ContentRange.getStart(), Loc) &&
          SM.isBeforeInBuffer(Loc, RangeInfo.ContentRange.getEnd()))
        continue;
    }

    // If the referenced decl is declared inside the range, no need to pass
    // as parameter.
    if (RangeInfo.DeclaredDecls.end() !=
      std::find_if(RangeInfo.DeclaredDecls.begin(), RangeInfo.DeclaredDecls.end(),
        [RD](DeclaredDecl DD) { return RD.VD == DD.VD; }))
      continue;

    // We don't need to pass down self.
    if (auto PD = dyn_cast<ParamDecl>(RD.VD)) {
      if (PD->isSelfParameter()) {
        continue;
      }
    }

    Parameters.emplace_back(RD.VD, sanitizeType(RD.Ty));
  }
  SmallString<64> Buffer;
  unsigned FuncBegin = Buffer.size();
  unsigned FuncNameOffset;
  {
    llvm::raw_svector_ostream OS(Buffer);

    if (!InsertToDC->isLocalContext()) {
      // Default to be file private.
      OS << tok::kw_fileprivate << " ";
    }

    // Inherit static if the containing function is.
    if (DC->getContextKind() == DeclContextKind::AbstractFunctionDecl) {
      if (auto FD = dyn_cast<FuncDecl>(static_cast<AbstractFunctionDecl*>(DC))) {
        if (FD->isStatic()) {
          OS << tok::kw_static << " ";
        }
      }
    }

    OS << tok::kw_func << " ";
    FuncNameOffset = Buffer.size() - FuncBegin;
    OS << PreferredName;
    OS << "(";
    for (auto &RD : Parameters) {
      OS << "_ " << RD.VD->getBaseName().userFacingName() << ": ";
      RD.Ty->reconstituteSugar(/*Recursive*/true)->print(OS);
      if (&RD != &Parameters.back())
        OS << ", ";
    }
    OS << ")";

    if (RangeInfo.UnhandledEffects.contains(EffectKind::Async))
      OS << " async";
    if (RangeInfo.UnhandledEffects.contains(EffectKind::Throws))
      OS << " " << tok::kw_throws;

    bool InsertedReturnType = false;
    if (auto Ty = RangeInfo.getType()) {
      // If the type of the range is not void, specify the return type.
      if (!Ty->isVoid()) {
        OS << " " << tok::arrow << " ";
        sanitizeType(Ty)->reconstituteSugar(/*Recursive*/true)->print(OS);
        InsertedReturnType = true;
      }
    }

    OS << " {\n";

    // Add "return" if the extracted entity is an expression.
    if (RangeInfo.Kind == RangeKind::SingleExpression && InsertedReturnType)
      OS << tok::kw_return << " ";
    OS << RangeInfo.ContentRange.str() << "\n}\n\n";
  }
  unsigned FuncEnd = Buffer.size();

  unsigned ReplaceBegin = Buffer.size();
  unsigned CallNameOffset;
  {
    llvm::raw_svector_ostream OS(Buffer);
    if (RangeInfo.exit() == ExitState::Positive)
      OS << tok::kw_return <<" ";

    if (RangeInfo.UnhandledEffects.contains(EffectKind::Throws))
      OS << tok::kw_try << " ";
    if (RangeInfo.UnhandledEffects.contains(EffectKind::Async))
      OS << "await ";

    CallNameOffset = Buffer.size() - ReplaceBegin;
    OS << PreferredName << "(";
    for (auto &RD : Parameters) {

      // Inout argument needs "&".
      if (RD.Ty->is<InOutType>())
        OS << "&";
      OS << RD.VD->getBaseName().userFacingName();
      if (&RD != &Parameters.back())
        OS << ", ";
    }
    OS << ")";
  }
  unsigned ReplaceEnd = Buffer.size();

  std::string ExtractedFuncName = PreferredName.str() + "(";
  for (size_t i = 0; i < Parameters.size(); ++i) {
    ExtractedFuncName += "_:";
  }
  ExtractedFuncName += ")";

  StringRef DeclStr(Buffer.begin() + FuncBegin, FuncEnd - FuncBegin);
  auto NotableFuncRegions = getNotableRegions(DeclStr, FuncNameOffset,
                                              ExtractedFuncName,
                                              /*IsFunctionLike=*/true);

  StringRef CallStr(Buffer.begin() + ReplaceBegin, ReplaceEnd - ReplaceBegin);
  auto NotableCallRegions = getNotableRegions(CallStr, CallNameOffset,
                                              ExtractedFuncName,
                                              /*IsFunctionLike=*/true);

  // Insert the new function's declaration.
  EditConsumer.accept(SM, InsertLoc, DeclStr, NotableFuncRegions);

  // Replace the code to extract with the function call.
  EditConsumer.accept(SM, RangeInfo.ContentRange, CallStr, NotableCallRegions);

  return false;
}

class RefactoringActionExtractExprBase {
  SourceFile *TheFile;
  ResolvedRangeInfo RangeInfo;
  DiagnosticEngine &DiagEngine;
  const bool ExtractRepeated;
  StringRef PreferredName;
  SourceEditConsumer &EditConsumer;

  ASTContext &Ctx;
  SourceManager &SM;

public:
  RefactoringActionExtractExprBase(SourceFile *TheFile,
                                   ResolvedRangeInfo RangeInfo,
                                   DiagnosticEngine &DiagEngine,
                                   bool ExtractRepeated,
                                   StringRef PreferredName,
                                   SourceEditConsumer &EditConsumer) :
    TheFile(TheFile), RangeInfo(RangeInfo), DiagEngine(DiagEngine),
    ExtractRepeated(ExtractRepeated), PreferredName(PreferredName),
    EditConsumer(EditConsumer), Ctx(TheFile->getASTContext()),
    SM(Ctx.SourceMgr){}
  bool performChange();
};

/// This is to ensure all decl references in two expressions are identical.
struct ReferenceCollector: public SourceEntityWalker {
  SmallVector<ValueDecl*, 4> References;

  ReferenceCollector(Expr *E) { walk(E); }
  bool visitDeclReference(ValueDecl *D, CharSourceRange Range,
                          TypeDecl *CtorTyRef, ExtensionDecl *ExtTyRef,
                          Type T, ReferenceMetaData Data) override {
    References.emplace_back(D);
    return true;
  }
  bool operator==(const ReferenceCollector &Other) const {
    if (References.size() != Other.References.size())
      return false;
    return std::equal(References.begin(), References.end(),
                      Other.References.begin());
  }
};

struct SimilarExprCollector: public SourceEntityWalker {
  SourceManager &SM;

  /// The expression under selection.
  Expr *SelectedExpr;
  ArrayRef<Token> AllTokens;
  llvm::SetVector<Expr*> &Bucket;

  /// The tokens included in the expression under selection.
  ArrayRef<Token> SelectedTokens;

  /// The referenced decls in the expression under selection.
  ReferenceCollector SelectedReferences;

  bool compareTokenContent(ArrayRef<Token> Left, ArrayRef<Token> Right) {
    if (Left.size() != Right.size())
      return false;
    return std::equal(Left.begin(), Left.end(), Right.begin(),
                      [](const Token &L, const Token& R) {
                        return L.getText() == R.getText();
                      });
  }

  /// Find all tokens included by an expression.
  ArrayRef<Token> getExprSlice(Expr *E) {
    return slice_token_array(AllTokens, E->getStartLoc(), E->getEndLoc());
  }

  SimilarExprCollector(SourceManager &SM, Expr *SelectedExpr,
                       ArrayRef<Token> AllTokens,
    llvm::SetVector<Expr*> &Bucket): SM(SM), SelectedExpr(SelectedExpr),
    AllTokens(AllTokens), Bucket(Bucket),
    SelectedTokens(getExprSlice(SelectedExpr)),
    SelectedReferences(SelectedExpr){}

  bool walkToExprPre(Expr *E) override {
    // We don't extract implicit expressions.
    if (E->isImplicit())
      return true;
    if (E->getKind() != SelectedExpr->getKind())
      return true;

    // First check the underlying token arrays have the same content.
    if (compareTokenContent(getExprSlice(E), SelectedTokens)) {
      ReferenceCollector CurrentReferences(E);

      // Next, check the referenced decls are same.
      if (CurrentReferences == SelectedReferences)
        Bucket.insert(E);
    }
    return true;
  }
};

bool RefactoringActionExtractExprBase::performChange() {
  // Check if the new name is ok.
  if (!Lexer::isIdentifier(PreferredName)) {
    DiagEngine.diagnose(SourceLoc(), diag::invalid_name, PreferredName);
    return true;
  }

  // Find the enclosing brace statement;
  ContextFinder Finder(*TheFile, RangeInfo.ContainedNodes.front(),
                       [](ASTNode N) { return N.isStmt(StmtKind::Brace); });

  auto *SelectedExpr = RangeInfo.ContainedNodes[0].get<Expr*>();
  Finder.resolve();
  SourceLoc InsertLoc;
  llvm::SetVector<ValueDecl*> AllVisibleDecls;
  struct DeclCollector: public SourceEntityWalker {
    llvm::SetVector<ValueDecl*> &Bucket;
    DeclCollector(llvm::SetVector<ValueDecl*> &Bucket): Bucket(Bucket) {}
    bool walkToDeclPre(Decl *D, CharSourceRange Range) override {
      if (auto *VD = dyn_cast<ValueDecl>(D))
        Bucket.insert(VD);
      return true;
    }
  } Collector(AllVisibleDecls);

  llvm::SetVector<Expr*> AllExpressions;

  if (!Finder.getContexts().empty()) {

    // Get the innermost brace statement.
    auto BS = static_cast<BraceStmt*>(Finder.getContexts().back().get<Stmt*>());

    // Collect all value decls inside the brace statement.
    Collector.walk(BS);

    if (ExtractRepeated) {
      // Collect all expressions we are going to extract.
      SimilarExprCollector(SM, SelectedExpr,
                           slice_token_array(TheFile->getAllTokens(),
                                             BS->getStartLoc(),
                                             BS->getEndLoc()),
                           AllExpressions).walk(BS);
    } else {
      AllExpressions.insert(SelectedExpr);
    }

    assert(!AllExpressions.empty() && "at least one expression is extracted.");
    for (auto Ele : BS->getElements()) {
      // Find the element that encloses the first expression under extraction.
      if (SM.rangeContains(Ele.getSourceRange(),
                           (*AllExpressions.begin())->getSourceRange())) {

        // Insert before the enclosing element.
        InsertLoc = Ele.getStartLoc();
      }
    }
  }

  // Complain about no inserting position.
  if (InsertLoc.isInvalid()) {
    DiagEngine.diagnose(SourceLoc(), diag::no_insert_position);
    return true;
  }

  // Correct name if collision happens.
  PreferredName = correctNameInternal(TheFile->getASTContext(), PreferredName,
                                      AllVisibleDecls.getArrayRef());

  // Print the type name of this expression.
  SmallString<16> TyBuffer;

  // We are not sure about the type of repeated expressions.
  if (!ExtractRepeated) {
    if (auto Ty = RangeInfo.getType()) {
      llvm::raw_svector_ostream OS(TyBuffer);
      OS << ": ";
      Ty->getRValueType()->reconstituteSugar(true)->print(OS);
    }
  }

  SmallString<64> DeclBuffer;
  llvm::raw_svector_ostream OS(DeclBuffer);
  unsigned StartOffset, EndOffset;
  OS << tok::kw_let << " ";
  StartOffset = DeclBuffer.size();
  OS << PreferredName;
  EndOffset = DeclBuffer.size();
  OS << TyBuffer.str() <<  " = " << RangeInfo.ContentRange.str() << "\n";

  NoteRegion DeclNameRegion{
    RefactoringRangeKind::BaseName,
    /*StartLine=*/1, /*StartColumn=*/StartOffset + 1,
    /*EndLine=*/1, /*EndColumn=*/EndOffset + 1,
    /*ArgIndex*/None
  };

  // Perform code change.
  EditConsumer.accept(SM, InsertLoc, DeclBuffer.str(), {DeclNameRegion});

  // Replace all occurrences of the extracted expression.
  for (auto *E : AllExpressions) {
    EditConsumer.accept(SM,
      Lexer::getCharSourceRangeFromSourceRange(SM, E->getSourceRange()),
      PreferredName,
      {{
        RefactoringRangeKind::BaseName,
        /*StartLine=*/1, /*StartColumn-*/1, /*EndLine=*/1,
        /*EndColumn=*/static_cast<unsigned int>(PreferredName.size() + 1),
        /*ArgIndex*/None
      }});
  }
  return false;
}

bool RefactoringActionExtractExpr::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
  switch (Info.Kind) {
    case RangeKind::SingleExpression:
      // We disallow extract literal expression for two reasons:
      // (1) since we print the type for extracted expression, the type of a
      // literal may print as "int2048" where it is not typically users' choice;
      // (2) Extracting one literal provides little value for users.
      return checkExtractConditions(Info, Diag).success();
    case RangeKind::PartOfExpression:
    case RangeKind::SingleDecl:
    case RangeKind::MultiTypeMemberDecl:
    case RangeKind::SingleStatement:
    case RangeKind::MultiStatement:
    case RangeKind::Invalid:
      return false;
  }
  llvm_unreachable("unhandled kind");
}

bool RefactoringActionExtractExpr::performChange() {
  return RefactoringActionExtractExprBase(TheFile, RangeInfo,
                                          DiagEngine, false, PreferredName,
                                          EditConsumer).performChange();
}

bool RefactoringActionExtractRepeatedExpr::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
  switch (Info.Kind) {
    case RangeKind::SingleExpression:
      return checkExtractConditions(Info, Diag).
        success({CannotExtractReason::Literal});
    case RangeKind::PartOfExpression:
    case RangeKind::SingleDecl:
    case RangeKind::MultiTypeMemberDecl:
    case RangeKind::SingleStatement:
    case RangeKind::MultiStatement:
    case RangeKind::Invalid:
      return false;
  }
  llvm_unreachable("unhandled kind");
}
bool RefactoringActionExtractRepeatedExpr::performChange() {
  return RefactoringActionExtractExprBase(TheFile, RangeInfo,
                                          DiagEngine, true, PreferredName,
                                          EditConsumer).performChange();
}


bool RefactoringActionMoveMembersToExtension::isApplicable(
    const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
  switch (Info.Kind) {
  case RangeKind::SingleDecl:
  case RangeKind::MultiTypeMemberDecl: {
    DeclContext *DC = Info.RangeContext;

    // The the common decl context is not a nomial type, we cannot create an
    // extension for it
    if (!DC || !DC->getInnermostDeclarationDeclContext() ||
        !isa<NominalTypeDecl>(DC->getInnermostDeclarationDeclContext()))
      return false;


    // Members of types not declared at top file level cannot be extracted
    // to an extension at top file level
    if (DC->getParent()->getContextKind() != DeclContextKind::FileUnit)
      return false;

    // Check if contained nodes are all allowed decls.
    for (auto Node : Info.ContainedNodes) {
      Decl *D = Node.dyn_cast<Decl*>();
      if (!D)
        return false;

      if (isa<AccessorDecl>(D) || isa<DestructorDecl>(D) ||
          isa<EnumCaseDecl>(D) || isa<EnumElementDecl>(D))
        return false;
    }

    // We should not move instance variables with storage into the extension
    // because they are not allowed to be declared there
    for (auto DD : Info.DeclaredDecls) {
      if (auto ASD = dyn_cast<AbstractStorageDecl>(DD.VD)) {
        // Only disallow storages in the common decl context, allow them in
        // any subtypes
        if (ASD->hasStorage() && ASD->getDeclContext() == DC) {
          return false;
        }
      }
    }

    return true;
  }
  case RangeKind::SingleExpression:
  case RangeKind::PartOfExpression:
  case RangeKind::SingleStatement:
  case RangeKind::MultiStatement:
  case RangeKind::Invalid:
    return false;
  }
  llvm_unreachable("unhandled kind");
}

bool RefactoringActionMoveMembersToExtension::performChange() {
  DeclContext *DC = RangeInfo.RangeContext;

  auto CommonTypeDecl =
      dyn_cast<NominalTypeDecl>(DC->getInnermostDeclarationDeclContext());
  assert(CommonTypeDecl && "Not applicable if common parent is no nomial type");

  SmallString<64> Buffer;
  llvm::raw_svector_ostream OS(Buffer);
  OS << "\n\n";
  OS << "extension " << CommonTypeDecl->getName() << " {\n";
  OS << RangeInfo.ContentRange.str().trim();
  OS << "\n}";

  // Insert extension after the type declaration
  EditConsumer.insertAfter(SM, CommonTypeDecl->getEndLoc(), Buffer);
  EditConsumer.remove(SM, RangeInfo.ContentRange);

  return false;
}

namespace {
// A SingleDecl range may not include all decls actually declared in that range:
// a var decl has accessors that aren't included. This will find those missing
// decls.
class FindAllSubDecls : public SourceEntityWalker {
  SmallPtrSetImpl<Decl *> &Found;
  public:
  FindAllSubDecls(SmallPtrSetImpl<Decl *> &found)
    : Found(found) {}

  bool walkToDeclPre(Decl *D, CharSourceRange range) override {
    // Record this Decl, and skip its contents if we've already touched it.
    if (!Found.insert(D).second)
      return false;

    if (auto ASD = dyn_cast<AbstractStorageDecl>(D)) {
      ASD->visitParsedAccessors([&](AccessorDecl *accessor) {
        Found.insert(accessor);
      });
    }
    return true;
  }
};
}
bool RefactoringActionReplaceBodiesWithFatalError::isApplicable(
  const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
  switch (Info.Kind) {
  case RangeKind::SingleDecl:
  case RangeKind::MultiTypeMemberDecl: {
    SmallPtrSet<Decl *, 16> Found;
    for (auto decl : Info.DeclaredDecls) {
      FindAllSubDecls(Found).walk(decl.VD);
    }
    for (auto decl : Found) {
      auto AFD = dyn_cast<AbstractFunctionDecl>(decl);
      if (AFD && !AFD->isImplicit())
        return true;
    }

    return false;
 }
  case RangeKind::SingleExpression:
  case RangeKind::PartOfExpression:
  case RangeKind::SingleStatement:
  case RangeKind::MultiStatement:
  case RangeKind::Invalid:
    return false;
  }
  llvm_unreachable("unhandled kind");
}

bool RefactoringActionReplaceBodiesWithFatalError::performChange() {
  const StringRef replacement = "{\nfatalError()\n}";
  SmallPtrSet<Decl *, 16> Found;
  for (auto decl : RangeInfo.DeclaredDecls) {
    FindAllSubDecls(Found).walk(decl.VD);
  }
  for (auto decl : Found) {
    auto AFD = dyn_cast<AbstractFunctionDecl>(decl);
    if (!AFD || AFD->isImplicit())
      continue;

    auto range = AFD->getBodySourceRange();
    // If we're in replacement mode (i.e. have an edit consumer), we can
    // rewrite the function body.
    auto charRange = Lexer::getCharSourceRangeFromSourceRange(SM, range);
    EditConsumer.accept(SM, charRange, replacement);

  }
  return false;
}

static std::pair<IfStmt *, IfStmt *>
findCollapseNestedIfTarget(ResolvedCursorInfoPtr CursorInfo) {
  auto StmtStartInfo = dyn_cast<ResolvedStmtStartCursorInfo>(CursorInfo);
  if (!StmtStartInfo)
    return {};

  // Ensure the statement is 'if' statement. It must not have 'else' clause.
  IfStmt *OuterIf = dyn_cast<IfStmt>(StmtStartInfo->getTrailingStmt());
  if (!OuterIf)
    return {};
  if (OuterIf->getElseStmt())
    return {};

  // The body must contain a sole inner 'if' statement.
  auto Body = dyn_cast_or_null<BraceStmt>(OuterIf->getThenStmt());
  if (!Body || Body->getNumElements() != 1)
    return {};

  IfStmt *InnerIf =
      dyn_cast_or_null<IfStmt>(Body->getFirstElement().dyn_cast<Stmt *>());
  if (!InnerIf)
    return {};

  // Inner 'if' statement also cannot have 'else' clause.
  if (InnerIf->getElseStmt())
    return {};

  return {OuterIf, InnerIf};
}

bool RefactoringActionCollapseNestedIfStmt::isApplicable(
    ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &Diag) {
  return findCollapseNestedIfTarget(CursorInfo).first;
}

bool RefactoringActionCollapseNestedIfStmt::performChange() {
  auto Target = findCollapseNestedIfTarget(CursorInfo);
  if (!Target.first)
    return true;
  auto OuterIf = Target.first;
  auto InnerIf = Target.second;

  EditorConsumerInsertStream OS(
      EditConsumer, SM,
      Lexer::getCharSourceRangeFromSourceRange(SM, OuterIf->getSourceRange()));

  OS << tok::kw_if << " ";

  // Emit conditions.
  bool first = true;
  for (auto &C : llvm::concat<StmtConditionElement>(OuterIf->getCond(),
                                                    InnerIf->getCond())) {
    if (first)
      first = false;
    else
      OS << ", ";
    OS << Lexer::getCharSourceRangeFromSourceRange(SM, C.getSourceRange())
              .str();
  }

  // Emit body.
  OS << " ";
  OS << Lexer::getCharSourceRangeFromSourceRange(
            SM, InnerIf->getThenStmt()->getSourceRange())
            .str();
  return false;
}

static std::unique_ptr<llvm::SetVector<Expr*>>
findConcatenatedExpressions(const ResolvedRangeInfo &Info, ASTContext &Ctx) {
  Expr *E = nullptr;

  switch (Info.Kind) {
  case RangeKind::SingleExpression:
    E = Info.ContainedNodes[0].get<Expr*>();
    break;
  case RangeKind::PartOfExpression:
    E = Info.CommonExprParent;
    break;
  default:
    return nullptr;
  }

  assert(E);

  struct StringInterpolationExprFinder: public SourceEntityWalker {
    std::unique_ptr<llvm::SetVector<Expr *>> Bucket =
        std::make_unique<llvm::SetVector<Expr *>>();
    ASTContext &Ctx;

    bool IsValidInterpolation = true;
    StringInterpolationExprFinder(ASTContext &Ctx): Ctx(Ctx) {}

    bool isConcatenationExpr(DeclRefExpr* Expr) {
      if (!Expr)
        return false;
      auto *FD = dyn_cast<FuncDecl>(Expr->getDecl());
      if (FD == nullptr || (FD != Ctx.getPlusFunctionOnString() &&
          FD != Ctx.getPlusFunctionOnRangeReplaceableCollection())) {
        return false;
      }
      return true;
    }

    bool walkToExprPre(Expr *E) override {
      if (E->isImplicit())
        return true;
      // FIXME: we should have ErrorType instead of null.
      if (E->getType().isNull())
        return true;

      //Only binary concatenation operators should exist in expression
      if (E->getKind() == ExprKind::Binary) {
        auto *BE = dyn_cast<BinaryExpr>(E);
        auto *OperatorDeclRef = BE->getSemanticFn()->getMemberOperatorRef();
        if (!(isConcatenationExpr(OperatorDeclRef) &&
            E->getType()->isString())) {
          IsValidInterpolation = false;
          return false;
        }
        return true;
      }
      // Everything that evaluates to string should be gathered.
      if (E->getType()->isString()) {
        Bucket->insert(E);
        return false;
      }
      if (auto *DR = dyn_cast<DeclRefExpr>(E)) {
        // Checks whether all function references in expression are concatenations.
        auto *FD = dyn_cast<FuncDecl>(DR->getDecl());
        auto IsConcatenation = isConcatenationExpr(DR);
        if (FD && IsConcatenation) {
          return false;
        }
      }
      // There was non-expected expression, it's not valid interpolation then.
      IsValidInterpolation = false;
      return false;
    }
  } Walker(Ctx);
  Walker.walk(E);

  // There should be two or more expressions to convert.
  if (!Walker.IsValidInterpolation || Walker.Bucket->size() < 2)
    return nullptr;

  return std::move(Walker.Bucket);
}

static void interpolatedExpressionForm(Expr *E, SourceManager &SM,
                                              llvm::raw_ostream &OS) {
  if (auto *Literal = dyn_cast<StringLiteralExpr>(E)) {
    OS << Literal->getValue();
    return;
  }
  auto ExpStr = Lexer::getCharSourceRangeFromSourceRange(SM,
    E->getSourceRange()).str().str();
  if (isa<InterpolatedStringLiteralExpr>(E)) {
    ExpStr.erase(0, 1);
    ExpStr.pop_back();
    OS << ExpStr;
    return;
  }
  OS << "\\(" << ExpStr << ")";
}

bool RefactoringActionConvertStringsConcatenationToInterpolation::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
  auto RangeContext = Info.RangeContext;
  if (RangeContext) {
    auto &Ctx = Info.RangeContext->getASTContext();
    return findConcatenatedExpressions(Info, Ctx) != nullptr;
  }
  return false;
}

bool RefactoringActionConvertStringsConcatenationToInterpolation::performChange() {
  auto Expressions = findConcatenatedExpressions(RangeInfo, Ctx);
  if (!Expressions)
    return true;
  EditorConsumerInsertStream OS(EditConsumer, SM, RangeInfo.ContentRange);
  OS << "\"";
  for (auto It = Expressions->begin(); It != Expressions->end(); ++It) {
    interpolatedExpressionForm(*It, SM, OS);
  }
  OS << "\"";
  return false;
}

/// Abstract helper class containing info about a TernaryExpr
/// that can be expanded into an IfStmt.
class ExpandableTernaryExprInfo {

public:
  virtual ~ExpandableTernaryExprInfo() {}

  virtual TernaryExpr *getTernary() = 0;

  virtual SourceRange getNameRange() = 0;

  virtual Type getType() = 0;

  virtual bool shouldDeclareNameAndType() {
    return !getType().isNull();
  }

  virtual bool isValid() {

    //Ensure all public properties are non-nil and valid
    if (!getTernary() || !getNameRange().isValid())
      return false;
    if (shouldDeclareNameAndType() && getType().isNull())
      return false;

    return true; //valid
  }

  CharSourceRange getNameCharRange(const SourceManager &SM) {
    return Lexer::getCharSourceRangeFromSourceRange(SM, getNameRange());
  }
};

/// Concrete subclass containing info about an AssignExpr
/// where the source is the expandable TernaryExpr.
class ExpandableAssignTernaryExprInfo: public ExpandableTernaryExprInfo {

public:
  ExpandableAssignTernaryExprInfo(AssignExpr *Assign): Assign(Assign) {}

  TernaryExpr *getTernary() override {
    if (!Assign)
      return nullptr;
    return dyn_cast_or_null<TernaryExpr>(Assign->getSrc());
  }

  SourceRange getNameRange() override {
    auto Invalid = SourceRange();

    if (!Assign)
      return Invalid;

    if (auto dest = Assign->getDest())
      return dest->getSourceRange();

    return Invalid;
  }

  Type getType() override {
    return nullptr;
  }

private:
  AssignExpr *Assign = nullptr;
};

/// Concrete subclass containing info about a PatternBindingDecl
/// where the pattern initializer is the expandable TernaryExpr.
class ExpandableBindingTernaryExprInfo: public ExpandableTernaryExprInfo {

public:
  ExpandableBindingTernaryExprInfo(PatternBindingDecl *Binding):
  Binding(Binding) {}

  TernaryExpr *getTernary() override {
    if (Binding && Binding->getNumPatternEntries() == 1) {
      if (auto *Init = Binding->getInit(0)) {
        return dyn_cast<TernaryExpr>(Init);
      }
    }

    return nullptr;
  }

  SourceRange getNameRange() override {
    if (auto Pattern = getNamePattern())
      return Pattern->getSourceRange();

    return SourceRange();
  }

  Type getType() override {
    if (auto Pattern = getNamePattern())
      return Pattern->getType();

    return nullptr;
  }

private:
  Pattern *getNamePattern() {
    if (!Binding || Binding->getNumPatternEntries() != 1)
      return nullptr;

    auto Pattern = Binding->getPattern(0);

    if (!Pattern)
      return nullptr;

    if (auto TyPattern = dyn_cast<TypedPattern>(Pattern))
      Pattern = TyPattern->getSubPattern();

    return Pattern;
  }

  PatternBindingDecl *Binding = nullptr;
};

std::unique_ptr<ExpandableTernaryExprInfo>
findExpandableTernaryExpression(const ResolvedRangeInfo &Info) {

  if (Info.Kind != RangeKind::SingleDecl
      && Info.Kind != RangeKind:: SingleExpression)
    return nullptr;

  if (Info.ContainedNodes.size() != 1)
    return nullptr;

  if (auto D = Info.ContainedNodes[0].dyn_cast<Decl*>())
    if (auto Binding = dyn_cast<PatternBindingDecl>(D))
      return std::make_unique<ExpandableBindingTernaryExprInfo>(Binding);

  if (auto E = Info.ContainedNodes[0].dyn_cast<Expr*>())
    if (auto Assign = dyn_cast<AssignExpr>(E))
      return std::make_unique<ExpandableAssignTernaryExprInfo>(Assign);

  return nullptr;
}

bool RefactoringActionExpandTernaryExpr::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
  auto Target = findExpandableTernaryExpression(Info);
  return Target && Target->isValid();
}

bool RefactoringActionExpandTernaryExpr::performChange() {
  auto Target = findExpandableTernaryExpression(RangeInfo);

  if (!Target || !Target->isValid())
    return true; //abort

  auto NameCharRange = Target->getNameCharRange(SM);

  auto IfRange = Target->getTernary()->getSourceRange();
  auto IfCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, IfRange);

  auto CondRange = Target->getTernary()->getCondExpr()->getSourceRange();
  auto CondCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, CondRange);

  auto ThenRange = Target->getTernary()->getThenExpr()->getSourceRange();
  auto ThenCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, ThenRange);

  auto ElseRange = Target->getTernary()->getElseExpr()->getSourceRange();
  auto ElseCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, ElseRange);

  SmallString<64> DeclBuffer;
  llvm::raw_svector_ostream OS(DeclBuffer);

  StringRef Space = " ";
  StringRef NewLine = "\n";

  if (Target->shouldDeclareNameAndType()) {
    //Specifier will not be replaced; append after specifier
    OS << NameCharRange.str() << tok::colon << Space;
    OS << Target->getType() << NewLine;
  }

  OS << tok::kw_if << Space;
  OS << CondCharRange.str() << Space;
  OS << tok::l_brace << NewLine;

  OS << NameCharRange.str() << Space;
  OS << tok::equal << Space;
  OS << ThenCharRange.str() << NewLine;

  OS << tok::r_brace << Space;
  OS << tok::kw_else << Space;
  OS << tok::l_brace << NewLine;

  OS << NameCharRange.str() << Space;
  OS << tok::equal << Space;
  OS << ElseCharRange.str() << NewLine;

  OS << tok::r_brace;

  //Start replacement with name range, skip the specifier
  auto ReplaceRange(NameCharRange);
  ReplaceRange.widen(IfCharRange);

  EditConsumer.accept(SM, ReplaceRange, DeclBuffer.str());

  return false; //don't abort
}

bool RefactoringActionConvertIfLetExprToGuardExpr::
  isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {

  if (Info.Kind != RangeKind::SingleStatement
      && Info.Kind != RangeKind::MultiStatement)
    return false;

  if (Info.ContainedNodes.empty())
    return false;

  IfStmt *If = nullptr;

  if (Info.ContainedNodes.size() == 1) {
    if (auto S = Info.ContainedNodes[0].dyn_cast<Stmt*>()) {
      If = dyn_cast<IfStmt>(S);
    }
  }

  if (!If)
    return false;

  auto CondList = If->getCond();

  if (CondList.size() == 1) {
    auto E = CondList[0];
    auto P = E.getKind();
    if (P == swift::StmtConditionElement::CK_PatternBinding) {
      auto Body = dyn_cast_or_null<BraceStmt>(If->getThenStmt());
      if (Body)
        return true;
    }
  }

  return false;
}

bool RefactoringActionConvertIfLetExprToGuardExpr::performChange() {

  auto S = RangeInfo.ContainedNodes[0].dyn_cast<Stmt*>();
  IfStmt *If = dyn_cast<IfStmt>(S);
  auto CondList = If->getCond();

  // Get if-let condition
  SourceRange range = CondList[0].getSourceRange();
  SourceManager &SM = RangeInfo.RangeContext->getASTContext().SourceMgr;
  auto CondCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, range);
  
  auto Body = dyn_cast_or_null<BraceStmt>(If->getThenStmt());
  
  // Get if-let then body.
  auto firstElement = Body->getFirstElement();
  auto lastElement = Body->getLastElement();
  SourceRange bodyRange = firstElement.getSourceRange();
  bodyRange.widen(lastElement.getSourceRange());
  auto BodyCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, bodyRange);
  
  SmallString<64> DeclBuffer;
  llvm::raw_svector_ostream OS(DeclBuffer);
  
  StringRef Space = " ";
  StringRef NewLine = "\n";
  
  OS << tok::kw_guard << Space;
  OS << CondCharRange.str().str() << Space;
  OS << tok::kw_else << Space;
  OS << tok::l_brace << NewLine;
  
  // Get if-let else body.
  if (auto *ElseBody = dyn_cast_or_null<BraceStmt>(If->getElseStmt())) {
    auto firstElseElement = ElseBody->getFirstElement();
    auto lastElseElement = ElseBody->getLastElement();
    SourceRange elseBodyRange = firstElseElement.getSourceRange();
    elseBodyRange.widen(lastElseElement.getSourceRange());
    auto ElseBodyCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, elseBodyRange);
    OS << ElseBodyCharRange.str().str() << NewLine;
  }
  
  OS << tok::kw_return << NewLine;
  OS << tok::r_brace << NewLine;
  OS << BodyCharRange.str().str();
  
  // Replace if-let to guard
  auto ReplaceRange = RangeInfo.ContentRange;
  EditConsumer.accept(SM, ReplaceRange, DeclBuffer.str());

  return false;
}

bool RefactoringActionConvertGuardExprToIfLetExpr::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
  if (Info.Kind != RangeKind::SingleStatement
      && Info.Kind != RangeKind::MultiStatement)
    return false;

  if (Info.ContainedNodes.empty())
    return false;

  GuardStmt *guardStmt = nullptr;

  if (Info.ContainedNodes.size() > 0) {
    if (auto S = Info.ContainedNodes[0].dyn_cast<Stmt*>()) {
      guardStmt = dyn_cast<GuardStmt>(S);
    }
  }

  if (!guardStmt)
    return false;

  auto CondList = guardStmt->getCond();

  if (CondList.size() == 1) {
    auto E = CondList[0];
    auto P = E.getPatternOrNull();
    if (P && E.getKind() == swift::StmtConditionElement::CK_PatternBinding)
      return true;
  }

  return false;
}

bool RefactoringActionConvertGuardExprToIfLetExpr::performChange() {

  // Get guard stmt
  auto S = RangeInfo.ContainedNodes[0].dyn_cast<Stmt*>();
  GuardStmt *Guard = dyn_cast<GuardStmt>(S);

  // Get guard condition
  auto CondList = Guard->getCond();

  // Get guard condition source
  SourceRange range = CondList[0].getSourceRange();
  SourceManager &SM = RangeInfo.RangeContext->getASTContext().SourceMgr;
  auto CondCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, range);
  
  SmallString<64> DeclBuffer;
  llvm::raw_svector_ostream OS(DeclBuffer);
  
  StringRef Space = " ";
  StringRef NewLine = "\n";
  
  OS << tok::kw_if << Space;
  OS << CondCharRange.str().str() << Space;
  OS << tok::l_brace << NewLine;

  // Get nodes after guard to place them at if-let body
  if (RangeInfo.ContainedNodes.size() > 1) {
    auto S = RangeInfo.ContainedNodes[1].getSourceRange();
    S.widen(RangeInfo.ContainedNodes.back().getSourceRange());
    auto BodyCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, S);
    OS << BodyCharRange.str().str() << NewLine;
  }
  OS << tok::r_brace;

  // Get guard body
  auto Body = dyn_cast_or_null<BraceStmt>(Guard->getBody());
  
  if (Body && Body->getNumElements() > 1) {
    auto firstElement = Body->getFirstElement();
    auto lastElement = Body->getLastElement();
    SourceRange bodyRange = firstElement.getSourceRange();
    bodyRange.widen(lastElement.getSourceRange());
    auto BodyCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, bodyRange);
    OS << Space << tok::kw_else << Space << tok::l_brace << NewLine;
    OS << BodyCharRange.str().str() << NewLine;
    OS << tok::r_brace;
  }
  
  // Replace guard to if-let
  auto ReplaceRange = RangeInfo.ContentRange;
  EditConsumer.accept(SM, ReplaceRange, DeclBuffer.str());
  
  return false;
}

bool RefactoringActionConvertToSwitchStmt::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {

  class ConditionalChecker : public ASTWalker {
  public:
    bool ParamsUseSameVars = true;
    bool ConditionUseOnlyAllowedFunctions = false;
    StringRef ExpectName;

    MacroWalking getMacroWalkingBehavior() const override {
      return MacroWalking::Arguments;
    }

    PostWalkResult<Expr *> walkToExprPost(Expr *E) override {
      if (E->getKind() != ExprKind::DeclRef)
        return Action::Continue(E);
      auto D = dyn_cast<DeclRefExpr>(E)->getDecl();
      if (D->getKind() == DeclKind::Var || D->getKind() == DeclKind::Param)
        ParamsUseSameVars = checkName(dyn_cast<VarDecl>(D));
      if (D->getKind() == DeclKind::Func)
        ConditionUseOnlyAllowedFunctions = checkName(dyn_cast<FuncDecl>(D));
      if (allCheckPassed())
        return Action::Continue(E);

      return Action::Stop();
    }

    bool allCheckPassed() {
      return ParamsUseSameVars && ConditionUseOnlyAllowedFunctions;
    }

  private:
    bool checkName(VarDecl *VD) {
      auto Name = VD->getName().str();
      if (ExpectName.empty())
        ExpectName = Name;
      return Name == ExpectName;
    }

    bool checkName(FuncDecl *FD) {
      const auto Name = FD->getBaseIdentifier().str();
      return Name == "~="
      || Name == "=="
      || Name == "__derived_enum_equals"
      || Name == "__derived_struct_equals"
      || Name == "||"
      || Name == "...";
    }
  };

  class SwitchConvertable {
  public:
    SwitchConvertable(const ResolvedRangeInfo &Info) : Info(Info) { }

    bool isApplicable() {
      if (Info.Kind != RangeKind::SingleStatement)
        return false;
      if (!findIfStmt())
        return false;
      return checkEachCondition();
    }

  private:
    const ResolvedRangeInfo &Info;
    IfStmt *If = nullptr;
    ConditionalChecker checker;

    bool findIfStmt() {
      if (Info.ContainedNodes.size() != 1)
        return false;
      if (auto S = Info.ContainedNodes.front().dyn_cast<Stmt*>())
        If = dyn_cast<IfStmt>(S);
      return If != nullptr;
    }

    bool checkEachCondition() {
      checker = ConditionalChecker();
      do {
        if (!checkEachElement())
          return false;
      } while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt())));
      return true;
    }

    bool checkEachElement() {
      bool result = true;
      auto ConditionalList = If->getCond();
      for (auto Element : ConditionalList) {
        result &= check(Element);
      }
      return result;
    }

    bool check(StmtConditionElement ConditionElement) {
      if (ConditionElement.getKind() == StmtConditionElement::CK_Availability)
        return false;
      if (ConditionElement.getKind() == StmtConditionElement::CK_PatternBinding)
        checker.ConditionUseOnlyAllowedFunctions = true;
      ConditionElement.walk(checker);
      return checker.allCheckPassed();
    }
  };
  return SwitchConvertable(Info).isApplicable();
}

bool RefactoringActionConvertToSwitchStmt::performChange() {

  class VarNameFinder : public ASTWalker {
  public:
    std::string VarName;

    MacroWalking getMacroWalkingBehavior() const override {
      return MacroWalking::Arguments;
    }

    PostWalkResult<Expr *> walkToExprPost(Expr *E) override {
      if (E->getKind() != ExprKind::DeclRef)
        return Action::Continue(E);
      auto D = dyn_cast<DeclRefExpr>(E)->getDecl();
      if (D->getKind() != DeclKind::Var && D->getKind() != DeclKind::Param)
        return Action::Continue(E);
      VarName = dyn_cast<VarDecl>(D)->getName().str().str();
      return Action::Stop();
    }
  };

  class ConditionalPatternFinder : public ASTWalker {
  public:
    ConditionalPatternFinder(SourceManager &SM) : SM(SM) {}

    SmallString<64> ConditionalPattern = SmallString<64>();

    MacroWalking getMacroWalkingBehavior() const override {
      return MacroWalking::Arguments;
    }

    PostWalkResult<Expr *> walkToExprPost(Expr *E) override {
      auto *BE = dyn_cast<BinaryExpr>(E);
      if (!BE)
        return Action::Continue(E);
      if (isFunctionNameAllowed(BE))
        appendPattern(BE->getLHS(), BE->getRHS());
      return Action::Continue(E);
    }

    PreWalkResult<Pattern *> walkToPatternPre(Pattern *P) override {
      ConditionalPattern.append(Lexer::getCharSourceRangeFromSourceRange(SM, P->getSourceRange()).str());
      if (P->getKind() == PatternKind::OptionalSome)
        ConditionalPattern.append("?");
      return Action::Stop();
    }

  private:

    SourceManager &SM;

    bool isFunctionNameAllowed(BinaryExpr *E) {
      Expr *Fn = E->getFn();
      if (auto DotSyntaxCall = dyn_cast_or_null<DotSyntaxCallExpr>(Fn)) {
        Fn = DotSyntaxCall->getFn();
      }
      DeclRefExpr *DeclRef = dyn_cast_or_null<DeclRefExpr>(Fn);
      if (!DeclRef) {
        return false;
      }
      auto FunctionDeclaration = dyn_cast_or_null<FuncDecl>(DeclRef->getDecl());
      if (!FunctionDeclaration) {
        return false;
      }
      auto &ASTCtx = FunctionDeclaration->getASTContext();
      const auto FunctionName = FunctionDeclaration->getBaseIdentifier();
      return FunctionName == ASTCtx.Id_MatchOperator ||
             FunctionName == ASTCtx.Id_EqualsOperator ||
             FunctionName == ASTCtx.Id_derived_enum_equals ||
             FunctionName == ASTCtx.Id_derived_struct_equals;
    }

    void appendPattern(Expr *LHS, Expr *RHS) {
      auto *PatternArgument = RHS;
      if (PatternArgument->getKind() == ExprKind::DeclRef)
        PatternArgument = LHS;
      if (ConditionalPattern.size() > 0)
        ConditionalPattern.append(", ");
      ConditionalPattern.append(Lexer::getCharSourceRangeFromSourceRange(SM, PatternArgument->getSourceRange()).str());
    }
  };

  class ConverterToSwitch {
  public:
    ConverterToSwitch(const ResolvedRangeInfo &Info,
                      SourceManager &SM) : Info(Info), SM(SM) { }

    void performConvert(SmallString<64> &Out) {
      If = findIf();
      OptionalLabel = If->getLabelInfo().Name.str().str();
      ControlExpression = findControlExpression();
      findPatternsAndBodies(PatternsAndBodies);
      DefaultStatements = findDefaultStatements();
      makeSwitchStatement(Out);
    }

  private:
    const ResolvedRangeInfo &Info;
    SourceManager &SM;

    IfStmt *If;
    IfStmt *PreviousIf;

    std::string OptionalLabel;
    std::string ControlExpression;
    SmallVector<std::pair<std::string, std::string>, 16> PatternsAndBodies;
    std::string DefaultStatements;

    IfStmt *findIf() {
      auto S = Info.ContainedNodes[0].dyn_cast<Stmt*>();
      return dyn_cast<IfStmt>(S);
    }

    std::string findControlExpression() {
      auto ConditionElement = If->getCond().front();
      auto Finder = VarNameFinder();
      ConditionElement.walk(Finder);
      return Finder.VarName;
    }

    void findPatternsAndBodies(SmallVectorImpl<std::pair<std::string, std::string>> &Out) {
      do {
        auto pattern = findPattern();
        auto body = findBodyStatements();
        Out.push_back(std::make_pair(pattern, body));
        PreviousIf = If;
      } while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt())));
    }

    std::string findPattern() {
      auto ConditionElement = If->getCond().front();
      auto Finder = ConditionalPatternFinder(SM);
      ConditionElement.walk(Finder);
      return Finder.ConditionalPattern.str().str();
    }

    std::string findBodyStatements() {
      return findBodyWithoutBraces(If->getThenStmt());
    }

    std::string findDefaultStatements() {
      auto ElseBody = dyn_cast_or_null<BraceStmt>(PreviousIf->getElseStmt());
      if (!ElseBody)
        return getTokenText(tok::kw_break).str();
      return findBodyWithoutBraces(ElseBody);
    }

    std::string findBodyWithoutBraces(Stmt *body) {
      auto BS = dyn_cast<BraceStmt>(body);
      if (!BS)
        return Lexer::getCharSourceRangeFromSourceRange(SM, body->getSourceRange()).str().str();
      if (BS->getElements().empty())
        return getTokenText(tok::kw_break).str();
      SourceRange BodyRange = BS->getElements().front().getSourceRange();
      BodyRange.widen(BS->getElements().back().getSourceRange());
      return Lexer::getCharSourceRangeFromSourceRange(SM, BodyRange).str().str();
    }

    void makeSwitchStatement(SmallString<64> &Out) {
      StringRef Space = " ";
      StringRef NewLine = "\n";
      llvm::raw_svector_ostream OS(Out);
      if (OptionalLabel.size() > 0)
        OS << OptionalLabel << ":" << Space;
      OS << tok::kw_switch << Space << ControlExpression << Space << tok::l_brace << NewLine;
      for (auto &pair : PatternsAndBodies) {
        OS << tok::kw_case << Space << pair.first << tok::colon << NewLine;
        OS << pair.second << NewLine;
      }
      OS << tok::kw_default << tok::colon << NewLine;
      OS << DefaultStatements << NewLine;
      OS << tok::r_brace;
    }

  };

  SmallString<64> result;
  ConverterToSwitch(RangeInfo, SM).performConvert(result);
  EditConsumer.accept(SM, RangeInfo.ContentRange, result.str());
  return false;
}

/// Struct containing info about an IfStmt that can be converted into a
/// TernaryExpr.
struct ConvertToTernaryExprInfo {
  ConvertToTernaryExprInfo() {}

  Expr *AssignDest() {

    if (!Then || !Then->getDest() || !Else || !Else->getDest())
      return nullptr;

    auto ThenDest = Then->getDest();
    auto ElseDest = Else->getDest();

    if (ThenDest->getKind() != ElseDest->getKind())
      return nullptr;

    switch (ThenDest->getKind()) {
      case ExprKind::DeclRef: {
        auto ThenRef = dyn_cast<DeclRefExpr>(Then->getDest());
        auto ElseRef = dyn_cast<DeclRefExpr>(Else->getDest());

        if (!ThenRef || !ThenRef->getDecl() || !ElseRef || !ElseRef->getDecl())
          return nullptr;

        const auto ThenName = ThenRef->getDecl()->getName();
        const auto ElseName = ElseRef->getDecl()->getName();

        if (ThenName.compare(ElseName) != 0)
          return nullptr;

        return Then->getDest();
      }
      case ExprKind::Tuple: {
        auto ThenTuple = dyn_cast<TupleExpr>(Then->getDest());
        auto ElseTuple = dyn_cast<TupleExpr>(Else->getDest());

        if (!ThenTuple || !ElseTuple)
          return nullptr;

        auto ThenNames = ThenTuple->getElementNames();
        auto ElseNames = ElseTuple->getElementNames();

        if (!ThenNames.equals(ElseNames))
          return nullptr;

        return ThenTuple;
      }
      default:
        return nullptr;
    }
  }

  Expr *ThenSrc() {
    if (!Then)
      return nullptr;
    return Then->getSrc();
  }

  Expr *ElseSrc() {
    if (!Else)
      return nullptr;
    return Else->getSrc();
  }

  bool isValid() {
    if (!Cond || !AssignDest() || !ThenSrc() || !ElseSrc()
        || !IfRange.isValid())
      return false;

    return true;
  }

  PatternBindingDecl *Binding = nullptr; //optional

  Expr *Cond = nullptr; //required
  AssignExpr *Then = nullptr; //required
  AssignExpr *Else = nullptr; //required
  SourceRange IfRange;
};

ConvertToTernaryExprInfo
findConvertToTernaryExpression(const ResolvedRangeInfo &Info) {

  auto notFound = ConvertToTernaryExprInfo();

  if (Info.Kind != RangeKind::SingleStatement
      && Info.Kind != RangeKind::MultiStatement)
    return notFound;

  if (Info.ContainedNodes.empty())
    return notFound;

  struct AssignExprFinder: public SourceEntityWalker {

    AssignExpr *Assign = nullptr;

    AssignExprFinder(Stmt* S) {
      if (S)
        walk(S);
    }

    virtual bool walkToExprPre(Expr *E) override {
      Assign = dyn_cast<AssignExpr>(E);
      return false;
    }
  };

  ConvertToTernaryExprInfo Target;

  IfStmt *If = nullptr;

  if (Info.ContainedNodes.size() == 1) {
    if (auto S = Info.ContainedNodes[0].dyn_cast<Stmt*>())
      If = dyn_cast<IfStmt>(S);
  }

  if (Info.ContainedNodes.size() == 2) {
    if (auto D = Info.ContainedNodes[0].dyn_cast<Decl*>())
      Target.Binding = dyn_cast<PatternBindingDecl>(D);
    if (auto S = Info.ContainedNodes[1].dyn_cast<Stmt*>())
      If = dyn_cast<IfStmt>(S);
  }

  if (!If)
    return notFound;

  auto CondList = If->getCond();

  if (CondList.size() != 1)
    return notFound;

  Target.Cond = CondList[0].getBooleanOrNull();
  Target.IfRange = If->getSourceRange();

  Target.Then = AssignExprFinder(If->getThenStmt()).Assign;
  Target.Else = AssignExprFinder(If->getElseStmt()).Assign;

  return Target;
}

bool RefactoringActionConvertToTernaryExpr::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
  return findConvertToTernaryExpression(Info).isValid();
}

bool RefactoringActionConvertToTernaryExpr::performChange() {
  auto Target = findConvertToTernaryExpression(RangeInfo);

  if (!Target.isValid())
    return true; //abort

  SmallString<64> DeclBuffer;
  llvm::raw_svector_ostream OS(DeclBuffer);

  StringRef Space = " ";

  auto IfRange = Target.IfRange;
  auto ReplaceRange = Lexer::getCharSourceRangeFromSourceRange(SM, IfRange);

  auto CondRange = Target.Cond->getSourceRange();
  auto CondCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, CondRange);

  auto ThenRange = Target.ThenSrc()->getSourceRange();
  auto ThenCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, ThenRange);

  auto ElseRange = Target.ElseSrc()->getSourceRange();
  auto ElseCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, ElseRange);

  CharSourceRange DestCharRange;

  if (Target.Binding) {
    auto DestRange = Target.Binding->getSourceRange();
    DestCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, DestRange);
    ReplaceRange.widen(DestCharRange);
  } else {
    auto DestRange = Target.AssignDest()->getSourceRange();
    DestCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, DestRange);
  }

  OS << DestCharRange.str() << Space << tok::equal << Space;
  OS << CondCharRange.str() << Space << tok::question_postfix << Space;
  OS << ThenCharRange.str() << Space << tok::colon << Space;
  OS << ElseCharRange.str();

  EditConsumer.accept(SM, ReplaceRange, DeclBuffer.str());

  return false; //don't abort
}

/// The helper class analyzes a given nominal decl or an extension decl to
/// decide whether stubs are required to filled in and the context in which
/// these stubs should be filled.
class FillProtocolStubContext {

  std::vector<ValueDecl*>
  getUnsatisfiedRequirements(const IterableDeclContext *IDC);

  /// Context in which the content should be filled; this could be either a
  /// nominal type declaraion or an extension declaration.
  DeclContext *DC;

  /// The type that adopts the required protocol stubs. For nominal type decl, this
  /// should be the declared type itself; for extension decl, this should be the
  /// extended type at hand.
  Type Adopter;

  /// The start location of the decl, either nominal type or extension, for the
  /// printer to figure out the right indentation.
  SourceLoc StartLoc;

  /// The location of '{' for the decl, thus we know where to insert the filling
  /// stubs.
  SourceLoc BraceStartLoc;

  /// The value decls that should be satisfied; this could be either function
  /// decls, property decls, or required type alias.
  std::vector<ValueDecl*> FillingContents;

public:
  FillProtocolStubContext(ExtensionDecl *ED) : DC(ED),
    Adopter(ED->getExtendedType()), StartLoc(ED->getStartLoc()),
    BraceStartLoc(ED->getBraces().Start),
    FillingContents(getUnsatisfiedRequirements(ED)) {};

  FillProtocolStubContext(NominalTypeDecl *ND) : DC(ND),
    Adopter(ND->getDeclaredType()), StartLoc(ND->getStartLoc()),
    BraceStartLoc(ND->getBraces().Start),
    FillingContents(getUnsatisfiedRequirements(ND)) {};

  FillProtocolStubContext() : DC(nullptr), Adopter(), FillingContents({}) {};

  static FillProtocolStubContext
  getContextFromCursorInfo(ResolvedCursorInfoPtr Tok);

  ArrayRef<ValueDecl*> getFillingContents() const {
    return llvm::makeArrayRef(FillingContents);
  }

  DeclContext *getFillingContext() const { return DC; }

  bool canProceed() const {
    return StartLoc.isValid() && BraceStartLoc.isValid() &&
      !getFillingContents().empty();
  }

  Type getAdopter() const { return Adopter; }
  SourceLoc getContextStartLoc() const { return StartLoc; }
  SourceLoc getBraceStartLoc() const { return BraceStartLoc; }
};

FillProtocolStubContext FillProtocolStubContext::getContextFromCursorInfo(
    ResolvedCursorInfoPtr CursorInfo) {
  if (!CursorInfo->isValid())
    return FillProtocolStubContext();
  auto ValueRefInfo = dyn_cast<ResolvedValueRefCursorInfo>(CursorInfo);
  if (!ValueRefInfo) {
    return FillProtocolStubContext();
  }
  if (!ValueRefInfo->isRef()) {
    // If the type name is on the declared nominal, e.g. "class A {}"
    if (auto ND = dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD())) {
      return FillProtocolStubContext(ND);
    }
  } else if (auto *ED = ValueRefInfo->getExtTyRef()) {
    // If the type ref is on a declared extension, e.g. "extension A {}"
    return FillProtocolStubContext(ED);
  }
  return FillProtocolStubContext();
}

std::vector<ValueDecl*> FillProtocolStubContext::
getUnsatisfiedRequirements(const IterableDeclContext *IDC) {
  // The results to return.
  std::vector<ValueDecl*> NonWitnessedReqs;

  // For each conformance of the extended nominal.
  for(ProtocolConformance *Con : IDC->getLocalConformances()) {

    // Collect non-witnessed requirements.
    Con->forEachNonWitnessedRequirement(
      [&](ValueDecl *VD) { NonWitnessedReqs.push_back(VD); });
  }

  return NonWitnessedReqs;
}

bool RefactoringActionFillProtocolStub::isApplicable(ResolvedCursorInfoPtr Tok,
                                                     DiagnosticEngine &Diag) {
  return FillProtocolStubContext::getContextFromCursorInfo(Tok).canProceed();
}

bool RefactoringActionFillProtocolStub::performChange() {
  // Get the filling protocol context from the input token.
  FillProtocolStubContext Context = FillProtocolStubContext::
    getContextFromCursorInfo(CursorInfo);

  assert(Context.canProceed());
  assert(!Context.getFillingContents().empty());
  assert(Context.getFillingContext());
  SmallString<128> Text;
  {
    llvm::raw_svector_ostream SS(Text);
    Type Adopter = Context.getAdopter();
    SourceLoc Loc = Context.getContextStartLoc();
    auto Contents = Context.getFillingContents();

    // For each unsatisfied requirement, print the stub to the buffer.
    std::for_each(Contents.begin(), Contents.end(), [&](ValueDecl *VD) {
      printRequirementStub(VD, Context.getFillingContext(), Adopter, Loc, SS);
    });
  }

  // Insert all stubs after '{' in the extension/nominal type decl.
  EditConsumer.insertAfter(SM, Context.getBraceStartLoc(), Text);
  return false;
}

static void collectAvailableRefactoringsAtCursor(
    SourceFile *SF, unsigned Line, unsigned Column,
    SmallVectorImpl<RefactoringKind> &Kinds,
    ArrayRef<DiagnosticConsumer *> DiagConsumers) {
  // Prepare the tool box.
  ASTContext &Ctx = SF->getASTContext();
  SourceManager &SM = Ctx.SourceMgr;
  DiagnosticEngine DiagEngine(SM);
  std::for_each(DiagConsumers.begin(), DiagConsumers.end(),
                [&](DiagnosticConsumer *Con) { DiagEngine.addConsumer(*Con); });
  SourceLoc Loc = SM.getLocForLineCol(SF->getBufferID().value(), Line, Column);
  if (Loc.isInvalid())
    return;

  ResolvedCursorInfoPtr Tok =
      evaluateOrDefault(SF->getASTContext().evaluator,
                        CursorInfoRequest{CursorInfoOwner(
                            SF, Lexer::getLocForStartOfToken(SM, Loc))},
                        new ResolvedCursorInfo());
  collectAvailableRefactorings(Tok, Kinds, /*Exclude rename*/ false);
}

static EnumDecl* getEnumDeclFromSwitchStmt(SwitchStmt *SwitchS) {
  if (auto SubjectTy = SwitchS->getSubjectExpr()->getType()) {
    // FIXME: Support more complex subject like '(Enum1, Enum2)'.
    return dyn_cast_or_null<EnumDecl>(SubjectTy->getAnyNominal());
  }
  return nullptr;
}

static bool performCasesExpansionInSwitchStmt(SwitchStmt *SwitchS,
                                              DiagnosticEngine &DiagEngine,
                                              SourceLoc ExpandedStmtLoc,
                                              EditorConsumerInsertStream &OS
                                              ) {
  // Assume enum elements are not handled in the switch statement.
  auto EnumDecl = getEnumDeclFromSwitchStmt(SwitchS);
  assert(EnumDecl);
  llvm::DenseSet<EnumElementDecl*> UnhandledElements;
  EnumDecl->getAllElements(UnhandledElements);
  for (auto Current : SwitchS->getCases()) {
    if (Current->isDefault()) {
      continue;
    }
    // For each handled enum element, remove it from the bucket.
    for (auto Item : Current->getCaseLabelItems()) {
      if (auto *EEP = dyn_cast_or_null<EnumElementPattern>(Item.getPattern())) {
        UnhandledElements.erase(EEP->getElementDecl());
      }
    }
  }

  // If all enum elements are handled in the switch statement, issue error.
  if (UnhandledElements.empty()) {
    DiagEngine.diagnose(ExpandedStmtLoc, diag::no_remaining_cases);
    return true;
  }

  printEnumElementsAsCases(UnhandledElements, OS);
  return false;
}

// Finds SwitchStmt that contains given CaseStmt.
static SwitchStmt* findEnclosingSwitchStmt(CaseStmt *CS,
                                           SourceFile *SF,
                                           DiagnosticEngine &DiagEngine) {
  auto IsSwitch = [](ASTNode Node) {
    return Node.is<Stmt*>() &&
    Node.get<Stmt*>()->getKind() == StmtKind::Switch;
  };
  ContextFinder Finder(*SF, CS, IsSwitch);
  Finder.resolve();

  // If failed to find the switch statement, issue error.
  if (Finder.getContexts().empty()) {
    DiagEngine.diagnose(CS->getStartLoc(), diag::no_parent_switch);
    return nullptr;
  }
  auto *SwitchS = static_cast<SwitchStmt*>(Finder.getContexts().back().
                                           get<Stmt*>());
  // Make sure that CaseStmt is included in switch that was found.
  auto Cases = SwitchS->getCases();
  auto Default = std::find(Cases.begin(), Cases.end(), CS);
  if (Default == Cases.end()) {
    DiagEngine.diagnose(CS->getStartLoc(), diag::no_parent_switch);
    return nullptr;
  }
  return SwitchS;
}

bool RefactoringActionExpandDefault::isApplicable(
    ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &Diag) {
  auto Exit = [&](bool Applicable) {
    if (!Applicable)
      Diag.diagnose(SourceLoc(), diag::invalid_default_location);
    return Applicable;
  };
  auto StmtStartInfo = dyn_cast<ResolvedStmtStartCursorInfo>(CursorInfo);
  if (!StmtStartInfo)
    return Exit(false);
  if (auto *CS = dyn_cast<CaseStmt>(StmtStartInfo->getTrailingStmt())) {
    auto EnclosingSwitchStmt =
        findEnclosingSwitchStmt(CS, CursorInfo->getSourceFile(), Diag);
    if (!EnclosingSwitchStmt)
      return false;
    auto EnumD = getEnumDeclFromSwitchStmt(EnclosingSwitchStmt);
    auto IsApplicable = CS->isDefault() && EnumD != nullptr;
    return IsApplicable;
  }
  return Exit(false);
}

bool RefactoringActionExpandDefault::performChange() {
  // If we've not seen the default statement inside the switch statement, issue
  // error.
  auto StmtStartInfo = cast<ResolvedStmtStartCursorInfo>(CursorInfo);
  auto *CS = static_cast<CaseStmt *>(StmtStartInfo->getTrailingStmt());
  auto *SwitchS = findEnclosingSwitchStmt(CS, TheFile, DiagEngine);
  assert(SwitchS);
  EditorConsumerInsertStream OS(EditConsumer, SM,
                                Lexer::getCharSourceRangeFromSourceRange(SM,
                                  CS->getLabelItemsRange()));
  return performCasesExpansionInSwitchStmt(SwitchS,
                                           DiagEngine,
                                           CS->getStartLoc(),
                                           OS);
}

bool RefactoringActionExpandSwitchCases::isApplicable(
    ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &DiagEngine) {
  auto StmtStartInfo = dyn_cast<ResolvedStmtStartCursorInfo>(CursorInfo);
  if (!StmtStartInfo || !StmtStartInfo->getTrailingStmt())
    return false;
  if (auto *Switch = dyn_cast<SwitchStmt>(StmtStartInfo->getTrailingStmt())) {
    return getEnumDeclFromSwitchStmt(Switch);
  }
  return false;
}

bool RefactoringActionExpandSwitchCases::performChange() {
  auto StmtStartInfo = cast<ResolvedStmtStartCursorInfo>(CursorInfo);
  auto *SwitchS = dyn_cast<SwitchStmt>(StmtStartInfo->getTrailingStmt());
  assert(SwitchS);

  auto InsertRange = CharSourceRange();
  auto Cases = SwitchS->getCases();
  auto Default = std::find_if(Cases.begin(), Cases.end(), [](CaseStmt *Stmt) {
    return Stmt->isDefault();
  });
  if (Default != Cases.end()) {
    auto DefaultRange = (*Default)->getLabelItemsRange();
    InsertRange = Lexer::getCharSourceRangeFromSourceRange(SM, DefaultRange);
  } else {
    auto RBraceLoc = SwitchS->getRBraceLoc();
    InsertRange = CharSourceRange(SM, RBraceLoc, RBraceLoc);
  }
  EditorConsumerInsertStream OS(EditConsumer, SM, InsertRange);
  if (SM.getLineAndColumnInBuffer(SwitchS->getLBraceLoc()).first ==
      SM.getLineAndColumnInBuffer(SwitchS->getRBraceLoc()).first) {
    OS << "\n";
  }
  auto Result = performCasesExpansionInSwitchStmt(SwitchS,
                                           DiagEngine,
                                           SwitchS->getStartLoc(),
                                           OS);
  return Result;
}

static Expr *findLocalizeTarget(ResolvedCursorInfoPtr CursorInfo) {
  auto ExprStartInfo = dyn_cast<ResolvedExprStartCursorInfo>(CursorInfo);
  if (!ExprStartInfo)
    return nullptr;
  struct StringLiteralFinder: public SourceEntityWalker {
    SourceLoc StartLoc;
    Expr *Target;
    StringLiteralFinder(SourceLoc StartLoc): StartLoc(StartLoc), Target(nullptr) {}
    bool walkToExprPre(Expr *E) override {
      if (E->getStartLoc() != StartLoc)
        return false;
      if (E->getKind() == ExprKind::InterpolatedStringLiteral)
        return false;
      if (E->getKind() == ExprKind::StringLiteral) {
        Target = E;
        return false;
      }
      return true;
    }
  } Walker(ExprStartInfo->getTrailingExpr()->getStartLoc());
  Walker.walk(ExprStartInfo->getTrailingExpr());
  return Walker.Target;
}

bool RefactoringActionLocalizeString::isApplicable(ResolvedCursorInfoPtr Tok,
                                                   DiagnosticEngine &Diag) {
  return findLocalizeTarget(Tok);
}

bool RefactoringActionLocalizeString::performChange() {
  Expr* Target = findLocalizeTarget(CursorInfo);
   if (!Target)
    return true;
  EditConsumer.accept(SM, Target->getStartLoc(), "NSLocalizedString(");
  EditConsumer.insertAfter(SM, Target->getEndLoc(), ", comment: \"\")");
  return false;
}

struct MemberwiseParameter {
  CharSourceRange NameRange;
  Type MemberType;
  Expr *DefaultExpr;

  MemberwiseParameter(CharSourceRange nameRange, Type type, Expr *initialExpr)
      : NameRange(nameRange), MemberType(type), DefaultExpr(initialExpr) {}
};

static void generateMemberwiseInit(SourceEditConsumer &EditConsumer,
                                   SourceManager &SM,
                                   ArrayRef<MemberwiseParameter> memberVector,
                                   SourceLoc targetLocation) {
  
  EditConsumer.accept(SM, targetLocation, "\ninternal init(");
  auto insertMember = [&SM](const MemberwiseParameter &memberData,
                            raw_ostream &OS, bool wantsSeparator) {
    {
      OS << SM.extractText(memberData.NameRange) << ": ";
      // Unconditionally print '@escaping' if we print out a function type -
      // the assignments we generate below will escape this parameter.
      if (isa<AnyFunctionType>(memberData.MemberType->getCanonicalType())) {
        OS << "@" << TypeAttributes::getAttrName(TAK_escaping) << " ";
      }
      OS << memberData.MemberType.getString();
    }

    bool HasAddedDefault = false;
    if (auto *expr = memberData.DefaultExpr) {
      if (expr->getSourceRange().isValid()) {
        auto range =
          Lexer::getCharSourceRangeFromSourceRange(
            SM, expr->getSourceRange());
        OS << " = " << SM.extractText(range);
        HasAddedDefault = true;
      }
    }
    if (!HasAddedDefault && memberData.MemberType->isOptional()) {
      OS << " = nil";
    }

    if (wantsSeparator) {
      OS << ", ";
    }
  };

  // Process the initial list of members, inserting commas as appropriate.
  std::string Buffer;
  llvm::raw_string_ostream OS(Buffer);
  for (const auto &memberData : llvm::enumerate(memberVector)) {
    bool wantsSeparator = (memberData.index() != memberVector.size() - 1);
    insertMember(memberData.value(), OS, wantsSeparator);
  }

  // Synthesize the body.
  OS << ") {\n";
  for (auto &member : memberVector) {
    // self.<property> = <property>
    auto name = SM.extractText(member.NameRange);
    OS << "self." << name << " = " << name << "\n";
  }
  OS << "}\n";

  // Accept the entire edit.
  EditConsumer.accept(SM, targetLocation, OS.str());
}

static SourceLoc
collectMembersForInit(ResolvedCursorInfoPtr CursorInfo,
                      SmallVectorImpl<MemberwiseParameter> &memberVector) {
  auto ValueRefInfo = dyn_cast<ResolvedValueRefCursorInfo>(CursorInfo);
  if (!ValueRefInfo || !ValueRefInfo->getValueD())
    return SourceLoc();

  NominalTypeDecl *nominalDecl =
      dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD());
  if (!nominalDecl || nominalDecl->getStoredProperties().empty() ||
      ValueRefInfo->isRef()) {
    return SourceLoc();
  }

  SourceLoc bracesStart = nominalDecl->getBraces().Start;
  if (!bracesStart.isValid())
    return SourceLoc();
  
  SourceLoc targetLocation = bracesStart.getAdvancedLoc(1);
  if (!targetLocation.isValid())
    return SourceLoc();

  SourceManager &SM = nominalDecl->getASTContext().SourceMgr;

  for (auto member : nominalDecl->getMembers()) {
    auto varDecl = dyn_cast<VarDecl>(member);
    if (!varDecl) {
      continue;
    }
    if (varDecl->getAttrs().hasAttribute<LazyAttr>()) {
      // Exclude lazy members from the memberwise initializer. This is
      // inconsistent with the implicitly synthesized memberwise initializer but
      // we think it makes more sense because otherwise the lazy variable's
      // initializer gets evaluated eagerly.
      continue;
    }

    if (!varDecl->isMemberwiseInitialized(/*preferDeclaredProperties=*/true)) {
      continue;
    }

    auto patternBinding = varDecl->getParentPatternBinding();
    if (!patternBinding)
      continue;

    const auto i = patternBinding->getPatternEntryIndexForVarDecl(varDecl);
    Expr *defaultInit = nullptr;
    if (patternBinding->isExplicitlyInitialized(i) ||
        patternBinding->isDefaultInitializable()) {
      defaultInit = patternBinding->getOriginalInit(i);
    }

    auto NameRange =
        Lexer::getCharSourceRangeFromSourceRange(SM, varDecl->getNameLoc());
    memberVector.emplace_back(NameRange, varDecl->getType(), defaultInit);
  }

  return targetLocation;
}

bool RefactoringActionMemberwiseInitLocalRefactoring::isApplicable(
    ResolvedCursorInfoPtr Tok, DiagnosticEngine &Diag) {

  SmallVector<MemberwiseParameter, 8> memberVector;
  return collectMembersForInit(Tok, memberVector).isValid();
}

bool RefactoringActionMemberwiseInitLocalRefactoring::performChange() {
  
  SmallVector<MemberwiseParameter, 8> memberVector;
  SourceLoc targetLocation = collectMembersForInit(CursorInfo, memberVector);
  if (targetLocation.isInvalid())
    return true;
  
  generateMemberwiseInit(EditConsumer, SM, memberVector, targetLocation);
  
  return false;
}

/// If \p NTD is a protocol, return all the protocols it inherits from. If it's
/// a type, return all the protocols it conforms to.
static SmallVector<ProtocolDecl *, 2> getAllProtocols(NominalTypeDecl *NTD) {
  if (auto Proto = dyn_cast<ProtocolDecl>(NTD)) {
    return SmallVector<ProtocolDecl *, 2>(
        Proto->getInheritedProtocols().begin(),
        Proto->getInheritedProtocols().end());
  } else {
    return NTD->getAllProtocols();
  }
}

class AddEquatableContext {

  /// Declaration context
  DeclContext *DC;

  /// Adopter type
  Type Adopter;

  /// Start location of declaration context brace
  SourceLoc StartLoc;

  /// Array of all inherited protocols' locations
  ArrayRef<InheritedEntry> ProtocolsLocations;

  /// Array of all conformed protocols
  SmallVector<swift::ProtocolDecl *, 2> Protocols;

  /// Start location of declaration,
  /// a place to write protocol name
  SourceLoc ProtInsertStartLoc;

  /// Stored properties of extending adopter
  ArrayRef<VarDecl *> StoredProperties;

  /// Range of internal members in declaration
  DeclRange Range;

  bool conformsToEquatableProtocol() {
    for (ProtocolDecl *Protocol : Protocols) {
      if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Equatable) {
        return true;
      }
    }
    return false;
  }

  bool isRequirementValid() {
    auto Reqs = getProtocolRequirements();
    if (Reqs.empty()) {
      return false;
    }
    auto Req = dyn_cast<FuncDecl>(Reqs[0]);
    return Req && Req->getParameters()->size() == 2;
  }

  bool isPropertiesListValid() {
    return !getUserAccessibleProperties().empty();
  }

  void printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent,
  ParameterList *Params);

  std::vector<ValueDecl *> getProtocolRequirements();

  std::vector<VarDecl *> getUserAccessibleProperties();

public:
  AddEquatableContext(NominalTypeDecl *Decl)
      : DC(Decl), Adopter(Decl->getDeclaredType()),
        StartLoc(Decl->getBraces().Start),
        ProtocolsLocations(Decl->getInherited()),
        Protocols(getAllProtocols(Decl)),
        ProtInsertStartLoc(Decl->getNameLoc()),
        StoredProperties(Decl->getStoredProperties()),
        Range(Decl->getMembers()){};

  AddEquatableContext(ExtensionDecl *Decl)
      : DC(Decl), Adopter(Decl->getExtendedType()),
        StartLoc(Decl->getBraces().Start),
        ProtocolsLocations(Decl->getInherited()),
        Protocols(getAllProtocols(Decl->getExtendedNominal())),
        ProtInsertStartLoc(Decl->getExtendedTypeRepr()->getEndLoc()),
        StoredProperties(Decl->getExtendedNominal()->getStoredProperties()),
        Range(Decl->getMembers()){};

  AddEquatableContext() : DC(nullptr), Adopter(), ProtocolsLocations(),
  Protocols(), StoredProperties(), Range(nullptr, nullptr) {};

  static AddEquatableContext
  getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info);

  std::string getInsertionTextForProtocol();

  std::string getInsertionTextForFunction(SourceManager &SM);

  bool isValid() {
    // FIXME: Allow to generate explicit == method for declarations which already have
    // compiler-generated == method
    return StartLoc.isValid() && ProtInsertStartLoc.isValid() &&
    !conformsToEquatableProtocol() && isPropertiesListValid() &&
    isRequirementValid();
  }

  SourceLoc getStartLocForProtocolDecl() {
    if (ProtocolsLocations.empty()) {
      return ProtInsertStartLoc;
    }
    return ProtocolsLocations.back().getSourceRange().Start;
  }

  bool isMembersRangeEmpty() {
    return Range.empty();
  }

  SourceLoc getInsertStartLoc();
};

SourceLoc AddEquatableContext::
getInsertStartLoc() {
  SourceLoc MaxLoc = StartLoc;
  for (auto Mem : Range) {
    if (Mem->getEndLoc().getOpaquePointerValue() >
        MaxLoc.getOpaquePointerValue()) {
      MaxLoc = Mem->getEndLoc();
    }
  }
  return MaxLoc;
}

std::string AddEquatableContext::
getInsertionTextForProtocol() {
  StringRef ProtocolName = getProtocolName(KnownProtocolKind::Equatable);
  std::string Buffer;
  llvm::raw_string_ostream OS(Buffer);
  if (ProtocolsLocations.empty()) {
    OS << ": " << ProtocolName;
    return Buffer;
  }
  OS << ", " << ProtocolName;
  return Buffer;
}

std::string AddEquatableContext::
getInsertionTextForFunction(SourceManager &SM) {
  auto Reqs = getProtocolRequirements();
  auto Req = dyn_cast<FuncDecl>(Reqs[0]);
  auto Params = Req->getParameters();
  StringRef ExtraIndent;
  StringRef CurrentIndent =
    Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent);
  std::string Indent;
  if (isMembersRangeEmpty()) {
    Indent = (CurrentIndent + ExtraIndent).str();
  } else {
    Indent = CurrentIndent.str();
  }
  PrintOptions Options = PrintOptions::printVerbose();
  Options.PrintDocumentationComments = false;
  Options.setBaseType(Adopter);
  Options.FunctionBody = [&](const ValueDecl *VD, ASTPrinter &Printer) {
    Printer << " {";
    Printer.printNewline();
    printFunctionBody(Printer, ExtraIndent, Params);
    Printer.printNewline();
    Printer << "}";
  };
  std::string Buffer;
  llvm::raw_string_ostream OS(Buffer);
  ExtraIndentStreamPrinter Printer(OS, Indent);
  Printer.printNewline();
  if (!isMembersRangeEmpty()) {
    Printer.printNewline();
  }
  Reqs[0]->print(Printer, Options);
  return Buffer;
}

std::vector<VarDecl *> AddEquatableContext::
getUserAccessibleProperties() {
  std::vector<VarDecl *> PublicProperties;
  for (VarDecl *Decl : StoredProperties) {
    if (Decl->Decl::isUserAccessible()) {
      PublicProperties.push_back(Decl);
    }
  }
  return PublicProperties;
}

std::vector<ValueDecl *> AddEquatableContext::
getProtocolRequirements() {
  std::vector<ValueDecl *> Collection;
  auto Proto = DC->getASTContext().getProtocol(KnownProtocolKind::Equatable);
  for (auto Member : Proto->getMembers()) {
    auto Req = dyn_cast<ValueDecl>(Member);
    if (!Req || Req->isInvalid() || !Req->isProtocolRequirement()) {
      continue;
    }
    Collection.push_back(Req);
  }
  return Collection;
}

AddEquatableContext
AddEquatableContext::getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info) {
  auto ValueRefInfo = dyn_cast<ResolvedValueRefCursorInfo>(Info);
  if (!ValueRefInfo) {
    return AddEquatableContext();
  }
  if (!ValueRefInfo->isRef()) {
    if (auto *NomDecl = dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD())) {
      return AddEquatableContext(NomDecl);
    }
  } else if (auto *ExtDecl = ValueRefInfo->getExtTyRef()) {
    if (ExtDecl->getExtendedNominal()) {
      return AddEquatableContext(ExtDecl);
    }
  }
  return AddEquatableContext();
}

void AddEquatableContext::
printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent, ParameterList *Params) {
  SmallString<128> Return;
  llvm::raw_svector_ostream SS(Return);
  SS << tok::kw_return;
  StringRef Space = " ";
  StringRef AdditionalSpace = "       ";
  StringRef Point = ".";
  StringRef Join = " == ";
  StringRef And = " &&";
  auto Props = getUserAccessibleProperties();
  auto FParam = Params->get(0)->getName();
  auto SParam = Params->get(1)->getName();
  auto Prop = Props[0]->getName();
  Printer << ExtraIndent << Return << Space
  << FParam << Point << Prop << Join << SParam << Point << Prop;
  if (Props.size() > 1) {
    std::for_each(Props.begin() + 1, Props.end(), [&](VarDecl *VD){
      auto Name = VD->getName();
      Printer << And;
      Printer.printNewline();
      Printer << ExtraIndent << AdditionalSpace << FParam << Point
      << Name << Join << SParam << Point << Name;
    });
  }
}

bool RefactoringActionAddEquatableConformance::isApplicable(
    ResolvedCursorInfoPtr Tok, DiagnosticEngine &Diag) {
  return AddEquatableContext::getDeclarationContextFromInfo(Tok).isValid();
}

bool RefactoringActionAddEquatableConformance::
performChange() {
  auto Context = AddEquatableContext::getDeclarationContextFromInfo(CursorInfo);
  EditConsumer.insertAfter(SM, Context.getStartLocForProtocolDecl(),
                           Context.getInsertionTextForProtocol());
  EditConsumer.insertAfter(SM, Context.getInsertStartLoc(),
                           Context.getInsertionTextForFunction(SM));
  return false;
}

class AddCodableContext {

  /// Declaration context
  DeclContext *DC;

  /// Start location of declaration context brace
  SourceLoc StartLoc;

  /// Array of all conformed protocols
  SmallVector<swift::ProtocolDecl *, 2> Protocols;

  /// Range of internal members in declaration
  DeclRange Range;

  bool conformsToCodableProtocol() {
    for (ProtocolDecl *Protocol : Protocols) {
      if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Encodable ||
          Protocol->getKnownProtocolKind() == KnownProtocolKind::Decodable) {
        return true;
      }
    }
    return false;
  }

public:
  AddCodableContext(NominalTypeDecl *Decl)
      : DC(Decl), StartLoc(Decl->getBraces().Start),
        Protocols(getAllProtocols(Decl)), Range(Decl->getMembers()){};

  AddCodableContext(ExtensionDecl *Decl)
      : DC(Decl), StartLoc(Decl->getBraces().Start),
        Protocols(getAllProtocols(Decl->getExtendedNominal())),
        Range(Decl->getMembers()){};

  AddCodableContext() : DC(nullptr), Protocols(), Range(nullptr, nullptr){};

  static AddCodableContext
  getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info);

  void printInsertionText(ResolvedCursorInfoPtr CursorInfo, SourceManager &SM,
                          llvm::raw_ostream &OS);

  bool isValid() { return StartLoc.isValid() && conformsToCodableProtocol(); }

  SourceLoc getInsertStartLoc();
};

SourceLoc AddCodableContext::getInsertStartLoc() {
  SourceLoc MaxLoc = StartLoc;
  for (auto Mem : Range) {
    if (Mem->getEndLoc().getOpaquePointerValue() >
        MaxLoc.getOpaquePointerValue()) {
      MaxLoc = Mem->getEndLoc();
    }
  }
  return MaxLoc;
}

/// Walks an AST and prints the synthesized Codable implementation.
class SynthesizedCodablePrinter : public ASTWalker {
private:
  ASTPrinter &Printer;

public:
  SynthesizedCodablePrinter(ASTPrinter &Printer) : Printer(Printer) {}

  MacroWalking getMacroWalkingBehavior() const override {
    return MacroWalking::Arguments;
  }

  PreWalkAction walkToDeclPre(Decl *D) override {
    auto *VD = dyn_cast<ValueDecl>(D);
    if (!VD)
      return Action::SkipChildren();

    if (!VD->isSynthesized()) {
      return Action::Continue();
    }
    SmallString<32> Scratch;
    auto name = VD->getName().getString(Scratch);
    // Print all synthesized enums,
    // since Codable can synthesize multiple enums (for associated values).
    auto shouldPrint =
        isa<EnumDecl>(VD) || name == "init(from:)" || name == "encode(to:)";
    if (!shouldPrint) {
      // Some other synthesized decl that we don't want to print.
      return Action::SkipChildren();
    }

    Printer.printNewline();

    if (auto enumDecl = dyn_cast<EnumDecl>(D)) {
      // Manually print enum here, since we don't want to print synthesized
      // functions.
      Printer << "enum " << enumDecl->getNameStr();
      PrintOptions Options;
      Options.PrintSpaceBeforeInheritance = false;
      enumDecl->printInherited(Printer, Options);
      Printer << " {";
      for (Decl *EC : enumDecl->getAllElements()) {
        Printer.printNewline();
        Printer << "  ";
        EC->print(Printer, Options);
      }
      Printer.printNewline();
      Printer << "}";
      return Action::SkipChildren();
    }

    PrintOptions Options;
    Options.SynthesizeSugarOnTypes = true;
    Options.FunctionDefinitions = true;
    Options.VarInitializers = true;
    Options.PrintExprs = true;
    Options.TypeDefinitions = true;
    Options.ExcludeAttrList.push_back(DAK_HasInitialValue);

    Printer.printNewline();
    D->print(Printer, Options);

    return Action::SkipChildren();
  }
};

void AddCodableContext::printInsertionText(ResolvedCursorInfoPtr CursorInfo,
                                           SourceManager &SM,
                                           llvm::raw_ostream &OS) {
  StringRef ExtraIndent;
  StringRef CurrentIndent =
      Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent);
  std::string Indent;
  if (getInsertStartLoc() == StartLoc) {
    Indent = (CurrentIndent + ExtraIndent).str();
  } else {
    Indent = CurrentIndent.str();
  }

  ExtraIndentStreamPrinter Printer(OS, Indent);
  Printer.printNewline();
  SynthesizedCodablePrinter Walker(Printer);
  DC->getAsDecl()->walk(Walker);
}

AddCodableContext
AddCodableContext::getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info) {
  auto ValueRefInfo = dyn_cast<ResolvedValueRefCursorInfo>(Info);
  if (!ValueRefInfo) {
    return AddCodableContext();
  }
  if (!ValueRefInfo->isRef()) {
    if (auto *NomDecl = dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD())) {
      return AddCodableContext(NomDecl);
    }
  }
  // TODO: support extensions
  // (would need to get synthesized nodes from the main decl,
  // and only if it's in the same file?)
  return AddCodableContext();
}

bool RefactoringActionAddExplicitCodableImplementation::isApplicable(
    ResolvedCursorInfoPtr Tok, DiagnosticEngine &Diag) {
  return AddCodableContext::getDeclarationContextFromInfo(Tok).isValid();
}

bool RefactoringActionAddExplicitCodableImplementation::performChange() {
  auto Context = AddCodableContext::getDeclarationContextFromInfo(CursorInfo);

  SmallString<64> Buffer;
  llvm::raw_svector_ostream OS(Buffer);
  Context.printInsertionText(CursorInfo, SM, OS);

  EditConsumer.insertAfter(SM, Context.getInsertStartLoc(), OS.str());
  return false;
}

static CharSourceRange
findSourceRangeToWrapInCatch(const ResolvedExprStartCursorInfo &CursorInfo,
                             SourceFile *TheFile, SourceManager &SM) {
  Expr *E = CursorInfo.getTrailingExpr();
  if (!E)
    return CharSourceRange();
  auto Node = ASTNode(E);
  auto NodeChecker = [](ASTNode N) { return N.isStmt(StmtKind::Brace); };
  ContextFinder Finder(*TheFile, Node, NodeChecker);
  Finder.resolve();
  auto Contexts = Finder.getContexts();
  if (Contexts.empty())
    return CharSourceRange();
  auto TargetNode = Contexts.back();
  BraceStmt *BStmt = dyn_cast<BraceStmt>(TargetNode.dyn_cast<Stmt*>());
  auto ConvertToCharRange = [&SM](SourceRange SR) {
    return Lexer::getCharSourceRangeFromSourceRange(SM, SR);
  };
  assert(BStmt);
  auto ExprRange = ConvertToCharRange(E->getSourceRange());
  // Check elements of the deepest BraceStmt, pick one that covers expression.
  for (auto Elem: BStmt->getElements()) {
    auto ElemRange = ConvertToCharRange(Elem.getSourceRange());
    if (ElemRange.contains(ExprRange))
      TargetNode = Elem;
  }
  return ConvertToCharRange(TargetNode.getSourceRange());
}

bool RefactoringActionConvertToDoCatch::isApplicable(ResolvedCursorInfoPtr Tok,
                                                     DiagnosticEngine &Diag) {
  auto ExprStartInfo = dyn_cast<ResolvedExprStartCursorInfo>(Tok);
  if (!ExprStartInfo || !ExprStartInfo->getTrailingExpr())
    return false;
  return isa<ForceTryExpr>(ExprStartInfo->getTrailingExpr());
}

bool RefactoringActionConvertToDoCatch::performChange() {
  auto ExprStartInfo = cast<ResolvedExprStartCursorInfo>(CursorInfo);
  auto *TryExpr = dyn_cast<ForceTryExpr>(ExprStartInfo->getTrailingExpr());
  assert(TryExpr);
  auto Range = findSourceRangeToWrapInCatch(*ExprStartInfo, TheFile, SM);
  if (!Range.isValid())
    return true;
  // Wrap given range in do catch block.
  EditConsumer.accept(SM, Range.getStart(), "do {\n");
  EditorConsumerInsertStream OS(EditConsumer, SM, Range.getEnd());
  OS << "\n} catch {\n" << getCodePlaceholder() << "\n}";

  // Delete ! from try! expression
  auto ExclaimLen = getKeywordLen(tok::exclaim_postfix);
  auto ExclaimRange = CharSourceRange(TryExpr->getExclaimLoc(), ExclaimLen);
  EditConsumer.remove(SM, ExclaimRange);
  return false;
}

/// Given a cursor position, this function tries to collect a number literal
/// expression immediately following the cursor.
static NumberLiteralExpr *getTrailingNumberLiteral(ResolvedCursorInfoPtr Tok) {
  // This cursor must point to the start of an expression.
  auto ExprStartInfo = dyn_cast<ResolvedExprStartCursorInfo>(Tok);
  if (!ExprStartInfo)
    return nullptr;

  // For every sub-expression, try to find the literal expression that matches
  // our criteria.
  class FindLiteralNumber : public ASTWalker {
    Expr * const parent;

  public:
    NumberLiteralExpr *found = nullptr;

    explicit FindLiteralNumber(Expr *parent) : parent(parent) { }

    MacroWalking getMacroWalkingBehavior() const override {
      return MacroWalking::Arguments;
    }

    PreWalkResult<Expr *> walkToExprPre(Expr *expr) override {
      if (auto *literal = dyn_cast<NumberLiteralExpr>(expr)) {
        // The sub-expression must have the same start loc with the outermost
        // expression, i.e. the cursor position.
        if (!found &&
            parent->getStartLoc().getOpaquePointerValue() ==
              expr->getStartLoc().getOpaquePointerValue()) {
          found = literal;
        }
      }
      return Action::SkipChildrenIf(found, expr);
    }
  };

  auto parent = ExprStartInfo->getTrailingExpr();
  FindLiteralNumber finder(parent);
  parent->walk(finder);
  return finder.found;
}

static std::string insertUnderscore(StringRef Text) {
  SmallString<64> Buffer;
  llvm::raw_svector_ostream OS(Buffer);
  for (auto It = Text.begin(); It != Text.end(); ++It) {
    unsigned Distance = It - Text.begin();
    if (Distance && !(Distance % 3)) {
      OS << '_';
    }
    OS << *It;
  }
  return OS.str().str();
}

void insertUnderscoreInDigits(StringRef Digits,
                              raw_ostream &OS) {
  StringRef BeforePointRef, AfterPointRef;
  std::tie(BeforePointRef, AfterPointRef) = Digits.split('.');

  std::string BeforePoint(BeforePointRef);
  std::string AfterPoint(AfterPointRef);

  // Insert '_' for the part before the decimal point.
  std::reverse(BeforePoint.begin(), BeforePoint.end());
  BeforePoint = insertUnderscore(BeforePoint);
  std::reverse(BeforePoint.begin(), BeforePoint.end());
  OS << BeforePoint;

  // Insert '_' for the part after the decimal point, if necessary.
  if (!AfterPoint.empty()) {
    OS << '.';
    OS << insertUnderscore(AfterPoint);
  }
}

bool RefactoringActionSimplifyNumberLiteral::isApplicable(
    ResolvedCursorInfoPtr Tok, DiagnosticEngine &Diag) {
  if (auto *Literal = getTrailingNumberLiteral(Tok)) {
    SmallString<64> Buffer;
    llvm::raw_svector_ostream OS(Buffer);
    StringRef Digits = Literal->getDigitsText();
    insertUnderscoreInDigits(Digits, OS);

    // If inserting '_' results in a different digit sequence, this refactoring
    // is applicable.
    return OS.str() != Digits;
  }
  return false;
}

bool RefactoringActionSimplifyNumberLiteral::performChange() {
  if (auto *Literal = getTrailingNumberLiteral(CursorInfo)) {

    EditorConsumerInsertStream OS(EditConsumer, SM,
                                  CharSourceRange(SM, Literal->getDigitsLoc(),
                                  Lexer::getLocForEndOfToken(SM,
                                    Literal->getEndLoc())));
    StringRef Digits = Literal->getDigitsText();
    insertUnderscoreInDigits(Digits, OS);
    return false;
  }
  return true;
}

static CallExpr *findTrailingClosureTarget(SourceManager &SM,
                                           ResolvedCursorInfoPtr CursorInfo) {
  if (CursorInfo->getKind() == CursorInfoKind::StmtStart)
    // StmtStart postion can't be a part of CallExpr.
    return nullptr;

  // Find inner most CallExpr
  ContextFinder Finder(
      *CursorInfo->getSourceFile(), CursorInfo->getLoc(), [](ASTNode N) {
        return N.isStmt(StmtKind::Brace) || N.isExpr(ExprKind::Call);
      });
  Finder.resolve();
  auto contexts = Finder.getContexts();
  if (contexts.empty())
    return nullptr;

  // If the innermost context is a statement (which will be a BraceStmt per
  // the filtering condition above), drop it.
  if (contexts.back().is<Stmt *>()) {
    contexts = contexts.drop_back();
  }

  if (contexts.empty() || !contexts.back().is<Expr*>())
    return nullptr;
  CallExpr *CE = cast<CallExpr>(contexts.back().get<Expr*>());

  // The last argument is a non-trailing closure?
  auto *Args = CE->getArgs();
  if (Args->empty() || Args->hasAnyTrailingClosures())
    return nullptr;

  auto *LastArg = Args->back().getExpr();
  if (auto *ICE = dyn_cast<ImplicitConversionExpr>(LastArg))
    LastArg = ICE->getSyntacticSubExpr();

  if (isa<ClosureExpr>(LastArg) || isa<CaptureListExpr>(LastArg))
    return CE;
  return nullptr;
}

bool RefactoringActionTrailingClosure::isApplicable(
    ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &Diag) {
  SourceManager &SM = CursorInfo->getSourceFile()->getASTContext().SourceMgr;
  return findTrailingClosureTarget(SM, CursorInfo);
}

bool RefactoringActionTrailingClosure::performChange() {
  auto *CE = findTrailingClosureTarget(SM, CursorInfo);
  if (!CE)
    return true;

  auto *ArgList = CE->getArgs()->getOriginalArgs();
  auto LParenLoc = ArgList->getLParenLoc();
  auto RParenLoc = ArgList->getRParenLoc();

  if (LParenLoc.isInvalid() || RParenLoc.isInvalid())
    return true;

  auto NumArgs = ArgList->size();
  if (NumArgs == 0)
    return true;

  auto *ClosureArg = ArgList->getExpr(NumArgs - 1);
  if (auto *ICE = dyn_cast<ImplicitConversionExpr>(ClosureArg))
    ClosureArg = ICE->getSyntacticSubExpr();

  // Replace:
  //   * Open paren with ' ' if the closure is sole argument.
  //   * Comma with ') ' otherwise.
  if (NumArgs > 1) {
    auto *PrevArg = ArgList->getExpr(NumArgs - 2);
    CharSourceRange PreRange(
        SM,
        Lexer::getLocForEndOfToken(SM, PrevArg->getEndLoc()),
        ClosureArg->getStartLoc());
    EditConsumer.accept(SM, PreRange, ") ");
  } else {
    CharSourceRange PreRange(SM, LParenLoc, ClosureArg->getStartLoc());
    EditConsumer.accept(SM, PreRange, " ");
  }
  // Remove original closing paren.
  CharSourceRange PostRange(
      SM,
      Lexer::getLocForEndOfToken(SM, ClosureArg->getEndLoc()),
      Lexer::getLocForEndOfToken(SM, RParenLoc));
  EditConsumer.remove(SM, PostRange);
  return false;
}

static bool collectRangeStartRefactorings(const ResolvedRangeInfo &Info) {
  switch (Info.Kind) {
  case RangeKind::SingleExpression:
  case RangeKind::SingleStatement:
  case RangeKind::SingleDecl:
  case RangeKind::PartOfExpression:
    return true;
  case RangeKind::MultiStatement:
  case RangeKind::MultiTypeMemberDecl:
  case RangeKind::Invalid:
    return false;
  }
}
    
bool RefactoringActionConvertToComputedProperty::
isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) {
  if (Info.Kind != RangeKind::SingleDecl) {
    return false;
  }
  
  if (Info.ContainedNodes.size() != 1) {
    return false;
  }
  
  auto D = Info.ContainedNodes[0].dyn_cast<Decl*>();
  if (!D) {
    return false;
  }
  
  auto Binding = dyn_cast<PatternBindingDecl>(D);
  if (!Binding) {
    return false;
  }

  auto SV = Binding->getSingleVar();
  if (!SV) {
    return false;
  }

  // willSet, didSet cannot be provided together with a getter
  for (auto AD : SV->getAllAccessors()) {
    if (AD->isObservingAccessor()) {
      return false;
    }
  }
  
  // 'lazy' must not be used on a computed property
  // NSCopying and IBOutlet attribute requires property to be mutable
  auto Attributies = SV->getAttrs();
  if (Attributies.hasAttribute<LazyAttr>() ||
      Attributies.hasAttribute<NSCopyingAttr>() ||
      Attributies.hasAttribute<IBOutletAttr>()) {
    return false;
  }

  // Property wrapper cannot be applied to a computed property
  if (SV->hasAttachedPropertyWrapper()) {
    return false;
  }

  // has an initializer
  return Binding->hasInitStringRepresentation(0);
}

bool RefactoringActionConvertToComputedProperty::performChange() {
  // Get an initialization
  auto D = RangeInfo.ContainedNodes[0].dyn_cast<Decl*>();
  auto Binding = dyn_cast<PatternBindingDecl>(D);
  SmallString<128> scratch;
  auto Init = Binding->getInitStringRepresentation(0, scratch);
  
  // Get type
  auto SV = Binding->getSingleVar();
  auto SVType = SV->getType();
  auto TR = SV->getTypeReprOrParentPatternTypeRepr();
  
  SmallString<64> DeclBuffer;
  llvm::raw_svector_ostream OS(DeclBuffer);
  StringRef Space = " ";
  StringRef NewLine = "\n";
  
  OS << tok::kw_var << Space;
  // Add var name
  OS << SV->getNameStr().str() << ":" << Space;
  // For computed property must write a type of var
  if (TR) {
    OS << Lexer::getCharSourceRangeFromSourceRange(SM, TR->getSourceRange()).str();
  } else {
    SVType.print(OS);
  }

  OS << Space << tok::l_brace << NewLine;
  // Add an initialization
  OS << tok::kw_return << Space << Init.str() << NewLine;
  OS << tok::r_brace;
  
  // Replace initializer to computed property
  auto ReplaceStartLoc = Binding->getLoc();
  auto ReplaceEndLoc = Binding->getSourceRange().End;
  auto ReplaceRange = SourceRange(ReplaceStartLoc, ReplaceEndLoc);
  auto ReplaceCharSourceRange = Lexer::getCharSourceRangeFromSourceRange(SM, ReplaceRange);
  EditConsumer.accept(SM, ReplaceCharSourceRange, DeclBuffer.str());
  return false; // success
}

namespace asyncrefactorings {

// TODO: Should probably split the refactorings into separate files

/// Whether the given type is (or conforms to) the stdlib Error type
bool isErrorType(Type Ty, ModuleDecl *MD) {
  if (!Ty)
    return false;
  return !MD->conformsToProtocol(Ty, Ty->getASTContext().getErrorDecl())
              .isInvalid();
}

// The single Decl* subject of a switch statement, or nullptr if none
Decl *singleSwitchSubject(const SwitchStmt *Switch) {
  if (auto *DRE = dyn_cast<DeclRefExpr>(Switch->getSubjectExpr()))
    return DRE->getDecl();
  return nullptr;
}

/// A more aggressive variant of \c Expr::getReferencedDecl that also looks
/// through autoclosures created to pass the \c self parameter to a member funcs
ValueDecl *getReferencedDecl(const Expr *Fn) {
  Fn = Fn->getSemanticsProvidingExpr();
  if (auto *DRE = dyn_cast<DeclRefExpr>(Fn))
    return DRE->getDecl();
  if (auto ApplyE = dyn_cast<SelfApplyExpr>(Fn))
    return getReferencedDecl(ApplyE->getFn());
  if (auto *ACE = dyn_cast<AutoClosureExpr>(Fn)) {
    if (auto *Unwrapped = ACE->getUnwrappedCurryThunkExpr())
      return getReferencedDecl(Unwrapped);
  }
  return nullptr;
}

FuncDecl *getUnderlyingFunc(const Expr *Fn) {
  return dyn_cast_or_null<FuncDecl>(getReferencedDecl(Fn));
}

/// Find the outermost call of the given location
CallExpr *findOuterCall(ResolvedCursorInfoPtr CursorInfo) {
  auto IncludeInContext = [](ASTNode N) {
    if (auto *E = N.dyn_cast<Expr *>())
      return !E->isImplicit();
    return false;
  };

  // TODO: Bit pointless using the "ContextFinder" here. Ideally we would have
  //       already generated a slice of the AST for anything that contains
  //       the cursor location
  ContextFinder Finder(*CursorInfo->getSourceFile(), CursorInfo->getLoc(),
                       IncludeInContext);
  Finder.resolve();
  auto Contexts = Finder.getContexts();
  if (Contexts.empty())
    return nullptr;

  CallExpr *CE = dyn_cast<CallExpr>(Contexts[0].get<Expr *>());
  if (!CE)
    return nullptr;

  SourceManager &SM = CursorInfo->getSourceFile()->getASTContext().SourceMgr;
  if (!SM.rangeContains(CE->getFn()->getSourceRange(), CursorInfo->getLoc()))
    return nullptr;
  return CE;
}

/// Find the function matching the given location if it is not an accessor and
/// either has a body or is a member of a protocol
FuncDecl *findFunction(ResolvedCursorInfoPtr CursorInfo) {
  auto IncludeInContext = [](ASTNode N) {
    if (auto *D = N.dyn_cast<Decl *>())
      return !D->isImplicit();
    return false;
  };

  ContextFinder Finder(*CursorInfo->getSourceFile(), CursorInfo->getLoc(),
                       IncludeInContext);
  Finder.resolve();

  auto Contexts = Finder.getContexts();
  if (Contexts.empty())
    return nullptr;

  if (Contexts.back().isDecl(DeclKind::Param))
    Contexts = Contexts.drop_back();

  auto *FD = dyn_cast_or_null<FuncDecl>(Contexts.back().get<Decl *>());
  if (!FD || isa<AccessorDecl>(FD))
    return nullptr;

  auto *Body = FD->getBody();
  if (!Body && !isa<ProtocolDecl>(FD->getDeclContext()))
    return nullptr;

  SourceManager &SM = CursorInfo->getSourceFile()->getASTContext().SourceMgr;
  SourceLoc DeclEnd = Body ? Body->getLBraceLoc() : FD->getEndLoc();
  if (!SM.rangeContains(SourceRange(FD->getStartLoc(), DeclEnd),
                        CursorInfo->getLoc()))
    return nullptr;

  return FD;
}

FuncDecl *isOperator(const BinaryExpr *BE) {
  auto *AE = dyn_cast<ApplyExpr>(BE->getFn());
  if (AE) {
    auto *Callee = AE->getCalledValue();
    if (Callee && Callee->isOperator() && isa<FuncDecl>(Callee))
      return cast<FuncDecl>(Callee);
  }
  return nullptr;
}

/// Describes the expressions to be kept from a call to the handler in a
/// function that has (or will have ) and async alternative. Eg.
/// ```
/// func toBeAsync(completion: (String?, Error?) -> Void) {
///   ...
///   completion("something", nil) // Result = ["something"], IsError = false
///   ...
///   completion(nil, MyError.Bad) // Result = [MyError.Bad], IsError = true
/// }
class HandlerResult {
  SmallVector<Argument, 2> Args;
  bool IsError = false;

public:
  HandlerResult() {}

  HandlerResult(ArrayRef<Argument> ArgsRef)
      : Args(ArgsRef.begin(), ArgsRef.end()) {}

  HandlerResult(Argument Arg, bool IsError) : IsError(IsError) {
    Args.push_back(Arg);
  }

  bool isError() { return IsError; }

  ArrayRef<Argument> args() { return Args; }
};

/// The type of the handler, ie. whether it takes regular parameters or a
/// single parameter of `Result` type.
enum class HandlerType { INVALID, PARAMS, RESULT };

/// A single return type of a refactored async function. If the async function
/// returns a tuple, each element of the tuple (represented by a \c
/// LabeledReturnType) might have a label, otherwise the \p Label is empty.
struct LabeledReturnType {
  Identifier Label;
  swift::Type Ty;

  LabeledReturnType(Identifier Label, swift::Type Ty) : Label(Label), Ty(Ty) {}
};

/// Given a function with an async alternative (or one that *could* have an
/// async alternative), stores information about the completion handler.
/// The completion handler can be either a variable (which includes a parameter)
/// or a function
struct AsyncHandlerDesc {
  PointerUnion<const VarDecl *, const AbstractFunctionDecl *> Handler = nullptr;
  HandlerType Type = HandlerType::INVALID;
  bool HasError = false;

  static AsyncHandlerDesc get(const ValueDecl *Handler, bool RequireName) {
    AsyncHandlerDesc HandlerDesc;
    if (auto Var = dyn_cast<VarDecl>(Handler)) {
      HandlerDesc.Handler = Var;
    } else if (auto Func = dyn_cast<AbstractFunctionDecl>(Handler)) {
      HandlerDesc.Handler = Func;
    } else {
      // The handler must be a variable or function
      return AsyncHandlerDesc();
    }

    // Callback must have a completion-like name
    if (RequireName && !isCompletionHandlerParamName(HandlerDesc.getNameStr()))
      return AsyncHandlerDesc();

    // Callback must be a function type and return void. Doesn't need to have
    // any parameters - may just be a "I'm done" callback
    auto *HandlerTy = HandlerDesc.getType()->getAs<AnyFunctionType>();
    if (!HandlerTy || !HandlerTy->getResult()->isVoid())
      return AsyncHandlerDesc();

    // Find the type of result in the handler (eg. whether it's a Result<...>,
    // just parameters, or nothing).
    auto HandlerParams = HandlerTy->getParams();
    if (HandlerParams.size() == 1) {
      auto ParamTy =
          HandlerParams.back().getPlainType()->getAs<BoundGenericType>();
      if (ParamTy && ParamTy->isResult()) {
        auto GenericArgs = ParamTy->getGenericArgs();
        assert(GenericArgs.size() == 2 && "Result should have two params");
        HandlerDesc.Type = HandlerType::RESULT;
        HandlerDesc.HasError = !GenericArgs.back()->isUninhabited();
      }
    }

    if (HandlerDesc.Type != HandlerType::RESULT) {
      // Only handle non-result parameters
      for (auto &Param : HandlerParams) {
        if (Param.getPlainType() && Param.getPlainType()->isResult())
          return AsyncHandlerDesc();
      }

      HandlerDesc.Type = HandlerType::PARAMS;
      if (!HandlerParams.empty()) {
        auto LastParamTy = HandlerParams.back().getParameterType();
        HandlerDesc.HasError = isErrorType(LastParamTy->getOptionalObjectType(),
                                           Handler->getModuleContext());
      }
    }

    return HandlerDesc;
  }

  bool isValid() const { return Type != HandlerType::INVALID; }

  /// Return the declaration of the completion handler as a \c ValueDecl.
  /// In practice, the handler will always be a \c VarDecl or \c
  /// AbstractFunctionDecl.
  /// \c getNameStr and \c getType provide access functions that are available
  /// for both variables and functions, but not on \c ValueDecls.
  const ValueDecl *getHandler() const {
    if (!Handler) {
      return nullptr;
    }
    if (auto Var = Handler.dyn_cast<const VarDecl *>()) {
      return Var;
    } else if (auto Func = Handler.dyn_cast<const AbstractFunctionDecl *>()) {
      return Func;
    } else {
      llvm_unreachable("Unknown handler type");
    }
  }

  /// Return the name of the completion handler. If it is a variable, the
  /// variable name, if it's a function, the function base name.
  StringRef getNameStr() const {
    if (auto Var = Handler.dyn_cast<const VarDecl *>()) {
      return Var->getNameStr();
    } else if (auto Func = Handler.dyn_cast<const AbstractFunctionDecl *>()) {
      return Func->getNameStr();
    } else {
      llvm_unreachable("Unknown handler type");
    }
  }

  HandlerType getHandlerType() const { return Type; }

  /// Get the type of the completion handler.
  swift::Type getType() const {
    if (auto Var = Handler.dyn_cast<const VarDecl *>()) {
      return Var->getType();
    } else if (auto Func = Handler.dyn_cast<const AbstractFunctionDecl *>()) {
      auto Type = Func->getInterfaceType();
      // Undo the self curry thunk if we are referencing a member function.
      if (Func->hasImplicitSelfDecl()) {
        assert(Type->is<AnyFunctionType>());
        Type = Type->getAs<AnyFunctionType>()->getResult();
      }
      return Type;
    } else {
      llvm_unreachable("Unknown handler type");
    }
  }

  ArrayRef<AnyFunctionType::Param> params() const {
    auto Ty = getType()->getAs<AnyFunctionType>();
    assert(Ty && "Type must be a function type");
    return Ty->getParams();
  }

  /// Retrieve the parameters relevant to a successful return from the
  /// completion handler. This drops the Error parameter if present.
  ArrayRef<AnyFunctionType::Param> getSuccessParams() const {
    if (HasError && Type == HandlerType::PARAMS)
      return params().drop_back();
    return params();
  }

  /// If the completion handler has an Error parameter, return it.
  Optional<AnyFunctionType::Param> getErrorParam() const {
    if (HasError && Type == HandlerType::PARAMS)
      return params().back();
    return None;
  }

  /// Get the type of the error that will be thrown by the \c async method or \c
  /// None if the completion handler doesn't accept an error parameter.
  /// This may be more specialized than the generic 'Error' type if the
  /// completion handler of the converted function takes a more specialized
  /// error type.
  Optional<swift::Type> getErrorType() const {
    if (HasError) {
      switch (Type) {
      case HandlerType::INVALID:
        return None;
      case HandlerType::PARAMS:
        // The last parameter of the completion handler is the error param
        return params().back().getPlainType()->lookThroughSingleOptionalType();
      case HandlerType::RESULT:
        assert(
            params().size() == 1 &&
            "Result handler should have the Result type as the only parameter");
        auto ResultType =
            params().back().getPlainType()->getAs<BoundGenericType>();
        auto GenericArgs = ResultType->getGenericArgs();
        assert(GenericArgs.size() == 2 && "Result should have two params");
        // The second (last) generic parameter of the Result type is the error
        // type.
        return GenericArgs.back();
      }
    } else {
      return None;
    }
  }

  /// The `CallExpr` if the given node is a call to the `Handler`
  CallExpr *getAsHandlerCall(ASTNode Node) const {
    if (!isValid())
      return nullptr;

    if (auto E = Node.dyn_cast<Expr *>()) {
      if (auto *CE = dyn_cast<CallExpr>(E->getSemanticsProvidingExpr())) {
        if (CE->getFn()->getReferencedDecl().getDecl() == getHandler()) {
          return CE;
        }
      }
    }
    return nullptr;
  }

  /// Returns \c true if the call to the completion handler contains possibly
  /// non-nil values for both the success and error parameters, e.g.
  /// \code
  ///   completion(result, error)
  /// \endcode
  /// This can only happen if the completion handler is a params handler.
  bool isAmbiguousCallToParamHandler(const CallExpr *CE) const {
    if (!HasError || Type != HandlerType::PARAMS) {
      // Only param handlers with an error can pass both an error AND a result.
      return false;
    }
    auto Args = CE->getArgs()->getArgExprs();
    if (!isa<NilLiteralExpr>(Args.back())) {
      // We've got an error parameter. If any of the success params is not nil,
      // the call is ambiguous.
      for (auto &Arg : Args.drop_back()) {
        if (!isa<NilLiteralExpr>(Arg)) {
          return true;
        }
      }
    }
    return false;
  }

  /// Given a call to the `Handler`, extract the expressions to be returned or
  /// thrown, taking care to remove the `.success`/`.failure` if it's a
  /// `RESULT` handler type.
  /// If the call is ambiguous (contains potentially non-nil arguments to both
  /// the result and the error parameters), the \p ReturnErrorArgsIfAmbiguous
  /// determines whether the success or error parameters are passed.
  HandlerResult extractResultArgs(const CallExpr *CE,
                                  bool ReturnErrorArgsIfAmbiguous) const {
    auto *ArgList = CE->getArgs();
    SmallVector<Argument, 2> Scratch(ArgList->begin(), ArgList->end());
    auto Args = llvm::makeArrayRef(Scratch);

    if (Type == HandlerType::PARAMS) {
      bool IsErrorResult;
      if (isAmbiguousCallToParamHandler(CE)) {
        IsErrorResult = ReturnErrorArgsIfAmbiguous;
      } else {
        // If there's an error parameter and the user isn't passing nil to it,
        // assume this is the error path.
        IsErrorResult =
            (HasError && !isa<NilLiteralExpr>(Args.back().getExpr()));
      }
      if (IsErrorResult)
        return HandlerResult(Args.back(), true);

      // We can drop the args altogether if they're just Void.
      if (willAsyncReturnVoid())
        return HandlerResult();

      return HandlerResult(HasError ? Args.drop_back() : Args);
    } else if (Type == HandlerType::RESULT) {
      if (Args.size() != 1)
        return HandlerResult(Args);

      auto *ResultCE = dyn_cast<CallExpr>(Args[0].getExpr());
      if (!ResultCE)
        return HandlerResult(Args);

      auto *DSC = dyn_cast<DotSyntaxCallExpr>(ResultCE->getFn());
      if (!DSC)
        return HandlerResult(Args);

      auto *D = dyn_cast<EnumElementDecl>(
          DSC->getFn()->getReferencedDecl().getDecl());
      if (!D)
        return HandlerResult(Args);

      auto ResultArgList = ResultCE->getArgs();
      auto isFailure = D->getNameStr() == StringRef("failure");

      // We can drop the arg altogether if it's just Void.
      if (!isFailure && willAsyncReturnVoid())
        return HandlerResult();

      // Otherwise the arg gets the .success() or .failure() call dropped.
      return HandlerResult(ResultArgList->get(0), isFailure);
    }

    llvm_unreachable("Unhandled result type");
  }

  // Convert the type of a success parameter in the completion handler function
  // to a return type suitable for an async function. If there is an error
  // parameter present e.g (T?, Error?) -> Void, this unwraps a level of
  // optionality from T?. If this is a Result<T, U> type, returns the success
  // type T.
  swift::Type getSuccessParamAsyncReturnType(swift::Type Ty) const {
    switch (Type) {
    case HandlerType::PARAMS: {
      // If there's an Error parameter in the handler, the success branch can
      // be unwrapped.
      if (HasError)
        Ty = Ty->lookThroughSingleOptionalType();

      return Ty;
    }
    case HandlerType::RESULT: {
      // Result<T, U> maps to T.
      return Ty->castTo<BoundGenericType>()->getGenericArgs()[0];
    }
    case HandlerType::INVALID:
      llvm_unreachable("Invalid handler type");
    }
  }

  /// If the async function returns a tuple, the label of the \p Index -th
  /// element in the returned tuple. If the function doesn't return a tuple or
  /// the element is unlabeled, an empty identifier is returned.
  Identifier getAsyncReturnTypeLabel(size_t Index) const {
    assert(Index < getSuccessParams().size());
    if (getSuccessParams().size() <= 1) {
      // There can't be any labels if the async function doesn't return a tuple.
      return Identifier();
    } else {
      return getSuccessParams()[Index].getInternalLabel();
    }
  }

  /// Gets the return value types for the async equivalent of this handler.
  ArrayRef<LabeledReturnType>
  getAsyncReturnTypes(SmallVectorImpl<LabeledReturnType> &Scratch) const {
    for (size_t I = 0; I < getSuccessParams().size(); ++I) {
      auto Ty = getSuccessParams()[I].getParameterType();
      Scratch.emplace_back(getAsyncReturnTypeLabel(I),
                           getSuccessParamAsyncReturnType(Ty));
    }
    return Scratch;
  }

  /// Whether the async equivalent of this handler returns Void.
  bool willAsyncReturnVoid() const {
    // If all of the success params will be converted to Void return types,
    // this will be a Void async function.
    return llvm::all_of(getSuccessParams(), [&](auto &param) {
      auto Ty = param.getParameterType();
      return getSuccessParamAsyncReturnType(Ty)->isVoid();
    });
  }

  // TODO: If we have an async alternative we should check its result types
  //       for whether to unwrap or not
  bool shouldUnwrap(swift::Type Ty) const {
    return HasError && Ty->isOptional();
  }
};

/// Given a completion handler that is part of a function signature, stores
/// information about that completion handler and its index within the function
/// declaration.
struct AsyncHandlerParamDesc : public AsyncHandlerDesc {
  /// Enum to represent the position of the completion handler param within
  /// the parameter list. Given `(A, B, C, D)`:
  ///   - A is `First`
  ///   - B and C are `Middle`
  ///   - D is `Last`
  /// The position is `Only` if there's a single parameter that is the
  /// completion handler and `None` if there is no handler.
  enum class Position {
    First, Middle, Last, Only, None
  };

  /// The function the completion handler is a parameter of.
  const FuncDecl *Func = nullptr;
  /// The index of the completion handler in the function that declares it.
  unsigned Index = 0;
  /// The async alternative, if one is found.
  const AbstractFunctionDecl *Alternative = nullptr;

  AsyncHandlerParamDesc() : AsyncHandlerDesc() {}
  AsyncHandlerParamDesc(const AsyncHandlerDesc &Handler, const FuncDecl *Func,
                        unsigned Index,
                        const AbstractFunctionDecl *Alternative)
      : AsyncHandlerDesc(Handler), Func(Func), Index(Index),
        Alternative(Alternative) {}

  static AsyncHandlerParamDesc find(const FuncDecl *FD,
                                    bool RequireAttributeOrName) {
    if (!FD || FD->hasAsync() || FD->hasThrows() ||
        !FD->getResultInterfaceType()->isVoid())
      return AsyncHandlerParamDesc();

    const auto *Alternative = FD->getAsyncAlternative();
    Optional<unsigned> Index =
        FD->findPotentialCompletionHandlerParam(Alternative);
    if (!Index)
      return AsyncHandlerParamDesc();

    bool RequireName = RequireAttributeOrName && !Alternative;
    return AsyncHandlerParamDesc(
        AsyncHandlerDesc::get(FD->getParameters()->get(*Index), RequireName),
        FD, *Index, Alternative);
  }

  /// Build an @available attribute with the name of the async alternative as
  /// the \c renamed argument, followed by a newline.
  SmallString<128> buildRenamedAttribute() const {
    SmallString<128> AvailabilityAttr;
    llvm::raw_svector_ostream OS(AvailabilityAttr);

    // If there's an alternative then there must already be an attribute,
    // don't add another.
    if (!isValid() || Alternative)
      return AvailabilityAttr;

    DeclName Name = Func->getName();
    OS << "@available(*, renamed: \"" << Name.getBaseName() << "(";
    ArrayRef<Identifier> ArgNames = Name.getArgumentNames();
    for (size_t I = 0; I < ArgNames.size(); ++I) {
      if (I != Index) {
        OS << ArgNames[I] << tok::colon;
      }
    }
    OS << ")\")\n";

    return AvailabilityAttr;
  }

  /// Retrieves the parameter decl for the completion handler parameter, or
  /// \c nullptr if no valid completion parameter is present.
  const ParamDecl *getHandlerParam() const {
    if (!isValid())
      return nullptr;
    return cast<ParamDecl>(getHandler());
  }

  /// See \c Position
  Position handlerParamPosition() const {
    if (!isValid())
      return Position::None;
    const auto *Params = Func->getParameters();
    if (Params->size() == 1)
      return Position::Only;
    if (Index == 0)
      return Position::First;
    if (Index == Params->size() - 1)
      return Position::Last;
    return Position::Middle;
  }

  bool operator==(const AsyncHandlerParamDesc &Other) const {
    return Handler == Other.Handler && Type == Other.Type &&
           HasError == Other.HasError && Index == Other.Index;
  }

  bool alternativeIsAccessor() const {
    return isa_and_nonnull<AccessorDecl>(Alternative);
  }
};

/// The type of a condition in a conditional statement.
enum class ConditionType {
  NIL,             // == nil
  NOT_NIL,         // != nil
  IS_TRUE,         // if b
  IS_FALSE,        // if !b
  SUCCESS_PATTERN, // case .success
  FAILURE_PATTEN   // case .failure
};

/// Indicates whether a condition describes a success or failure path. For
/// example, a check for whether an error parameter is present is a failure
/// path. A check for a nil error parameter is a success path. This is distinct
/// from ConditionType, as it relies on contextual information about what values
/// need to be checked for success or failure.
enum class ConditionPath { SUCCESS, FAILURE };

static ConditionPath flippedConditionPath(ConditionPath Path) {
  switch (Path) {
  case ConditionPath::SUCCESS:
    return ConditionPath::FAILURE;
  case ConditionPath::FAILURE:
    return ConditionPath::SUCCESS;
  }
  llvm_unreachable("Unhandled case in switch!");
}

/// Finds the `Subject` being compared to in various conditions. Also finds any
/// pattern that may have a bound name.
struct CallbackCondition {
  Optional<ConditionType> Type;
  const Decl *Subject = nullptr;
  const Pattern *BindPattern = nullptr;

  /// Initializes a `CallbackCondition` with a `!=` or `==` comparison of
  /// an `Optional` typed `Subject` to `nil`, or a `Bool` typed `Subject` to a
  /// boolean literal, ie.
  ///   - `<Subject> != nil`
  ///   - `<Subject> == nil`
  ///   - `<Subject> != true`
  ///   - `<Subject> == false`
  CallbackCondition(const BinaryExpr *BE, const FuncDecl *Operator) {
    bool FoundNil = false;
    BooleanLiteralExpr *FoundBool = nullptr;
    bool DidUnwrapOptional = false;

    for (auto *Operand : {BE->getLHS(), BE->getRHS()}) {
      Operand = Operand->getSemanticsProvidingExpr();
      if (auto *IIOE = dyn_cast<InjectIntoOptionalExpr>(Operand)) {
        Operand = IIOE->getSubExpr()->getSemanticsProvidingExpr();
        DidUnwrapOptional = true;
      }
      if (isa<NilLiteralExpr>(Operand)) {
        FoundNil = true;
      } else if (auto *BLE = dyn_cast<BooleanLiteralExpr>(Operand)) {
        FoundBool = BLE;
      } else if (auto *DRE = dyn_cast<DeclRefExpr>(Operand)) {
        Subject = DRE->getDecl();
      }
    }

    if (!Subject)
      return;

    if (FoundNil) {
      if (Operator->getBaseName() == "==") {
        Type = ConditionType::NIL;
      } else if (Operator->getBaseName() == "!=") {
        Type = ConditionType::NOT_NIL;
      }
    } else if (FoundBool) {
      if (Operator->getBaseName() == "==") {
        Type = FoundBool->getValue() ? ConditionType::IS_TRUE
                                     : ConditionType::IS_FALSE;
      } else if (Operator->getBaseName() == "!=" && !DidUnwrapOptional) {
        // Note that we don't consider this case if we unwrapped an optional,
        // as e.g optBool != false is a check for true *or* nil.
        Type = FoundBool->getValue() ? ConditionType::IS_FALSE
                                     : ConditionType::IS_TRUE;
      }
    }
  }

  /// A bool condition expression.
  explicit CallbackCondition(const Expr *E) {
    // FIXME: Sema should produce ErrorType.
    if (!E->getType() || !E->getType()->isBool())
      return;

    auto CondType = ConditionType::IS_TRUE;
    E = E->getSemanticsProvidingExpr();

    // If we have a prefix negation operator, this is a check for false.
    if (auto *PrefixOp = dyn_cast<PrefixUnaryExpr>(E)) {
      auto *Callee = PrefixOp->getCalledValue();
      if (Callee && Callee->isOperator() && Callee->getBaseName() == "!") {
        CondType = ConditionType::IS_FALSE;
        E = PrefixOp->getOperand()->getSemanticsProvidingExpr();
      }
    }

    auto *DRE = dyn_cast<DeclRefExpr>(E);
    if (!DRE)
      return;

    Subject = DRE->getDecl();
    Type = CondType;
  }

  /// Initializes a `CallbackCondition` with binding of an `Optional` or
  /// `Result` typed `Subject`, ie.
  ///   - `let bind = <Subject>`
  ///   - `case .success(let bind) = <Subject>`
  ///   - `case .failure(let bind) = <Subject>`
  ///   - `let bind = try? <Subject>.get()`
  CallbackCondition(const Pattern *P, const Expr *Init) {
    Init = Init->getSemanticsProvidingExpr();
    P = P->getSemanticsProvidingPattern();

    if (auto *DRE = dyn_cast<DeclRefExpr>(Init)) {
      if (auto *OSP = dyn_cast<OptionalSomePattern>(P)) {
        // `let bind = <Subject>`
        Type = ConditionType::NOT_NIL;
        Subject = DRE->getDecl();
        BindPattern = OSP->getSubPattern();
      } else if (auto *EEP = dyn_cast<EnumElementPattern>(P)) {
        // `case .<func>(let <bind>) = <Subject>`
        initFromEnumPattern(DRE->getDecl(), EEP);
      }
    } else if (auto *OTE = dyn_cast<OptionalTryExpr>(Init)) {
      // `let bind = try? <Subject>.get()`
      if (auto *OSP = dyn_cast<OptionalSomePattern>(P))
        initFromOptionalTry(OSP->getSubPattern(), OTE);
    }
  }

  /// Initializes a `CallbackCondtion` from a case statement inside a switch
  /// on `Subject` with `Result` type, ie.
  /// ```
  /// switch <Subject> {
  /// case .success(let bind):
  /// case .failure(let bind):
  /// }
  /// ```
  CallbackCondition(const Decl *Subject, const CaseLabelItem *CaseItem) {
    if (auto *EEP = dyn_cast<EnumElementPattern>(
            CaseItem->getPattern()->getSemanticsProvidingPattern())) {
      // `case .<func>(let <bind>)`
      initFromEnumPattern(Subject, EEP);
    }
  }

  bool isValid() const { return Type.has_value(); }

private:
  void initFromEnumPattern(const Decl *D, const EnumElementPattern *EEP) {
    if (auto *EED = EEP->getElementDecl()) {
      auto eedTy = EED->getParentEnum()->getDeclaredType();
      if (!eedTy || !eedTy->isResult())
        return;
      if (EED->getNameStr() == StringRef("failure")) {
        Type = ConditionType::FAILURE_PATTEN;
      } else {
        Type = ConditionType::SUCCESS_PATTERN;
      }
      Subject = D;
      BindPattern = EEP->getSubPattern();
    }
  }

  void initFromOptionalTry(const class Pattern *P, const OptionalTryExpr *OTE) {
    auto *ICE = dyn_cast<ImplicitConversionExpr>(OTE->getSubExpr());
    if (!ICE)
      return;
    auto *CE = dyn_cast<CallExpr>(ICE->getSyntacticSubExpr());
    if (!CE)
      return;
    auto *DSC = dyn_cast<DotSyntaxCallExpr>(CE->getFn());
    if (!DSC)
      return;

    auto *BaseDRE = dyn_cast<DeclRefExpr>(DSC->getBase());
    if (!BaseDRE->getType() || !BaseDRE->getType()->isResult())
      return;

    auto *FnDRE = dyn_cast<DeclRefExpr>(DSC->getFn());
    if (!FnDRE)
      return;
    auto *FD = dyn_cast<FuncDecl>(FnDRE->getDecl());
    if (!FD || FD->getNameStr() != StringRef("get"))
      return;

    Type = ConditionType::NOT_NIL;
    Subject = BaseDRE->getDecl();
    BindPattern = P;
  }
};

/// A CallbackCondition with additional semantic information about whether it
/// is for a success path or failure path.
struct ClassifiedCondition : public CallbackCondition {
  ConditionPath Path;

  /// Whether this represents an Obj-C style boolean flag check for success.
  bool IsObjCStyleFlagCheck;

  explicit ClassifiedCondition(CallbackCondition Cond, ConditionPath Path,
                               bool IsObjCStyleFlagCheck)
      : CallbackCondition(Cond), Path(Path),
        IsObjCStyleFlagCheck(IsObjCStyleFlagCheck) {}
};

/// A wrapper for a map of parameter decls to their classified conditions, or
/// \c None if they are not present in any conditions.
struct ClassifiedCallbackConditions final
    : llvm::MapVector<const Decl *, ClassifiedCondition> {
  Optional<ClassifiedCondition> lookup(const Decl *D) const {
    auto Res = find(D);
    if (Res == end())
      return None;
    return Res->second;
  }
};

/// A list of nodes to print, along with a list of locations that may have
/// preceding comments attached, which also need printing. For example:
///
/// \code
/// if .random() {
///   // a
///   print("hello")
///   // b
/// }
/// \endcode
///
/// To print out the contents of the if statement body, we'll include the AST
/// node for the \c print call. This will also include the preceding comment
/// \c a, but won't include the comment \c b. To ensure the comment \c b gets
/// printed, the SourceLoc for the closing brace \c } is added as a possible
/// comment loc.
class NodesToPrint {
  SmallVector<ASTNode, 0> Nodes;
  SmallVector<SourceLoc, 2> PossibleCommentLocs;

public:
  NodesToPrint() {}
  NodesToPrint(ArrayRef<ASTNode> Nodes, ArrayRef<SourceLoc> PossibleCommentLocs)
      : Nodes(Nodes.begin(), Nodes.end()),
        PossibleCommentLocs(PossibleCommentLocs.begin(),
                            PossibleCommentLocs.end()) {}

  ArrayRef<ASTNode> getNodes() const { return Nodes; }
  ArrayRef<SourceLoc> getPossibleCommentLocs() const {
    return PossibleCommentLocs;
  }

  /// Add an AST node to print.
  void addNode(ASTNode Node) {
    // Note we skip vars as they'll be printed as a part of their
    // PatternBindingDecl.
    if (!Node.isDecl(DeclKind::Var))
      Nodes.push_back(Node);
  }

  /// Add a SourceLoc which may have a preceding comment attached. If so, the
  /// comment will be printed out at the appropriate location.
  void addPossibleCommentLoc(SourceLoc Loc) {
    if (Loc.isValid())
      PossibleCommentLocs.push_back(Loc);
  }

  /// Add all the nodes in the brace statement to the list of nodes to print.
  /// This should be preferred over adding the nodes manually as it picks up the
  /// end location of the brace statement as a possible comment loc, ensuring
  /// that we print any trailing comments in the brace statement.
  void addNodesInBraceStmt(BraceStmt *Brace) {
    for (auto Node : Brace->getElements())
      addNode(Node);

    // Ignore the end locations of implicit braces, as they're likely bogus.
    // e.g for a case statement, the r-brace loc points to the last token of the
    // last node in the body.
    if (!Brace->isImplicit())
      addPossibleCommentLoc(Brace->getRBraceLoc());
  }

  /// Add the nodes and comment locs from another NodesToPrint.
  void addNodes(NodesToPrint OtherNodes) {
    Nodes.append(OtherNodes.Nodes.begin(), OtherNodes.Nodes.end());
    PossibleCommentLocs.append(OtherNodes.PossibleCommentLocs.begin(),
                               OtherNodes.PossibleCommentLocs.end());
  }

  /// Whether the last recorded node is an explicit return or break statement.
  bool hasTrailingReturnOrBreak() const {
    if (Nodes.empty())
      return false;
    return (Nodes.back().isStmt(StmtKind::Return) ||
            Nodes.back().isStmt(StmtKind::Break)) &&
           !Nodes.back().isImplicit();
  }

  /// If the last recorded node is an explicit return or break statement that
  /// can be safely dropped, drop it from the list.
  void dropTrailingReturnOrBreakIfPossible() {
    if (!hasTrailingReturnOrBreak())
      return;

    auto *Node = Nodes.back().get<Stmt *>();

    // If this is a return statement with return expression, let's preserve it.
    if (auto *RS = dyn_cast<ReturnStmt>(Node)) {
      if (RS->hasResult())
        return;
    }

    // Remove the node from the list, but make sure to add it as a possible
    // comment loc to preserve any of its attached comments.
    Nodes.pop_back();
    addPossibleCommentLoc(Node->getStartLoc());
  }

  /// Returns a list of nodes to print in a brace statement. This picks up the
  /// end location of the brace statement as a possible comment loc, ensuring
  /// that we print any trailing comments in the brace statement.
  static NodesToPrint inBraceStmt(BraceStmt *stmt) {
    NodesToPrint Nodes;
    Nodes.addNodesInBraceStmt(stmt);
    return Nodes;
  }
};

/// The statements within the closure of call to a function taking a callback
/// are split into a `SuccessBlock` and `ErrorBlock` (`ClassifiedBlocks`).
/// This class stores the nodes for each block, as well as a mapping of
/// decls to any patterns they are used in.
class ClassifiedBlock {
  NodesToPrint Nodes;

  // A mapping of closure params to a list of patterns that bind them.
  using ParamPatternBindingsMap =
      llvm::MapVector<const Decl *, TinyPtrVector<const Pattern *>>;
  ParamPatternBindingsMap ParamPatternBindings;

public:
  const NodesToPrint &nodesToPrint() const { return Nodes; }

  /// Attempt to retrieve an existing bound name for a closure parameter, or
  /// an empty string if there's no suitable existing binding.
  StringRef boundName(const Decl *D) const {
    // Adopt the same name as the representative single pattern, if it only
    // binds a single var.
    if (auto *P = getSinglePatternFor(D)) {
      if (P->getSingleVar())
        return P->getBoundName().str();
    }
    return StringRef();
  }

  /// Checks whether a closure parameter can be represented by a single pattern
  /// that binds it. If the param is only bound by a single pattern, that will
  /// be returned. If there's a pattern with a single var that binds it, that
  /// will be returned, preferring a 'let' pattern to prefer out of line
  /// printing of 'var' patterns.
  const Pattern *getSinglePatternFor(const Decl *D) const {
    auto Iter = ParamPatternBindings.find(D);
    if (Iter == ParamPatternBindings.end())
      return nullptr;

    const auto &Patterns = Iter->second;
    if (Patterns.empty())
      return nullptr;
    if (Patterns.size() == 1)
      return Patterns[0];

    // If we have multiple patterns, search for the best single var pattern to
    // use, preferring a 'let' binding.
    const Pattern *FirstSingleVar = nullptr;
    for (auto *P : Patterns) {
      if (!P->getSingleVar())
        continue;

      if (!P->hasAnyMutableBindings())
        return P;

      if (!FirstSingleVar)
        FirstSingleVar = P;
    }
    return FirstSingleVar;
  }

  /// Retrieve any bound vars that are effectively aliases of a given closure
  /// parameter.
  llvm::SmallDenseSet<const Decl *> getAliasesFor(const Decl *D) const {
    auto Iter = ParamPatternBindings.find(D);
    if (Iter == ParamPatternBindings.end())
      return {};

    llvm::SmallDenseSet<const Decl *> Aliases;

    // The single pattern that we replace the decl with is always an alias.
    if (auto *P = getSinglePatternFor(D)) {
      if (auto *SingleVar = P->getSingleVar())
        Aliases.insert(SingleVar);
    }

    // Any other let bindings we have are also aliases.
    for (auto *P : Iter->second) {
      if (auto *SingleVar = P->getSingleVar()) {
        if (!P->hasAnyMutableBindings())
          Aliases.insert(SingleVar);
      }
    }
    return Aliases;
  }

  const ParamPatternBindingsMap &paramPatternBindings() const {
    return ParamPatternBindings;
  }

  void addNodesInBraceStmt(BraceStmt *Brace) {
    Nodes.addNodesInBraceStmt(Brace);
  }
  void addPossibleCommentLoc(SourceLoc Loc) {
    Nodes.addPossibleCommentLoc(Loc);
  }
  void addAllNodes(NodesToPrint OtherNodes) {
    Nodes.addNodes(std::move(OtherNodes));
  }

  void addNode(ASTNode Node) {
    Nodes.addNode(Node);
  }

  void addBinding(const ClassifiedCondition &FromCondition) {
    auto *P = FromCondition.BindPattern;
    if (!P)
      return;

    // Patterns that don't bind anything aren't interesting.
    SmallVector<VarDecl *, 2> Vars;
    P->collectVariables(Vars);
    if (Vars.empty())
      return;

    ParamPatternBindings[FromCondition.Subject].push_back(P);
  }

  void addAllBindings(const ClassifiedCallbackConditions &FromConditions) {
    for (auto &Entry : FromConditions)
      addBinding(Entry.second);
  }
};

/// The type of block rewritten code may be placed in.
enum class BlockKind {
  SUCCESS, ERROR, FALLBACK
};

/// A completion handler function parameter that is known to be a Bool flag
/// indicating success or failure.
struct KnownBoolFlagParam {
  const ParamDecl *Param;
  bool IsSuccessFlag;
};

/// A set of parameters for a completion callback closure.
class ClosureCallbackParams final {
  const AsyncHandlerParamDesc &HandlerDesc;
  ArrayRef<const ParamDecl *> AllParams;
  llvm::SetVector<const ParamDecl *> SuccessParams;
  const ParamDecl *ErrParam = nullptr;
  Optional<KnownBoolFlagParam> BoolFlagParam;

public:
  ClosureCallbackParams(const AsyncHandlerParamDesc &HandlerDesc,
                        const ClosureExpr *Closure)
      : HandlerDesc(HandlerDesc),
        AllParams(Closure->getParameters()->getArray()) {
    assert(AllParams.size() == HandlerDesc.params().size());
    assert(HandlerDesc.Type != HandlerType::RESULT || AllParams.size() == 1);

    SuccessParams.insert(AllParams.begin(), AllParams.end());
    if (HandlerDesc.HasError && HandlerDesc.Type == HandlerType::PARAMS)
      ErrParam = SuccessParams.pop_back_val();

    // Check to see if we have a known bool flag parameter.
    if (auto *AsyncAlt = HandlerDesc.Func->getAsyncAlternative()) {
      if (auto Conv = AsyncAlt->getForeignAsyncConvention()) {
        auto FlagIdx = Conv->completionHandlerFlagParamIndex();
        if (FlagIdx && *FlagIdx >= 0 && *FlagIdx < AllParams.size()) {
          auto IsSuccessFlag = Conv->completionHandlerFlagIsErrorOnZero();
          BoolFlagParam = {AllParams[*FlagIdx], IsSuccessFlag};
        }
      }
    }
  }

  /// Whether the closure has a particular parameter.
  bool hasParam(const ParamDecl *Param) const {
    return Param == ErrParam || SuccessParams.contains(Param);
  }

  /// Whether \p Param is a success param.
  bool isSuccessParam(const ParamDecl *Param) const {
    return SuccessParams.contains(Param);
  }

  /// Whether \p Param is a closure parameter that may be unwrapped. This
  /// includes optional parameters as well as \c Result parameters that may be
  /// unwrapped through e.g 'try? res.get()'.
  bool isUnwrappableParam(const ParamDecl *Param) const {
    if (!hasParam(Param))
      return false;
    if (getResultParam() == Param)
      return true;
    return HandlerDesc.shouldUnwrap(Param->getType());
  }

  /// Whether \p Param is the known Bool parameter that indicates success or
  /// failure.
  bool isKnownBoolFlagParam(const ParamDecl *Param) const {
    if (auto BoolFlag = getKnownBoolFlagParam())
      return BoolFlag->Param == Param;
    return false;
  }

  /// Whether \p Param is a closure parameter that has a binding available in
  /// the async variant of the call for a particular \p Block.
  bool hasBinding(const ParamDecl *Param, BlockKind Block) const {
    switch (Block) {
    case BlockKind::SUCCESS:
      // Known bool flags get dropped from the imported async variant.
      if (isKnownBoolFlagParam(Param))
        return false;

      return isSuccessParam(Param);
    case BlockKind::ERROR:
      return Param == ErrParam;
    case BlockKind::FALLBACK:
      // We generally want to bind everything in the fallback case.
      return hasParam(Param);
    }
    llvm_unreachable("Unhandled case in switch");
  }

  /// Retrieve the parameters to bind in a given \p Block.
  TinyPtrVector<const ParamDecl *> getParamsToBind(BlockKind Block) {
    TinyPtrVector<const ParamDecl *> Result;
    for (auto *Param : AllParams) {
      if (hasBinding(Param, Block))
        Result.push_back(Param);
    }
    return Result;
  }

  /// If there is a known Bool flag parameter indicating success or failure,
  /// returns it, \c None otherwise.
  Optional<KnownBoolFlagParam> getKnownBoolFlagParam() const {
    return BoolFlagParam;
  }

  /// All the parameters of the closure passed as the completion handler.
  ArrayRef<const ParamDecl *> getAllParams() const { return AllParams; }

  /// The success parameters of the closure passed as the completion handler.
  /// Note this includes a \c Result parameter.
  ArrayRef<const ParamDecl *> getSuccessParams() const {
    return SuccessParams.getArrayRef();
  }

  /// The error parameter of the closure passed as the completion handler, or
  /// \c nullptr if there is no error parameter.
  const ParamDecl *getErrParam() const { return ErrParam; }

  /// If the closure has a single \c Result parameter, returns it, \c nullptr
  /// otherwise.
  const ParamDecl *getResultParam() const {
    return HandlerDesc.Type == HandlerType::RESULT ? SuccessParams[0] : nullptr;
  }
};

/// Whether or not the given statement starts a new scope. Note that most
/// statements are handled by the \c BraceStmt check. The others listed are
/// a somewhat special case since they can also declare variables in their
/// condition.
static bool startsNewScope(Stmt *S) {
  switch (S->getKind()) {
  case StmtKind::Brace:
  case StmtKind::If:
  case StmtKind::While:
  case StmtKind::ForEach:
  case StmtKind::Case:
    return true;
  default:
    return false;
  }
}

struct ClassifiedBlocks {
  ClassifiedBlock SuccessBlock;
  ClassifiedBlock ErrorBlock;
};

/// Classifer of callback closure statements that that have either multiple
/// non-Result parameters or a single Result parameter and return Void.
///
/// It performs a (possibly incorrect) best effort and may give up in certain
/// cases. Aims to cover the idiomatic cases of either having no error
/// parameter at all, or having success/error code wrapped in ifs/guards/switch
/// using either pattern binding or nil checks.
///
/// Code outside any clear conditions is assumed to be solely part of the
/// success block for now, though some heuristics could be added to classify
/// these better in the future.
struct CallbackClassifier {
  /// Updates the success and error block of `Blocks` with nodes and bound
  /// names from `Body`. Errors are added through `DiagEngine`, possibly
  /// resulting in partially filled out blocks.
  static void classifyInto(ClassifiedBlocks &Blocks,
                           const ClosureCallbackParams &Params,
                           llvm::DenseSet<SwitchStmt *> &HandledSwitches,
                           DiagnosticEngine &DiagEngine, BraceStmt *Body) {
    assert(!Body->getElements().empty() && "Cannot classify empty body");
    CallbackClassifier Classifier(Blocks, Params, HandledSwitches, DiagEngine);
    Classifier.classifyNodes(Body->getElements(), Body->getRBraceLoc());
  }

private:
  ClassifiedBlocks &Blocks;
  const ClosureCallbackParams &Params;
  llvm::DenseSet<SwitchStmt *> &HandledSwitches;
  DiagnosticEngine &DiagEngine;
  ClassifiedBlock *CurrentBlock;

  /// This is set to \c true if we're currently classifying on a known condition
  /// path, where \c CurrentBlock is set to the appropriate block. This lets us
  /// be more lenient with unhandled conditions as we already know the block
  /// we're supposed to be in.
  bool IsKnownConditionPath = false;

  CallbackClassifier(ClassifiedBlocks &Blocks,
                     const ClosureCallbackParams &Params,
                     llvm::DenseSet<SwitchStmt *> &HandledSwitches,
                     DiagnosticEngine &DiagEngine)
      : Blocks(Blocks), Params(Params), HandledSwitches(HandledSwitches),
        DiagEngine(DiagEngine), CurrentBlock(&Blocks.SuccessBlock) {}

  /// Attempt to apply custom classification logic to a given node, returning
  /// \c true if the node was classified, otherwise \c false.
  bool tryClassifyNode(ASTNode Node) {
    auto *Statement = Node.dyn_cast<Stmt *>();
    if (!Statement)
      return false;

    if (auto *IS = dyn_cast<IfStmt>(Statement)) {
      NodesToPrint TempNodes;
      if (auto *BS = dyn_cast<BraceStmt>(IS->getThenStmt())) {
        TempNodes = NodesToPrint::inBraceStmt(BS);
      } else {
        TempNodes = NodesToPrint({IS->getThenStmt()}, /*commentLocs*/ {});
      }

      classifyConditional(IS, IS->getCond(), std::move(TempNodes),
                          IS->getElseStmt());
      return true;
    } else if (auto *GS = dyn_cast<GuardStmt>(Statement)) {
      classifyConditional(GS, GS->getCond(), NodesToPrint(), GS->getBody());
      return true;
    } else if (auto *SS = dyn_cast<SwitchStmt>(Statement)) {
      classifySwitch(SS);
      return true;
    } else if (auto *RS = dyn_cast<ReturnStmt>(Statement)) {
      // We can look through an implicit Void return of a SingleValueStmtExpr,
      // as that's semantically a statement.
      if (RS->hasResult() && RS->isImplicit()) {
        auto Ty = RS->getResult()->getType();
        if (Ty && Ty->isVoid()) {
          if (auto *SVE = dyn_cast<SingleValueStmtExpr>(RS->getResult()))
            return tryClassifyNode(SVE->getStmt());
        }
      }
    }
    return false;
  }

  /// Classify a node, or add the node to the block if it cannot be classified.
  /// Returns \c true if there was an error.
  bool classifyNode(ASTNode Node) {
    auto DidClassify = tryClassifyNode(Node);
    if (!DidClassify)
      CurrentBlock->addNode(Node);
    return DiagEngine.hadAnyError();
  }

  void classifyNodes(ArrayRef<ASTNode> Nodes, SourceLoc EndCommentLoc) {
    for (auto Node : Nodes) {
      auto HadError = classifyNode(Node);
      if (HadError)
        return;
    }
    // Make sure to pick up any trailing comments.
    CurrentBlock->addPossibleCommentLoc(EndCommentLoc);
  }

  /// Whether any of the provided ASTNodes have a child expression that force
  /// unwraps the error parameter. Note that this doesn't walk into new scopes.
  bool hasForceUnwrappedErrorParam(ArrayRef<ASTNode> Nodes) {
    auto *ErrParam = Params.getErrParam();
    if (!ErrParam)
      return false;

    class ErrUnwrapFinder : public ASTWalker {
      const ParamDecl *ErrParam;
      bool FoundUnwrap = false;

    public:
      explicit ErrUnwrapFinder(const ParamDecl *ErrParam)
          : ErrParam(ErrParam) {}
      bool foundUnwrap() const { return FoundUnwrap; }

      MacroWalking getMacroWalkingBehavior() const override {
        return MacroWalking::Arguments;
      }

      PreWalkResult<Expr *> walkToExprPre(Expr *E) override {
        // Don't walk into ternary conditionals as they may have additional
        // conditions such as err != nil that make a force unwrap now valid.
        if (isa<TernaryExpr>(E))
          return Action::SkipChildren(E);

        auto *FVE = dyn_cast<ForceValueExpr>(E);
        if (!FVE)
          return Action::Continue(E);

        auto *DRE = dyn_cast<DeclRefExpr>(FVE->getSubExpr());
        if (!DRE)
          return Action::Continue(E);

        if (DRE->getDecl() != ErrParam)
          return Action::Continue(E);

        // If we find the node we're looking for, make a note of it, and abort
        // the walk.
        FoundUnwrap = true;
        return Action::Stop();
      }

      PreWalkResult<Stmt *> walkToStmtPre(Stmt *S) override {
        // Don't walk into new explicit scopes, we only want to consider force
        // unwraps in the immediate conditional body.
        if (!S->isImplicit() && startsNewScope(S))
          return Action::SkipChildren(S);
        return Action::Continue(S);
      }

      PreWalkAction walkToDeclPre(Decl *D) override {
        // Don't walk into new explicit DeclContexts.
        return Action::VisitChildrenIf(D->isImplicit() || !isa<DeclContext>(D));
      }
    };
    for (auto Node : Nodes) {
      ErrUnwrapFinder walker(ErrParam);
      Node.walk(walker);
      if (walker.foundUnwrap())
        return true;
    }
    return false;
  }

  /// Given a callback condition, classify it as a success or failure path.
  Optional<ClassifiedCondition>
  classifyCallbackCondition(const CallbackCondition &Cond,
                            const NodesToPrint &SuccessNodes, Stmt *ElseStmt) {
    if (!Cond.isValid())
      return None;

    // If the condition involves a refutable pattern, we can't currently handle
    // it.
    if (Cond.BindPattern && Cond.BindPattern->isRefutablePattern())
      return None;

    auto *SubjectParam = dyn_cast<ParamDecl>(Cond.Subject);
    if (!SubjectParam)
      return None;

    // For certain types of condition, they need to be certain kinds of params.
    auto CondType = *Cond.Type;
    switch (CondType) {
    case ConditionType::NOT_NIL:
    case ConditionType::NIL:
      if (!Params.isUnwrappableParam(SubjectParam))
        return None;
      break;
    case ConditionType::IS_TRUE:
    case ConditionType::IS_FALSE:
      if (!Params.isSuccessParam(SubjectParam))
        return None;
      break;
    case ConditionType::SUCCESS_PATTERN:
    case ConditionType::FAILURE_PATTEN:
      if (SubjectParam != Params.getResultParam())
        return None;
      break;
    }

    // Let's start with a success path, and flip any negative conditions.
    auto Path = ConditionPath::SUCCESS;

    // If it's an error param, that's a flip.
    if (SubjectParam == Params.getErrParam())
      Path = flippedConditionPath(Path);

    // If we have a nil, false, or failure condition, that's a flip.
    switch (CondType) {
    case ConditionType::NIL:
    case ConditionType::IS_FALSE:
    case ConditionType::FAILURE_PATTEN:
      Path = flippedConditionPath(Path);
      break;
    case ConditionType::IS_TRUE:
    case ConditionType::NOT_NIL:
    case ConditionType::SUCCESS_PATTERN:
      break;
    }

    // If we have a bool condition, it could be an Obj-C style flag check, which
    // we do some extra checking for. Otherwise, we're done.
    if (CondType != ConditionType::IS_TRUE &&
        CondType != ConditionType::IS_FALSE) {
      return ClassifiedCondition(Cond, Path, /*ObjCFlagCheck*/ false);
    }

    // Check to see if we have a known bool flag parameter that indicates
    // success or failure.
    if (auto KnownBoolFlag = Params.getKnownBoolFlagParam()) {
      if (KnownBoolFlag->Param != SubjectParam)
        return None;

      // The path may need to be flipped depending on whether the flag indicates
      // success.
      if (!KnownBoolFlag->IsSuccessFlag)
        Path = flippedConditionPath(Path);

      return ClassifiedCondition(Cond, Path, /*ObjCStyleFlagCheck*/ true);
    }

    // If we've reached here, we have a bool flag check that isn't specified in
    // the async convention. We apply a heuristic to see if the error param is
    // force unwrapped in the conditional body. In that case, the user is
    // expecting it to be the error path, and it's more likely than not that the
    // flag value conveys no more useful information in the error block.

    // First check the success block.
    auto FoundInSuccessBlock =
        hasForceUnwrappedErrorParam(SuccessNodes.getNodes());

    // Then check the else block if we have it.
    if (ASTNode ElseNode = ElseStmt) {
      // Unwrap the BraceStmt of the else clause if needed. This is needed as
      // we won't walk into BraceStmts by default as they introduce new
      // scopes.
      ArrayRef<ASTNode> Nodes;
      if (auto *BS = dyn_cast<BraceStmt>(ElseStmt)) {
        Nodes = BS->getElements();
      } else {
        Nodes = llvm::makeArrayRef(ElseNode);
      }
      if (hasForceUnwrappedErrorParam(Nodes)) {
        // If we also found an unwrap in the success block, we don't know what's
        // happening here.
        if (FoundInSuccessBlock)
          return None;

        // Otherwise we can determine this as a success condition. Note this is
        // flipped as if the error is present in the else block, this condition
        // is for success.
        return ClassifiedCondition(Cond, ConditionPath::SUCCESS,
                                   /*ObjCStyleFlagCheck*/ true);
      }
    }

    if (FoundInSuccessBlock) {
      // Note that the path is flipped as if the error is present in the success
      // block, this condition is for failure.
      return ClassifiedCondition(Cond, ConditionPath::FAILURE,
                                 /*ObjCStyleFlagCheck*/ true);
    }

    // Otherwise we can't classify this.
    return None;
  }

  /// Classifies all the conditions present in a given StmtCondition, taking
  /// into account its success body and failure body. Returns \c true if there
  /// were any conditions that couldn't be classified, \c false otherwise.
  bool classifyConditionsOf(StmtCondition Cond,
                            const NodesToPrint &ThenNodesToPrint,
                            Stmt *ElseStmt,
                            ClassifiedCallbackConditions &Conditions) {
    bool UnhandledConditions = false;
    Optional<ClassifiedCondition> ObjCFlagCheck;
    auto TryAddCond = [&](CallbackCondition CC) {
      auto Classified =
          classifyCallbackCondition(CC, ThenNodesToPrint, ElseStmt);

      // If we couldn't classify this, or if there are multiple Obj-C style flag
      // checks, this is unhandled.
      if (!Classified || (ObjCFlagCheck && Classified->IsObjCStyleFlagCheck)) {
        UnhandledConditions = true;
        return;
      }

      // If we've seen multiple conditions for the same subject, don't handle
      // this.
      if (!Conditions.insert({CC.Subject, *Classified}).second) {
        UnhandledConditions = true;
        return;
      }

      if (Classified->IsObjCStyleFlagCheck)
        ObjCFlagCheck = Classified;
    };

    for (auto &CondElement : Cond) {
      if (auto *BoolExpr = CondElement.getBooleanOrNull()) {
        SmallVector<Expr *, 1> Exprs;
        Exprs.push_back(BoolExpr);

        while (!Exprs.empty()) {
          auto *Next = Exprs.pop_back_val()->getSemanticsProvidingExpr();
          if (auto *ACE = dyn_cast<AutoClosureExpr>(Next))
            Next = ACE->getSingleExpressionBody()->getSemanticsProvidingExpr();

          if (auto *BE = dyn_cast_or_null<BinaryExpr>(Next)) {
            auto *Operator = isOperator(BE);
            if (Operator) {
              // If we have an && operator, decompose its arguments.
              if (Operator->getBaseName() == "&&") {
                Exprs.push_back(BE->getLHS());
                Exprs.push_back(BE->getRHS());
              } else {
                // Otherwise check to see if we have an == nil or != nil
                // condition.
                TryAddCond(CallbackCondition(BE, Operator));
              }
              continue;
            }
          }

          // Check to see if we have a lone bool condition.
          TryAddCond(CallbackCondition(Next));
        }
      } else if (auto *P = CondElement.getPatternOrNull()) {
        TryAddCond(CallbackCondition(P, CondElement.getInitializer()));
      }
    }
    return UnhandledConditions || Conditions.empty();
  }

  /// Classifies the conditions of a conditional statement, and adds the
  /// necessary nodes to either the success or failure block.
  void classifyConditional(Stmt *Statement, StmtCondition Condition,
                           NodesToPrint ThenNodesToPrint, Stmt *ElseStmt) {
    ClassifiedCallbackConditions CallbackConditions;
    bool UnhandledConditions = classifyConditionsOf(
        Condition, ThenNodesToPrint, ElseStmt, CallbackConditions);
    auto ErrCondition = CallbackConditions.lookup(Params.getErrParam());

    if (UnhandledConditions) {
      // Some unknown conditions. If there's an else, assume we can't handle
      // and use the fallback case. Otherwise add to either the success or
      // error block depending on some heuristics, known conditions will have
      // placeholders added (ideally we'd remove them)
      // TODO: Remove known conditions and split the `if` statement

      if (IsKnownConditionPath) {
        // If we're on a known condition path, we can be lenient as we already
        // know what block we're in and can therefore just add the conditional
        // straight to it.
        CurrentBlock->addNode(Statement);
      } else if (CallbackConditions.empty()) {
        // Technically this has a similar problem, ie. the else could have
        // conditions that should be in either success/error
        CurrentBlock->addNode(Statement);
      } else if (ElseStmt) {
        DiagEngine.diagnose(Statement->getStartLoc(),
                            diag::unknown_callback_conditions);
      } else if (ErrCondition && ErrCondition->Path == ConditionPath::FAILURE) {
        Blocks.ErrorBlock.addNode(Statement);
      } else {
        for (auto &Entry : CallbackConditions) {
          if (Entry.second.Path == ConditionPath::FAILURE) {
            Blocks.ErrorBlock.addNode(Statement);
            return;
          }
        }
        Blocks.SuccessBlock.addNode(Statement);
      }
      return;
    }

    // If all the conditions were classified, make sure they're all consistently
    // on the success or failure path.
    Optional<ConditionPath> Path;
    for (auto &Entry : CallbackConditions) {
      auto &Cond = Entry.second;
      if (!Path) {
        Path = Cond.Path;
      } else if (*Path != Cond.Path) {
        // Similar to the unknown conditions case. Add the whole if unless
        // there's an else, in which case use the fallback instead.
        // TODO: Split the `if` statement

        if (ElseStmt) {
          DiagEngine.diagnose(Statement->getStartLoc(),
                              diag::mixed_callback_conditions);
        } else {
          CurrentBlock->addNode(Statement);
        }
        return;
      }
    }
    assert(Path && "Didn't classify a path?");

    auto *ThenBlock = &Blocks.SuccessBlock;
    auto *ElseBlock = &Blocks.ErrorBlock;

    // If the condition is for a failure path, the error block is ThenBlock, and
    // the success block is ElseBlock.
    if (*Path == ConditionPath::FAILURE)
      std::swap(ThenBlock, ElseBlock);

    // We'll be dropping the statement, but make sure to keep any attached
    // comments.
    CurrentBlock->addPossibleCommentLoc(Statement->getStartLoc());

    ThenBlock->addAllBindings(CallbackConditions);

    // TODO: Handle nested ifs
    setNodes(ThenBlock, ElseBlock, std::move(ThenNodesToPrint));

    if (ElseStmt) {
      if (auto *BS = dyn_cast<BraceStmt>(ElseStmt)) {
        // If this is a guard statement, we know that we'll always exit,
        // allowing us to classify any additional nodes into the opposite block.
        auto AlwaysExits = isa<GuardStmt>(Statement);
        setNodes(ElseBlock, ThenBlock, NodesToPrint::inBraceStmt(BS),
                 AlwaysExits);
      } else {
        // If we reached here, we should have an else if statement. Given we
        // know we're in the else of a known condition, temporarily flip the
        // current block, and set that we know what path we're on.
        llvm::SaveAndRestore<bool> CondScope(IsKnownConditionPath, true);
        llvm::SaveAndRestore<ClassifiedBlock *> BlockScope(CurrentBlock,
                                                           ElseBlock);
        classifyNodes(ArrayRef<ASTNode>(ElseStmt),
                      /*endCommentLoc*/ SourceLoc());
      }
    }
  }

  /// Adds \p Nodes to \p Block, potentially flipping the current block if we
  /// can determine that the nodes being added will cause control flow to leave
  /// the scope.
  ///
  /// \param Block The block to add the nodes to.
  /// \param OtherBlock The block for the opposing condition path.
  /// \param Nodes The nodes to add.
  /// \param AlwaysExitsScope Whether the nodes being added always exit the
  /// scope, and therefore whether the current block should be flipped.
  void setNodes(ClassifiedBlock *Block, ClassifiedBlock *OtherBlock,
                NodesToPrint Nodes, bool AlwaysExitsScope = false) {
    // Drop an explicit trailing 'return' or 'break' if we can.
    bool HasTrailingReturnOrBreak = Nodes.hasTrailingReturnOrBreak();
    if (HasTrailingReturnOrBreak)
      Nodes.dropTrailingReturnOrBreakIfPossible();

    // If we know we're exiting the scope, we can set IsKnownConditionPath, as
    // we know any future nodes should be classified into the other block.
    if (HasTrailingReturnOrBreak || AlwaysExitsScope) {
      CurrentBlock = OtherBlock;
      IsKnownConditionPath = true;
      Block->addAllNodes(std::move(Nodes));
    } else {
      Block->addAllNodes(std::move(Nodes));
    }
  }

  void classifySwitch(SwitchStmt *SS) {
    auto *ResultParam = Params.getResultParam();
    if (singleSwitchSubject(SS) != ResultParam) {
      CurrentBlock->addNode(SS);
      return;
    }

    // We'll be dropping the switch, but make sure to keep any attached
    // comments.
    CurrentBlock->addPossibleCommentLoc(SS->getStartLoc());

    // Push the cases into a vector. This is only done to eagerly evaluate the
    // AsCaseStmtRange sequence so we can know what the last case is.
    SmallVector<CaseStmt *, 2> Cases;
    Cases.append(SS->getCases().begin(), SS->getCases().end());

    for (auto *CS : Cases) {
      if (CS->hasFallthroughDest()) {
        DiagEngine.diagnose(CS->getLoc(), diag::callback_with_fallthrough);
        return;
      }

      if (CS->isDefault()) {
        DiagEngine.diagnose(CS->getLoc(), diag::callback_with_default);
        return;
      }

      auto Items = CS->getCaseLabelItems();
      if (Items.size() > 1) {
        DiagEngine.diagnose(CS->getLoc(), diag::callback_multiple_case_items);
        return;
      }

      if (Items[0].getWhereLoc().isValid()) {
        DiagEngine.diagnose(CS->getLoc(), diag::callback_where_case_item);
        return;
      }

      auto *Block = &Blocks.SuccessBlock;
      auto *OtherBlock = &Blocks.ErrorBlock;
      auto SuccessNodes = NodesToPrint::inBraceStmt(CS->getBody());

      // Classify the case pattern.
      auto CC = classifyCallbackCondition(
          CallbackCondition(ResultParam, &Items[0]), SuccessNodes,
          /*elseStmt*/ nullptr);
      if (!CC) {
        DiagEngine.diagnose(CS->getLoc(), diag::unknown_callback_case_item);
        return;
      }

      if (CC->Path == ConditionPath::FAILURE)
        std::swap(Block, OtherBlock);

      // We'll be dropping the case, but make sure to keep any attached
      // comments. Because these comments will effectively be part of the
      // previous case, add them to CurrentBlock.
      CurrentBlock->addPossibleCommentLoc(CS->getStartLoc());

      // Make sure to grab trailing comments in the last case stmt.
      if (CS == Cases.back())
        Block->addPossibleCommentLoc(SS->getRBraceLoc());

      setNodes(Block, OtherBlock, std::move(SuccessNodes));
      Block->addBinding(*CC);
    }
    // Mark this switch statement as having been transformed.
    HandledSwitches.insert(SS);
  }
};

/// Base name of a decl if it has one, an empty \c DeclBaseName otherwise.
static DeclBaseName getDeclName(const Decl *D) {
  if (auto *VD = dyn_cast<ValueDecl>(D)) {
    if (VD->hasName())
      return VD->getBaseName();
  }
  return DeclBaseName();
}

class DeclCollector : private SourceEntityWalker {
  llvm::DenseSet<const Decl *> &Decls;

public:
  /// Collect all explicit declarations declared in \p Scope (or \p SF if
  /// \p Scope is a nullptr) that are not within their own scope.
  static void collect(BraceStmt *Scope, SourceFile &SF,
                      llvm::DenseSet<const Decl *> &Decls) {
    DeclCollector Collector(Decls);
    if (Scope) {
      for (auto Node : Scope->getElements()) {
        Collector.walk(Node);
      }
    } else {
      Collector.walk(SF);
    }
  }

private:
  DeclCollector(llvm::DenseSet<const Decl *> &Decls)
      : Decls(Decls) {}

  bool walkToDeclPre(Decl *D, CharSourceRange Range) override {
    // Want to walk through top level code decls (which are implicitly added
    // for top level non-decl code) and pattern binding decls (which contain
    // the var decls that we care about).
    if (isa<TopLevelCodeDecl>(D) || isa<PatternBindingDecl>(D))
      return true;

    if (!D->isImplicit())
      Decls.insert(D);
    return false;
  }

  bool walkToExprPre(Expr *E) override {
    return !isa<ClosureExpr>(E);
  }

  bool walkToStmtPre(Stmt *S) override {
    return S->isImplicit() || !startsNewScope(S);
  }
};

class ReferenceCollector : private SourceEntityWalker {
  SourceManager *SM;
  llvm::DenseSet<const Decl *> DeclaredDecls;
  llvm::DenseSet<const Decl *> &ReferencedDecls;

  ASTNode Target;
  bool AfterTarget;

public:
  /// Collect all explicit references in \p Scope (or \p SF if \p Scope is
  /// a nullptr) that are after \p Target and not first declared. That is,
  /// references that we don't want to shadow with hoisted declarations.
  ///
  /// Also collect all declarations that are \c DeclContexts, which is an
  /// over-appoximation but let's us ignore them elsewhere.
  static void collect(ASTNode Target, BraceStmt *Scope, SourceFile &SF,
                      llvm::DenseSet<const Decl *> &Decls) {
    ReferenceCollector Collector(Target, &SF.getASTContext().SourceMgr,
                                 Decls);
    if (Scope)
      Collector.walk(Scope);
    else
      Collector.walk(SF);
  }

private:
  ReferenceCollector(ASTNode Target, SourceManager *SM,
                     llvm::DenseSet<const Decl *> &Decls)
      : SM(SM), DeclaredDecls(), ReferencedDecls(Decls), Target(Target),
        AfterTarget(false) {}

  bool walkToDeclPre(Decl *D, CharSourceRange Range) override {
    // Bit of a hack, include all contexts so they're never renamed (seems worse
    // to rename a class/function than it does a variable). Again, an
    // over-approximation, but hopefully doesn't come up too often.
    if (isa<DeclContext>(D) && !D->isImplicit()) {
      ReferencedDecls.insert(D);
    }

    if (AfterTarget && !D->isImplicit()) {
      DeclaredDecls.insert(D);
    } else if (D == Target.dyn_cast<Decl *>()) {
      AfterTarget = true;
    }
    return shouldWalkInto(D->getSourceRange());
  }

  bool walkToExprPre(Expr *E) override {
    if (AfterTarget && !E->isImplicit()) {
      if (auto *DRE = dyn_cast<DeclRefExpr>(E)) {
        if (auto *D = DRE->getDecl()) {
          // Only care about references that aren't declared, as seen decls will
          // be renamed (if necessary) during the refactoring.
          if (!D->isImplicit() && !DeclaredDecls.count(D)) {
            ReferencedDecls.insert(D);

            // Also add the async alternative of a function to prevent
            // collisions if a call is replaced with the alternative.
            if (auto *AFD = dyn_cast<AbstractFunctionDecl>(D)) {
              if (auto *Alternative = AFD->getAsyncAlternative())
                ReferencedDecls.insert(Alternative);
            }
          }
        }
      }
    } else if (E == Target.dyn_cast<Expr *>()) {
      AfterTarget = true;
    }
    return shouldWalkInto(E->getSourceRange());
  }

  bool walkToStmtPre(Stmt *S) override {
    if (S == Target.dyn_cast<Stmt *>())
      AfterTarget = true;
    return shouldWalkInto(S->getSourceRange());
  }

  bool walkToPatternPre(Pattern *P) override {
    if (P == Target.dyn_cast<Pattern *>())
      AfterTarget = true;
    return shouldWalkInto(P->getSourceRange());
  }

  bool shouldWalkInto(SourceRange Range) {
    return AfterTarget || (SM &&
        SM->rangeContainsTokenLoc(Range, Target.getStartLoc()));
  }
};

/// Similar to the \c ReferenceCollector but collects references in all scopes
/// without any starting point in each scope. In addition, it tracks the number
/// of references to a decl in a given scope.
class ScopedDeclCollector : private SourceEntityWalker {
public:
  using DeclsTy = llvm::DenseSet<const Decl *>;
  using RefDeclsTy = llvm::DenseMap<const Decl *, /*numRefs*/ unsigned>;

private:
  using ScopedDeclsTy = llvm::DenseMap<const Stmt *, RefDeclsTy>;

  struct Scope {
    DeclsTy DeclaredDecls;
    RefDeclsTy *ReferencedDecls;
    Scope(RefDeclsTy *ReferencedDecls) : DeclaredDecls(),
        ReferencedDecls(ReferencedDecls) {}
  };

  ScopedDeclsTy ReferencedDecls;
  llvm::SmallVector<Scope, 4> ScopeStack;

public:
  /// Starting at \c Scope, collect all explicit references in every scope
  /// within (including the initial) that are not first declared, ie. those that
  /// could end up shadowed. Also include all \c DeclContext declarations as
  /// we'd like to avoid renaming functions and types completely.
  void collect(ASTNode Node) {
    walk(Node);
  }

  const RefDeclsTy *getReferencedDecls(Stmt *Scope) const {
    auto Res = ReferencedDecls.find(Scope);
    if (Res == ReferencedDecls.end())
      return nullptr;
    return &Res->second;
  }

private:
  bool walkToDeclPre(Decl *D, CharSourceRange Range) override {
    if (ScopeStack.empty() || D->isImplicit())
      return true;

    ScopeStack.back().DeclaredDecls.insert(D);
    if (isa<DeclContext>(D))
      (*ScopeStack.back().ReferencedDecls)[D] += 1;
    return true;
  }

  bool walkToExprPre(Expr *E) override {
    if (ScopeStack.empty())
      return true;

    if (!E->isImplicit()) {
      if (auto *DRE = dyn_cast<DeclRefExpr>(E)) {
        if (auto *D = DRE->getDecl()) {
          // If we have a reference that isn't declared in the same scope,
          // increment the number of references to that decl.
          if (!D->isImplicit() && !ScopeStack.back().DeclaredDecls.count(D)) {
            (*ScopeStack.back().ReferencedDecls)[D] += 1;

            // Also add the async alternative of a function to prevent
            // collisions if a call is replaced with the alternative.
            if (auto *AFD = dyn_cast<AbstractFunctionDecl>(D)) {
              if (auto *Alternative = AFD->getAsyncAlternative())
                (*ScopeStack.back().ReferencedDecls)[Alternative] += 1;
            }
          }
        }
      }
    }
    return true;
  }

  bool walkToStmtPre(Stmt *S) override {
    // Purposely check \c BraceStmt here rather than \c startsNewScope.
    // References in the condition should be applied to the previous scope, not
    // the scope of that statement.
    if (isa<BraceStmt>(S))
      ScopeStack.emplace_back(&ReferencedDecls[S]);
    return true;
  }

  bool walkToStmtPost(Stmt *S) override {
    if (isa<BraceStmt>(S)) {
      size_t NumScopes = ScopeStack.size();
      if (NumScopes >= 2) {
        // Add any referenced decls to the parent scope that weren't declared
        // there.
        auto &ParentStack = ScopeStack[NumScopes - 2];
        for (auto DeclAndNumRefs : *ScopeStack.back().ReferencedDecls) {
          auto *D = DeclAndNumRefs.first;
          if (!ParentStack.DeclaredDecls.count(D))
            (*ParentStack.ReferencedDecls)[D] += DeclAndNumRefs.second;
        }
      }
      ScopeStack.pop_back();
    }
    return true;
  }
};

/// Checks whether an ASTNode contains a reference to a given declaration.
class DeclReferenceFinder : private SourceEntityWalker {
  bool HasFoundReference = false;
  const Decl *Search;

  bool walkToExprPre(Expr *E) override {
    if (auto DRE = dyn_cast<DeclRefExpr>(E)) {
      if (DRE->getDecl() == Search) {
        HasFoundReference = true;
        return false;
      }
    }
    return true;
  }

  DeclReferenceFinder(const Decl *Search) : Search(Search) {}

public:
  /// Returns \c true if \p node contains a reference to \p Search, \c false
  /// otherwise.
  static bool containsReference(ASTNode Node, const ValueDecl *Search) {
    DeclReferenceFinder Checker(Search);
    Checker.walk(Node);
    return Checker.HasFoundReference;
  }
};

/// Builds up async-converted code for an AST node.
///
/// If it is a function, its declaration will have `async` added. If a
/// completion handler is present, it will be removed and the return type of
/// the function will reflect the parameters of the handler, including an
/// added `throws` if necessary.
///
/// Calls to the completion handler are replaced with either a `return` or
/// `throws` depending on the arguments.
///
/// Calls to functions with an async alternative will be replaced with a call
/// to the alternative, possibly wrapped in a do/catch. The do/catch is skipped
/// if the the closure either:
///   1. Has no error
///   2. Has an error but no error handling (eg. just ignores)
///   3. Has error handling that only calls the containing function's handler
///      with an error matching the error argument
///
/// (2) is technically not the correct translation, but in practice it's likely
/// the code a user would actually want.
///
/// If the success vs error handling split inside the closure cannot be
/// determined and the closure takes regular parameters (ie. not a Result), a
/// fallback translation is used that keeps all the same variable names and
/// simply moves the code within the closure out.
///
/// The fallback is generally avoided, however, since it's quite unlikely to be
/// the code the user intended. In most cases the refactoring will continue,
/// with any unhandled decls wrapped in placeholders instead.
class AsyncConverter : private SourceEntityWalker {
  struct Scope {
    llvm::DenseSet<DeclBaseName> Names;

    /// If this scope is wrapped in a \c withChecked(Throwing)Continuation, the
    /// name of the continuation that must be resumed where there previously was
    /// a call to the function's completion handler.
    /// Otherwise an empty identifier.
    Identifier ContinuationName;

    Scope(Identifier ContinuationName)
        : Names(), ContinuationName(ContinuationName) {}

    /// Whether this scope is wrapped in a \c withChecked(Throwing)Continuation.
    bool isWrappedInContination() const { return !ContinuationName.empty(); }
  };
  SourceFile *SF;
  SourceManager &SM;
  DiagnosticEngine &DiagEngine;

  // Node to convert
  ASTNode StartNode;

  // Completion handler of `StartNode` (if it's a function with an async
  // alternative)
  AsyncHandlerParamDesc TopHandler;

  SmallString<0> Buffer;
  llvm::raw_svector_ostream OS;

  // Decls where any force unwrap or optional chain of that decl should be
  // elided, e.g for a previously optional closure parameter that has become a
  // non-optional local.
  llvm::DenseSet<const Decl *> Unwraps;

  // Decls whose references should be replaced with, either because they no
  // longer exist or are a different type. Any replaced code should ideally be
  // handled by the refactoring properly, but that's not possible in all cases
  llvm::DenseSet<const Decl *> Placeholders;

  // Mapping from decl -> name, used as the name of possible new local
  // declarations of old completion handler parametes, as well as the
  // replacement for other hoisted declarations and their references
  llvm::DenseMap<const Decl *, Identifier> Names;

  /// The scopes (containing all name decls and whether the scope is wrapped in
  /// a continuation) as the AST is being walked. The first element is the
  /// initial scope and the last is the current scope.
  llvm::SmallVector<Scope, 4> Scopes;

  // Mapping of \c BraceStmt -> declarations referenced in that statement
  // without first being declared. These are used to fill the \c ScopeNames
  // map on entering that scope.
  ScopedDeclCollector ScopedDecls;

  /// The switch statements that have been re-written by this transform.
  llvm::DenseSet<SwitchStmt *> HandledSwitches;

  // The last source location that has been output. Used to output the source
  // between handled nodes
  SourceLoc LastAddedLoc;

  // Number of expressions (or pattern binding decl) currently nested in, taking
  // into account hoisting and the possible removal of ifs/switches
  int NestedExprCount = 0;

  // Whether a completion handler body is currently being hoisted out of its
  // call
  bool Hoisting = false;

  /// Whether a pattern is currently being converted.
  bool ConvertingPattern = false;

  /// A mapping of inline patterns to print for closure parameters.
  using InlinePatternsToPrint = llvm::DenseMap<const Decl *, const Pattern *>;

public:
  /// Convert a function
  AsyncConverter(SourceFile *SF, SourceManager &SM,
                 DiagnosticEngine &DiagEngine, AbstractFunctionDecl *FD,
                 const AsyncHandlerParamDesc &TopHandler)
      : SF(SF), SM(SM), DiagEngine(DiagEngine), StartNode(FD),
        TopHandler(TopHandler), OS(Buffer) {
    Placeholders.insert(TopHandler.getHandler());
    ScopedDecls.collect(FD);

    // Shouldn't strictly be necessary, but prefer possible shadowing over
    // crashes caused by a missing scope
    addNewScope({});
  }

  /// Convert a call
  AsyncConverter(SourceFile *SF, SourceManager &SM,
                 DiagnosticEngine &DiagEngine, CallExpr *CE, BraceStmt *Scope)
      : SF(SF), SM(SM), DiagEngine(DiagEngine), StartNode(CE), OS(Buffer) {
    ScopedDecls.collect(CE);

    // Create the initial scope, can be more accurate than the general
    // \c ScopedDeclCollector as there is a starting point.
    llvm::DenseSet<const Decl *> UsedDecls;
    DeclCollector::collect(Scope, *SF, UsedDecls);
    ReferenceCollector::collect(StartNode, Scope, *SF, UsedDecls);
    addNewScope(UsedDecls);
  }

  ASTContext &getASTContext() const { return SF->getASTContext(); }

  bool convert() {
    assert(Buffer.empty() && "AsyncConverter can only be used once");

    if (auto *FD = dyn_cast_or_null<FuncDecl>(StartNode.dyn_cast<Decl *>())) {
      addFuncDecl(FD);
      if (FD->getBody()) {
        convertNode(FD->getBody());
      }
    } else {
      convertNode(StartNode, /*StartOverride=*/{}, /*ConvertCalls=*/true,
                  /*IncludeComments=*/false);
    }
    return !DiagEngine.hadAnyError();
  }

  /// When adding an async alternative method for the function declaration \c
  /// FD, this function tries to create a function body for the legacy function
  /// (the one with a completion handler), which calls the newly converted async
  /// function. There are certain situations in which we fail to create such a
  /// body, e.g. if the completion handler has the signature `(String, Error?)
  /// -> Void` in which case we can't synthesize the result of type \c String in
  /// the error case.
  bool createLegacyBody() {
    assert(Buffer.empty() &&
           "AsyncConverter can only be used once");

    if (!canCreateLegacyBody())
      return false;

    FuncDecl *FD = cast<FuncDecl>(StartNode.get<Decl *>());
    OS << tok::l_brace << "\n"; // start function body
    OS << "Task " << tok::l_brace << "\n";
    addHoistedNamedCallback(FD, TopHandler, TopHandler.getNameStr(), [&]() {
      if (TopHandler.HasError) {
        OS << tok::kw_try << " ";
      }
      OS << "await ";

      // Since we're *creating* the async alternative here, there shouldn't
      // already be one. Thus, just assume that the call to the alternative is
      // the same as the call to the old completion handler function, minus the
      // completion handler arg.
      addForwardingCallTo(FD, /*HandlerReplacement=*/"");
    });
    OS << "\n";
    OS << tok::r_brace << "\n"; // end 'Task'
    OS << tok::r_brace << "\n"; // end function body
    return true;
  }

  /// Creates an async alternative function that forwards onto the completion
  /// handler function through
  /// withCheckedContinuation/withCheckedThrowingContinuation.
  bool createAsyncWrapper() {
    assert(Buffer.empty() && "AsyncConverter can only be used once");
    auto *FD = cast<FuncDecl>(StartNode.get<Decl *>());

    // First add the new async function declaration.
    addFuncDecl(FD);
    OS << tok::l_brace << "\n";

    // Then add the body.
    OS << tok::kw_return << " ";
    if (TopHandler.HasError)
      OS << tok::kw_try << " ";

    OS << "await ";

    // withChecked[Throwing]Continuation { continuation in
    if (TopHandler.HasError) {
      OS << "withCheckedThrowingContinuation";
    } else {
      OS << "withCheckedContinuation";
    }
    OS << " " << tok::l_brace << " continuation " << tok::kw_in << "\n";

    // fnWithHandler(args...) { ... }
    auto ClosureStr =
        getAsyncWrapperCompletionClosure("continuation", TopHandler);
    addForwardingCallTo(FD, /*HandlerReplacement=*/ClosureStr);

    OS << "\n";
    OS << tok::r_brace << "\n"; // end continuation closure
    OS << tok::r_brace << "\n"; // end function body
    return true;
  }

  void replace(ASTNode Node, SourceEditConsumer &EditConsumer,
               SourceLoc StartOverride = SourceLoc()) {
    SourceRange Range = Node.getSourceRange();
    if (StartOverride.isValid()) {
      Range = SourceRange(StartOverride, Range.End);
    }
    CharSourceRange CharRange =
        Lexer::getCharSourceRangeFromSourceRange(SM, Range);
    EditConsumer.accept(SM, CharRange, Buffer.str());
    Buffer.clear();
  }

  void insertAfter(ASTNode Node, SourceEditConsumer &EditConsumer) {
    EditConsumer.insertAfter(SM, Node.getEndLoc(), "\n\n");
    EditConsumer.insertAfter(SM, Node.getEndLoc(), Buffer.str());
    Buffer.clear();
  }

private:
  bool canCreateLegacyBody() {
    FuncDecl *FD = dyn_cast<FuncDecl>(StartNode.dyn_cast<Decl *>());
    if (!FD) {
      return false;
    }
    if (FD == nullptr || FD->getBody() == nullptr) {
      return false;
    }
    if (FD->hasThrows()) {
      assert(!TopHandler.isValid() && "We shouldn't have found a handler desc "
                                       "if the original function throws");
      return false;
    }
    return TopHandler.isValid();
  }

  /// Prints a tuple of elements, or a lone single element if only one is
  /// present, using the provided printing function.
  template <typename Container, typename PrintFn>
  void addTupleOf(const Container &Elements, llvm::raw_ostream &OS,
                  PrintFn PrintElt) {
    if (Elements.size() == 1) {
      PrintElt(Elements[0]);
      return;
    }
    OS << tok::l_paren;
    llvm::interleave(Elements, PrintElt, [&]() { OS << tok::comma << " "; });
    OS << tok::r_paren;
  }

  /// Retrieve the completion handler closure argument for an async wrapper
  /// function.
  std::string
  getAsyncWrapperCompletionClosure(StringRef ContName,
                                   const AsyncHandlerParamDesc &HandlerDesc) {
    std::string OutputStr;
    llvm::raw_string_ostream OS(OutputStr);

    OS << tok::l_brace; // start closure

    // Prepare parameter names for the closure.
    auto SuccessParams = HandlerDesc.getSuccessParams();
    SmallVector<SmallString<4>, 2> SuccessParamNames;
    for (auto idx : indices(SuccessParams)) {
      SuccessParamNames.emplace_back("result");

      // If we have multiple success params, number them e.g res1, res2...
      if (SuccessParams.size() > 1)
        SuccessParamNames.back().append(std::to_string(idx + 1));
    }
    Optional<SmallString<4>> ErrName;
    if (HandlerDesc.getErrorParam())
      ErrName.emplace("error");

    auto HasAnyParams = !SuccessParamNames.empty() || ErrName;
    if (HasAnyParams)
      OS << " ";

    // res1, res2
    llvm::interleave(
        SuccessParamNames, [&](auto Name) { OS << Name; },
        [&]() { OS << tok::comma << " "; });

    // , err
    if (ErrName) {
      if (!SuccessParamNames.empty())
        OS << tok::comma << " ";

      OS << *ErrName;
    }
    if (HasAnyParams)
      OS << " " << tok::kw_in;

    OS << "\n";

    // The closure body.
    switch (HandlerDesc.Type) {
    case HandlerType::PARAMS: {
      // For a (Success?, Error?) -> Void handler, we do an if let on the error.
      if (ErrName) {
        // if let err = err {
        OS << tok::kw_if << " " << tok::kw_let << " ";
        OS << *ErrName << " " << tok::equal << " " << *ErrName << " ";
        OS << tok::l_brace << "\n";
        for (auto Idx : indices(SuccessParamNames)) {
          auto ParamTy = SuccessParams[Idx].getParameterType();
          if (!HandlerDesc.shouldUnwrap(ParamTy))
            continue;
        }

        // continuation.resume(throwing: err)
        OS << ContName << tok::period << "resume" << tok::l_paren;
        OS << "throwing" << tok::colon << " " << *ErrName;
        OS << tok::r_paren << "\n";

        // return }
        OS << tok::kw_return << "\n";
        OS << tok::r_brace << "\n";
      }

      // If we have any success params that we need to unwrap, insert a guard.
      for (auto Idx : indices(SuccessParamNames)) {
        auto &Name = SuccessParamNames[Idx];
        auto ParamTy = SuccessParams[Idx].getParameterType();
        if (!HandlerDesc.shouldUnwrap(ParamTy))
          continue;

        // guard let res = res else {
        OS << tok::kw_guard << " " << tok::kw_let << " ";
        OS << Name << " " << tok::equal << " " << Name << " " << tok::kw_else;
        OS << " " << tok::l_brace << "\n";

        // fatalError(...)
        OS << "fatalError" << tok::l_paren;
        OS << "\"Expected non-nil result '" << Name << "' for nil error\"";
        OS << tok::r_paren << "\n";

        // End guard.
        OS << tok::r_brace << "\n";
      }

      // continuation.resume(returning: (res1, res2, ...))
      OS << ContName << tok::period << "resume" << tok::l_paren;
      OS << "returning" << tok::colon << " ";
      addTupleOf(SuccessParamNames, OS, [&](auto Ref) { OS << Ref; });
      OS << tok::r_paren << "\n";
      break;
    }
    case HandlerType::RESULT: {
      // continuation.resume(with: res)
      assert(SuccessParamNames.size() == 1);
      OS << ContName << tok::period << "resume" << tok::l_paren;
      OS << "with" << tok::colon << " " << SuccessParamNames[0];
      OS << tok::r_paren << "\n";
      break;
    }
    case HandlerType::INVALID:
      llvm_unreachable("Should not have an invalid handler here");
    }

    OS << tok::r_brace; // end closure
    return OutputStr;
  }

  /// Retrieves the SourceRange of the preceding comment, or an invalid range if
  /// there is no preceding comment.
  CharSourceRange getPrecedingCommentRange(SourceLoc Loc) {
    auto Tokens = SF->getAllTokens();
    auto TokenIter = token_lower_bound(Tokens, Loc);
    if (TokenIter == Tokens.end() || !TokenIter->hasComment())
      return CharSourceRange();
    return TokenIter->getCommentRange();
  }

  /// Retrieves the location for the start of a comment attached to the token
  /// at the provided location, or the location itself if there is no comment.
  SourceLoc getLocIncludingPrecedingComment(SourceLoc Loc) {
    auto CommentRange = getPrecedingCommentRange(Loc);
    if (CommentRange.isInvalid())
      return Loc;
    return CommentRange.getStart();
  }

  /// If the provided SourceLoc has a preceding comment, print it out.
  void printCommentIfNeeded(SourceLoc Loc) {
    auto CommentRange = getPrecedingCommentRange(Loc);
    if (CommentRange.isValid())
      OS << "\n" << CommentRange.str();
  }

  void convertNodes(const NodesToPrint &ToPrint) {
    // Sort the possible comment locs in reverse order so we can pop them as we
    // go.
    SmallVector<SourceLoc, 2> CommentLocs;
    CommentLocs.append(ToPrint.getPossibleCommentLocs().begin(),
                       ToPrint.getPossibleCommentLocs().end());
    llvm::sort(CommentLocs.begin(), CommentLocs.end(), [](auto lhs, auto rhs) {
      return lhs.getOpaquePointerValue() > rhs.getOpaquePointerValue();
    });

    // First print the nodes we've been asked to print.
    for (auto Node : ToPrint.getNodes()) {
      // If we need to print comments, do so now.
      while (!CommentLocs.empty()) {
        auto CommentLoc = CommentLocs.back().getOpaquePointerValue();
        auto NodeLoc = Node.getStartLoc().getOpaquePointerValue();
        assert(CommentLoc != NodeLoc &&
               "Added node to both comment locs and nodes to print?");

        // If the comment occurs after the node, don't print now. Wait until
        // the right node comes along.
        if (CommentLoc > NodeLoc)
          break;

        printCommentIfNeeded(CommentLocs.pop_back_val());
      }
      OS << "\n";
      convertNode(Node);
    }

    // We're done printing nodes. Make sure to output the remaining comments.
    while (!CommentLocs.empty())
      printCommentIfNeeded(CommentLocs.pop_back_val());
  }

  void convertNode(ASTNode Node, SourceLoc StartOverride = {},
                   bool ConvertCalls = true,
                   bool IncludePrecedingComment = true) {
    if (!StartOverride.isValid())
      StartOverride = Node.getStartLoc();

    // Make sure to include any preceding comments attached to the loc
    if (IncludePrecedingComment)
      StartOverride = getLocIncludingPrecedingComment(StartOverride);

    llvm::SaveAndRestore<SourceLoc> RestoreLoc(LastAddedLoc, StartOverride);
    llvm::SaveAndRestore<int> RestoreCount(NestedExprCount,
                                           ConvertCalls ? 0 : 1);

    walk(Node);
    addRange(LastAddedLoc, Node.getEndLoc(), /*ToEndOfToken=*/true);
  }

  void convertPattern(const Pattern *P) {
    // Only print semantic patterns. This cleans up the output of the transform
    // and works around some bogus source locs that can appear with typed
    // patterns in if let statements.
    P = P->getSemanticsProvidingPattern();

    // Set up the start of the pattern as the last loc printed to make sure we
    // accurately fill in the gaps as we customize the printing of sub-patterns.
    llvm::SaveAndRestore<SourceLoc> RestoreLoc(LastAddedLoc, P->getStartLoc());
    llvm::SaveAndRestore<bool> RestoreFlag(ConvertingPattern, true);

    walk(const_cast<Pattern *>(P));
    addRange(LastAddedLoc, P->getEndLoc(), /*ToEndOfToken*/ true);
  }

  /// Check whether \p Node requires the remainder of this scope to be wrapped
  /// in a \c withChecked(Throwing)Continuation. If it is necessary, add
  /// a call to \c withChecked(Throwing)Continuation and modify the current
  /// scope (\c Scopes.back() ) so that it knows it's wrapped in a continuation.
  ///
  /// Wrapping a node in a continuation is necessary if the following conditions
  /// are satisfied:
  ///  - It contains a reference to the \c TopHandler's completion hander,
  ///    because these completion handler calls need to be promoted to \c return
  ///    statements in the refactored method, but
  ///  - We cannot hoist the completion handler of \p Node, because it doesn't
  ///    have an async alternative by our heuristics (e.g. because of a
  ///    completion handler name mismatch or because it also returns a value
  ///    synchronously).
  void wrapScopeInContinationIfNecessary(ASTNode Node) {
    if (NestedExprCount != 0) {
      // We can't start a continuation in the middle of an expression
      return;
    }
    if (Scopes.back().isWrappedInContination()) {
      // We are already in a continuation. No need to add another one.
      return;
    }
    if (!DeclReferenceFinder::containsReference(Node,
                                                TopHandler.getHandler())) {
      // The node doesn't have a reference to the function's completion handler.
      // It can stay a call with a completion handler, because we don't need to
      // promote a completion handler call to a 'return'.
      return;
    }

    // Wrap the current call in a continuation

    Identifier contName = createUniqueName("continuation");
    Scopes.back().Names.insert(contName);
    Scopes.back().ContinuationName = contName;

    insertCustom(Node.getStartLoc(), [&]() {
      OS << tok::kw_return << ' ';
      if (TopHandler.HasError) {
        OS << tok::kw_try << ' ';
      }
      OS << "await ";
      if (TopHandler.HasError) {
        OS << "withCheckedThrowingContinuation ";
      } else {
        OS << "withCheckedContinuation ";
      }
      OS << tok::l_brace << ' ' << contName << ' ' << tok::kw_in << '\n';
    });
  }

  bool walkToPatternPre(Pattern *P) override {
    // If we're not converting a pattern, there's nothing extra to do.
    if (!ConvertingPattern)
      return true;

    // When converting a pattern, don't print the 'let' or 'var' of binding
    // subpatterns, as they're illegal when nested in PBDs, and we print a
    // top-level one.
    if (auto *BP = dyn_cast<BindingPattern>(P)) {
      return addCustom(BP->getSourceRange(), [&]() {
        convertPattern(BP->getSubPattern());
      });
    }
    return true;
  }

  bool walkToDeclPre(Decl *D, CharSourceRange Range) override {
    if (isa<PatternBindingDecl>(D)) {
      // We can't hoist a closure inside a PatternBindingDecl. If it contains
      // a call to the completion handler, wrap it in a continuation.
      wrapScopeInContinationIfNecessary(D);
      NestedExprCount++;
      return true;
    }

    // Functions and types already have their names in \c Scopes.Names, only
    // variables should need to be renamed.
    if (isa<VarDecl>(D)) {
      // If we don't already have a name for the var, assign it one. Note that
      // vars in binding patterns may already have assigned names here.
      if (Names.find(D) == Names.end()) {
        auto Ident = assignUniqueName(D, StringRef());
        Scopes.back().Names.insert(Ident);
      }
      addCustom(D->getSourceRange(), [&]() {
        OS << newNameFor(D);
      });
    }

    // Note we don't walk into any nested local function decls. If we start
    // doing so in the future, be sure to update the logic that deals with
    // converting unhandled returns into placeholders in walkToStmtPre.
    return false;
  }

  bool walkToDeclPost(Decl *D) override {
    NestedExprCount--;
    return true;
  }

#define PLACEHOLDER_START "<#"
#define PLACEHOLDER_END "#>"
  bool walkToExprPre(Expr *E) override {
    // TODO: Handle Result.get as well
    if (auto *DRE = dyn_cast<DeclRefExpr>(E)) {
      if (auto *D = DRE->getDecl()) {
        // Look through to the parent var decl if we have one. This ensures we
        // look at the var in a case stmt's pattern rather than the var that's
        // implicitly declared in the body.
        if (auto *VD = dyn_cast<VarDecl>(D)) {
          if (auto *Parent = VD->getParentVarDecl())
            D = Parent;
        }

        bool AddPlaceholder = Placeholders.count(D);
        StringRef Name = newNameFor(D, false);
        if (AddPlaceholder || !Name.empty())
          return addCustom(DRE->getSourceRange(), [&]() {
            if (AddPlaceholder)
              OS << PLACEHOLDER_START;
            if (!Name.empty())
              OS << Name;
            else
              D->getName().print(OS);
            if (AddPlaceholder)
              OS << PLACEHOLDER_END;
          });
      }
    } else if (isa<ForceValueExpr>(E) || isa<BindOptionalExpr>(E)) {
      // Remove a force unwrap or optional chain of a returned success value,
      // as it will no longer be optional. For force unwraps, this is always a
      // valid transform. For optional chains, it is a locally valid transform
      // within the optional chain e.g foo?.x -> foo.x, but may change the type
      // of the overall chain, which could cause errors elsewhere in the code.
      // However this is generally more useful to the user than just leaving
      // 'foo' as a placeholder. Note this is only the case when no other
      // optionals are involved in the chain, e.g foo?.x?.y -> foo.x?.y is
      // completely valid.
      if (auto *D = E->getReferencedDecl().getDecl()) {
        if (Unwraps.count(D))
          return addCustom(E->getSourceRange(),
                           [&]() { OS << newNameFor(D, true); });
      }
    } else if (CallExpr *CE = TopHandler.getAsHandlerCall(E)) {
      if (Scopes.back().isWrappedInContination()) {
        return addCustom(E->getSourceRange(),
                         [&]() { convertHandlerToContinuationResume(CE); });
      } else if (NestedExprCount == 0) {
        return addCustom(E->getSourceRange(),
                         [&]() { convertHandlerToReturnOrThrows(CE); });
      }
    } else if (auto *CE = dyn_cast<CallExpr>(E)) {
      // Try and hoist a call's completion handler. Don't do so if
      //  - the current expression is nested (we can't start hoisting in the
      //    middle of an expression)
      //  - the current scope is wrapped in a continuation (we can't have await
      //    calls in the continuation block)
      if (NestedExprCount == 0 && !Scopes.back().isWrappedInContination()) {
        // If the refactoring is on the call itself, do not require the callee
        // to have the @available attribute or a completion-like name.
        auto HandlerDesc = AsyncHandlerParamDesc::find(
            getUnderlyingFunc(CE->getFn()),
            /*RequireAttributeOrName=*/StartNode.dyn_cast<Expr *>() != CE);
        if (HandlerDesc.isValid()) {
          return addCustom(CE->getSourceRange(),
                           [&]() { addHoistedCallback(CE, HandlerDesc); });
        }
      }
    }

    // A void SingleValueStmtExpr is semantically more like a statement than
    // an expression, so recurse without bumping the expr depth or wrapping in
    // continuation.
    if (auto *SVE = dyn_cast<SingleValueStmtExpr>(E)) {
      auto ty = SVE->getType();
      if (!ty || ty->isVoid())
        return true;
    }

    // We didn't do any special conversion for this expression. If needed, wrap
    // it in a continuation.
    wrapScopeInContinationIfNecessary(E);

    NestedExprCount++;
    return true;
  }

  bool replaceRangeWithPlaceholder(SourceRange range) {
    return addCustom(range, [&]() {
      OS << PLACEHOLDER_START;
      addRange(range, /*toEndOfToken*/ true);
      OS << PLACEHOLDER_END;
    });
  }

  bool walkToExprPost(Expr *E) override {
    if (auto *SVE = dyn_cast<SingleValueStmtExpr>(E)) {
      auto ty = SVE->getType();
      if (!ty || ty->isVoid())
        return true;
    }
    NestedExprCount--;
    return true;
  }

#undef PLACEHOLDER_START
#undef PLACEHOLDER_END

  bool walkToStmtPre(Stmt *S) override {
    // CaseStmt has an implicit BraceStmt inside it, which *should* start a new
    // scope, so don't check isImplicit here.
    if (startsNewScope(S)) {
      // Add all names of decls referenced within this statement that aren't
      // also declared first, plus any contexts. Note that \c getReferencedDecl
      // will only return a value for a \c BraceStmt. This means that \c IfStmt
      // (and other statements with conditions) will have their own empty scope,
      // which is fine for our purposes - their existing names are always valid.
      // The body of those statements will include the decls if they've been
      // referenced, so shadowing is still avoided there.
      if (auto *ReferencedDecls = ScopedDecls.getReferencedDecls(S)) {
        llvm::DenseSet<const Decl *> Decls;
        for (auto DeclAndNumRefs : *ReferencedDecls)
          Decls.insert(DeclAndNumRefs.first);
        addNewScope(Decls);
      } else {
        addNewScope({});
      }
    } else if (Hoisting && !S->isImplicit()) {
      // Some break and return statements need to be turned into placeholders,
      // as they may no longer perform the control flow that the user is
      // expecting.
      if (auto *BS = dyn_cast<BreakStmt>(S)) {
        // For a break, if it's jumping out of a switch statement that we've
        // re-written as a part of the transform, turn it into a placeholder, as
        // it would have been lifted out of the switch statement.
        if (auto *SS = dyn_cast<SwitchStmt>(BS->getTarget())) {
          if (HandledSwitches.contains(SS))
            return replaceRangeWithPlaceholder(S->getSourceRange());
        }
      } else if (isa<ReturnStmt>(S) && NestedExprCount == 0) {
        // For a return, if it's not nested inside another closure or function,
        // turn it into a placeholder, as it will be lifted out of the callback.
        // Note that we only turn the 'return' token into a placeholder as we
        // still want to be able to apply transforms to the argument.
        replaceRangeWithPlaceholder(S->getStartLoc());
      }
    }
    return true;
  }

  bool walkToStmtPost(Stmt *S) override {
    if (startsNewScope(S)) {
      bool ClosedScopeWasWrappedInContinuation =
          Scopes.back().isWrappedInContination();
      Scopes.pop_back();
      if (ClosedScopeWasWrappedInContinuation &&
          !Scopes.back().isWrappedInContination()) {
        // The nested scope was wrapped in a continuation but the current one
        // isn't anymore. Add the '}' that corresponds to the the call to
        // withChecked(Throwing)Continuation.
        insertCustom(S->getEndLoc(), [&]() { OS << tok::r_brace << '\n'; });
      }
    }
    return true;
  }

  bool addCustom(SourceRange Range, llvm::function_ref<void()> Custom = {}) {
    addRange(LastAddedLoc, Range.Start);
    Custom();
    LastAddedLoc = Lexer::getLocForEndOfToken(SM, Range.End);
    return false;
  }

  /// Insert custom text at the given \p Loc that shouldn't replace any existing
  /// source code.
  bool insertCustom(SourceLoc Loc, llvm::function_ref<void()> Custom = {}) {
    addRange(LastAddedLoc, Loc);
    Custom();
    LastAddedLoc = Loc;
    return false;
  }

  void addRange(SourceLoc Start, SourceLoc End, bool ToEndOfToken = false) {
    if (ToEndOfToken) {
      OS << Lexer::getCharSourceRangeFromSourceRange(SM,
                                                     SourceRange(Start, End))
                .str();
    } else {
      OS << CharSourceRange(SM, Start, End).str();
    }
  }

  void addRange(SourceRange Range, bool ToEndOfToken = false) {
    addRange(Range.Start, Range.End, ToEndOfToken);
  }

  void addFuncDecl(const FuncDecl *FD) {
    auto *Params = FD->getParameters();
    auto *HandlerParam = TopHandler.getHandlerParam();
    auto ParamPos = TopHandler.handlerParamPosition();

    // If the completion handler parameter has a default argument, the async
    // version is effectively @discardableResult, as not all the callers care
    // about receiving the completion call.
    if (HandlerParam && HandlerParam->isDefaultArgument())
      OS << tok::at_sign << "discardableResult" << "\n";

    // First chunk: start -> the parameter to remove (if any)
    SourceLoc LeftEndLoc;
    switch (ParamPos) {
    case AsyncHandlerParamDesc::Position::None:
    case AsyncHandlerParamDesc::Position::Only:
    case AsyncHandlerParamDesc::Position::First:
      // Handler is the first param (or there is none), so only include the (
      LeftEndLoc = Params->getLParenLoc().getAdvancedLoc(1);
      break;
    case AsyncHandlerParamDesc::Position::Middle:
      // Handler is somewhere in the middle of the params, so we need to
      // include any comments and comma up until the handler
      LeftEndLoc = Params->get(TopHandler.Index)->getStartLoc();
      LeftEndLoc = getLocIncludingPrecedingComment(LeftEndLoc);
      break;
    case AsyncHandlerParamDesc::Position::Last:
      // Handler is the last param, which means we don't want the comma. This
      // is a little annoying since we *do* want the comments past for the
      // last parameter
      LeftEndLoc = Lexer::getLocForEndOfToken(
          SM, Params->get(TopHandler.Index - 1)->getEndLoc());
      // Skip to the end of any comments
      Token Next = Lexer::getTokenAtLocation(SM, LeftEndLoc,
                                             CommentRetentionMode::None);
      if (Next.getKind() != tok::NUM_TOKENS)
        LeftEndLoc = Next.getLoc();
      break;
    }
    addRange(FD->getSourceRangeIncludingAttrs().Start, LeftEndLoc);

    // Second chunk: end of the parameter to remove -> right parenthesis
    SourceLoc MidStartLoc;
    SourceLoc MidEndLoc = Params->getRParenLoc().getAdvancedLoc(1);
    switch (ParamPos) {
    case AsyncHandlerParamDesc::Position::None:
      // No handler param, so make sure to include them all
      MidStartLoc = LeftEndLoc;
      break;
    case AsyncHandlerParamDesc::Position::First:
    case AsyncHandlerParamDesc::Position::Middle:
      // Handler param is either the first or one of the middle params. Skip
      // past it but make sure to include comments preceding the param after
      // the handler
      MidStartLoc = Params->get(TopHandler.Index + 1)->getStartLoc();
      MidStartLoc = getLocIncludingPrecedingComment(MidStartLoc);
      break;
    case AsyncHandlerParamDesc::Position::Only:
    case AsyncHandlerParamDesc::Position::Last:
      // Handler param is last, this is easy since there's no other params
      // to copy over
      MidStartLoc = Params->getRParenLoc();
      break;
    }
    addRange(MidStartLoc, MidEndLoc);

    // Third chunk: add in async and throws if necessary
    if (!FD->hasAsync())
      OS << " async";
    if (FD->hasThrows() || TopHandler.HasError)
      // TODO: Add throws if converting a function and it has a converted call
      //       without a do/catch
      OS << " " << tok::kw_throws;

    // Fourth chunk: if no parent handler (ie. not adding an async
    // alternative), the rest of the decl. Otherwise, add in the new return
    // type
    if (!TopHandler.isValid()) {
      SourceLoc RightStartLoc = MidEndLoc;
      if (FD->hasThrows()) {
        RightStartLoc = Lexer::getLocForEndOfToken(SM, FD->getThrowsLoc());
      }
      SourceLoc RightEndLoc =
          FD->getBody() ? FD->getBody()->getLBraceLoc() : RightStartLoc;
      addRange(RightStartLoc, RightEndLoc);
      return;
    }

    SmallVector<LabeledReturnType, 2> Scratch;
    auto ReturnTypes = TopHandler.getAsyncReturnTypes(Scratch);
    if (ReturnTypes.empty()) {
      OS << " ";
      return;
    }

    // Print the function result type, making sure to omit a '-> Void' return.
    if (!TopHandler.willAsyncReturnVoid()) {
      OS << " -> ";
      addAsyncFuncReturnType(TopHandler);
    }

    if (FD->hasBody())
      OS << " ";

    // TODO: Should remove the generic param and where clause for the error
    //       param if it exists (and no other parameter uses that type)
    TrailingWhereClause *TWC = FD->getTrailingWhereClause();
    if (TWC && TWC->getWhereLoc().isValid()) {
      auto Range = TWC->getSourceRange();
      OS << Lexer::getCharSourceRangeFromSourceRange(SM, Range).str();
      if (FD->hasBody())
        OS << " ";
    }
  }

  void addFallbackVars(ArrayRef<const ParamDecl *> FallbackParams,
                       const ClosureCallbackParams &AllParams) {
    for (auto *Param : FallbackParams) {
      auto Ty = Param->getType();
      auto ParamName = newNameFor(Param);

      // If this is the known bool success param, we can use 'let' and type it
      // as non-optional, as it gets bound in both blocks.
      if (AllParams.isKnownBoolFlagParam(Param)) {
        OS << tok::kw_let << " " << ParamName << ": ";
        Ty->print(OS);
        OS << "\n";
        continue;
      }

      OS << tok::kw_var << " " << ParamName << ": ";
      Ty->print(OS);
      if (!Ty->getOptionalObjectType())
        OS << "?";

      OS << " = " << tok::kw_nil << "\n";
    }
  }

  void addDo() { OS << tok::kw_do << " " << tok::l_brace << "\n"; }

  /// Assuming that \p Result represents an error result to completion handler,
  /// returns \c true if the error has already been handled through a
  /// 'try await'.
  bool isErrorAlreadyHandled(HandlerResult Result) {
    assert(Result.isError());
    assert(Result.args().size() == 1 &&
           "There should only be one error parameter");
    // We assume that the error has already been handled if its variable
    // declaration doesn't exist anymore, which is the case if it's in
    // Placeholders but not in Unwraps (if it's in Placeholders and Unwraps
    // an optional Error has simply been promoted to a non-optional Error).
    if (auto *DRE = dyn_cast<DeclRefExpr>(Result.args().back().getExpr())) {
      if (Placeholders.count(DRE->getDecl()) &&
          !Unwraps.count(DRE->getDecl())) {
        return true;
      }
    }
    return false;
  }

  /// Returns \c true if the source representation of \p E can be interpreted
  /// as an expression returning an Optional value.
  bool isExpressionOptional(Expr *E) {
    if (isa<InjectIntoOptionalExpr>(E)) {
      // E is downgrading a non-Optional result to an Optional. Its source
      // representation isn't Optional.
      return false;
    }
    if (auto DRE = dyn_cast<DeclRefExpr>(E)) {
      if (Unwraps.count(DRE->getDecl())) {
        // E has been promoted to a non-Optional value. It can't be used as an
        // Optional anymore.
        return false;
      }
    }
    if (!E->getType().isNull() && E->getType()->isOptional()) {
      return true;
    }
    // We couldn't determine the type. Assume non-Optional.
    return false;
  }

  /// Converts a call \p CE to a completion handler. Depending on the call it
  /// will be interpreted as a call that's returning a success result, an error
  /// or, if the call is completely ambiguous, adds an if-let that checks if the
  /// error is \c nil at runtime and dispatches to the success or error case
  /// depending on it.
  /// \p AddConvertedHandlerCall needs to add the converted version of the
  /// completion handler. Depending on the given \c HandlerResult, it must be
  /// intepreted as a success or error call.
  /// \p AddConvertedErrorCall must add the converted equivalent of returning an
  /// error. The passed \c StringRef contains the name of a variable that is of
  /// type 'Error'.
  void convertHandlerCall(
      const CallExpr *CE,
      llvm::function_ref<void(HandlerResult)> AddConvertedHandlerCall,
      llvm::function_ref<void(StringRef)> AddConvertedErrorCall) {
    auto Result =
        TopHandler.extractResultArgs(CE, /*ReturnErrorArgsIfAmbiguous=*/true);
    if (!TopHandler.isAmbiguousCallToParamHandler(CE)) {
      if (Result.isError()) {
        if (!isErrorAlreadyHandled(Result)) {
          // If the error has already been handled, we don't need to add another
          // throwing call.
          AddConvertedHandlerCall(Result);
        }
      } else {
        AddConvertedHandlerCall(Result);
      }
    } else {
      assert(Result.isError() && "If the call was ambiguous, we should have "
                                 "retrieved its error representation");
      assert(Result.args().size() == 1 &&
             "There should only be one error parameter");
      Expr *ErrorExpr = Result.args().back().getExpr();
      if (isErrorAlreadyHandled(Result)) {
        // The error has already been handled, interpret the call as a success
        // call.
        auto SuccessExprs = TopHandler.extractResultArgs(
            CE, /*ReturnErrorArgsIfAmbiguous=*/false);
        AddConvertedHandlerCall(SuccessExprs);
      } else if (!isExpressionOptional(ErrorExpr)) {
        // The error is never nil. No matter what the success param is, we
        // interpret it as an error call.
        AddConvertedHandlerCall(Result);
      } else {
        // The call was truly ambiguous. Add an
        // if let error = <convert error arg> {
        //   throw error // or equivalent
        // } else {
        //   <interpret call as success call>
        // }
        auto SuccessExprs = TopHandler.extractResultArgs(
            CE, /*ReturnErrorArgsIfAmbiguous=*/false);

        // The variable 'error' is only available in the 'if let' scope, so we
        // don't need to create a new unique one.
        StringRef ErrorName = "error";
        OS << tok::kw_if << ' ' << tok::kw_let << ' ' << ErrorName << ' '
           << tok::equal << ' ';
        convertNode(ErrorExpr, /*StartOverride=*/{}, /*ConvertCalls=*/false);
        OS << ' ' << tok::l_brace << '\n';
        AddConvertedErrorCall(ErrorName);
        OS << tok::r_brace << ' ' << tok::kw_else << ' ' << tok::l_brace
           << '\n';
        AddConvertedHandlerCall(SuccessExprs);
        OS << '\n' << tok::r_brace;
      }
    }
  }

  /// Convert a call \p CE to a completion handler to its 'return' or 'throws'
  /// equivalent.
  void convertHandlerToReturnOrThrows(const CallExpr *CE) {
    return convertHandlerCall(
        CE,
        [&](HandlerResult Exprs) {
          convertHandlerToReturnOrThrowsImpl(CE, Exprs);
        },
        [&](StringRef ErrorName) {
          OS << tok::kw_throw << ' ' << ErrorName << '\n';
        });
  }

  /// Convert the call \p CE to a completion handler to its 'return' or 'throws'
  /// equivalent, where \p Result determines whether the call should be
  /// interpreted as an error or success call.
  void convertHandlerToReturnOrThrowsImpl(const CallExpr *CE,
                                          HandlerResult Result) {
    bool AddedReturnOrThrow = true;
    if (!Result.isError()) {
      // It's possible the user has already written an explicit return statement
      // for the completion handler call, e.g 'return completion(args...)'. In
      // that case, be sure not to add another return.
      auto *parent = getWalker().Parent.getAsStmt();
      if (isa_and_nonnull<ReturnStmt>(parent) &&
          !cast<ReturnStmt>(parent)->isImplicit()) {
        // The statement already has a return keyword. Don't add another one.
        AddedReturnOrThrow = false;
      } else {
        OS << tok::kw_return;
      }
    } else {
      OS << tok::kw_throw;
    }

    auto Args = Result.args();
    if (!Args.empty()) {
      if (AddedReturnOrThrow)
        OS << ' ';

      addTupleOf(Args, OS, [&](Argument Arg) {
        // Special case: If the completion handler is a params handler that
        // takes an error, we could pass arguments to it without unwrapping
        // them. E.g.
        //   simpleWithError { (res: String?, error: Error?) in
        //     completion(res, nil)
        //   }
        // But after refactoring `simpleWithError` to an async function we have
        //   let res: String = await simple()
        // and `res` is no longer an `Optional`. Thus it's in `Placeholders` and
        // `Unwraps` and any reference to it will be replaced by a placeholder
        // unless it is wrapped in an unwrapping expression. This would cause us
        // to create `return <#res# >`.
        // Under our assumption that either the error or the result parameter
        // are non-nil, the above call to the completion handler is equivalent
        // to
        //   completion(res!, nil)
        // which correctly yields
        //   return res
        // Synthesize the force unwrap so that we get the expected results.
        auto *E = Arg.getExpr();
        if (TopHandler.getHandlerType() == HandlerType::PARAMS &&
            TopHandler.HasError) {
          if (auto DRE =
                  dyn_cast<DeclRefExpr>(E->getSemanticsProvidingExpr())) {
            auto D = DRE->getDecl();
            if (Unwraps.count(D)) {
              E = new (getASTContext()) ForceValueExpr(E, SourceLoc());
            }
          }
        }
        // Can't just add the range as we need to perform replacements
        convertNode(E, /*StartOverride=*/Arg.getLabelLoc(),
                    /*ConvertCalls=*/false);
      });
    }
  }

  /// Convert a call \p CE to a completion handler to resumes of the
  /// continuation that's currently on top of the stack.
  void convertHandlerToContinuationResume(const CallExpr *CE) {
    return convertHandlerCall(
        CE,
        [&](HandlerResult Exprs) {
          convertHandlerToContinuationResumeImpl(CE, Exprs);
        },
        [&](StringRef ErrorName) {
          Identifier ContinuationName = Scopes.back().ContinuationName;
          OS << ContinuationName << tok::period << "resume" << tok::l_paren
             << "throwing" << tok::colon << ' ' << ErrorName;
          OS << tok::r_paren << '\n';
        });
  }

  /// Convert a call \p CE to a completion handler to resumes of the
  /// continuation that's currently on top of the stack.
  /// \p Result determines whether the call should be interpreted as a success
  /// or error call.
  void convertHandlerToContinuationResumeImpl(const CallExpr *CE,
                                              HandlerResult Result) {
    assert(Scopes.back().isWrappedInContination());

    std::vector<Argument> Args;
    StringRef ResumeArgumentLabel;
    switch (TopHandler.getHandlerType()) {
    case HandlerType::PARAMS: {
      Args = Result.args();
      if (!Result.isError()) {
        ResumeArgumentLabel = "returning";
      } else {
        ResumeArgumentLabel = "throwing";
      }
      break;
    }
    case HandlerType::RESULT: {
      Args = {CE->getArgs()->begin(), CE->getArgs()->end()};
      ResumeArgumentLabel = "with";
      break;
    }
    case HandlerType::INVALID:
      llvm_unreachable("Invalid top handler");
    }

    // A vector in which each argument of Result has an entry. If the entry is
    // not empty then that argument has been unwrapped using 'guard let' into
    // a variable with that name.
    SmallVector<Identifier, 4> ArgNames;
    ArgNames.reserve(Args.size());

    /// When unwrapping a result argument \p Arg into a variable using
    /// 'guard let' return a suitable name for the unwrapped variable.
    /// \p ArgIndex is the index of \p Arg in the results passed to the
    /// completion handler.
    auto GetSuitableNameForGuardUnwrap = [&](Expr *Arg,
                                             unsigned ArgIndex) -> Identifier {
      // If Arg is a DeclRef, use its name for the guard unwrap.
      // guard let myVar1 = myVar.
      if (auto DRE = dyn_cast<DeclRefExpr>(Arg)) {
        return createUniqueName(DRE->getDecl()->getBaseIdentifier().str());
      } else if (auto IIOE = dyn_cast<InjectIntoOptionalExpr>(Arg)) {
        if (auto DRE = dyn_cast<DeclRefExpr>(IIOE->getSubExpr())) {
          return createUniqueName(DRE->getDecl()->getBaseIdentifier().str());
        }
      }
      if (Args.size() == 1) {
        // We only have a single result. 'result' seems a resonable name.
        return createUniqueName("result");
      } else {
        // We are returning a tuple. Name the result elements 'result' +
        // index in tuple.
        return createUniqueName("result" + std::to_string(ArgIndex));
      }
    };

    unsigned ArgIndex = 0;
    for (auto Arg : Args) {
      auto *ArgExpr = Arg.getExpr();
      Identifier ArgName;
      if (isExpressionOptional(ArgExpr) && TopHandler.HasError) {
        ArgName = GetSuitableNameForGuardUnwrap(ArgExpr, ArgIndex);
        Scopes.back().Names.insert(ArgName);
        OS << tok::kw_guard << ' ' << tok::kw_let << ' ' << ArgName << ' '
           << tok::equal << ' ';

        // If the argument is a call with a trailing closure, the generated
        // guard statement will not compile.
        // e.g. 'guard let result1 = value.map { $0 + 1 } else { ... }'
        // doesn't compile. Adding parentheses makes the code compile.
        auto HasTrailingClosure = false;
        if (auto *CE = dyn_cast<CallExpr>(ArgExpr)) {
          if (CE->getArgs()->hasAnyTrailingClosures())
            HasTrailingClosure = true;
        }

        if (HasTrailingClosure)
          OS << tok::l_paren;

        convertNode(ArgExpr, /*StartOverride=*/Arg.getLabelLoc(),
                    /*ConvertCalls=*/false);

        if (HasTrailingClosure)
          OS << tok::r_paren;

        OS << ' ' << tok::kw_else << ' ' << tok::l_brace << '\n';
        OS << "fatalError" << tok::l_paren;
        OS << "\"Expected non-nil result ";
        if (ArgName.str() != "result") {
          OS << "'" << ArgName << "' ";
        }
        OS << "in the non-error case\"";
        OS << tok::r_paren << '\n';
        OS << tok::r_brace << '\n';
      }
      ArgNames.push_back(ArgName);
      ArgIndex++;
    }

    Identifier ContName = Scopes.back().ContinuationName;
    OS << ContName << tok::period << "resume" << tok::l_paren
       << ResumeArgumentLabel << tok::colon << ' ';

    ArgIndex = 0;
    addTupleOf(Args, OS, [&](Argument Arg) {
      Identifier ArgName = ArgNames[ArgIndex];
      if (!ArgName.empty()) {
        OS << ArgName;
      } else {
        // Can't just add the range as we need to perform replacements
        convertNode(Arg.getExpr(), /*StartOverride=*/Arg.getLabelLoc(),
                    /*ConvertCalls=*/false);
      }
      ArgIndex++;
    });
    OS << tok::r_paren;
  }

  /// From the given expression \p E, which is an argument to a function call,
  /// extract the passed closure if there is one. Otherwise return \c nullptr.
  ClosureExpr *extractCallback(Expr *E) {
    E = lookThroughFunctionConversionExpr(E);
    if (auto Closure = dyn_cast<ClosureExpr>(E)) {
      return Closure;
    } else if (auto CaptureList = dyn_cast<CaptureListExpr>(E)) {
      return dyn_cast<ClosureExpr>(CaptureList->getClosureBody());
    } else {
      return nullptr;
    }
  }

  /// Callback arguments marked as e.g. `@convention(block)` produce arguments
  /// that are `FunctionConversionExpr`.
  /// We don't care about the conversions and want to shave them off.
  Expr *lookThroughFunctionConversionExpr(Expr *E) {
    if (auto FunctionConversion = dyn_cast<FunctionConversionExpr>(E)) {
      return lookThroughFunctionConversionExpr(
          FunctionConversion->getSubExpr());
    } else {
      return E;
    }
  }

  void addHoistedCallback(const CallExpr *CE,
                          const AsyncHandlerParamDesc &HandlerDesc) {
    llvm::SaveAndRestore<bool> RestoreHoisting(Hoisting, true);

    auto *ArgList = CE->getArgs();
    if (HandlerDesc.Index >= ArgList->size()) {
      DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg);
      return;
    }

    Expr *CallbackArg =
        lookThroughFunctionConversionExpr(ArgList->getExpr(HandlerDesc.Index));
    if (ClosureExpr *Callback = extractCallback(CallbackArg)) {
      // The user is using a closure for the completion handler
      addHoistedClosureCallback(CE, HandlerDesc, Callback);
      return;
    }
    if (auto CallbackDecl = getReferencedDecl(CallbackArg)) {
      if (CallbackDecl == TopHandler.getHandler()) {
        // We are refactoring the function that declared the completion handler
        // that would be called here. We can't call the completion handler
        // anymore because it will be removed. But since the function that
        // declared it is being refactored to async, we can just return the
        // values.
        if (!HandlerDesc.willAsyncReturnVoid()) {
          OS << tok::kw_return << " ";
        }
        InlinePatternsToPrint InlinePatterns;
        addAwaitCall(CE, ClassifiedBlock(), {}, InlinePatterns, HandlerDesc,
                     /*AddDeclarations*/ false);
        return;
      }
      // We are not removing the completion handler, so we can call it once the
      // async function returns.

      // The completion handler that is called as part of the \p CE call.
      // This will be called once the async function returns.
      auto CompletionHandler =
          AsyncHandlerDesc::get(CallbackDecl, /*RequireAttributeOrName=*/false);
      if (CompletionHandler.isValid()) {
        if (auto CalledFunc = getUnderlyingFunc(CE->getFn())) {
          StringRef HandlerName = Lexer::getCharSourceRangeFromSourceRange(
              SM, CallbackArg->getSourceRange()).str();
          addHoistedNamedCallback(
              CalledFunc, CompletionHandler, HandlerName, [&] {
                InlinePatternsToPrint InlinePatterns;
                addAwaitCall(CE, ClassifiedBlock(), {}, InlinePatterns,
                             HandlerDesc, /*AddDeclarations*/ false);
              });
          return;
        }
      }
    }
    DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg);
  }

  /// Add a binding to a known bool flag that indicates success or failure.
  void addBoolFlagParamBindingIfNeeded(Optional<KnownBoolFlagParam> Flag,
                                       BlockKind Block) {
    if (!Flag)
      return;
    // Figure out the polarity of the binding based on the block we're in and
    // whether the flag indicates success.
    auto Polarity = true;
    switch (Block) {
    case BlockKind::SUCCESS:
      break;
    case BlockKind::ERROR:
      Polarity = !Polarity;
      break;
    case BlockKind::FALLBACK:
      llvm_unreachable("Not a valid place to bind");
    }
    if (!Flag->IsSuccessFlag)
      Polarity = !Polarity;

    OS << newNameFor(Flag->Param) << " " << tok::equal << " ";
    OS << (Polarity ? tok::kw_true : tok::kw_false) << "\n";
  }

  /// Add a call to the async alternative of \p CE and convert the \p Callback
  /// to be executed after the async call. \p HandlerDesc describes the
  /// completion handler in the function that's called by \p CE and \p ArgList
  /// are the arguments being passed in \p CE.
  void addHoistedClosureCallback(const CallExpr *CE,
                                 const AsyncHandlerParamDesc &HandlerDesc,
                                 const ClosureExpr *Callback) {
    if (HandlerDesc.params().size() != Callback->getParameters()->size()) {
      DiagEngine.diagnose(CE->getStartLoc(), diag::mismatched_callback_args);
      return;
    }
    ClosureCallbackParams CallbackParams(HandlerDesc, Callback);
    ClassifiedBlocks Blocks;
    auto *CallbackBody = Callback->getBody();
    if (!HandlerDesc.HasError) {
      Blocks.SuccessBlock.addNodesInBraceStmt(CallbackBody);
    } else if (!CallbackBody->getElements().empty()) {
      CallbackClassifier::classifyInto(Blocks, CallbackParams, HandledSwitches,
                                       DiagEngine, CallbackBody);
    }

    auto SuccessBindings = CallbackParams.getParamsToBind(BlockKind::SUCCESS);
    auto *ErrParam = CallbackParams.getErrParam();
    if (DiagEngine.hadAnyError()) {
      // For now, only fallback when the results are params with an error param,
      // in which case only the names are used (defaulted to the names of the
      // params if none).
      if (HandlerDesc.Type != HandlerType::PARAMS || !HandlerDesc.HasError)
        return;
      DiagEngine.resetHadAnyError();

      // Note that we don't print any inline patterns here as we just want
      // assignments to the names in the outer scope.
      InlinePatternsToPrint InlinePatterns;

      auto AllBindings = CallbackParams.getParamsToBind(BlockKind::FALLBACK);

      prepareNames(ClassifiedBlock(), AllBindings, InlinePatterns);
      preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams,
                                    BlockKind::FALLBACK);
      addFallbackVars(AllBindings, CallbackParams);
      addDo();
      addAwaitCall(CE, Blocks.SuccessBlock, SuccessBindings, InlinePatterns,
                   HandlerDesc, /*AddDeclarations*/ false);
      OS << "\n";

      // If we have a known Bool success param, we need to bind it.
      addBoolFlagParamBindingIfNeeded(CallbackParams.getKnownBoolFlagParam(),
                                      BlockKind::SUCCESS);
      addFallbackCatch(CallbackParams);
      OS << "\n";
      convertNodes(NodesToPrint::inBraceStmt(CallbackBody));

      clearNames(AllBindings);
      return;
    }

    auto *ErrOrResultParam = ErrParam;
    if (auto *ResultParam = CallbackParams.getResultParam())
      ErrOrResultParam = ResultParam;

    auto ErrorNodes = Blocks.ErrorBlock.nodesToPrint().getNodes();
    bool RequireDo = !ErrorNodes.empty();
    // Check if we *actually* need a do/catch (see class comment)
    if (ErrorNodes.size() == 1) {
      auto Node = ErrorNodes[0];
      if (auto *HandlerCall = TopHandler.getAsHandlerCall(Node)) {
        auto Res = TopHandler.extractResultArgs(
            HandlerCall, /*ReturnErrorArgsIfAmbiguous=*/true);
        if (Res.args().size() == 1) {
          // Skip if we have the param itself or the name it's bound to
          auto *ArgExpr = Res.args()[0].getExpr();
          auto *SingleDecl = ArgExpr->getReferencedDecl().getDecl();
          auto ErrName = Blocks.ErrorBlock.boundName(ErrOrResultParam);
          RequireDo = SingleDecl != ErrOrResultParam &&
                      !(Res.isError() && SingleDecl &&
                        SingleDecl->getName().isSimpleName(ErrName));
        }
      }
    }

    // If we're not requiring a 'do', we'll be dropping the error block. But
    // let's make sure we at least preserve the comments in the error block by
    // transplanting them into the success block. This should make sure they
    // maintain a sensible ordering.
    if (!RequireDo) {
      auto ErrorNodes = Blocks.ErrorBlock.nodesToPrint();
      for (auto CommentLoc : ErrorNodes.getPossibleCommentLocs())
        Blocks.SuccessBlock.addPossibleCommentLoc(CommentLoc);
    }

    if (RequireDo) {
      addDo();
    }

    auto InlinePatterns = getInlinePatternsToPrint(Blocks.SuccessBlock,
                                                   SuccessBindings, Callback);

    prepareNames(Blocks.SuccessBlock, SuccessBindings, InlinePatterns);
    preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams,
                                  BlockKind::SUCCESS);

    addAwaitCall(CE, Blocks.SuccessBlock, SuccessBindings, InlinePatterns,
                 HandlerDesc, /*AddDeclarations=*/true);
    printOutOfLineBindingPatterns(Blocks.SuccessBlock, InlinePatterns);
    convertNodes(Blocks.SuccessBlock.nodesToPrint());
    clearNames(SuccessBindings);

    if (RequireDo) {
      // We don't use inline patterns for the error path.
      InlinePatternsToPrint ErrInlinePatterns;

      // Always use the ErrParam name if none is bound.
      prepareNames(Blocks.ErrorBlock, llvm::makeArrayRef(ErrOrResultParam),
                   ErrInlinePatterns,
                   /*AddIfMissing=*/HandlerDesc.Type != HandlerType::RESULT);
      preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams,
                                    BlockKind::ERROR);

      addCatch(ErrOrResultParam);
      convertNodes(Blocks.ErrorBlock.nodesToPrint());
      OS << "\n" << tok::r_brace;
      clearNames(llvm::makeArrayRef(ErrOrResultParam));
    }
  }

  /// Add a call to the async alternative of \p FD. Afterwards, pass the results
  /// of the async call to the completion handler, named \p HandlerName and
  /// described by \p HandlerDesc.
  /// \p AddAwaitCall adds the call to the refactored async method to the output
  /// stream without storing the result to any variables.
  /// This is used when the user didn't use a closure for the callback, but
  /// passed in a variable or function name for the completion handler.
  void addHoistedNamedCallback(const FuncDecl *FD,
                               const AsyncHandlerDesc &HandlerDesc,
                               StringRef HandlerName,
                               std::function<void(void)> AddAwaitCall) {
    if (HandlerDesc.HasError) {
      // "result" and "error" always okay to use here since they're added
      // in their own scope, which only contains new code.
      addDo();
      if (!HandlerDesc.willAsyncReturnVoid()) {
        OS << tok::kw_let << " result";
        addResultTypeAnnotationIfNecessary(FD, HandlerDesc);
        OS << " " << tok::equal << " ";
      }
      AddAwaitCall();
      OS << "\n";
      addCallToCompletionHandler("result", HandlerDesc, HandlerName);
      OS << "\n";
      OS << tok::r_brace << " " << tok::kw_catch << " " << tok::l_brace << "\n";
      addCallToCompletionHandler(StringRef(), HandlerDesc, HandlerName);
      OS << "\n" << tok::r_brace; // end catch
    } else {
      // This code may be placed into an existing scope, in that case create
      // a unique "result" name so that it doesn't cause shadowing or redecls.
      StringRef ResultName;
      if (!HandlerDesc.willAsyncReturnVoid()) {
        Identifier Unique = createUniqueName("result");
        Scopes.back().Names.insert(Unique);
        ResultName = Unique.str();

        OS << tok::kw_let << " " << ResultName;
        addResultTypeAnnotationIfNecessary(FD, HandlerDesc);
        OS << " " << tok::equal << " ";
      } else {
        // The name won't end up being used, just give it a bogus one so that
        // the result path is taken (versus the error path).
        ResultName = "result";
      }
      AddAwaitCall();
      OS << "\n";
      addCallToCompletionHandler(ResultName, HandlerDesc, HandlerName);
    }
  }

  /// Checks whether a binding pattern for a given decl can be printed inline in
  /// an await call, e.g 'let ((x, y), z) = await foo()', where '(x, y)' is the
  /// inline pattern.
  const Pattern *
  bindingPatternToPrintInline(const Decl *D, const ClassifiedBlock &Block,
                              const ClosureExpr *CallbackClosure) {
    // Only currently done for callback closures.
    if (!CallbackClosure)
      return nullptr;

    // If we can reduce the pattern bindings down to a single pattern, we may
    // be able to print it inline.
    auto *P = Block.getSinglePatternFor(D);
    if (!P)
      return nullptr;

    // Patterns that bind a single var are always printed inline.
    if (P->getSingleVar())
      return P;

    // If we have a multi-var binding, and the decl being bound is referenced
    // elsewhere in the block, we cannot print the pattern immediately in the
    // await call. Instead, we'll print it out of line.
    auto *Decls = ScopedDecls.getReferencedDecls(CallbackClosure->getBody());
    assert(Decls);
    auto NumRefs = Decls->lookup(D);
    return NumRefs == 1 ? P : nullptr;
  }

  /// Retrieve a map of patterns to print inline for an array of param decls.
  InlinePatternsToPrint
  getInlinePatternsToPrint(const ClassifiedBlock &Block,
                           ArrayRef<const ParamDecl *> Params,
                           const ClosureExpr *CallbackClosure) {
    InlinePatternsToPrint Patterns;
    for (auto *Param : Params) {
      if (auto *P = bindingPatternToPrintInline(Param, Block, CallbackClosure))
        Patterns[Param] = P;
    }
    return Patterns;
  }

  /// Print any out of line binding patterns that could not be printed as inline
  /// patterns. These typically appear directly after an await call, e.g:
  /// \code
  /// let x = await foo()
  /// let (y, z) = x
  /// \endcode
  void
  printOutOfLineBindingPatterns(const ClassifiedBlock &Block,
                                const InlinePatternsToPrint &InlinePatterns) {
    for (auto &Entry : Block.paramPatternBindings()) {
      auto *D = Entry.first;
      auto Aliases = Block.getAliasesFor(D);

      for (auto *P : Entry.second) {
        // If we already printed this as an inline pattern, there's nothing else
        // to do.
        if (InlinePatterns.lookup(D) == P)
          continue;

        // If this is an alias binding, it can be elided.
        if (auto *SingleVar = P->getSingleVar()) {
          if (Aliases.contains(SingleVar))
            continue;
        }

        auto HasMutable = P->hasAnyMutableBindings();
        OS << "\n" << (HasMutable ? tok::kw_var : tok::kw_let) << " ";
        convertPattern(P);
        OS << " = ";
        OS << newNameFor(D);
      }
    }
  }

  /// Prints an \c await call to an \c async function, binding any return values
  /// into variables.
  ///
  /// \param CE The call expr to convert.
  /// \param SuccessBlock The nodes present in the success block following the
  /// call.
  /// \param SuccessParams The success parameters, which will be printed as
  /// return values.
  /// \param InlinePatterns A map of patterns that can be printed inline for
  /// a given param.
  /// \param HandlerDesc A description of the completion handler.
  /// \param AddDeclarations Whether or not to add \c let or \c var keywords to
  /// the return value bindings.
  void addAwaitCall(const CallExpr *CE, const ClassifiedBlock &SuccessBlock,
                    ArrayRef<const ParamDecl *> SuccessParams,
                    const InlinePatternsToPrint &InlinePatterns,
                    const AsyncHandlerParamDesc &HandlerDesc,
                    bool AddDeclarations) {
    auto *Args = CE->getArgs();

    // Print the bindings to match the completion handler success parameters,
    // making sure to omit in the case of a Void return.
    if (!SuccessParams.empty() && !HandlerDesc.willAsyncReturnVoid()) {
      auto AllLet = true;

      // Gather the items to print for the variable bindings. This can either be
      // a param decl, or a pattern that binds it.
      using DeclOrPattern = llvm::PointerUnion<const Decl *, const Pattern *>;
      SmallVector<DeclOrPattern, 4> ToPrint;
      for (auto *Param : SuccessParams) {
        // Check if we have an inline pattern to print.
        if (auto *P = InlinePatterns.lookup(Param)) {
          if (P->hasAnyMutableBindings())
            AllLet = false;
          ToPrint.push_back(P);
          continue;
        }
        ToPrint.push_back(Param);
      }

      if (AddDeclarations) {
        if (AllLet) {
          OS << tok::kw_let;
        } else {
          OS << tok::kw_var;
        }
        OS << " ";
      }
      // 'res =' or '(res1, res2, ...) ='
      addTupleOf(ToPrint, OS, [&](DeclOrPattern Elt) {
        if (auto *P = Elt.dyn_cast<const Pattern *>()) {
          convertPattern(P);
          return;
        }
        OS << newNameFor(Elt.get<const Decl *>());
      });
      OS << " " << tok::equal << " ";
    }

    if (HandlerDesc.HasError) {
      OS << tok::kw_try << " ";
    }
    OS << "await ";

    // Try to replace the name with that of the alternative. Use the existing
    // name if for some reason that's not possible.
    bool NameAdded = false;
    if (HandlerDesc.Alternative) {
      const ValueDecl *Named = HandlerDesc.Alternative;
      if (auto *Accessor = dyn_cast<AccessorDecl>(HandlerDesc.Alternative))
        Named = Accessor->getStorage();
      if (!Named->getBaseName().isSpecial()) {
        Names.try_emplace(HandlerDesc.Func,
                          Named->getBaseName().getIdentifier());
        convertNode(CE->getFn(), /*StartOverride=*/{}, /*ConvertCalls=*/false,
                    /*IncludeComments=*/false);
        NameAdded = true;
      }
    }
    if (!NameAdded) {
      addRange(CE->getStartLoc(), CE->getFn()->getEndLoc(),
               /*ToEndOfToken=*/true);
    }

    if (!HandlerDesc.alternativeIsAccessor())
      OS << tok::l_paren;

    size_t ConvertedArgIndex = 0;
    ArrayRef<ParamDecl *> AlternativeParams;
    if (HandlerDesc.Alternative)
      AlternativeParams = HandlerDesc.Alternative->getParameters()->getArray();

    for (auto I : indices(*Args)) {
      auto Arg = Args->get(I);
      auto *ArgExpr = Arg.getExpr();
      if (I == HandlerDesc.Index || isa<DefaultArgumentExpr>(ArgExpr))
        continue;

      if (ConvertedArgIndex > 0)
        OS << tok::comma << " ";

      if (HandlerDesc.Alternative) {
        // Skip argument if it's defaulted and has a different name
        while (ConvertedArgIndex < AlternativeParams.size() &&
               AlternativeParams[ConvertedArgIndex]->isDefaultArgument() &&
               AlternativeParams[ConvertedArgIndex]->getArgumentName() !=
               Arg.getLabel()) {
          ConvertedArgIndex++;
        }

        if (ConvertedArgIndex < AlternativeParams.size()) {
          // Could have a different argument label (or none), so add it instead
          auto Name = AlternativeParams[ConvertedArgIndex]->getArgumentName();
          if (!Name.empty())
            OS << Name << ": ";
          convertNode(ArgExpr, /*StartOverride=*/{}, /*ConvertCalls=*/false);

          ConvertedArgIndex++;
          continue;
        }

        // Fallthrough if arguments don't match up for some reason
      }

      // Can't just add the range as we need to perform replacements. Also
      // make sure to include the argument label (if any)
      convertNode(ArgExpr, /*StartOverride=*/Arg.getLabelLoc(),
                  /*ConvertCalls=*/false);
      ConvertedArgIndex++;
    }

    if (!HandlerDesc.alternativeIsAccessor())
      OS << tok::r_paren;
  }

  void addFallbackCatch(const ClosureCallbackParams &Params) {
    auto *ErrParam = Params.getErrParam();
    assert(ErrParam);
    auto ErrName = newNameFor(ErrParam);
    OS << tok::r_brace << " " << tok::kw_catch << " " << tok::l_brace << "\n"
       << ErrName << " = error\n";

    // If we have a known Bool success param, we need to bind it.
    addBoolFlagParamBindingIfNeeded(Params.getKnownBoolFlagParam(),
                                    BlockKind::ERROR);
    OS << tok::r_brace;
  }

  void addCatch(const ParamDecl *ErrParam) {
    OS << "\n" << tok::r_brace << " " << tok::kw_catch << " ";
    auto ErrName = newNameFor(ErrParam, false);
    if (!ErrName.empty() && ErrName != "_") {
      OS << tok::kw_let << " " << ErrName << " ";
    }
    OS << tok::l_brace;
  }

  void preparePlaceholdersAndUnwraps(AsyncHandlerDesc HandlerDesc,
                                     const ClosureCallbackParams &Params,
                                     BlockKind Block) {
    // Params that have been dropped always need placeholdering.
    for (auto *Param : Params.getAllParams()) {
      if (!Params.hasBinding(Param, Block))
        Placeholders.insert(Param);
    }
    // For the fallback case, no other params need placeholdering, as they are
    // all freely accessible in the fallback case.
    if (Block == BlockKind::FALLBACK)
      return;

    switch (HandlerDesc.Type) {
    case HandlerType::PARAMS: {
      auto *ErrParam = Params.getErrParam();
      auto SuccessParams = Params.getSuccessParams();
      switch (Block) {
      case BlockKind::FALLBACK:
        llvm_unreachable("Already handled");
      case BlockKind::ERROR:
        if (ErrParam) {
          if (HandlerDesc.shouldUnwrap(ErrParam->getType())) {
            Placeholders.insert(ErrParam);
            Unwraps.insert(ErrParam);
          }
          // Can't use success params in the error body
          Placeholders.insert(SuccessParams.begin(), SuccessParams.end());
        }
        break;
      case BlockKind::SUCCESS:
        for (auto *SuccessParam : SuccessParams) {
          auto Ty = SuccessParam->getType();
          if (HandlerDesc.shouldUnwrap(Ty)) {
            // Either unwrap or replace with a placeholder if there's some other
            // reference
            Unwraps.insert(SuccessParam);
            Placeholders.insert(SuccessParam);
          }

          // Void parameters get omitted where possible, so turn any reference
          // into a placeholder, as its usage is unlikely what the user wants.
          if (HandlerDesc.getSuccessParamAsyncReturnType(Ty)->isVoid())
            Placeholders.insert(SuccessParam);
        }
        // Can't use the error param in the success body
        if (ErrParam)
          Placeholders.insert(ErrParam);
        break;
      }
      break;
    }
    case HandlerType::RESULT: {
      // Any uses of the result parameter in the current body (that aren't
      // replaced) are invalid, so replace them with a placeholder.
      auto *ResultParam = Params.getResultParam();
      assert(ResultParam);
      Placeholders.insert(ResultParam);
      break;
    }
    default:
      llvm_unreachable("Unhandled handler type");
    }
  }

  /// Add a mapping from each passed parameter to a new name, possibly
  /// synthesizing a new one if hoisting it would cause a redeclaration or
  /// shadowing. If there's no bound name and \c AddIfMissing is false, no
  /// name will be added.
  void prepareNames(const ClassifiedBlock &Block,
                    ArrayRef<const ParamDecl *> Params,
                    const InlinePatternsToPrint &InlinePatterns,
                    bool AddIfMissing = true) {
    for (auto *PD : Params) {
      // If this param is to be replaced by a pattern that binds multiple
      // separate vars, it's not actually going to be added to the scope, and
      // therefore doesn't need naming. This avoids needing to rename a var with
      // the same name later on in the scope, as it's not actually clashing.
      if (auto *P = InlinePatterns.lookup(PD)) {
        if (!P->getSingleVar())
          continue;
      }
      auto Name = Block.boundName(PD);
      if (Name.empty() && !AddIfMissing)
        continue;

      auto Ident = assignUniqueName(PD, Name);

      // Also propagate the name to any aliases.
      for (auto *Alias : Block.getAliasesFor(PD))
        Names[Alias] = Ident;
    }
  }

  /// Returns a unique name using \c Name as base that doesn't clash with any
  /// other names in the current scope.
  Identifier createUniqueName(StringRef Name) {
    Identifier Ident = getASTContext().getIdentifier(Name);
    if (Name == "_")
      return Ident;

    auto &CurrentNames = Scopes.back().Names;
    if (CurrentNames.count(Ident)) {
      // Add a number to the end of the name until it's unique given the current
      // names in scope.
      llvm::SmallString<32> UniquedName;
      unsigned UniqueId = 1;
      do {
        UniquedName = Name;
        UniquedName.append(std::to_string(UniqueId));
        Ident = getASTContext().getIdentifier(UniquedName);
        UniqueId++;
      } while (CurrentNames.count(Ident));
    }
    return Ident;
  }

  /// Create a unique name for the variable declared by \p D that doesn't
  /// clash with any other names in scope, using \p BoundName as the base name
  /// if not empty and the name of \p D otherwise. Adds this name to both
  /// \c Names and the current scope's names (\c Scopes.Names).
  Identifier assignUniqueName(const Decl *D, StringRef BoundName) {
    Identifier Ident;
    if (BoundName.empty()) {
      BoundName = getDeclName(D).userFacingName();
      if (BoundName.empty())
        return Ident;
    }

    if (BoundName.startswith("$")) {
      llvm::SmallString<8> NewName;
      NewName.append("val");
      NewName.append(BoundName.drop_front());
      Ident = createUniqueName(NewName);
    } else {
      Ident = createUniqueName(BoundName);
    }

    Names.try_emplace(D, Ident);
    Scopes.back().Names.insert(Ident);
    return Ident;
  }

  StringRef newNameFor(const Decl *D, bool Required = true) {
    auto Res = Names.find(D);
    if (Res == Names.end()) {
      assert(!Required && "Missing name for decl when one was required");
      return StringRef();
    }
    return Res->second.str();
  }

  void addNewScope(const llvm::DenseSet<const Decl *> &Decls) {
    if (Scopes.empty()) {
      Scopes.emplace_back(/*ContinuationName=*/Identifier());
    } else {
      // If the parent scope is nested in a continuation, the new one is also.
      // Carry over the continuation name.
      Identifier PreviousContinuationName = Scopes.back().ContinuationName;
      Scopes.emplace_back(PreviousContinuationName);
    }
    for (auto D : Decls) {
      auto Name = getDeclName(D);
      if (!Name.empty())
        Scopes.back().Names.insert(Name);
    }
  }

  void clearNames(ArrayRef<const ParamDecl *> Params) {
    for (auto *Param : Params) {
      Unwraps.erase(Param);
      Placeholders.erase(Param);
      Names.erase(Param);
    }
  }

  /// Adds a forwarding call to the old completion handler function, with
  /// \p HandlerReplacement that allows for a custom replacement or, if empty,
  /// removal of the completion handler closure.
  void addForwardingCallTo(const FuncDecl *FD, StringRef HandlerReplacement) {
    OS << FD->getBaseName() << tok::l_paren;

    auto *Params = FD->getParameters();
    size_t ConvertedArgsIndex = 0;
    for (size_t I = 0, E = Params->size(); I < E; ++I) {
      if (I == TopHandler.Index) {
        /// If we're not replacing the handler with anything, drop it.
        if (HandlerReplacement.empty())
          continue;

        // Use a trailing closure if the handler is the last param
        if (I == E - 1) {
          OS << tok::r_paren << " ";
          OS << HandlerReplacement;
          return;
        }

        // Otherwise fall through to do the replacement.
      }

      if (ConvertedArgsIndex > 0)
        OS << tok::comma << " ";

      const auto *Param = Params->get(I);
      if (!Param->getArgumentName().empty())
        OS << Param->getArgumentName() << tok::colon << " ";

      if (I == TopHandler.Index) {
        OS << HandlerReplacement;
      } else {
        OS << Param->getParameterName();
      }

      ConvertedArgsIndex++;
    }
    OS << tok::r_paren;
  }

  /// Adds a forwarded error argument to a completion handler call. If the error
  /// type of \p HandlerDesc is more specialized than \c Error, an
  /// 'as! CustomError' cast to the more specialized error type will be added to
  /// the output stream.
  void addForwardedErrorArgument(StringRef ErrorName,
                                 const AsyncHandlerDesc &HandlerDesc) {
    // If the error type is already Error, we can pass it as-is.
    auto ErrorType = *HandlerDesc.getErrorType();
    if (ErrorType->getCanonicalType() ==
        getASTContext().getErrorExistentialType()) {
      OS << ErrorName;
      return;
    }

    // Otherwise we need to add a force cast to the destination custom error
    // type. If this is for an Error? parameter, we'll need to add parens around
    // the cast to silence a compiler warning about force casting never
    // producing nil.
    auto RequiresParens = HandlerDesc.getErrorParam().has_value();
    if (RequiresParens)
      OS << tok::l_paren;

    OS << ErrorName << " " << tok::kw_as << tok::exclaim_postfix << " ";
    ErrorType->lookThroughSingleOptionalType()->print(OS);

    if (RequiresParens)
      OS << tok::r_paren;
  }

  /// If \p T has a natural default value like \c nil for \c Optional or \c ()
  /// for \c Void, add that default value to the output. Otherwise, add a
  /// placeholder that contains \p T's name as the hint.
  void addDefaultValueOrPlaceholder(Type T) {
    if (T->isOptional()) {
      OS << tok::kw_nil;
    } else if (T->isVoid()) {
      OS << "()";
    } else {
      OS << "<#";
      T.print(OS);
      OS << "#>";
    }
  }

  /// Adds the \c Index -th parameter to the completion handler described by \p
  /// HanderDesc.
  /// If \p ResultName is not empty, it is assumed that a variable with that
  /// name contains the result returned from the async alternative. If the
  /// callback also takes an error parameter, \c nil passed to the completion
  /// handler for the error. If \p ResultName is empty, it is a assumed that a
  /// variable named 'error' contains the error thrown from the async method and
  /// 'nil' will be passed to the completion handler for all result parameters.
  void addCompletionHandlerArgument(size_t Index, StringRef ResultName,
                                    const AsyncHandlerDesc &HandlerDesc) {
    if (HandlerDesc.HasError && Index == HandlerDesc.params().size() - 1) {
      // The error parameter is the last argument of the completion handler.
      if (ResultName.empty()) {
        addForwardedErrorArgument("error", HandlerDesc);
      } else {
        addDefaultValueOrPlaceholder(HandlerDesc.params()[Index].getPlainType());
      }
    } else {
      if (ResultName.empty()) {
        addDefaultValueOrPlaceholder(HandlerDesc.params()[Index].getPlainType());
      } else if (HandlerDesc
                     .getSuccessParamAsyncReturnType(
                         HandlerDesc.params()[Index].getPlainType())
                     ->isVoid()) {
        // Void return types are not returned by the async function, synthesize
        // a Void instance.
        OS << tok::l_paren << tok::r_paren;
      } else if (HandlerDesc.getSuccessParams().size() > 1) {
        // If the async method returns a tuple, we need to pass its elements to
        // the completion handler separately. For example:
        //
        // func foo() async -> (String, Int) {}
        //
        // causes the following legacy body to be created:
        //
        // func foo(completion: (String, Int) -> Void) {
        //   Task {
        //     let result = await foo()
        //     completion(result.0, result.1)
        //   }
        // }
        OS << ResultName << tok::period;

        auto Label = HandlerDesc.getAsyncReturnTypeLabel(Index);
        if (!Label.empty()) {
          OS << Label;
        } else {
          OS << Index;
        }
      } else {
        OS << ResultName;
      }
    }
  }

  /// Add a call to the completion handler named \p HandlerName and described by
  /// \p HandlerDesc, passing all the required arguments. See \c
  /// getCompletionHandlerArgument for how the arguments are synthesized.
  void addCallToCompletionHandler(StringRef ResultName,
                                  const AsyncHandlerDesc &HandlerDesc,
                                  StringRef HandlerName) {
    OS << HandlerName << tok::l_paren;

    // Construct arguments to pass to the completion handler
    switch (HandlerDesc.Type) {
    case HandlerType::INVALID:
      llvm_unreachable("Cannot be rewritten");
      break;
    case HandlerType::PARAMS: {
      for (size_t I = 0; I < HandlerDesc.params().size(); ++I) {
        if (I > 0) {
          OS << tok::comma << " ";
        }
        addCompletionHandlerArgument(I, ResultName, HandlerDesc);
      }
      break;
    }
    case HandlerType::RESULT: {
      if (!ResultName.empty()) {
        OS << tok::period_prefix << "success" << tok::l_paren;
        if (!HandlerDesc.willAsyncReturnVoid()) {
          OS << ResultName;
        } else {
          OS << tok::l_paren << tok::r_paren;
        }
        OS << tok::r_paren;
      } else {
        OS << tok::period_prefix << "failure" << tok::l_paren;
        addForwardedErrorArgument("error", HandlerDesc);
        OS << tok::r_paren;
      }
      break;
    }
    }
    OS << tok::r_paren; // Close the call to the completion handler
  }

  /// Adds the result type of a refactored async function that previously
  /// returned results via a completion handler described by \p HandlerDesc.
  void addAsyncFuncReturnType(const AsyncHandlerDesc &HandlerDesc) {
    // Type or (Type1, Type2, ...)
    SmallVector<LabeledReturnType, 2> Scratch;
    auto ReturnTypes = HandlerDesc.getAsyncReturnTypes(Scratch);
    if (ReturnTypes.empty()) {
      OS << "Void";
    } else {
      addTupleOf(ReturnTypes, OS, [&](LabeledReturnType LabelAndType) {
        if (!LabelAndType.Label.empty()) {
          OS << LabelAndType.Label << tok::colon << " ";
        }
        LabelAndType.Ty->print(OS);
      });
    }
  }

  /// If \p FD is generic, adds a type annotation with the return type of the
  /// converted async function. This is used when creating a legacy function,
  /// calling the converted 'async' function so that the generic parameters of
  /// the legacy function are passed to the generic function. For example for
  /// \code
  /// func foo<GenericParam>() async -> GenericParam {}
  /// \endcode
  /// we generate
  /// \code
  /// func foo<GenericParam>(completion: (GenericParam) -> Void) {
  ///   Task {
  ///     let result: GenericParam = await foo()
  ///               <------------>
  ///     completion(result)
  ///   }
  /// }
  /// \endcode
  /// This function adds the range marked by \c <----->
  void addResultTypeAnnotationIfNecessary(const FuncDecl *FD,
                                          const AsyncHandlerDesc &HandlerDesc) {
    if (FD->isGeneric()) {
      OS << tok::colon << " ";
      addAsyncFuncReturnType(HandlerDesc);
    }
  }
};

} // namespace asyncrefactorings

bool RefactoringActionConvertCallToAsyncAlternative::isApplicable(
    ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &Diag) {
  using namespace asyncrefactorings;

  // Currently doesn't check that the call is in an async context. This seems
  // possibly useful in some situations, so we'll see what the feedback is.
  // May need to change in the future
  auto *CE = findOuterCall(CursorInfo);
  if (!CE)
    return false;

  auto HandlerDesc = AsyncHandlerParamDesc::find(
      getUnderlyingFunc(CE->getFn()), /*RequireAttributeOrName=*/false);
  return HandlerDesc.isValid();
}

/// Converts a call of a function with a possible async alternative, to use it
/// instead. Currently this is any function that
///   1. has a void return type,
///   2. has a void returning closure as its last parameter, and
///   3. is not already async
///
/// For now the call need not be in an async context, though this may change
/// depending on feedback.
bool RefactoringActionConvertCallToAsyncAlternative::performChange() {
  using namespace asyncrefactorings;

  auto *CE = findOuterCall(CursorInfo);
  assert(CE &&
         "Should not run performChange when refactoring is not applicable");

  // Find the scope this call is in
  ContextFinder Finder(
      *CursorInfo->getSourceFile(), CursorInfo->getLoc(),
      [](ASTNode N) { return N.isStmt(StmtKind::Brace) && !N.isImplicit(); });
  Finder.resolve();
  auto Scopes = Finder.getContexts();
  BraceStmt *Scope = nullptr;
  if (!Scopes.empty())
    Scope = cast<BraceStmt>(Scopes.back().get<Stmt *>());

  AsyncConverter Converter(TheFile, SM, DiagEngine, CE, Scope);
  if (!Converter.convert())
    return true;

  Converter.replace(CE, EditConsumer);
  return false;
}

bool RefactoringActionConvertToAsync::isApplicable(
    ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &Diag) {
  using namespace asyncrefactorings;

  // As with the call refactoring, should possibly only apply if there's
  // actually calls to async alternatives. At the moment this will just add
  // `async` if there are no calls, which is probably fine.
  return findFunction(CursorInfo);
}

/// Converts a whole function to async, converting any calls to functions with
/// async alternatives as above.
bool RefactoringActionConvertToAsync::performChange() {
  using namespace asyncrefactorings;

  auto *FD = findFunction(CursorInfo);
  assert(FD &&
         "Should not run performChange when refactoring is not applicable");

  auto HandlerDesc =
      AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
  AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc);
  if (!Converter.convert())
    return true;

  Converter.replace(FD, EditConsumer, FD->getSourceRangeIncludingAttrs().Start);
  return false;
}

bool RefactoringActionAddAsyncAlternative::isApplicable(
    ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &Diag) {
  using namespace asyncrefactorings;

  auto *FD = findFunction(CursorInfo);
  if (!FD)
    return false;

  auto HandlerDesc =
      AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
  return HandlerDesc.isValid();
}

/// Adds an async alternative and marks the current function as deprecated.
/// Equivalent to the conversion but
///   1. only works on functions that themselves are a possible async
///      alternative, and
///   2. has extra handling to convert the completion/handler/callback closure
///      parameter to either `return`/`throws`
bool RefactoringActionAddAsyncAlternative::performChange() {
  using namespace asyncrefactorings;

  auto *FD = findFunction(CursorInfo);
  assert(FD &&
         "Should not run performChange when refactoring is not applicable");

  auto HandlerDesc =
      AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
  assert(HandlerDesc.isValid() &&
         "Should not run performChange when refactoring is not applicable");

  AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc);
  if (!Converter.convert())
    return true;

  // Add a reference to the async function so that warnings appear when the
  // synchronous function is used in an async context
  SmallString<128> AvailabilityAttr = HandlerDesc.buildRenamedAttribute();
  EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
                      AvailabilityAttr);

  AsyncConverter LegacyBodyCreator(TheFile, SM, DiagEngine, FD, HandlerDesc);
  if (LegacyBodyCreator.createLegacyBody()) {
    LegacyBodyCreator.replace(FD->getBody(), EditConsumer);
  }

  // Add the async alternative
  Converter.insertAfter(FD, EditConsumer);

  return false;
}

bool RefactoringActionAddAsyncWrapper::isApplicable(
    ResolvedCursorInfoPtr CursorInfo, DiagnosticEngine &Diag) {
  using namespace asyncrefactorings;

  auto *FD = findFunction(CursorInfo);
  if (!FD)
    return false;

  auto HandlerDesc =
      AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
  return HandlerDesc.isValid();
}

bool RefactoringActionAddAsyncWrapper::performChange() {
  using namespace asyncrefactorings;

  auto *FD = findFunction(CursorInfo);
  assert(FD &&
         "Should not run performChange when refactoring is not applicable");

  auto HandlerDesc =
      AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
  assert(HandlerDesc.isValid() &&
         "Should not run performChange when refactoring is not applicable");

  AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc);
  if (!Converter.createAsyncWrapper())
    return true;

  // Add a reference to the async function so that warnings appear when the
  // synchronous function is used in an async context
  SmallString<128> AvailabilityAttr = HandlerDesc.buildRenamedAttribute();
  EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
                      AvailabilityAttr);

  // Add the async wrapper.
  Converter.insertAfter(FD, EditConsumer);

  return false;
}

/// Retrieve the macro expansion buffer for the given macro expansion
/// expression.
static Optional<unsigned> getMacroExpansionBuffer(
    SourceManager &sourceMgr, MacroExpansionExpr *expansion) {
  if (auto rewritten = expansion->getRewritten()) {
    return sourceMgr.findBufferContainingLoc(rewritten->getStartLoc());
  }

  return None;
}

/// Retrieve the macro expansion buffer for the given macro expansion
/// declaration.
static Optional<unsigned>
getMacroExpansionBuffer(SourceManager &sourceMgr,
                        MacroExpansionDecl *expansion) {
  return evaluateOrDefault(expansion->getASTContext().evaluator,
                           ExpandMacroExpansionDeclRequest{expansion}, {});
}

/// Retrieve the macro expansion buffers for the given attached macro reference.
static llvm::SmallVector<unsigned, 2>
getMacroExpansionBuffers(MacroDecl *macro, const CustomAttr *attr, Decl *decl) {
  auto roles = macro->getMacroRoles() & getAttachedMacroRoles();
  if (!roles)
    return { };

  ASTContext &ctx = macro->getASTContext();
  llvm::SmallVector<unsigned, 2> allBufferIDs;
  if (roles.contains(MacroRole::Accessor)) {
    if (auto storage = dyn_cast<AbstractStorageDecl>(decl)) {
      auto bufferIDs = evaluateOrDefault(
          ctx.evaluator, ExpandAccessorMacros{storage}, { });
      allBufferIDs.append(bufferIDs.begin(), bufferIDs.end());
    }
  }

  if (roles.contains(MacroRole::MemberAttribute)) {
    if (auto idc = dyn_cast<IterableDeclContext>(decl)) {
      for (auto memberDecl : idc->getAllMembers()) {
        auto bufferIDs = evaluateOrDefault(
            ctx.evaluator, ExpandMemberAttributeMacros{memberDecl}, { });
        allBufferIDs.append(bufferIDs.begin(), bufferIDs.end());
      }
    }
  }

  if (roles.contains(MacroRole::Member)) {
    auto bufferIDs = evaluateOrDefault(
        ctx.evaluator, ExpandSynthesizedMemberMacroRequest{decl}, { });
    allBufferIDs.append(bufferIDs.begin(), bufferIDs.end());
  }

  if (roles.contains(MacroRole::Peer)) {
    auto bufferIDs = evaluateOrDefault(
        ctx.evaluator, ExpandPeerMacroRequest{decl}, { });
    allBufferIDs.append(bufferIDs.begin(), bufferIDs.end());
  }

  if (roles.contains(MacroRole::Conformance)) {
    if (auto nominal = dyn_cast<NominalTypeDecl>(decl)) {
      auto bufferIDs = evaluateOrDefault(
          ctx.evaluator, ExpandConformanceMacros{nominal}, { });
      allBufferIDs.append(bufferIDs.begin(), bufferIDs.end());
    }
  }

  // Drop any buffers that come from other macros. We could eliminate this
  // step by adding more fine-grained requests above, which only expand for a
  // single custom attribute.
  SourceManager &sourceMgr = ctx.SourceMgr;
  auto removedAt = std::remove_if(
      allBufferIDs.begin(), allBufferIDs.end(),
      [&](unsigned bufferID) {
        auto generatedInfo = sourceMgr.getGeneratedSourceInfo(bufferID);
        if (!generatedInfo)
          return true;

        return generatedInfo->attachedMacroCustomAttr != attr;
      });
  allBufferIDs.erase(removedAt, allBufferIDs.end());
  return allBufferIDs;
}

/// Given a resolved cursor, determine whether it is for a macro expansion and
/// return the list of macro expansion buffer IDs that are associated with the
/// macro reference here.
static llvm::SmallVector<unsigned, 2>
getMacroExpansionBuffers(SourceManager &sourceMgr, ResolvedCursorInfoPtr Info) {
  auto *refInfo = dyn_cast<ResolvedValueRefCursorInfo>(Info);
  if (!refInfo || !refInfo->isRef())
    return {};

  auto *macro = dyn_cast_or_null<MacroDecl>(refInfo->getValueD());
  if (!macro)
    return {};

  // Attached macros
  if (auto customAttrRef = refInfo->getCustomAttrRef()) {
    auto macro = cast<MacroDecl>(refInfo->getValueD());
    return getMacroExpansionBuffers(macro, customAttrRef->first,
                                    customAttrRef->second);
  }

  // FIXME: A resolved cursor should contain a slice up to its reference.
  // We shouldn't need to find it again.
  ContextFinder Finder(*Info->getSourceFile(), Info->getLoc(), [&](ASTNode N) {
    if (auto *expr =
            dyn_cast_or_null<MacroExpansionExpr>(N.dyn_cast<Expr *>())) {
      return expr->getStartLoc() == Info->getLoc() ||
             expr->getMacroNameLoc().getBaseNameLoc() == Info->getLoc();
    } else if (auto *decl =
                   dyn_cast_or_null<MacroExpansionDecl>(N.dyn_cast<Decl *>())) {
      return decl->getStartLoc() == Info->getLoc() ||
             decl->getMacroNameLoc().getBaseNameLoc() == Info->getLoc();
    }

    return false;
  });
  Finder.resolve();

  if (!Finder.getContexts().empty()) {
    Optional<unsigned> bufferID;
    if (auto *target = dyn_cast_or_null<MacroExpansionExpr>(
            Finder.getContexts()[0].dyn_cast<Expr *>())) {
      bufferID = getMacroExpansionBuffer(sourceMgr, target);
    } else if (auto *target = dyn_cast_or_null<MacroExpansionDecl>(
                   Finder.getContexts()[0].dyn_cast<Decl *>())) {
      bufferID = getMacroExpansionBuffer(sourceMgr, target);
    }

    if (bufferID)
      return {*bufferID};
  }

  return {};
}

bool RefactoringActionExpandMacro::isApplicable(ResolvedCursorInfoPtr Info,
                                                DiagnosticEngine &Diag) {
  return !getMacroExpansionBuffers(Diag.SourceMgr, Info).empty();
}

bool RefactoringActionExpandMacro::performChange() {
  auto bufferIDs = getMacroExpansionBuffers(SM, CursorInfo);
  if (bufferIDs.empty())
    return true;

  // Send all of the rewritten buffer snippets.
  CustomAttr *attachedMacroAttr = nullptr;
  for (auto bufferID: bufferIDs) {
    auto generatedInfo = SM.getGeneratedSourceInfo(bufferID);
    if (!generatedInfo || generatedInfo->originalSourceRange.isInvalid())
      continue;

    auto rewrittenBuffer = SM.extractText(generatedInfo->generatedSourceRange);

    // If there's no change, drop the edit entirely.
    if (generatedInfo->originalSourceRange.getStart() ==
          generatedInfo->originalSourceRange.getEnd() &&
        rewrittenBuffer.empty())
      continue;

    // `TheFile` is the file of the actual expansion site, where as
    // `OriginalFile` is the possibly enclosing buffer. Concretely:
    // ```
    // // m.swift
    // @AddMemberAttributes
    // struct Foo {
    //   // --- expanded from @AddMemberAttributes eg. @_someBufferName ---
    //   @AddedAttribute
    //   // ---
    //   let someMember: Int
    // }
    // ```
    //
    // When expanding `AddedAttribute`, the expansion actually applies to the
    // original source (`m.swift`) rather than the buffer of the expansion
    // site (`@_someBufferName`). Thus, we need to include the path to the
    // original source as well. Note that this path could itself be another
    // expansion.
    auto originalSourceRange = generatedInfo->originalSourceRange;
    SourceFile *originalFile =
        MD->getSourceFileContainingLocation(originalSourceRange.getStart());
    StringRef originalPath;
    if (originalFile->getBufferID().hasValue() &&
        TheFile->getBufferID() != originalFile->getBufferID()) {
      originalPath = SM.getIdentifierForBuffer(*originalFile->getBufferID());
    }

    EditConsumer.accept(SM, {originalPath,
                             originalSourceRange,
                             SM.getIdentifierForBuffer(bufferID),
                             rewrittenBuffer,
                             {}});

    if (generatedInfo->attachedMacroCustomAttr && !attachedMacroAttr)
      attachedMacroAttr = generatedInfo->attachedMacroCustomAttr;
  }

  // For an attached macro, remove the custom attribute; it's been fully
  // subsumed by its expansions.
  if (attachedMacroAttr) {
    SourceRange range = attachedMacroAttr->getRangeWithAt();
    auto charRange = Lexer::getCharSourceRangeFromSourceRange(SM, range);
    EditConsumer.accept(SM, charRange, StringRef());
  }

  return false;
}

} // end of anonymous namespace

StringRef swift::ide::
getDescriptiveRefactoringKindName(RefactoringKind Kind) {
    switch(Kind) {
      case RefactoringKind::None:
        llvm_unreachable("Should be a valid refactoring kind");
#define REFACTORING(KIND, NAME, ID) case RefactoringKind::KIND: return NAME;
#include "swift/Refactoring/RefactoringKinds.def"
    }
    llvm_unreachable("unhandled kind");
  }

  StringRef swift::ide::
  getDescriptiveRenameUnavailableReason(RenameAvailableKind Kind) {
    switch(Kind) {
      case RenameAvailableKind::Available:
        return "";
      case RenameAvailableKind::Unavailable_system_symbol:
        return "symbol from system module cannot be renamed";
      case RenameAvailableKind::Unavailable_has_no_location:
        return "symbol without a declaration location cannot be renamed";
      case RenameAvailableKind::Unavailable_has_no_name:
        return "cannot find the name of the symbol";
      case RenameAvailableKind::Unavailable_has_no_accessibility:
        return "cannot decide the accessibility of the symbol";
      case RenameAvailableKind::Unavailable_decl_from_clang:
        return "cannot rename a Clang symbol from its Swift reference";
    }
    llvm_unreachable("unhandled kind");
  }

SourceLoc swift::ide::RangeConfig::getStart(SourceManager &SM) {
  return SM.getLocForLineCol(BufferID, Line, Column);
}

SourceLoc swift::ide::RangeConfig::getEnd(SourceManager &SM) {
  return getStart(SM).getAdvancedLoc(Length);
}

struct swift::ide::FindRenameRangesAnnotatingConsumer::Implementation {
  std::unique_ptr<SourceEditConsumer> pRewriter;
  Implementation(SourceManager &SM, unsigned BufferId, raw_ostream &OS)
  : pRewriter(new SourceEditOutputConsumer(SM, BufferId, OS)) {}
  static StringRef tag(RefactoringRangeKind Kind) {
    switch (Kind) {
      case RefactoringRangeKind::BaseName:
        return "base";
      case RefactoringRangeKind::KeywordBaseName:
        return "keywordBase";
      case RefactoringRangeKind::ParameterName:
        return "param";
      case RefactoringRangeKind::NoncollapsibleParameterName:
        return "noncollapsibleparam";
      case RefactoringRangeKind::DeclArgumentLabel:
        return "arglabel";
      case RefactoringRangeKind::CallArgumentLabel:
        return "callarg";
      case RefactoringRangeKind::CallArgumentColon:
        return "callcolon";
      case RefactoringRangeKind::CallArgumentCombined:
        return "callcombo";
      case RefactoringRangeKind::SelectorArgumentLabel:
        return "sel";
    }
    llvm_unreachable("unhandled kind");
  }
  void accept(SourceManager &SM, const RenameRangeDetail &Range) {
    std::string NewText;
    llvm::raw_string_ostream OS(NewText);
    StringRef Tag = tag(Range.RangeKind);
    OS << "<" << Tag;
    if (Range.Index.has_value())
      OS << " index=" << *Range.Index;
    OS << ">" << Range.Range.str() << "</" << Tag << ">";
    pRewriter->accept(SM, {/*Path=*/{}, Range.Range, /*BufferName=*/{},
                           OS.str(), /*RegionsWorthNote=*/{}});
  }
};

swift::ide::FindRenameRangesAnnotatingConsumer::
FindRenameRangesAnnotatingConsumer(SourceManager &SM, unsigned BufferId,
                                   raw_ostream &OS) :
    Impl(*new Implementation(SM, BufferId, OS)) {}

swift::ide::FindRenameRangesAnnotatingConsumer::~FindRenameRangesAnnotatingConsumer() {
  delete &Impl;
}

void swift::ide::FindRenameRangesAnnotatingConsumer::
accept(SourceManager &SM, RegionType RegionType,
       ArrayRef<RenameRangeDetail> Ranges) {
  if (RegionType == RegionType::Mismatch || RegionType == RegionType::Unmatched)
    return;
  for (const auto &Range : Ranges) {
    Impl.accept(SM, Range);
  }
}

Optional<RenameAvailabilityInfo>
swift::ide::renameAvailabilityInfo(const ValueDecl *VD,
                                   Optional<RenameRefInfo> RefInfo) {
  RenameAvailableKind AvailKind = RenameAvailableKind::Available;
  if (getRelatedSystemDecl(VD)){
    AvailKind = RenameAvailableKind::Unavailable_system_symbol;
  } else if (VD->getClangDecl()) {
    AvailKind = RenameAvailableKind::Unavailable_decl_from_clang;
  } else if (VD->getLoc().isInvalid()) {
    AvailKind = RenameAvailableKind::Unavailable_has_no_location;
  } else if (!VD->hasName()) {
    AvailKind = RenameAvailableKind::Unavailable_has_no_name;
  }

  if (isa<AbstractFunctionDecl>(VD)) {
    // Disallow renaming accessors.
    if (isa<AccessorDecl>(VD))
      return None;

    // Disallow renaming deinit.
    if (isa<DestructorDecl>(VD))
      return None;

    // Disallow renaming init with no arguments.
    if (auto CD = dyn_cast<ConstructorDecl>(VD)) {
      if (!CD->getParameters()->size())
        return None;

      if (RefInfo && !RefInfo->IsArgLabel) {
        NameMatcher Matcher(*(RefInfo->SF));
        auto Resolved = Matcher.resolve({RefInfo->Loc, /*ResolveArgs*/true});
        if (Resolved.LabelRanges.empty())
          return None;
      }
    }

    // Disallow renaming 'callAsFunction' method with no arguments.
    if (auto FD = dyn_cast<FuncDecl>(VD)) {
      // FIXME: syntactic rename can only decide by checking the spelling, not
      // whether it's an instance method, so we do the same here for now.
      if (FD->getBaseIdentifier() == FD->getASTContext().Id_callAsFunction) {
        if (!FD->getParameters()->size())
          return None;

        if (RefInfo && !RefInfo->IsArgLabel) {
          NameMatcher Matcher(*(RefInfo->SF));
          auto Resolved = Matcher.resolve({RefInfo->Loc, /*ResolveArgs*/true});
          if (Resolved.LabelRanges.empty())
            return None;
        }
      }
    }
  }

  // Always return local rename for parameters.
  // FIXME: if the cursor is on the argument, we should return global rename.
  if (isa<ParamDecl>(VD))
    return RenameAvailabilityInfo{RefactoringKind::LocalRename, AvailKind};

  // If the indexer considers VD a global symbol, then we apply global rename.
  if (index::isLocalSymbol(VD))
    return RenameAvailabilityInfo{RefactoringKind::LocalRename, AvailKind};
  return RenameAvailabilityInfo{RefactoringKind::GlobalRename, AvailKind};
}

void swift::ide::collectAvailableRefactorings(
    ResolvedCursorInfoPtr CursorInfo, SmallVectorImpl<RefactoringKind> &Kinds,
    bool ExcludeRename) {
  DiagnosticEngine DiagEngine(
      CursorInfo->getSourceFile()->getASTContext().SourceMgr);

  if (!ExcludeRename) {
    if (auto Info = getRenameInfo(CursorInfo)) {
      if (Info->Availability.AvailableKind == RenameAvailableKind::Available) {
        Kinds.push_back(Info->Availability.Kind);
      }
    }
  }

#define CURSOR_REFACTORING(KIND, NAME, ID)                                     \
  if (RefactoringKind::KIND != RefactoringKind::LocalRename &&                 \
      RefactoringAction##KIND::isApplicable(CursorInfo, DiagEngine))           \
    Kinds.push_back(RefactoringKind::KIND);
#include "swift/Refactoring/RefactoringKinds.def"
}

void swift::ide::collectAvailableRefactorings(
    SourceFile *SF, RangeConfig Range, bool &CollectRangeStartRefactorings,
    SmallVectorImpl<RefactoringKind> &Kinds,
    ArrayRef<DiagnosticConsumer *> DiagConsumers) {
  if (Range.Length == 0) {
    return collectAvailableRefactoringsAtCursor(SF, Range.Line, Range.Column,
                                                Kinds, DiagConsumers);
  }
  // Prepare the tool box.
  ASTContext &Ctx = SF->getASTContext();
  SourceManager &SM = Ctx.SourceMgr;
  DiagnosticEngine DiagEngine(SM);
  std::for_each(DiagConsumers.begin(), DiagConsumers.end(),
    [&](DiagnosticConsumer *Con) { DiagEngine.addConsumer(*Con); });
  ResolvedRangeInfo Result = evaluateOrDefault(SF->getASTContext().evaluator,
    RangeInfoRequest(RangeInfoOwner({SF,
                      Range.getStart(SF->getASTContext().SourceMgr),
                      Range.getEnd(SF->getASTContext().SourceMgr)})),
                                               ResolvedRangeInfo());

  bool enableInternalRefactoring = getenv("SWIFT_ENABLE_INTERNAL_REFACTORING_ACTIONS");

#define RANGE_REFACTORING(KIND, NAME, ID)                                      \
  if (RefactoringAction##KIND::isApplicable(Result, DiagEngine))               \
    Kinds.push_back(RefactoringKind::KIND);
#define INTERNAL_RANGE_REFACTORING(KIND, NAME, ID)                            \
  if (enableInternalRefactoring)                                              \
    RANGE_REFACTORING(KIND, NAME, ID)
#include "swift/Refactoring/RefactoringKinds.def"

  CollectRangeStartRefactorings = collectRangeStartRefactorings(Result);
}

bool swift::ide::
refactorSwiftModule(ModuleDecl *M, RefactoringOptions Opts,
                    SourceEditConsumer &EditConsumer,
                    DiagnosticConsumer &DiagConsumer) {
  assert(Opts.Kind != RefactoringKind::None && "should have a refactoring kind.");

  // Use the default name if not specified.
  if (Opts.PreferredName.empty()) {
    Opts.PreferredName = getDefaultPreferredName(Opts.Kind).str();
  }

  switch (Opts.Kind) {
#define SEMANTIC_REFACTORING(KIND, NAME, ID)                                   \
case RefactoringKind::KIND: {                                                  \
      RefactoringAction##KIND Action(M, Opts, EditConsumer, DiagConsumer);     \
      if (RefactoringKind::KIND == RefactoringKind::LocalRename ||             \
          Action.isApplicable())                                               \
        return Action.performChange();                                         \
      return true;                                                             \
  }
#include "swift/Refactoring/RefactoringKinds.def"
    case RefactoringKind::GlobalRename:
    case RefactoringKind::FindGlobalRenameRanges:
    case RefactoringKind::FindLocalRenameRanges:
      llvm_unreachable("not a valid refactoring kind");
    case RefactoringKind::None:
      llvm_unreachable("should not enter here.");
  }
  llvm_unreachable("unhandled kind");
}

static std::vector<ResolvedLoc>
resolveRenameLocations(ArrayRef<RenameLoc> RenameLocs, SourceFile &SF,
                       DiagnosticEngine &Diags) {
  SourceManager &SM = SF.getASTContext().SourceMgr;
  unsigned BufferID = SF.getBufferID().value();

  std::vector<UnresolvedLoc> UnresolvedLocs;
  for (const RenameLoc &RenameLoc : RenameLocs) {
    DeclNameViewer OldName(RenameLoc.OldName);
    SourceLoc Location = SM.getLocForLineCol(BufferID, RenameLoc.Line,
                                             RenameLoc.Column);

    if (!OldName.isValid()) {
      Diags.diagnose(Location, diag::invalid_name, RenameLoc.OldName);
      return {};
    }

    if (!RenameLoc.NewName.empty()) {
      DeclNameViewer NewName(RenameLoc.NewName);
      ArrayRef<StringRef> ParamNames = NewName.args();
      bool newOperator = Lexer::isOperator(NewName.base());
      bool NewNameIsValid = NewName.isValid() &&
        (Lexer::isIdentifier(NewName.base()) || newOperator) &&
        std::all_of(ParamNames.begin(), ParamNames.end(), [](StringRef Label) {
          return Label.empty() || Lexer::isIdentifier(Label);
        });

      if (!NewNameIsValid) {
        Diags.diagnose(Location, diag::invalid_name, RenameLoc.NewName);
        return {};
      }

      if (NewName.partsCount() != OldName.partsCount()) {
        Diags.diagnose(Location, diag::arity_mismatch, RenameLoc.NewName,
                       RenameLoc.OldName);
        return {};
      }

      if (RenameLoc.Usage == NameUsage::Call && !RenameLoc.IsFunctionLike) {
        Diags.diagnose(Location, diag::name_not_functionlike, RenameLoc.NewName);
        return {};
      }
    }

    bool isOperator = Lexer::isOperator(OldName.base());
    UnresolvedLocs.push_back({
      Location,
      (RenameLoc.Usage == NameUsage::Unknown ||
       (RenameLoc.Usage == NameUsage::Call && !isOperator))
    });
  }

  NameMatcher Resolver(SF);
  return Resolver.resolve(UnresolvedLocs, SF.getAllTokens());
}

int swift::ide::syntacticRename(SourceFile *SF, ArrayRef<RenameLoc> RenameLocs,
                                SourceEditConsumer &EditConsumer,
                                DiagnosticConsumer &DiagConsumer) {

  assert(SF && "null source file");

  SourceManager &SM = SF->getASTContext().SourceMgr;
  DiagnosticEngine DiagEngine(SM);
  DiagEngine.addConsumer(DiagConsumer);

  auto ResolvedLocs = resolveRenameLocations(RenameLocs, *SF, DiagEngine);
  if (ResolvedLocs.size() != RenameLocs.size())
    return true; // Already diagnosed.

  size_t index = 0;
  llvm::StringSet<> ReplaceTextContext;
  for(const RenameLoc &Rename: RenameLocs) {
    ResolvedLoc &Resolved = ResolvedLocs[index++];
    TextReplacementsRenamer Renamer(SM, Rename.OldName, Rename.NewName,
                                    ReplaceTextContext);
    RegionType Type = Renamer.addSyntacticRenameRanges(Resolved, Rename);
    if (Type == RegionType::Mismatch) {
      DiagEngine.diagnose(Resolved.Range.getStart(), diag::mismatched_rename,
                          Rename.NewName);
      EditConsumer.accept(SM, Type, None);
    } else {
      EditConsumer.accept(SM, Type, Renamer.getReplacements());
    }
  }

  return false;
}

int swift::ide::findSyntacticRenameRanges(
    SourceFile *SF, ArrayRef<RenameLoc> RenameLocs,
    FindRenameRangesConsumer &RenameConsumer,
    DiagnosticConsumer &DiagConsumer) {
  assert(SF && "null source file");

  SourceManager &SM = SF->getASTContext().SourceMgr;
  DiagnosticEngine DiagEngine(SM);
  DiagEngine.addConsumer(DiagConsumer);

  auto ResolvedLocs = resolveRenameLocations(RenameLocs, *SF, DiagEngine);
  if (ResolvedLocs.size() != RenameLocs.size())
    return true; // Already diagnosed.

  size_t index = 0;
  for (const RenameLoc &Rename : RenameLocs) {
    ResolvedLoc &Resolved = ResolvedLocs[index++];
    RenameRangeDetailCollector Renamer(SM, Rename.OldName);
    RegionType Type = Renamer.addSyntacticRenameRanges(Resolved, Rename);
    if (Type == RegionType::Mismatch) {
      DiagEngine.diagnose(Resolved.Range.getStart(), diag::mismatched_rename,
                          Rename.NewName);
      RenameConsumer.accept(SM, Type, None);
    } else {
      RenameConsumer.accept(SM, Type, Renamer.Ranges);
    }
  }

  return false;
}

int swift::ide::findLocalRenameRanges(
    SourceFile *SF, RangeConfig Range,
    FindRenameRangesConsumer &RenameConsumer,
    DiagnosticConsumer &DiagConsumer) {
  assert(SF && "null source file");

  SourceManager &SM = SF->getASTContext().SourceMgr;
  DiagnosticEngine Diags(SM);
  Diags.addConsumer(DiagConsumer);

  auto StartLoc = Lexer::getLocForStartOfToken(SM, Range.getStart(SM));
  Optional<RenameRangeCollector> RangeCollector =
      localRenames(SF, StartLoc, StringRef(), Diags);
  if (!RangeCollector)
    return true;

  return findSyntacticRenameRanges(SF, RangeCollector->results(),
                                   RenameConsumer, DiagConsumer);
}
back to top