| // Copyright 2014 Google Inc. All Rights Reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| // Author: nevena@google.com (Nevena Lazic) |
| |
| #include "lossmin/minimizers/parallel-boosting-with-momentum.h" |
| |
| #include <math.h> |
| #include <functional> |
| |
| #include "lossmin/minimizers/gradient-evaluator.h" |
| #include "third_party/eigen/Eigen/Core" |
| |
| namespace lossmin { |
| |
| void ParallelBoostingWithMomentum::Setup() { |
| // Per-coordinate learning rates are learning_rates[j] = 1 / sparsity * Lj, |
| // where sparsity is the maximum instance l0 norm and Lj is upper bound on |
| // the loss curvature along coordinate j. |
| float sparsity = gradient_evaluator().Sparsity(); |
| gradient_evaluator().PerCoordinateCurvature(&learning_rates_); |
| learning_rates_ = |
| (learning_rates_.array() + l2()).inverse().matrix() / sparsity; |
| |
| // Initialize the approximating function parameters. |
| alpha_ = 0.5f; |
| beta_ = 1.0f - alpha_; |
| phi_center_ = Weights::Zero(gradient_evaluator().NumWeights()); |
| } |
| |
| void ParallelBoostingWithMomentum::EpochUpdate(Weights *weights, int epoch, |
| bool check_convergence) { |
| // Compute the intermediate weight vector y. |
| Weights y = (1.0f - alpha_) * *weights + alpha_ * phi_center_; |
| |
| // Compute the gradient of the loss wrt y. |
| Weights gradient_wrt_y = Weights::Zero(y.size()); |
| gradient_evaluator().Gradient(y, &gradient_wrt_y); |
| if (l2() > 0.0f) gradient_wrt_y += l2() * y; |
| |
| // Gradient step. |
| *weights -= gradient_wrt_y.cwiseProduct(learning_rates_); |
| |
| // l1 shrinkage. |
| if (l1() > 0.0f) { |
| L1Prox(l1() * learning_rates_, weights); |
| gradient_wrt_y += l1() * weights->unaryExpr(std::ptr_fun(Sign)); |
| } |
| |
| // Update the approximation function. |
| phi_center_ -= (1.0 - alpha_) / alpha_ * (y - *weights); |
| alpha_ = |
| -beta_ / 2.0 + |
| pow(beta_ + beta_ * beta_ / 4.0, static_cast<float>(0.5)); |
| beta_ *= (1.0 - alpha_); |
| |
| // Check convergence. |
| if (check_convergence) ConvergenceCheck(*weights, gradient_wrt_y); |
| } |
| |
| } // namespace lossmin |