/*
 * Decompiled with CFR 0.152.
 */
package hex.optimization;

import hex.optimization.L_BFGS;
import hex.optimization.OptimizationUtils;
import water.H2O;
import water.MemoryManager;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MathUtils;

public class ADMM {
    public static double shrinkage(double x2, double kappa) {
        double sign = x2 < 0.0 ? -1.0 : 1.0;
        double sx = x2 * sign;
        return sx <= kappa ? 0.0 : sign * (sx - kappa);
    }

    public static void subgrad(double lambda, double[] beta, double[] grad) {
        if (beta == null) {
            return;
        }
        for (int i2 = 0; i2 < grad.length - 1; ++i2) {
            grad[i2] = beta[i2] < 0.0 ? ADMM.shrinkage(grad[i2] - lambda, lambda * 1.0E-4) : (beta[i2] > 0.0 ? ADMM.shrinkage(grad[i2] + lambda, lambda * 1.0E-4) : ADMM.shrinkage(grad[i2], lambda));
        }
    }

    public static class L1Solver {
        final double RELTOL;
        final double ABSTOL;
        double gerr;
        int iter;
        final double _eps;
        final int max_iter;
        MathUtils.Norm _gradientNorm = MathUtils.Norm.L_Infinite;
        public double[] _u;
        public static double DEFAULT_RELTOL = 0.01;
        public static double DEFAULT_ABSTOL = 1.0E-4;
        public L_BFGS.ProgressMonitor _pm;

        public L1Solver setGradientNorm(MathUtils.Norm n2) {
            this._gradientNorm = n2;
            return this;
        }

        public L1Solver(double eps, int max_iter, double[] u2) {
            this(eps, max_iter, DEFAULT_RELTOL, DEFAULT_ABSTOL, u2);
        }

        public L1Solver(double eps, int max_iter, double reltol, double abstol, double[] u2) {
            this._eps = eps;
            this.max_iter = max_iter;
            this._u = u2;
            this.RELTOL = reltol;
            this.ABSTOL = abstol;
        }

        public boolean solve(ProximalSolver solver, double[] res, double lambda, boolean hasIntercept) {
            return this.solve(solver, res, lambda, hasIntercept, null, null);
        }

        private double computeErr(double[] z2, double[] grad, double lambda, double[] lb, double[] ub) {
            int j2;
            grad = (double[])grad.clone();
            this.gerr = 0.0;
            if (lb != null) {
                for (j2 = 0; j2 < z2.length; ++j2) {
                    if (z2[j2] != lb[j2] || !(grad[j2] > 0.0)) continue;
                    grad[j2] = z2[j2] >= 0.0 ? -lambda : lambda;
                }
            }
            if (ub != null) {
                for (j2 = 0; j2 < z2.length; ++j2) {
                    if (z2[j2] != ub[j2] || !(grad[j2] < 0.0)) continue;
                    grad[j2] = z2[j2] >= 0.0 ? -lambda : lambda;
                }
            }
            ADMM.subgrad(lambda, z2, grad);
            switch (this._gradientNorm) {
                case L_Infinite: {
                    this.gerr = ArrayUtils.linfnorm(grad, false);
                    break;
                }
                case L2_2: {
                    this.gerr = ArrayUtils.l2norm2(grad, false);
                    break;
                }
                case L2: {
                    this.gerr = Math.sqrt(ArrayUtils.l2norm2(grad, false));
                    break;
                }
                case L1: {
                    this.gerr = ArrayUtils.l1norm(grad, false);
                    break;
                }
                default: {
                    throw H2O.unimpl();
                }
            }
            return this.gerr;
        }

