| // Copyright 2014 Google Inc. All Rights Reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| // Author: nevena@google.com (Nevena Lazic) |
| // |
| // In InnerProductLossFunction, the loss of a labeled example |
| // (instance_i, label_i) is only a function of the inner product |
| // <weights, instance_i> and label_i. label_i is a scalar (single element |
| // VectorXf). |
| // |
| // Derived classes need to provide InnerProductExampleLoss, |
| // InnerProductExampleGradient, and InnerProductPredictLabel, and set the |
| // 'curvature_' parameter. |
| // |
| // Current implementations include logistic-regression, averaged-logistic, |
| // linear-regression, poisson-regression, smooth-hinge. |
| |
| #pragma once |
| |
| #include <float.h> |
| #include <math.h> |
| #include <functional> |
| #include <mutex> |
| |
| #include "lossmin/eigen-types.h" |
| #include "lossmin/losses/loss-function.h" |
| |
| namespace lossmin { |
| |
| class InnerProductLossFunction : public LossFunction { |
| public: |
| // Returns the loss for a single example. |
| double ExampleLoss( |
| const Weights &weights, const InstanceSet &instances, |
| const LabelSet &labels, int example) const override; |
| |
| // Adds the gradient of a single example to 'gradient'. |
| void AddExampleGradient( |
| const Weights &weights, const InstanceSet &instances, |
| const LabelSet &labels, int example, double weights_scale, |
| double example_scale, Weights *gradient) const override; |
| |
| // Returns the gradient of a single example. |
| void ExampleGradient( |
| const Weights &weights, const InstanceSet &instances, |
| const LabelSet &labels, int example, double weights_scale, |
| double example_scale, |
| std::vector<std::pair<int, double>> *example_gradient) const override; |
| |
| // Assigns labels to 'instances' given 'weights'. |
| void PredictLabels(const Weights &weights, const InstanceSet &instances, |
| LabelSet *labels) const override; |
| |
| // Returns an upper bound on the loss curvature. |
| double LossCurvature(const InstanceSet &instances) const override; |
| |
| // Returns an upper bound on the per-coordinate curvature. |
| void PerCoordinateCurvature( |
| const InstanceSet &instances, |
| VectorXd *per_coordinate_curvature) const override; |
| |
| virtual double InnerProductExampleLoss(double inner_product, double label) |
| const = 0; |
| |
| virtual double InnerProductExampleGradient(double inner_product, double label) |
| const = 0; |
| |
| virtual double InnerProductPredictLabel(double inner_product) const = 0; |
| |
| // Returns 'curvature_'. |
| virtual double InnerProductCurvature(double inner_product, double label) const { |
| return curvature_; |
| } |
| |
| protected: |
| // Sets the upper bound on the curvature of the loss function. |
| void set_curvature(double curvature) { curvature_ = curvature; } |
| |
| private: |
| // Mutex for synchronous updates of the gradient vector. |
| mutable std::mutex gradient_update_mutex_; |
| |
| // Upper bound on the absolute value of the second derivative of the loss: |
| // |d^2 loss(x) / dx^2| <= curvature_, where 'x' is the inner product |
| // <instance, weights>. Should be set by derived classes. |
| double curvature_; |
| }; |
| |
| // Linear regression with squared error loss. |
| class LinearRegressionLossFunction : public InnerProductLossFunction { |
| public: |
| LinearRegressionLossFunction() { set_curvature(1.0); } |
| |
| // Returns the squared error loss. |
| double InnerProductExampleLoss(double inner_product, double label) |
| const override { |
| return 0.5 * (inner_product - label) * (inner_product - label); |
| } |
| |
| // Returns the gradient of the squared error loss wrt 'inner_product'. |
| double InnerProductExampleGradient(double inner_product, double label) |
| const override { |
| return inner_product - label; |
| } |
| |
| // Assigns a label given 'inner_product'. |
| double InnerProductPredictLabel(double inner_product) const override { |
| return inner_product; |
| } |
| }; |
| |
| } // namespace lossmin |