/*
 * Decompiled with CFR 0.152.
 */
package org.hipparchus.filtering.kalman.unscented;

import org.hipparchus.exception.Localizable;
import org.hipparchus.exception.LocalizedCoreFormats;
import org.hipparchus.exception.MathIllegalArgumentException;
import org.hipparchus.exception.MathRuntimeException;
import org.hipparchus.filtering.kalman.KalmanFilter;
import org.hipparchus.filtering.kalman.Measurement;
import org.hipparchus.filtering.kalman.ProcessEstimate;
import org.hipparchus.filtering.kalman.unscented.UnscentedEvolution;
import org.hipparchus.filtering.kalman.unscented.UnscentedProcess;
import org.hipparchus.linear.MatrixDecomposer;
import org.hipparchus.linear.MatrixUtils;
import org.hipparchus.linear.RealMatrix;
import org.hipparchus.linear.RealVector;
import org.hipparchus.util.UnscentedTransformProvider;

public class UnscentedKalmanFilter<T extends Measurement>
implements KalmanFilter<T> {
    private UnscentedProcess<T> process;
    private ProcessEstimate predicted;
    private ProcessEstimate corrected;
    private final MatrixDecomposer decomposer;
    private final int n;
    private final UnscentedTransformProvider utProvider;

    public UnscentedKalmanFilter(MatrixDecomposer decomposer, UnscentedProcess<T> process, ProcessEstimate initialState, UnscentedTransformProvider utProvider) {
        this.decomposer = decomposer;
        this.process = process;
        this.corrected = initialState;
        this.n = this.corrected.getState().getDimension();
        this.utProvider = utProvider;
        if (this.n == 0) {
            throw new MathIllegalArgumentException((Localizable)LocalizedCoreFormats.ZERO_STATE_SIZE, new Object[0]);
        }
    }

    @Override
    public ProcessEstimate estimationStep(T measurement) throws MathRuntimeException {
        RealVector[] sigmaPoints = this.utProvider.unscentedTransform(this.corrected.getState(), this.corrected.getCovariance());
        return this.predictionAndCorrectionSteps(measurement, sigmaPoints);
    }

    public ProcessEstimate predictionAndCorrectionSteps(T measurement, RealVector[] sigmaPoints) throws MathRuntimeException {
        UnscentedEvolution evolution = this.process.getEvolution(this.getCorrected().getTime(), sigmaPoints, measurement);
        this.predict(evolution.getCurrentTime(), evolution.getCurrentStates(), evolution.getProcessNoiseMatrix());
        RealVector[] predictedSigmaPoints = this.utProvider.unscentedTransform(this.predicted.getState(), this.predicted.getCovariance());
        RealVector[] predictedMeasurements = this.process.getPredictedMeasurements(predictedSigmaPoints, measurement);
        RealVector predictedMeasurement = this.utProvider.getUnscentedMeanState(predictedMeasurements);
        RealMatrix r = this.computeInnovationCovarianceMatrix(predictedMeasurements, predictedMeasurement, measurement.getCovariance());
        RealMatrix crossCovarianceMatrix = this.computeCrossCovarianceMatrix(predictedSigmaPoints, this.predicted.getState(), predictedMeasurements, predictedMeasurement);
        RealVector innovation = r == null ? null : this.process.getInnovation(measurement, predictedMeasurement, this.predicted.getState(), r);
        this.correct(measurement, r, crossCovarianceMatrix, innovation);
        return this.getCorrected();
    }

    private void predict(double time, RealVector[] predictedStates, RealMatrix noise) {
        RealVector predictedState = this.utProvider.getUnscentedMeanState(predictedStates);
        RealMatrix predictedCovariance = this.utProvider.getUnscentedCovariance(predictedStates, predictedState).add(noise);
        this.predicted = new ProcessEstimate(time, predictedState, predictedCovariance);
        this.corrected = null;
    }

    private void correct(T measurement, RealMatrix innovationCovarianceMatrix, RealMatrix crossCovarianceMatrix, RealVector innovation) throws MathIllegalArgumentException {
        if (innovation == null) {
            this.corrected = this.predicted;
            return;
        }
        RealMatrix k = this.decomposer.decompose(innovationCovarianceMatrix).solve(crossCovarianceMatrix.transpose()).transpose();
        RealVector correctedState = this.predicted.getState().add(k.operate(innovation));
        RealMatrix correctedCovariance = this.predicted.getCovariance().subtract(k.multiply(innovationCovarianceMatrix).multiplyTransposed(k));
        this.corrected = new ProcessEstimate(measurement.getTime(), correctedState, correctedCovariance, null, null, innovationCovarianceMatrix, k);
    }

    @Override
    public ProcessEstimate getPredicted() {
        return this.predicted;
    }

    @Override
    public ProcessEstimate getCorrected() {
        return this.corrected;
    }

    public UnscentedTransformProvider getUnscentedTransformProvider() {
        return this.utProvider;
    }

    private RealMatrix computeInnovationCovarianceMatrix(RealVector[] predictedMeasurements, RealVector predictedMeasurement, RealMatrix r) {
        if (predictedMeasurement == null) {
            return null;
        }
        RealMatrix innovationCovarianceMatrix = this.utProvider.getUnscentedCovariance(predictedMeasurements, predictedMeasurement);
        return innovationCovarianceMatrix.add(r);
    }

    private RealMatrix computeCrossCovarianceMatrix(RealVector[] predictedStates, RealVector predictedState, RealVector[] predictedMeasurements, RealVector predictedMeasurement) {
        RealMatrix crossCovarianceMatrix = MatrixUtils.createRealMatrix((int)predictedState.getDimension(), (int)predictedMeasurement.getDimension());
        RealVector wc = this.utProvider.getWc();
        for (int i = 0; i <= 2 * this.n; ++i) {
            RealVector stateDiff = predictedStates[i].subtract(predictedState);
            RealVector measDiff = predictedMeasurements[i].subtract(predictedMeasurement);
            crossCovarianceMatrix = crossCovarianceMatrix.add(this.outer(stateDiff, measDiff).scalarMultiply(wc.getEntry(i)));
        }
        return crossCovarianceMatrix;
    }

    private RealMatrix outer(RealVector a, RealVector b) {
        RealMatrix outMatrix = MatrixUtils.createRealMatrix((int)a.getDimension(), (int)b.getDimension());
        for (int row = 0; row < outMatrix.getRowDimension(); ++row) {
            for (int col = 0; col < outMatrix.getColumnDimension(); ++col) {
                outMatrix.setEntry(row, col, a.getEntry(row) * b.getEntry(col));
            }
        }
        return outMatrix;
    }
}

