/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.commonWalkingControlModules.orientationControl;

import java.util.ArrayList;
import java.util.List;
import org.ejml.data.DMatrix1Row;
import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import us.ihmc.commonWalkingControlModules.orientationControl.VariationalCommonValues;
import us.ihmc.commonWalkingControlModules.orientationControl.VariationalDynamicsCalculator;
import us.ihmc.commonWalkingControlModules.orientationControl.VariationalFunction;
import us.ihmc.commons.lists.RecyclingArrayList;
import us.ihmc.euclid.referenceFrame.interfaces.FrameQuaternionReadOnly;
import us.ihmc.euclid.tuple3D.Vector3D;
import us.ihmc.euclid.tuple3D.interfaces.Tuple3DBasics;
import us.ihmc.euclid.tuple3D.interfaces.Tuple3DReadOnly;
import us.ihmc.euclid.tuple3D.interfaces.Vector3DBasics;
import us.ihmc.euclid.tuple3D.interfaces.Vector3DReadOnly;
import us.ihmc.euclid.tuple4D.interfaces.QuaternionReadOnly;
import us.ihmc.matrixlib.NativeCommonOps;
import us.ihmc.robotics.math.trajectories.generators.MultipleWaypointsOrientationTrajectoryGenerator;
import us.ihmc.robotics.math.trajectories.generators.MultipleWaypointsPositionTrajectoryGenerator;

