Changing convergence check to verify KKT conditions directly

Change-Id: Ia80bc87e064ecf6ac414f30a98490a1e31371ae0
diff --git a/lossmin/minimizers/loss-minimizer.cc b/lossmin/minimizers/loss-minimizer.cc
index 427c4ac..cd052d7 100644
--- a/lossmin/minimizers/loss-minimizer.cc
+++ b/lossmin/minimizers/loss-minimizer.cc
@@ -130,13 +130,11 @@
   return min_loss_rate;
 }
 
+// By default only check if gradient is == 0
 void LossMinimizer::ConvergenceCheck(const Weights &weights,
                                      const Weights &gradient) {
-  Weights gradient_for_convergence =
-      (weights.array() == 0.0f).select(abs(gradient.array()) - l1_, gradient);
-  if (gradient_for_convergence.norm() / weights.size() <
-      convergence_threshold_) {
-    converged_ = true;
+  if (gradient.squaredNorm() / weights.size() < convergence_threshold_) {
+    set_converged(true);
   }
 }
 
diff --git a/lossmin/minimizers/loss-minimizer.h b/lossmin/minimizers/loss-minimizer.h
index 48df05a..58ac3b4 100644
--- a/lossmin/minimizers/loss-minimizer.h
+++ b/lossmin/minimizers/loss-minimizer.h
@@ -109,11 +109,10 @@
     return loss;
   }
 
-  // Checks convergence of deterministic methods based on two criteria:
-  // (1) the squared l2 norm of the gradient is below 'convergence_threshold_'
-  // (2) for zero-valued weights, the corresponding gradient values are < l1.
+  // Checks convergence of deterministic methods.
   // If converged, the flag 'converged_' is set to true.
-  void ConvergenceCheck(const Weights &weights, const Weights &gradient);
+  virtual void ConvergenceCheck(const Weights &weights,
+                                const Weights &gradient);
 
   // Checks convergence based on the decrease in loss over the last
   // 'num_convergence_epochs_' epochs. If converged, the flag 'converged_' is
@@ -126,12 +125,17 @@
   void set_use_simple_convergence_check(bool use_simple_convergence_check) {
     use_simple_convergence_check_ = use_simple_convergence_check;
   }
+  float convergence_threshold() const { return convergence_threshold_; }
   void set_convergence_threshold(float convergence_threshold) {
     convergence_threshold_ = convergence_threshold;
   }
   void set_num_convergence_epochs(int num_convergence_epochs) {
     num_convergence_epochs_ = num_convergence_epochs;
   }
+  float zero_threshold() const { return zero_threshold_; }
+  void set_zero_threshold(float zero_threshold) {
+    zero_threshold_ = zero_threshold;
+  }
 
   // Returns the best initial learning rate for stochastic methods by evaluating
   // elements of 'initial_rates_' on a subset of 'num_search_examples_' training
@@ -181,11 +185,13 @@
     return gradient_evaluator_;
   }
 
-  // Returns the l1 regularization parameter.
+  // Getter/setter of the l1 regularization parameter.
   float l1() const { return l1_; }
+  void set_l1(float l1) { l1_ = l1; }
 
-  // Returns the l2 regularization parameter.
+  // Getter/setter of the l2 regularization parameter.
   float l2() const { return l2_; }
+  void set_l2(float l2) { l2_ = l2; }
 
   // Getter/setter for the 'num_scale_iterations_', used for sparse updates in
   // stochastic methods.
@@ -261,9 +267,13 @@
   bool use_simple_convergence_check_ = false;  // which convergence check to use
   int num_convergence_epochs_ = 5;             // used in SimpleConvergenceCheck
 
+  // zero_threshold_ is the threshold below which we treat the coordinate value
+  // as zero (in absolute terms). This is used in ConvergenceCheck.
+  float zero_threshold_ = 1e-6;
+
   // The number of epochs (iterations) when Run() was executed.
   // In other words, each epoch is a step towards minimum during minimization.
-  // This variable gets updated when Run() is called
+  // This variable gets updated when Run() is called.
   int num_epochs_run_ = 0;
 
   // Initial learning rate, used in stochastic methods.
diff --git a/lossmin/minimizers/parallel-boosting-with-momentum.cc b/lossmin/minimizers/parallel-boosting-with-momentum.cc
index 7254401..694b75c 100644
--- a/lossmin/minimizers/parallel-boosting-with-momentum.cc
+++ b/lossmin/minimizers/parallel-boosting-with-momentum.cc
@@ -28,6 +28,30 @@
   phi_center_ = Weights::Zero(gradient_evaluator().NumWeights());
 }
 
