[OpenMP] Make OpenMPOpt aware of the OpenMP runtime's status

The `OpenMPOpt` pass contains optimizations that generate new calls into
the OpenMP runtime. This causes problems if we are in a state where the
runtime has already been linked statically. Generating these new calls
will result in them never being resolved. We should indicate if we are
in a "post-link" LTO phase and prevent OpenMPOpt from generating new
runtime calls.

Generally, it's not desireable for passes to maintain state about the
context in which they're called. But this is the only reasonable
solution to static linking when we have a pass that generates new
runtime calls.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D142646

(cherry picked from commit 0bdde9dfb9b1dbfabee147c196db820e1f5dca1f)
diff --git a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
index bf08336..73aee47 100644
--- a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
+++ b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
@@ -37,13 +37,25 @@
 /// OpenMP optimizations pass.
 class OpenMPOptPass : public PassInfoMixin<OpenMPOptPass> {
 public:
+  OpenMPOptPass() : LTOPhase(ThinOrFullLTOPhase::None) {}
+  OpenMPOptPass(ThinOrFullLTOPhase LTOPhase) : LTOPhase(LTOPhase) {}
+
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
+
+private:
+  const ThinOrFullLTOPhase LTOPhase = ThinOrFullLTOPhase::None;
 };
 
 class OpenMPOptCGSCCPass : public PassInfoMixin<OpenMPOptCGSCCPass> {
 public:
+  OpenMPOptCGSCCPass() : LTOPhase(ThinOrFullLTOPhase::None) {}
+  OpenMPOptCGSCCPass(ThinOrFullLTOPhase LTOPhase) : LTOPhase(LTOPhase) {}
+
   PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
                         LazyCallGraph &CG, CGSCCUpdateResult &UR);
+
+private:
+  const ThinOrFullLTOPhase LTOPhase = ThinOrFullLTOPhase::None;
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index 75b34b8..eed29c2 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -1604,7 +1604,7 @@
   }
 
   // Try to run OpenMP optimizations, quick no-op if no OpenMP metadata present.
-  MPM.addPass(OpenMPOptPass());
+  MPM.addPass(OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink));
 
   // Remove unused virtual tables to improve the quality of code generated by
   // whole-program devirtualization and bitset lowering.
@@ -1811,7 +1811,8 @@
   addVectorPasses(Level, MainFPM, /* IsFullLTO */ true);
 
   // Run the OpenMPOpt CGSCC pass again late.
-  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(OpenMPOptCGSCCPass()));
+  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(
+      OpenMPOptCGSCCPass(ThinOrFullLTOPhase::FullLTOPostLink)));
 
   invokePeepholeEPCallbacks(MainFPM, Level);
   MainFPM.addPass(JumpThreadingPass());
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index ad44d86..10af416 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -44,6 +44,7 @@
 MODULE_PASS("attributor", AttributorPass())
 MODULE_PASS("annotation2metadata", Annotation2MetadataPass())
 MODULE_PASS("openmp-opt", OpenMPOptPass())
+MODULE_PASS("openmp-opt-postlink", OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink))
 MODULE_PASS("called-value-propagation", CalledValuePropagationPass())
 MODULE_PASS("canonicalize-aliases", CanonicalizeAliasesPass())
 MODULE_PASS("cg-profile", CGProfilePass())
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index bee154d..ee90bf8 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -188,9 +188,9 @@
 struct OMPInformationCache : public InformationCache {
   OMPInformationCache(Module &M, AnalysisGetter &AG,
                       BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
-                      KernelSet &Kernels)
+                      KernelSet &Kernels, bool OpenMPPostLink)
       : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
-        Kernels(Kernels) {
+        Kernels(Kernels), OpenMPPostLink(OpenMPPostLink) {
 
     OMPBuilder.initialize();
     initializeRuntimeFunctions(M);
@@ -448,6 +448,24 @@
       CI->setCallingConv(Fn->getCallingConv());
   }
 
+  // Helper function to determine if it's legal to create a call to the runtime
+  // functions.
+  bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
+    // We can always emit calls if we haven't yet linked in the runtime.
+    if (!OpenMPPostLink)
+      return true;
+
+    // Once the runtime has been already been linked in we cannot emit calls to
+    // any undefined functions.
+    for (RuntimeFunction Fn : Fns) {
+      RuntimeFunctionInfo &RFI = RFIs[Fn];
+
+      if (RFI.Declaration && RFI.Declaration->isDeclaration())
+        return false;
+    }
+    return true;
+  }
+
   /// Helper to initialize all runtime function information for those defined
   /// in OpenMPKinds.def.
   void initializeRuntimeFunctions(Module &M) {
@@ -523,6 +541,9 @@
 
   /// Collection of known OpenMP runtime functions..
   DenseSet<const Function *> RTLFunctions;
+
+  /// Indicates if we have already linked in the OpenMP device library.
+  bool OpenMPPostLink = false;
 };
 
 template <typename Ty, bool InsertInvalidates = true>
