[Syntax] Add a helper to find expansion by its first spelled token

Summary: Used in clangd for a code tweak that expands a macro.

Reviewers: sammccall

Reviewed By: sammccall

Subscribers: kadircet, cfe-commits

Tags: #clang

Differential Revision: https://reviews.llvm.org/D62954

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@363698 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/include/clang/Tooling/Syntax/Tokens.h b/include/clang/Tooling/Syntax/Tokens.h
index 0a0d47b..7c933bd 100644
--- a/include/clang/Tooling/Syntax/Tokens.h
+++ b/include/clang/Tooling/Syntax/Tokens.h
@@ -200,6 +200,25 @@
   llvm::Optional<llvm::ArrayRef<syntax::Token>>
   spelledForExpanded(llvm::ArrayRef<syntax::Token> Expanded) const;
 
+  /// An expansion produced by the preprocessor, includes macro expansions and
+  /// preprocessor directives. Preprocessor always maps a non-empty range of
+  /// spelled tokens to a (possibly empty) range of expanded tokens. Here is a
+  /// few examples of expansions:
+  ///    #pragma once      // Expands to an empty range.
+  ///    #define FOO 1 2 3 // Expands an empty range.
+  ///    FOO               // Expands to "1 2 3".
+  /// FIXME(ibiryukov): implement this, currently #include expansions are empty.
+  ///    #include <vector> // Expands to tokens produced by the include.
+  struct Expansion {
+    llvm::ArrayRef<syntax::Token> Spelled;
+    llvm::ArrayRef<syntax::Token> Expanded;
+  };
+  /// If \p Spelled starts a mapping (e.g. if it's a macro name or '#' starting
+  /// a preprocessor directive) return the subrange of expanded tokens that the
+  /// macro expands to.
+  llvm::Optional<Expansion>
+  expansionStartingAt(const syntax::Token *Spelled) const;
+
   /// Lexed tokens of a file before preprocessing. E.g. for the following input
   ///     #define DECL(name) int name = 10
   ///     DECL(a);
diff --git a/lib/Tooling/Syntax/Tokens.cpp b/lib/Tooling/Syntax/Tokens.cpp
index f291f18..e226237 100644
--- a/lib/Tooling/Syntax/Tokens.cpp
+++ b/lib/Tooling/Syntax/Tokens.cpp
@@ -199,6 +199,32 @@
                   : LastSpelled + 1);
 }
 
+llvm::Optional<TokenBuffer::Expansion>
+TokenBuffer::expansionStartingAt(const syntax::Token *Spelled) const {
+  assert(Spelled);
+  assert(Spelled->location().isFileID() && "not a spelled token");
+  auto FileIt = Files.find(SourceMgr->getFileID(Spelled->location()));
+  assert(FileIt != Files.end() && "file not tracked by token buffer");
+
+  auto &File = FileIt->second;
+  assert(File.SpelledTokens.data() <= Spelled &&
+         Spelled < (File.SpelledTokens.data() + File.SpelledTokens.size()));
+
+  unsigned SpelledIndex = Spelled - File.SpelledTokens.data();
+  auto M = llvm::bsearch(File.Mappings, [&](const Mapping &M) {
+    return SpelledIndex <= M.BeginSpelled;
+  });
+  if (M == File.Mappings.end() || M->BeginSpelled != SpelledIndex)
+    return llvm::None;
+
+  Expansion E;
+  E.Spelled = llvm::makeArrayRef(File.SpelledTokens.data() + M->BeginSpelled,
+                                 File.SpelledTokens.data() + M->EndSpelled);
+  E.Expanded = llvm::makeArrayRef(ExpandedTokens.data() + M->BeginExpanded,
+                                  ExpandedTokens.data() + M->EndExpanded);
+  return E;
+}
+
 std::vector<syntax::Token> syntax::tokenize(FileID FID, const SourceManager &SM,
                                             const LangOptions &LO) {
   std::vector<syntax::Token> Tokens;
diff --git a/unittests/Tooling/Syntax/TokensTest.cpp b/unittests/Tooling/Syntax/TokensTest.cpp
index 1d931fa..34c80fc 100644
--- a/unittests/Tooling/Syntax/TokensTest.cpp
+++ b/unittests/Tooling/Syntax/TokensTest.cpp
@@ -55,6 +55,7 @@
 using ::testing::AllOf;
 using ::testing::Contains;
 using ::testing::ElementsAre;
+using ::testing::Field;
 using ::testing::Matcher;
 using ::testing::Not;
 using ::testing::StartsWith;
@@ -65,6 +66,13 @@
 MATCHER_P(SameRange, A, "") {
   return A.begin() == arg.begin() && A.end() == arg.end();
 }
+
+Matcher<TokenBuffer::Expansion>
+IsExpansion(Matcher<llvm::ArrayRef<syntax::Token>> Spelled,
+            Matcher<llvm::ArrayRef<syntax::Token>> Expanded) {
+  return AllOf(Field(&TokenBuffer::Expansion::Spelled, Spelled),
+               Field(&TokenBuffer::Expansion::Expanded, Expanded));
+}
 // Matchers for syntax::Token.
 MATCHER_P(Kind, K, "") { return arg.kind() == K; }
 MATCHER_P2(HasText, Text, SourceMgr, "") {
@@ -629,6 +637,76 @@
               ValueIs(SameRange(findSpelled("not_mapped"))));
 }
 
