blob: 99ed8850c2e84ccc15acc60b6c1aa65a9b13c50e [file] [log] [blame]
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operations to transform features (observations)."""
import json
import os
import re
import functools
from collections.abc import Callable
import numpy as np
import tensorflow.compat.v2 as tf
from tf_agents.typing import types
def build_quantile_map(quantile_file_dir: str):
"""build feature quantile map by reading from files in quantile_file_dir."""
quantile_map = {}
pattern = os.path.join(quantile_file_dir, '(.*).buckets')
for quantile_file_path in tf.io.gfile.glob(
os.path.join(quantile_file_dir, '*.buckets')):
m = re.fullmatch(pattern, quantile_file_path)
assert m
feature_name = m.group(1)
with tf.io.gfile.GFile(quantile_file_path, 'r') as quantile_file:
raw_quantiles = [float(x) for x in quantile_file]
quantile_map[feature_name] = raw_quantiles
return quantile_map
def discard_fn(obs: types.Float):
"""discard the input feature by setting it to 0."""
zeros = tf.zeros_like(obs, dtype=tf.float32)
return tf.expand_dims(zeros, axis=-1)
def identity_fn(obs: types.Float):
"""Return the same value with expanding the last dimension."""
return tf.cast(tf.expand_dims(obs, -1), tf.float32)
def get_normalize_fn(quantile: list[float],
with_sqrt: bool,
with_z_score_normalization: bool,
eps: float = 1e-8,
preprocessing_fn: Callable[[types.Tensor], types.Float]
| None = None):
"""Return a normalization function to normalize the input feature."""
if not preprocessing_fn:
# pylint: disable=unnecessary-lambda-assignment
preprocessing_fn = lambda x: x
processed_quantile = [preprocessing_fn(x) for x in quantile]
mean = np.mean(processed_quantile)
std = np.std(processed_quantile)
def normalize(obs: types.Float):
obs = tf.expand_dims(obs, -1)
x = tf.cast(
tf.raw_ops.Bucketize(input=obs, boundaries=quantile),
tf.float32) / len(quantile)
features = [x, x * x]
if with_sqrt:
features.append(tf.sqrt(x))
if with_z_score_normalization:
y = preprocessing_fn(tf.cast(obs, tf.float32))
y = (y - mean) / (std + eps)
features.append(y)
return tf.concat(features, axis=-1)
return normalize
def get_ir2vec_normalize_fn(
ir2vec_with_z_score_normalization: bool = False,
eps: float = 1e-8) -> Callable[[tf.Tensor], tf.Tensor]:
"""Return a normalization function for embeddings.
Args:
ir2vec_with_z_score_normalization: If True, applies z-score standardization
(whitening) to the embeddings per batch. This transforms the embeddings
to have zero mean and unit variance across the batch dimension.
Setting this to True would reduce the impact of outlier embeddings and
improve training stability, especially when embeddings from different
batches have significantly different scales.
eps: Small epsilon value added to standard deviation to prevent division
by zero during standardization.
Returns:
A function that normalizes IR2Vec embeddings according to the specified
parameters.
"""
if ir2vec_with_z_score_normalization:
def z_score_normalize(batch_embeddings: tf.Tensor) -> tf.Tensor:
"""Whitens the embeddings per batch."""
mean = tf.math.reduce_mean(batch_embeddings, axis=0, keepdims=True)
std = tf.math.reduce_std(batch_embeddings, axis=0, keepdims=True)
standardized_embeddings = (batch_embeddings - mean) / (std + eps)
return standardized_embeddings
return z_score_normalize
# Currently, we just return the identity function for embeddings.
# We can extend this to include other normalizations like L2
# normalization if needed.
return functools.partial(tf.cast, dtype=tf.float32)
def get_ir2vec_dimensions_from_vocab_file(vocab_file_path: str) -> int:
"""Read the IR2Vec vocabulary file and get embedding dimensions from the
first embedding in the first section.
Args:
vocab_file_path: Path to the IR2Vec vocabulary JSON file.
Returns:
The number of dimensions in the embeddings.
Raises:
ValueError: If the vocabulary file structure is invalid.
"""
# Load the vocabulary file and get the length of the first embedding.
# Robust structure checks are done by IR2Vec within LLVM.
# This method could be replaced to use IR2Vec Python APIs, when available.
with tf.io.gfile.GFile(vocab_file_path, 'r') as f:
vocab_data = json.load(f)
# Check if vocab_data is a dict with sections
if not isinstance(vocab_data, dict) or not vocab_data:
raise ValueError('Vocabulary file must contain a non-empty dictionary')
# Get the first section
sections = vocab_data.values()
first_section = next(iter(sections), None)
if not isinstance(first_section, dict) or not first_section:
raise ValueError('Vocabulary file sections must be non-empty '
'dictionaries')
# Get the first embedding
embeddings = first_section.values()
first_embedding = next(iter(embeddings), None)
if not isinstance(first_embedding, list):
raise ValueError('Vocabulary file embeddings must be lists')
if not first_embedding:
raise ValueError('Vocabulary file embeddings must be non-empty')
# Find any embedding array and return its length
return len(first_embedding)