blob: 46608ec3a35abf8c7f759a34041cb5b4f399ec0d [file] [log] [blame]
# Copyright 2023 The Fuchsia Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
from functools import cached_property, lru_cache
from recipe_engine import recipe_api
class TensorFlowApi(recipe_api.RecipeApi):
"""Provides functions to interact with TensorFlow installed via vpython."""
@property
def aot_compiler(self):
self._install()
return self._venv_dir.join("bin", "saved_model_cli")
@property
def path(self):
"""Path to the TensorFlow installation directory."""
return self._install()
@property
def vpython_spec(self):
"""Path to a vpython spec file for scripts that depend on TF."""
return self.resource("tensorflow.vpython")
@cached_property
def _venv_dir(self):
"""Path to the root of the virtualenv containing TensorFlow."""
return self.m.path.mkdtemp("tensorflow-venv")
@lru_cache
def _install(self):
"""Installs TensorFlow and returns the path to the installation."""
tensorflow_dir = self.m.step(
"get tensorflow",
[
"vpython3",
"-vpython-root",
self._venv_dir,
self.resource("get_tensorflow.py"),
],
stdout=self.m.raw_io.output_text(add_output_log=True),
step_test_data=lambda: self.m.raw_io.test_api.stream_output_text(
str(
self._venv_dir.join(
"lib", "python3.8", "site-packages", "tensorflow"
)
)
),
).stdout.strip()
return self.m.path.abs_to_path(tensorflow_dir)