+void ParallelBoostingWithMomentum::ConvergenceCheck(const Weights &weights,
+                                                    const Weights &gradient) {
+  float error_squared = 0.0f;
+  for (int i = 0; i < gradient.size(); i++) {
+    // for weights > 0 the gradient should be == -l1
+    if (weights(i) > zero_threshold()) {
+      error_squared += (gradient(i) + l1()) * (gradient(i) + l1());
+    }
+    // for weights < 0 the gradient should be == l1
+    else if (weights(i) < -zero_threshold()) {
+      error_squared += (gradient(i) - l1()) * (gradient(i) - l1());
+    }
+    // for weights == 0 the gradient should be between -l1 and l1
+    else {
+      float err = std::max(std::abs(gradient(i)) - l1(), 0.0f);
+      error_squared += err * err;
+    }
+  }
+
+  if (std::sqrt(error_squared) / weights.size() < convergence_threshold()) {
+    set_converged(true);
+  }
+}
+
 void ParallelBoostingWithMomentum::EpochUpdate(Weights *weights, int epoch,
                                                bool check_convergence) {
   // Compute the intermediate weight vector y.
@@ -36,6 +60,7 @@
   // Compute the gradient of the loss wrt y.
   Weights gradient_wrt_y = Weights::Zero(y.size());
   gradient_evaluator().Gradient(y, &gradient_wrt_y);
+
   if (l2() > 0.0f) gradient_wrt_y += l2() * y;
 
   // Gradient step.
@@ -44,18 +69,24 @@
   // l1 shrinkage.
   if (l1() > 0.0f) {
     L1Prox(l1() * learning_rates_, weights);
-    gradient_wrt_y += l1() * weights->unaryExpr(std::ptr_fun(Sign));
   }
 
   // Update the approximation function.
   phi_center_ -= (1.0 - alpha_) / alpha_ * (y - *weights);
   alpha_ =
-      -beta_ / 2.0  +
-      pow(beta_ + beta_ * beta_ / 4.0, static_cast<float>(0.5));
+      -beta_ / 2.0 + pow(beta_ + beta_ * beta_ / 4.0, static_cast<float>(0.5));
   beta_ *= (1.0 - alpha_);
 
-  // Check convergence.
-  if (check_convergence) ConvergenceCheck(*weights, gradient_wrt_y);
+  // Compute the gradient of the objective except the l1 part and check
+  // convergence
+  if (check_convergence) {
+    Weights gradient_wrt_weights = Weights::Zero(weights->size());
+    gradient_evaluator().Gradient(*weights, &gradient_wrt_weights);
+    if (l2() > 0.0f) {
+      gradient_wrt_weights += l2() * *weights;
+    }
+    ConvergenceCheck(*weights, gradient_wrt_weights);
+  }
 }
 
 }  // namespace lossmin
diff --git a/lossmin/minimizers/parallel-boosting-with-momentum.h b/lossmin/minimizers/parallel-boosting-with-momentum.h
index 619dfe8..11455b3 100644
--- a/lossmin/minimizers/parallel-boosting-with-momentum.h
+++ b/lossmin/minimizers/parallel-boosting-with-momentum.h
@@ -20,8 +20,8 @@
 
 class ParallelBoostingWithMomentum : public LossMinimizer {
  public:
-  ParallelBoostingWithMomentum(
-      float l1, float l2, const GradientEvaluator &gradient_evaluator)
+  ParallelBoostingWithMomentum(float l1, float l2,
+                               const GradientEvaluator &gradient_evaluator)
       : LossMinimizer(l1, l2, gradient_evaluator) {
     Setup();
   }
@@ -29,6 +29,22 @@
   // Sets learning rates and other parameters.
   void Setup() override;
 
+  // Checks convergence by verifying the KKT conditions directly.
+  // |gradient| is the gradient of the objective at |weights|, not including
+  // the contribution of the l1 penalty part of the objective.
+  //
+  // The function checks whether the mean norm of violations of the KKT
+  // condition is below convergence threshold. The KKT condition is necessary
+  // and sufficient for |weights| to be a minimizer:
+  // If weights_i < 0 then gradient_i == l1()
+  // If weights_i > 0 then gradient_i == -l1()
+  // If weights_i == 0 then -l1() <= gradient_i <= l1().
+  //
+  // The squared violations at each coordinate are summed and the square
+  // root divided by weights.size() is compared to convergence_thresold()
+  void ConvergenceCheck(const Weights &weights,
+                        const Weights &gradient) override;
+
  private:
   // Updates 'weights' and the quadratic approximation function phi(w), such
   // that at iteration k, loss(weights_k) <= min_w phi_k(w).
@@ -57,4 +73,3 @@
 };
 
 }  // namespace lossmin
-