blob: a4035b7016dcee00ffbf53aad4a9faa6ef7f93e1 [file] [log] [blame] [edit]
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Manage baseline scores (the reference from which we form a reward)"""
from typing import Any, Generic, TypeVar
from collections.abc import Callable
T = TypeVar("T")
class BaselineCache(Generic[T]):
"""Manages a cache of baseline scores."""
def __init__(self, *, get_key: Callable[[T], Any]):
"""Constructor.
Args:
get_key: A callable that returns the key for an item.
"""
self._get_key = get_key
self._cache = {}
def get_score(self, items: list[T | None],
get_scores_func: Callable[[list[T]], list[float]]):
"""Get the scores for a batch of items.
The scores are returned in the same order as the provided items. A None
result indicates the score could not be obtained.
Args:
items: A list of items to get scores for.
get_scores_func: A callable that returns the scores for a batch of
items.
get_scores_func: Responsible for timely completion. It must not
raise, and it must return results in the order of the items
provided. A None value is expected for items that could not
produce a value.
"""
todo = {i for i in items if self._get_key(i) not in self._cache}
scores = get_scores_func(list(todo))
if len(scores) != len(todo):
raise ValueError(
"got a different number of results for the requested items")
for i, s in zip(todo, scores):
self._cache[self._get_key(i)] = s
return [self._cache[self._get_key(i)] for i in items]
def get_cache(self):
"""Intended for testing."""
return self._cache