/*
 * Copyright (C) 2003, 2004 Bjrn-Ove Heimsund
 * 
 * This file is part of SMT.
 * 
 * This library is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as published by the
 * Free Software Foundation; either version 2.1 of the License, or (at your
 * option) any later version.
 * 
 * This library is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
 * for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; if not, write to the Free Software Foundation,
 * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 */

package smt.test;

import junit.framework.TestCase;

import mt.DenseMatrix;
import mt.LowerSPDDenseMatrix;
import mt.Matrix;
import mt.MatrixEntry;
import mt.Vector;
import mt.fact.DenseCholesky;
import mt.fact.DenseLU;
import mt.test.Utilities;
import mt.util.Matrices;
import smt.CompRowMatrix;
import smt.iter.IterativeSolverNotConvergedException;
import smt.iter.mixed.MixedSolver;

/**
 * Test of the mixed solvers
 */
public abstract class MixedSolverTest extends TestCase {

    /**
     * Max sizes
     */
    protected final int max = 50, bmax = 10;

    /**
     * Numerical comparison tolerance
     */
    protected final double tol = 1e-4;

    /**
     * Diagonal shift for singularity handling
     */
    protected double shift = 100;

    /**
     * System matrices
     */
    protected Matrix A, B, Bt, C;

    /**
     * Solution and right hand side vectors
     */
    protected Vector u, q, f, g;

    /**
     * Solver to test
     */
    protected MixedSolver solver;

    /**
     * The whole system matrix (A, B and C)
     */
    protected Matrix D;

    /**
     * Solution arrays
     */
    protected double[] ud, qd;

    /**
     * Constructor for MixedSolverTest
     */
    public MixedSolverTest(String arg0) {
        super(arg0);
    }

    protected void setUp() throws Exception {
        int n = Utilities.getInt(max), m = Utilities.getInt(max), b = Utilities
                .getInt(bmax);

        A = new CompRowMatrix(n, n, b);
        B = new CompRowMatrix(m, n, b);
        Bt = new CompRowMatrix(n, m, b);
        C = new CompRowMatrix(m, m, b);

        Utilities.rowPopulate(B, b);
        Bt = B.transpose(Bt);

        Utilities.rowPopulate(A, b);
        Utilities.rowPopulate(C, b);

        // Symmetrize A
        for (MatrixEntry e : A)
            if (e.column() > e.row())
                e.set(A.get(e.row(), e.column()));

        // Make it positive definite
        A.addDiagonal(shift);
        DenseCholesky dc = new DenseCholesky(new LowerSPDDenseMatrix(A));
        while (!dc.isSPD()) {
            A.addDiagonal(shift);
            dc = new DenseCholesky(new LowerSPDDenseMatrix(A));
        }

        // Copy the whole thing into D
        D = new CompRowMatrix(n + m, n + m, 2 * b);
        for (MatrixEntry e : A)
            D.set(e.row(), e.column(), e.get());
        for (MatrixEntry e : B)
            D.set(e.row() + n, e.column(), e.get());
        for (MatrixEntry e : Bt)
            D.set(e.row(), e.column() + n, e.get());
        for (MatrixEntry e : C)
            D.set(e.row() + n, e.column() + n, e.get());

        // Ensure the whole matrix is non-singular
        A.addDiagonal(shift);
        C.addDiagonal(shift);
        D.addDiagonal(shift);
        DenseLU lu = new DenseLU(new DenseMatrix(D));
        while (lu.isSingular()) {
            A.addDiagonal(shift);
            C.addDiagonal(shift);
            D.addDiagonal(shift);
            lu = new DenseLU(new DenseMatrix(D));
        }

        q = Matrices.random(n);
        u = Matrices.random(m);
        f = Matrices.random(n);
        g = Matrices.random(m);

        createSolver();

        // Compute the correct right hand sides
        B.transMultAdd(u, A.mult(q, f));
        C.multAdd(u, B.mult(q, g));

        // Store for later comparisons
        ud = Matrices.getArray(u);
        qd = Matrices.getArray(q);

        // Randomize the inital solution vectors
        Matrices.random(u);
        Matrices.random(q);
    }

    protected abstract void createSolver() throws Exception;

    protected void tearDown() throws Exception {
        A = B = Bt = C = null;
        f = g = q = u = null;
        qd = ud = null;
        solver = null;
    }

    public void testSolve() {
        try {
            solver.solve(A, B, Bt, C, q, u, f, g);

            int n = q.size(), m = u.size();
            for (int i = 0; i < n; ++i)
                assertEquals(qd[i], q.get(i), tol);
            for (int i = 0; i < m; ++i)
                assertEquals(ud[i], u.get(i), tol);
        } catch (IterativeSolverNotConvergedException e) {
            fail("Solver did not converge: " + e.getReason() + ". Residual="
                    + e.getResidual());
        }
    }

}