public class DifferentialVariationalSegment
implements VariationalFunction {
    private final double dt;
    private final DMatrixRMaj PB = new DMatrixRMaj(3, 3);
    private final DMatrixRMaj BTransposeP = new DMatrixRMaj(3, 6);
    private final DMatrixRMaj PDot = new DMatrixRMaj(6, 6);
    private final VariationalDynamicsCalculator dynamicsCalculator = new VariationalDynamicsCalculator();
    private final RecyclingArrayList<DMatrixRMaj> PReverseTrajectory = new RecyclingArrayList(() -> new DMatrixRMaj(6, 6));
    private final RecyclingArrayList<DMatrixRMaj> KReverseTrajectory = new RecyclingArrayList(() -> new DMatrixRMaj(6, 6));
    final List<DMatrixRMaj> PTrajectory = new ArrayList<DMatrixRMaj>();
    final List<DMatrixRMaj> KTrajectory = new ArrayList<DMatrixRMaj>();
    private final Vector3DBasics angularVelocityInBodyFrame = new Vector3D();

    public DifferentialVariationalSegment(double dt) {
        this.dt = dt;
    }

    public void set(VariationalCommonValues commonValues, MultipleWaypointsOrientationTrajectoryGenerator orientationTrajectory, MultipleWaypointsPositionTrajectoryGenerator angularMomentumTrajectory, DMatrixRMaj PAtEnd, double startTime, double endTime) {
        this.set(commonValues.getQ(), commonValues.getRInverse(), commonValues.getInertia(), commonValues.getInertiaInverse(), orientationTrajectory, angularMomentumTrajectory, PAtEnd, startTime, endTime);
    }

    public void set(DMatrixRMaj Q, DMatrixRMaj RInverse, DMatrixRMaj inertia, DMatrixRMaj inertiaInverse, MultipleWaypointsOrientationTrajectoryGenerator orientationTrajectory, MultipleWaypointsPositionTrajectoryGenerator angularMomentumTrajectory, DMatrixRMaj PAtEnd, double startTime, double endTime) {
        this.PTrajectory.clear();
        this.KTrajectory.clear();
        this.PReverseTrajectory.clear();
        this.KReverseTrajectory.clear();
        ((DMatrixRMaj)this.PReverseTrajectory.add()).set((DMatrixD1)PAtEnd);
        this.computeDesireds(endTime, inertia, inertiaInverse, orientationTrajectory, angularMomentumTrajectory);
        this.computeGainMatrix(this.dynamicsCalculator.getB(), PAtEnd, RInverse, (DMatrixRMaj)this.KReverseTrajectory.add());
        for (double time = endTime - this.dt; time >= startTime + this.dt / 10.0; time -= this.dt) {
            this.computeDesireds(time, inertia, inertiaInverse, orientationTrajectory, angularMomentumTrajectory);
            DMatrixRMaj previousP = (DMatrixRMaj)this.PReverseTrajectory.getLast();
            DMatrixRMaj newP = (DMatrixRMaj)this.PReverseTrajectory.add();
            CommonOps_DDRM.mult((DMatrix1Row)previousP, (DMatrix1Row)this.dynamicsCalculator.getB(), (DMatrix1Row)this.PB);
            this.computePDot(Q, this.PB, RInverse, previousP, this.dynamicsCalculator.getA());
            CommonOps_DDRM.add((DMatrixD1)previousP, (double)(-this.dt), (DMatrixD1)this.PDot, (DMatrixD1)newP);
            this.computeGainMatrix(this.dynamicsCalculator.getB(), newP, RInverse, (DMatrixRMaj)this.KReverseTrajectory.add());
        }
        for (int i = this.PReverseTrajectory.size() - 1; i >= 0; --i) {
            this.PTrajectory.add((DMatrixRMaj)this.PReverseTrajectory.get(i));
            this.KTrajectory.add((DMatrixRMaj)this.KReverseTrajectory.get(i));
        }
    }

    private void computeDesireds(double time, DMatrixRMaj inertia, DMatrixRMaj inertiaInverse, MultipleWaypointsOrientationTrajectoryGenerator orientationTrajectory, MultipleWaypointsPositionTrajectoryGenerator angularMomentumTrajectory) {
        orientationTrajectory.compute(time);
        angularMomentumTrajectory.compute(time);
        FrameQuaternionReadOnly desiredOrientation = orientationTrajectory.getOrientation();
        this.angularVelocityInBodyFrame.set((Tuple3DReadOnly)orientationTrajectory.getAngularVelocity());
        desiredOrientation.transform((Tuple3DBasics)this.angularVelocityInBodyFrame);
        this.dynamicsCalculator.compute((QuaternionReadOnly)desiredOrientation, (Vector3DReadOnly)this.angularVelocityInBodyFrame, (Vector3DReadOnly)angularMomentumTrajectory.getVelocity(), inertia, inertiaInverse);
    }

    private void computeGainMatrix(DMatrixRMaj B, DMatrixRMaj P, DMatrixRMaj RInverse, DMatrixRMaj KMatrixToPack) {
        CommonOps_DDRM.mult((DMatrix1Row)B, (DMatrix1Row)P, (DMatrix1Row)this.BTransposeP);
        CommonOps_DDRM.mult((DMatrix1Row)RInverse, (DMatrix1Row)this.BTransposeP, (DMatrix1Row)KMatrixToPack);
    }

    @Override
    public void compute(double timeInState, DMatrixRMaj PToPack, DMatrixRMaj KToPack) {
        int startIndex = this.getStartIndex(timeInState);
        DMatrixRMaj startP = this.PTrajectory.get(startIndex);
        DMatrixRMaj startK = this.KTrajectory.get(startIndex);
        if (startIndex == this.PTrajectory.size() - 1) {
            PToPack.set((DMatrixD1)this.PTrajectory.get(this.PTrajectory.size() - 1));
            KToPack.set((DMatrixD1)this.KTrajectory.get(this.KTrajectory.size() - 1));
            return;
        }
        DMatrixRMaj endP = this.PTrajectory.get(startIndex + 1);
        DMatrixRMaj endK = this.KTrajectory.get(startIndex + 1);
        double alpha = this.getAlphaBetweenSegments(timeInState);
        DifferentialVariationalSegment.interpolate(startP, endP, alpha, PToPack);
        DifferentialVariationalSegment.interpolate(startK, endK, alpha, KToPack);
    }

    int getStartIndex(double timeInState) {
        return (int)Math.floor(timeInState / this.dt + this.dt / 10.0);
    }

    double getAlphaBetweenSegments(double timeInState) {
        return timeInState % this.dt / this.dt;
    }

    private static void interpolate(DMatrixRMaj start, DMatrixRMaj end, double alpha, DMatrixRMaj ret) {
        CommonOps_DDRM.scale((double)(1.0 - alpha), (DMatrixD1)start, (DMatrixD1)ret);
        CommonOps_DDRM.addEquals((DMatrixD1)ret, (double)alpha, (DMatrixD1)end);
    }

    private void computePDot(DMatrixRMaj Q, DMatrixRMaj PB, DMatrixRMaj RInverse, DMatrixRMaj P, DMatrixRMaj A) {
        DifferentialVariationalSegment.computePDot(Q, PB, RInverse, P, A, this.PDot);
    }

    static void computePDot(DMatrixRMaj Q, DMatrixRMaj PB, DMatrixRMaj RInverse, DMatrixRMaj P, DMatrixRMaj A, DMatrixRMaj PDotToPack) {
        NativeCommonOps.multQuad((DMatrix1Row)PB, (DMatrix1Row)RInverse, (DMatrix1Row)PDotToPack);
        CommonOps_DDRM.addEquals((DMatrixD1)PDotToPack, (double)-1.0, (DMatrixD1)Q);
        CommonOps_DDRM.multAdd((double)-1.0, (DMatrix1Row)P, (DMatrix1Row)A, (DMatrix1Row)PDotToPack);
        CommonOps_DDRM.multAddTransA((double)-1.0, (DMatrix1Row)A, (DMatrix1Row)P, (DMatrix1Row)PDotToPack);
    }
}

