Changing convergence check to verify KKT conditions directly
Change-Id: Ia80bc87e064ecf6ac414f30a98490a1e31371ae0
diff --git a/lossmin/minimizers/loss-minimizer.cc b/lossmin/minimizers/loss-minimizer.cc
index 427c4ac..cd052d7 100644
--- a/lossmin/minimizers/loss-minimizer.cc
+++ b/lossmin/minimizers/loss-minimizer.cc
@@ -130,13 +130,11 @@
return min_loss_rate;
}
+// By default only check if gradient is == 0
void LossMinimizer::ConvergenceCheck(const Weights &weights,
const Weights &gradient) {
- Weights gradient_for_convergence =
- (weights.array() == 0.0f).select(abs(gradient.array()) - l1_, gradient);
- if (gradient_for_convergence.norm() / weights.size() <
- convergence_threshold_) {
- converged_ = true;
+ if (gradient.squaredNorm() / weights.size() < convergence_threshold_) {
+ set_converged(true);
}
}
diff --git a/lossmin/minimizers/loss-minimizer.h b/lossmin/minimizers/loss-minimizer.h
index 48df05a..58ac3b4 100644
--- a/lossmin/minimizers/loss-minimizer.h
+++ b/lossmin/minimizers/loss-minimizer.h
@@ -109,11 +109,10 @@
return loss;
}
- // Checks convergence of deterministic methods based on two criteria:
- // (1) the squared l2 norm of the gradient is below 'convergence_threshold_'
- // (2) for zero-valued weights, the corresponding gradient values are < l1.
+ // Checks convergence of deterministic methods.
// If converged, the flag 'converged_' is set to true.
- void ConvergenceCheck(const Weights &weights, const Weights &gradient);
+ virtual void ConvergenceCheck(const Weights &weights,
+ const Weights &gradient);
// Checks convergence based on the decrease in loss over the last
// 'num_convergence_epochs_' epochs. If converged, the flag 'converged_' is
@@ -126,12 +125,17 @@
void set_use_simple_convergence_check(bool use_simple_convergence_check) {
use_simple_convergence_check_ = use_simple_convergence_check;
}
+ float convergence_threshold() const { return convergence_threshold_; }
void set_convergence_threshold(float convergence_threshold) {
convergence_threshold_ = convergence_threshold;
}
void set_num_convergence_epochs(int num_convergence_epochs) {
num_convergence_epochs_ = num_convergence_epochs;
}
+ float zero_threshold() const { return zero_threshold_; }
+ void set_zero_threshold(float zero_threshold) {
+ zero_threshold_ = zero_threshold;
+ }
// Returns the best initial learning rate for stochastic methods by evaluating
// elements of 'initial_rates_' on a subset of 'num_search_examples_' training
@@ -181,11 +185,13 @@
return gradient_evaluator_;
}
- // Returns the l1 regularization parameter.
+ // Getter/setter of the l1 regularization parameter.
float l1() const { return l1_; }
+ void set_l1(float l1) { l1_ = l1; }
- // Returns the l2 regularization parameter.
+ // Getter/setter of the l2 regularization parameter.
float l2() const { return l2_; }
+ void set_l2(float l2) { l2_ = l2; }
// Getter/setter for the 'num_scale_iterations_', used for sparse updates in
// stochastic methods.
@@ -261,9 +267,13 @@
bool use_simple_convergence_check_ = false; // which convergence check to use
int num_convergence_epochs_ = 5; // used in SimpleConvergenceCheck
+ // zero_threshold_ is the threshold below which we treat the coordinate value
+ // as zero (in absolute terms). This is used in ConvergenceCheck.
+ float zero_threshold_ = 1e-6;
+
// The number of epochs (iterations) when Run() was executed.
// In other words, each epoch is a step towards minimum during minimization.
- // This variable gets updated when Run() is called
+ // This variable gets updated when Run() is called.
int num_epochs_run_ = 0;
// Initial learning rate, used in stochastic methods.
diff --git a/lossmin/minimizers/parallel-boosting-with-momentum.cc b/lossmin/minimizers/parallel-boosting-with-momentum.cc
index 7254401..694b75c 100644
--- a/lossmin/minimizers/parallel-boosting-with-momentum.cc
+++ b/lossmin/minimizers/parallel-boosting-with-momentum.cc
@@ -28,6 +28,30 @@
phi_center_ = Weights::Zero(gradient_evaluator().NumWeights());
}
+void ParallelBoostingWithMomentum::ConvergenceCheck(const Weights &weights,
+ const Weights &gradient) {
+ float error_squared = 0.0f;
+ for (int i = 0; i < gradient.size(); i++) {
+ // for weights > 0 the gradient should be == -l1
+ if (weights(i) > zero_threshold()) {
+ error_squared += (gradient(i) + l1()) * (gradient(i) + l1());
+ }
+ // for weights < 0 the gradient should be == l1
+ else if (weights(i) < -zero_threshold()) {
+ error_squared += (gradient(i) - l1()) * (gradient(i) - l1());
+ }
+ // for weights == 0 the gradient should be between -l1 and l1
+ else {
+ float err = std::max(std::abs(gradient(i)) - l1(), 0.0f);
+ error_squared += err * err;
+ }
+ }
+
+ if (std::sqrt(error_squared) / weights.size() < convergence_threshold()) {
+ set_converged(true);
+ }
+}
+
void ParallelBoostingWithMomentum::EpochUpdate(Weights *weights, int epoch,
bool check_convergence) {
// Compute the intermediate weight vector y.
@@ -36,6 +60,7 @@
// Compute the gradient of the loss wrt y.
Weights gradient_wrt_y = Weights::Zero(y.size());
gradient_evaluator().Gradient(y, &gradient_wrt_y);
+
if (l2() > 0.0f) gradient_wrt_y += l2() * y;
// Gradient step.
@@ -44,18 +69,24 @@
// l1 shrinkage.
if (l1() > 0.0f) {
L1Prox(l1() * learning_rates_, weights);
- gradient_wrt_y += l1() * weights->unaryExpr(std::ptr_fun(Sign));
}
// Update the approximation function.
phi_center_ -= (1.0 - alpha_) / alpha_ * (y - *weights);
alpha_ =
- -beta_ / 2.0 +
- pow(beta_ + beta_ * beta_ / 4.0, static_cast<float>(0.5));
+ -beta_ / 2.0 + pow(beta_ + beta_ * beta_ / 4.0, static_cast<float>(0.5));
beta_ *= (1.0 - alpha_);
- // Check convergence.
- if (check_convergence) ConvergenceCheck(*weights, gradient_wrt_y);
+ // Compute the gradient of the objective except the l1 part and check
+ // convergence
+ if (check_convergence) {
+ Weights gradient_wrt_weights = Weights::Zero(weights->size());
+ gradient_evaluator().Gradient(*weights, &gradient_wrt_weights);
+ if (l2() > 0.0f) {
+ gradient_wrt_weights += l2() * *weights;
+ }
+ ConvergenceCheck(*weights, gradient_wrt_weights);
+ }
}
} // namespace lossmin
diff --git a/lossmin/minimizers/parallel-boosting-with-momentum.h b/lossmin/minimizers/parallel-boosting-with-momentum.h
index 619dfe8..11455b3 100644
--- a/lossmin/minimizers/parallel-boosting-with-momentum.h
+++ b/lossmin/minimizers/parallel-boosting-with-momentum.h
@@ -20,8 +20,8 @@
class ParallelBoostingWithMomentum : public LossMinimizer {
public:
- ParallelBoostingWithMomentum(
- float l1, float l2, const GradientEvaluator &gradient_evaluator)
+ ParallelBoostingWithMomentum(float l1, float l2,
+ const GradientEvaluator &gradient_evaluator)
: LossMinimizer(l1, l2, gradient_evaluator) {
Setup();
}
@@ -29,6 +29,22 @@
// Sets learning rates and other parameters.
void Setup() override;
+ // Checks convergence by verifying the KKT conditions directly.
+ // |gradient| is the gradient of the objective at |weights|, not including
+ // the contribution of the l1 penalty part of the objective.
+ //
+ // The function checks whether the mean norm of violations of the KKT
+ // condition is below convergence threshold. The KKT condition is necessary
+ // and sufficient for |weights| to be a minimizer:
+ // If weights_i < 0 then gradient_i == l1()
+ // If weights_i > 0 then gradient_i == -l1()
+ // If weights_i == 0 then -l1() <= gradient_i <= l1().
+ //
+ // The squared violations at each coordinate are summed and the square
+ // root divided by weights.size() is compared to convergence_thresold()
+ void ConvergenceCheck(const Weights &weights,
+ const Weights &gradient) override;
+
private:
// Updates 'weights' and the quadratic approximation function phi(w), such
// that at iteration k, loss(weights_k) <= min_w phi_k(w).
@@ -57,4 +73,3 @@
};
} // namespace lossmin
-