+TEST_F(TokenBufferTest, ExpansionStartingAt) {
+  // Object-like macro expansions.
+  recordTokens(R"cpp(
+    #define FOO 3+4
+    int a = FOO 1;
+    int b = FOO 2;
+  )cpp");
+
+  llvm::ArrayRef<syntax::Token> Foo1 = findSpelled("FOO 1").drop_back();
+  EXPECT_THAT(
+      Buffer.expansionStartingAt(Foo1.data()),
+      ValueIs(IsExpansion(SameRange(Foo1),
+                          SameRange(findExpanded("3 + 4 1").drop_back()))));
+
+  llvm::ArrayRef<syntax::Token> Foo2 = findSpelled("FOO 2").drop_back();
+  EXPECT_THAT(
+      Buffer.expansionStartingAt(Foo2.data()),
+      ValueIs(IsExpansion(SameRange(Foo2),
+                          SameRange(findExpanded("3 + 4 2").drop_back()))));
+
+  // Function-like macro expansions.
+  recordTokens(R"cpp(
+    #define ID(X) X
+    int a = ID(1+2+3);
+    int b = ID(ID(2+3+4));
+  )cpp");
+
+  llvm::ArrayRef<syntax::Token> ID1 = findSpelled("ID ( 1 + 2 + 3 )");
+  EXPECT_THAT(Buffer.expansionStartingAt(&ID1.front()),
+              ValueIs(IsExpansion(SameRange(ID1),
+                                  SameRange(findExpanded("1 + 2 + 3")))));
+  // Only the first spelled token should be found.
+  for (const auto &T : ID1.drop_front())
+    EXPECT_EQ(Buffer.expansionStartingAt(&T), llvm::None);
+
+  llvm::ArrayRef<syntax::Token> ID2 = findSpelled("ID ( ID ( 2 + 3 + 4 ) )");
+  EXPECT_THAT(Buffer.expansionStartingAt(&ID2.front()),
+              ValueIs(IsExpansion(SameRange(ID2),
+                                  SameRange(findExpanded("2 + 3 + 4")))));
+  // Only the first spelled token should be found.
+  for (const auto &T : ID2.drop_front())
+    EXPECT_EQ(Buffer.expansionStartingAt(&T), llvm::None);
+
+  // PP directives.
+  recordTokens(R"cpp(
+#define FOO 1
+int a = FOO;
+#pragma once
+int b = 1;
+  )cpp");
+
+  llvm::ArrayRef<syntax::Token> DefineFoo = findSpelled("# define FOO 1");
+  EXPECT_THAT(
+      Buffer.expansionStartingAt(&DefineFoo.front()),
+      ValueIs(IsExpansion(SameRange(DefineFoo),
+                          SameRange(findExpanded("int a").take_front(0)))));
+  // Only the first spelled token should be found.
+  for (const auto &T : DefineFoo.drop_front())
+    EXPECT_EQ(Buffer.expansionStartingAt(&T), llvm::None);
+
+  llvm::ArrayRef<syntax::Token> PragmaOnce = findSpelled("# pragma once");
+  EXPECT_THAT(
+      Buffer.expansionStartingAt(&PragmaOnce.front()),
+      ValueIs(IsExpansion(SameRange(PragmaOnce),
+                          SameRange(findExpanded("int b").take_front(0)))));
+  // Only the first spelled token should be found.
+  for (const auto &T : PragmaOnce.drop_front())
+    EXPECT_EQ(Buffer.expansionStartingAt(&T), llvm::None);
+}
+
 TEST_F(TokenBufferTest, TokensToFileRange) {
   addFile("./foo.h", "token_from_header");
   llvm::Annotations Code(R"cpp(