fix: use runfiles strategy to get runfiles root
diff --git a/python/runfiles/runfiles.py b/python/runfiles/runfiles.py index ffa2473..1b3e8eb 100644 --- a/python/runfiles/runfiles.py +++ b/python/runfiles/runfiles.py
@@ -95,6 +95,9 @@ raise TypeError() self._runfiles_root = path + def _GetRunfilesDir(self) -> str: + return self._runfiles_root + def RlocationChecked(self, path: str) -> str: # Use posixpath instead of os.path, because Bazel only creates a runfiles # tree on Unix platforms, so `Create()` will only create a directory-based @@ -118,7 +121,7 @@ def __init__(self, strategy: Union[_ManifestBased, _DirectoryBased]) -> None: self._strategy = strategy - self._python_runfiles_root = _FindPythonRunfilesRoot() + self._python_runfiles_root = self._strategy._GetRunfilesDir() self._repo_mapping = _ParseRepoMapping( strategy.RlocationChecked("_repo_mapping") ) @@ -321,19 +324,6 @@ # Support legacy imports by defining a private symbol. _Runfiles = Runfiles - -def _FindPythonRunfilesRoot() -> str: - """Finds the root of the Python runfiles tree.""" - root = __file__ - # Walk up our own runfiles path to the root of the runfiles tree from which - # the current file is being run. This path coincides with what the Bazel - # Python stub sets up as sys.path[0]. Since that entry can be changed at - # runtime, we rederive it here. - for _ in range("rules_python/python/runfiles/runfiles.py".count("/") + 1): - root = os.path.dirname(root) - return root - - def _ParseRepoMapping(repo_mapping_path: Optional[str]) -> Dict[Tuple[str, str], str]: """Parses the repository mapping manifest.""" # If the repository mapping file can't be found, that is not an error: We
diff --git a/tests/runfiles/runfiles_test.py b/tests/runfiles/runfiles_test.py index 03350f3..c6f454f 100644 --- a/tests/runfiles/runfiles_test.py +++ b/tests/runfiles/runfiles_test.py
@@ -527,7 +527,7 @@ expected = "" else: expected = "rules_python" - r = runfiles.Create({"RUNFILES_DIR": "whatever"}) + r = runfiles.Create({"RUNFILES_DIR": os.environ.get("RUNFILES_DIR")}) assert r is not None # mypy doesn't understand the unittest api. self.assertEqual(r.CurrentRepository(), expected)