// Copyright 2014 Google Inc. All Rights Reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// Author: nevena@google.com (Nevena Lazic)
//
// In InnerProductLossFunction, the loss of a labeled example
// (instance_i, label_i) is only a function of the inner product
// <weights, instance_i> and label_i. label_i is a scalar (single element
// VectorXf).
//
// Derived classes need to provide InnerProductExampleLoss,
// InnerProductExampleGradient, and InnerProductPredictLabel, and set the
// 'curvature_' parameter.
//
// Current implementations include logistic-regression, averaged-logistic,
// linear-regression, poisson-regression, smooth-hinge.

#pragma once

#include <float.h>
#include <math.h>
#include <functional>
#include <mutex>

#include "lossmin/eigen-types.h"
#include "lossmin/losses/loss-function.h"

namespace lossmin {

class InnerProductLossFunction : public LossFunction {
 public:
  // Returns the loss for a single example.
  float ExampleLoss(
      const Weights &weights, const InstanceSet &instances,
      const LabelSet &labels, int example) const override;

  // Adds the gradient of a single example to 'gradient'.
  void AddExampleGradient(
      const Weights &weights, const InstanceSet &instances,
      const LabelSet &labels, int example, float weights_scale,
      float example_scale, Weights *gradient) const override;

  // Returns the gradient of a single example.
  void ExampleGradient(
      const Weights &weights, const InstanceSet &instances,
      const LabelSet &labels, int example, float weights_scale,
      float example_scale,
      std::vector<std::pair<int, float>> *example_gradient) const override;

  // Assigns labels to 'instances' given 'weights'.
  void PredictLabels(const Weights &weights, const InstanceSet &instances,
                     LabelSet *labels) const override;

  // Returns an upper bound on the loss curvature.
  float LossCurvature(const InstanceSet &instances) const override;

  // Returns an upper bound on the per-coordinate curvature.
  void PerCoordinateCurvature(
      const InstanceSet &instances,
      VectorXf *per_coordinate_curvature) const override;

  virtual float InnerProductExampleLoss(float inner_product, float label)
      const = 0;

  virtual float InnerProductExampleGradient(float inner_product, float label)
      const = 0;

  virtual float InnerProductPredictLabel(float inner_product) const = 0;

  // Returns 'curvature_'.
  virtual float InnerProductCurvature(float inner_product, float label) const {
    return curvature_;
  }

 protected:
  // Sets the upper bound on the curvature of the loss function.
  void set_curvature(float curvature) { curvature_ = curvature; }

 private:
  // Mutex for synchronous updates of the gradient vector.
  mutable std::mutex gradient_update_mutex_;

  // Upper bound on the absolute value of the second derivative of the loss:
  // |d^2 loss(x) / dx^2| <= curvature_, where 'x' is the inner product
  //  <instance, weights>. Should be set by derived classes.
  float curvature_;
};

// Linear regression with squared error loss.
class LinearRegressionLossFunction : public InnerProductLossFunction {
 public:
  LinearRegressionLossFunction() { set_curvature(1.0); }

  // Returns the squared error loss.
  float InnerProductExampleLoss(float inner_product, float label)
      const override {
    return 0.5 * (inner_product - label) * (inner_product - label);
  }

  // Returns the gradient of the squared error loss wrt 'inner_product'.
  float InnerProductExampleGradient(float inner_product, float label)
      const override {
    return inner_product - label;
  }

  // Assigns a label given 'inner_product'.
  float InnerProductPredictLabel(float inner_product) const override {
    return inner_product;
  }
};

}  // namespace lossmin