        public boolean solve(ProximalSolver solver, double[] z2, double l1pen, boolean hasIntercept, double[] lb, double[] ub) {
            int i2;
            double[] u2;
            this.gerr = Double.POSITIVE_INFINITY;
            this.iter = 0;
            if (l1pen == 0.0 && lb == null && ub == null) {
                solver.solve(null, z2);
                return true;
            }
            int hasIcpt = hasIntercept ? 1 : 0;
            int N2 = z2.length;
            double abstol = this.ABSTOL * Math.sqrt(N2);
            double[] rho = solver.rho();
            double[] x2 = (double[])z2.clone();
            double[] beta_given = MemoryManager.malloc8d(N2);
            if (this._u != null) {
                u2 = this._u;
                for (int i3 = 0; i3 < beta_given.length - hasIcpt; ++i3) {
                    beta_given[i3] = z2[i3] - this._u[i3];
                }
            } else {
                u2 = this._u = MemoryManager.malloc8d(z2.length);
            }
            double[] kappa = MemoryManager.malloc8d(rho.length);
            if (l1pen > 0.0) {
                for (i2 = 0; i2 < N2 - hasIcpt; ++i2) {
                    kappa[i2] = rho[i2] != 0.0 ? l1pen / rho[i2] : 0.0;
                }
            }
            double orlx = 1.0;
            double reltol = this.RELTOL;
            for (i2 = 0; i2 < this.max_iter && solver.solve(beta_given, x2); ++i2) {
                if (this._pm != null && (i2 + 1) % 5 == 0) {
                    this._pm.progress(z2, solver.gradient(z2));
                }
                double rnorm = 0.0;
                double snorm = 0.0;
                double unorm = 0.0;
                double xnorm = 0.0;
                for (int j2 = 0; j2 < N2 - hasIcpt; ++j2) {
                    double xj = x2[j2];
                    double zjold = z2[j2];
                    double x_hat = xj * orlx + (1.0 - orlx) * zjold;
                    double zj = ADMM.shrinkage(x_hat + u2[j2], kappa[j2]);
                    if (lb != null && zj < lb[j2]) {
                        zj = lb[j2];
                    }
                    if (ub != null && zj > ub[j2]) {
                        zj = ub[j2];
                    }
                    int n2 = j2;
                    u2[n2] = u2[n2] + (x_hat - zj);
                    beta_given[j2] = zj - u2[j2];
                    double r2 = xj - zj;
                    double s2 = zj - zjold;
                    rnorm += r2 * r2;
                    snorm += s2 * s2;
                    xnorm += xj * xj;
                    unorm += rho[j2] * rho[j2] * u2[j2] * u2[j2];
                    z2[j2] = zj;
                }
                if (hasIntercept) {
                    int idx = x2.length - 1;
                    double icpt = x2[idx];
                    if (lb != null && icpt < lb[idx]) {
                        icpt = lb[idx];
                    }
                    if (ub != null && icpt > ub[idx]) {
                        icpt = ub[idx];
                    }
                    double r3 = x2[idx] - icpt;
                    double s3 = icpt - z2[idx];
                    int n3 = idx;
                    u2[n3] = u2[n3] + r3;
                    beta_given[idx] = icpt - u2[idx];
                    rnorm += r3 * r3;
                    snorm += s3 * s3;
                    xnorm += icpt * icpt;
                    unorm += rho[idx] * rho[idx] * u2[idx] * u2[idx];
                    z2[idx] = icpt;
                }
                if (!(rnorm < abstol + reltol * Math.sqrt(xnorm)) || !(snorm < abstol + reltol * Math.sqrt(unorm))) continue;
                double oldGerr = this.gerr;
                this.computeErr(z2, solver.gradient((double[])z2)._gradient, l1pen, lb, ub);
                if (this.gerr > this._eps) {
                    Log.debug("ADMM.L1Solver: iter = " + i2 + " , gerr =  " + this.gerr + ", oldGerr = " + oldGerr + ", rnorm = " + rnorm + ", snorm  " + snorm);
                    if (abstol > 1.0E-12) {
                        abstol *= 0.1;
                    }
                    if (reltol > 1.0E-10) {
                        reltol *= 0.1;
                    }
                    reltol *= 0.1;
                    continue;
                }
                if (this.gerr > this._eps) {
                    Log.warn("ADMM solver finished with gerr = " + this.gerr + " >  eps = " + this._eps);
                }
                this.iter = i2;
                if (this._pm != null && (i2 + 1) % 5 == 0) {
                    this._pm.progress(z2, solver.gradient(z2));
                }
                return true;
            }
            this.computeErr(z2, solver.gradient((double[])z2)._gradient, l1pen, lb, ub);
            if (this.iter == this.max_iter) {
                Log.warn("ADMM solver reached maximum number of iterations (" + this.max_iter + ")");
            } else {
                Log.warn("ADMM solver stopped after " + i2 + " iterations. (max_iter=" + this.max_iter + ")");
            }
            if (this.gerr > this._eps) {
                Log.warn("ADMM solver finished with gerr = " + this.gerr + " >  eps = " + this._eps);
            }
            this.iter = this.max_iter;
            if (this._pm != null && (i2 + 1) % 5 == 0) {
                this._pm.progress(z2, solver.gradient(z2));
            }
            return false;
        }

        public String toString() {
            return "iter = " + this.iter + ", gerr = " + this.gerr;
        }

        public static double estimateRho(double x2, double l1pen, double lb, double ub) {
            if (Double.isInfinite(x2)) {
                return 0.0;
            }
            double rho = 0.0;
            if (l1pen != 0.0 && x2 != 0.0) {
                double D2;
                if (x2 > 0.0) {
                    D2 = l1pen * (l1pen + 4.0 * x2);
                    if (D2 >= 0.0) {
                        double r2 = (l1pen + (D2 = Math.sqrt(D2))) / (2.0 * x2);
                        if (r2 > 0.0) {
                            rho = r2;
                        } else {
                            Log.warn("negative rho estimate(1)! r = " + r2);
                        }
                    }
                } else if (x2 < 0.0 && (D2 = l1pen * (l1pen - 4.0 * x2)) >= 0.0) {
                    double r3 = -(l1pen + (D2 = Math.sqrt(D2))) / (2.0 * x2);
                    if (r3 > 0.0) {
                        rho = r3;
                    } else {
                        Log.warn("negative rho estimate(2)!  r = " + r3);
                    }
                }
                rho *= 0.25;
            }
            if (!Double.isInfinite(lb) || !Double.isInfinite(ub)) {
                boolean oob = -Math.min(x2 - lb, ub - x2) > -1.0E-4;
                rho = oob ? 10.0 : 0.1;
            }
            return rho;
        }
    }

    public static interface ProximalSolver {
        public double[] rho();

        public boolean solve(double[] var1, double[] var2);

        public boolean hasGradient();

        public OptimizationUtils.GradientInfo gradient(double[] var1);

        public int iter();
    }
}