@@ -1412,7 +1433,10 @@
       Changed |= WasSplit;
       return WasSplit;
     };
-    RFI.foreachUse(SCC, SplitMemTransfers);
+    if (OMPInfoCache.runtimeFnsAvailable(
+            {OMPRTL___tgt_target_data_begin_mapper_issue,
+             OMPRTL___tgt_target_data_begin_mapper_wait}))
+      RFI.foreachUse(SCC, SplitMemTransfers);
 
     return Changed;
   }
@@ -3914,6 +3938,12 @@
   bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
 
+    // We cannot change to SPMD mode if the runtime functions aren't availible.
+    if (!OMPInfoCache.runtimeFnsAvailable(
+            {OMPRTL___kmpc_get_hardware_thread_id_in_block,
+             OMPRTL___kmpc_barrier_simple_spmd}))
+      return false;
+
     if (!SPMDCompatibilityTracker.isAssumed()) {
       for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
         if (!NonCompatibleI)
@@ -4021,6 +4051,13 @@
     if (!ReachedKnownParallelRegions.isValidState())
       return ChangeStatus::UNCHANGED;
 
+    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
+    if (!OMPInfoCache.runtimeFnsAvailable(
+            {OMPRTL___kmpc_get_hardware_num_threads_in_block,
+             OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
+             OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
+      return ChangeStatus::UNCHANGED;
+
     const int InitModeArgNo = 1;
     const int InitUseStateMachineArgNo = 2;
 
@@ -4167,7 +4204,6 @@
     BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
 
     Module &M = *Kernel->getParent();
-    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
     FunctionCallee BlockHwSizeFn =
         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
             M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
@@ -5343,7 +5379,10 @@
   BumpPtrAllocator Allocator;
   CallGraphUpdater CGUpdater;
 
-  OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, Kernels);
+  bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
+                  LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
+  OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, Kernels,
+                                PostLink);
 
   unsigned MaxFixpointIterations =
       (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
@@ -5417,9 +5456,11 @@
   CallGraphUpdater CGUpdater;
   CGUpdater.initialize(CG, C, AM, UR);
 
+  bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
+                  LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
   SetVector<Function *> Functions(SCC.begin(), SCC.end());
   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
-                                /*CGSCC*/ &Functions, Kernels);
+                                /*CGSCC*/ &Functions, Kernels, PostLink);
 
   unsigned MaxFixpointIterations =
       (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
diff --git a/llvm/test/Transforms/OpenMP/custom_state_machines_pre_lto.ll b/llvm/test/Transforms/OpenMP/custom_state_machines_pre_lto.ll
index 6dae173..eb83c59 100644
--- a/llvm/test/Transforms/OpenMP/custom_state_machines_pre_lto.ll
+++ b/llvm/test/Transforms/OpenMP/custom_state_machines_pre_lto.ll
@@ -2,7 +2,9 @@
 ; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=AMDGPU
 ; RUN: opt --mtriple=nvptx64--         -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=NVPTX
 ; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -openmp-opt-disable-state-machine-rewrite -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=AMDGPU
+; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt-postlink < %s | FileCheck %s --check-prefix=AMDGPU
 ; RUN: opt --mtriple=nvptx64--         -openmp-opt-disable-state-machine-rewrite -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=NVPTX
+; RUN: opt --mtriple=nvptx64--         -S -passes=openmp-opt-postlink < %s | FileCheck %s --check-prefix=NVPTX
 
 ;; void p0(void);
 ;; void p1(void);
diff --git a/llvm/test/Transforms/OpenMP/spmdization.ll b/llvm/test/Transforms/OpenMP/spmdization.ll
index e7100db..277d092 100644
--- a/llvm/test/Transforms/OpenMP/spmdization.ll
+++ b/llvm/test/Transforms/OpenMP/spmdization.ll
@@ -2,7 +2,9 @@
 ; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt < %s | FileCheck %s --check-prefixes=AMDGPU
 ; RUN: opt --mtriple=nvptx64-- -S -passes=openmp-opt < %s | FileCheck %s --check-prefixes=NVPTX
 ; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt -openmp-opt-disable-spmdization < %s | FileCheck %s --check-prefix=AMDGPU-DISABLED
+; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt-postlink < %s | FileCheck %s --check-prefix=AMDGPU-DISABLED
 ; RUN: opt --mtriple=nvptx64-- -S -passes=openmp-opt -openmp-opt-disable-spmdization < %s | FileCheck %s --check-prefix=NVPTX-DISABLED
+; RUN: opt --mtriple=nvptx64-- -S -passes=openmp-opt-postlink < %s | FileCheck %s --check-prefix=NVPTX-DISABLED
 
 ;; void unknown(void);
 ;; void spmd_amenable(void) __attribute__((assume("ompx_spmd_amenable")));