blob: 42a28d5c199abbfe4983dfb3aa904490ed548006 [file] [log] [blame]
// Copyright 2014 Google Inc. All Rights Reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// Author: nevena@google.com (Nevena Lazic)
//
// Interface for computing the value and gradient of a loss function given
// parameters 'weights' and labeled examples {(instance_i, label_i)}
#pragma once
#include <math.h>
#include "lossmin/eigen-types.h"
namespace lossmin {
// Abstract interface for computing the value and gradient of a loss function.
// Derived classes must provide the following methods: LossCurvature,
// PerCoordinateCurvature, ExampleLoss, AddExampleGradient, and
// PredictLabels.
// If multiple threads concurrently call the AddExampleGradient() method of the
// same LossFunction object to update the same 'gradient' vector, changes to
// 'gradient' need to be protected by a mutex in the implementation. If the
// AddExampleGradient() method of the same LossFunction object is used for
// different 'gradient' vectors, this will result in unnecessary synchronization
// costs. 'is_synchronous_' flag indicates whether to use synchronization.
// If multiple threads concurrently call the AddExampleGradient() method of
// different LossFunction objects to update the same 'gradient' vector, this
// may result in race conditions. In this case, changes to 'gradient' should be
// protected by the caller.
class LossFunction {
public:
LossFunction() {}
virtual ~LossFunction() {}
// Returns the loss for a single example (row 'example' of 'instances' and
// 'labels').
virtual double ExampleLoss(
const Weights &weights, const InstanceSet &instances,
const LabelSet &labels, int example) const = 0;
// Computes the gradient of the loss for a single example (row 'example' of
// 'instances' and labels) with respect to 'weights_scale * weights'. Adds
// 'example_scale * example_gradient' to 'gradient'.
// If multiple threads call the AddExampleGradient method of the same object
// to update the same 'gradient' vector, updates to 'gradient' need to be
// protected by a mutex in the implementation.
virtual void AddExampleGradient(
const Weights &weights, const InstanceSet &instances,
const LabelSet &labels, int example, double weights_scale,
double example_scale, Weights *gradient) const = 0;
// Returns the gradient of the loss for a single example. Used in AdaGrad.
virtual void ExampleGradient(
const Weights &weights, const InstanceSet &instances,
const LabelSet &labels, int example, double weights_scale,
double example_scale,
std::vector<std::pair<int, double>> *example_gradient) const = 0;
// Predicts 'labels' for 'instances' given 'weights'.
virtual void PredictLabels(
const Weights &weights, const InstanceSet &instances,
LabelSet *labels) const = 0;
// Returns an upper bound on the loss curvature (max eigenvalue of the loss
// Hessian matrix). Required by DeterministicGradientDescent. Optionally
// required by StochasticVarianceReducedGradient (for default learning rate)
// and StochasticGradientDescent for CURVATURE_BASED learning rate scheduling.
virtual double LossCurvature(const InstanceSet &instances) const = 0;
// Returns an upper bound on the curvature of the loss along each coordinate
// (max absolute value of the second derivative) of the data. Required by
// ParallelBoostingWithmomentum.
virtual void PerCoordinateCurvature(
const InstanceSet &instances,
VectorXd *per_coordinate_curvature) const = 0;
// Initializes parameters to a suggested setting for this loss if appropriate.
virtual void Init(Weights *weights) const {}
// Returns the number of model parameters for the given number of features.
// Needs to be overriden when #parameters != #features.
virtual int NumWeights(int num_features) const { return num_features; }
// Returns the total loss for a set of examples. Default implementation runs
// through the examples and calls ExampleLoss on each.
virtual double BatchLoss(const Weights &weights, const InstanceSet &instances,
const LabelSet &labels) const;
// Returns the total gradient for a set of examples. Default implementation
// runs through the examples and calls AddExampleGradient on each. Should
// always be called in a single-threaded context.
virtual void BatchGradient(
const Weights &weights, const InstanceSet &instances,
const LabelSet &labels, Weights *gradient) const;
// Sets the 'is_synchronous_' flag.
void set_synchronous_update(bool is_sync) { synchronous_update_ = is_sync; }
// Returns the 'is_synchronous_' flag.
bool synchronous_update() const { return synchronous_update_; }
private:
// Flag indicating whether to use synchronous or asynchronous updates in
// AddExampleGradient();
bool synchronous_update_ = false;
//DISALLOW_COPY_AND_ASSIGN(LossFunction);
};
} // namespace lossmin