/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.commonWalkingControlModules.modelPredictiveController.core;

import gnu.trove.list.TIntList;
import org.ejml.data.DMatrix;
import org.ejml.data.DMatrixRMaj;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.ForceObjectiveCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.ForceRateTrackingCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.ForceTrackingCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.MPCCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.MPCCommandList;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.MPCContinuityCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.MPCValueCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.NormalForceBoundCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.RhoBoundCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.RhoObjectiveCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.RhoRateTrackingCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.RhoTrackingCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.VRPTrackingCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.core.BlockInverseCalculator;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.core.LinearMPCIndexHandler;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.core.MPCQPInputCalculator;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.core.MPCQPSolver;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.core.RowMajorNativeMatrixGrower;
import us.ihmc.commonWalkingControlModules.momentumBasedController.optimization.NativeQPInputTypeA;
import us.ihmc.commonWalkingControlModules.momentumBasedController.optimization.NativeQPInputTypeC;
import us.ihmc.convexOptimization.quadraticProgram.InverseMatrixCalculator;
import us.ihmc.matrixlib.NativeMatrix;
import us.ihmc.robotics.time.ExecutionTimer;
import us.ihmc.yoVariables.registry.YoRegistry;
import us.ihmc.yoVariables.variable.YoBoolean;
import us.ihmc.yoVariables.variable.YoDouble;
import us.ihmc.yoVariables.variable.YoInteger;

public class LinearMPCQPSolver {
    private static final boolean debug = false;
    protected final YoRegistry registry = new YoRegistry(this.getClass().getSimpleName());
    private final ExecutionTimer qpSolverTimer = new ExecutionTimer("mpcSolverTimer", 0.5, this.registry);
    public final MPCQPSolver qpSolver;
    private final YoBoolean addRateRegularization = new YoBoolean("AddRateRegularization", this.registry);
    private final YoBoolean foundSolution = new YoBoolean("foundSolution", this.registry);
    protected final NativeMatrix previousSolution;
    public final NativeQPInputTypeA qpInputTypeA = new NativeQPInputTypeA(0);
    public final NativeQPInputTypeC qpInputTypeC = new NativeQPInputTypeC(0);
    protected final NativeMatrix solverOutput;
    private final YoInteger numberOfActiveVariables = new YoInteger("numberOfActiveMPCVariables", this.registry);
    private final YoInteger numberOfIterations = new YoInteger("numberOfMPCIterations", this.registry);
    private final YoInteger numberOfEqualityConstraints = new YoInteger("numberOfMPCEqualityConstraints", this.registry);
    private final YoInteger numberOfInequalityConstraints = new YoInteger("numberOfMPCInequalityConstraints", this.registry);
    private final YoInteger numberOfConstraints = new YoInteger("numberOfMPCConstraints", this.registry);
    private final YoDouble comCoefficientRegularization = new YoDouble("comCoefficientRegularization", this.registry);
    private final YoDouble rhoCoefficientRegularization = new YoDouble("rhoCoefficientRegularization", this.registry);
    private final YoDouble comRateCoefficientRegularization = new YoDouble("comRateCoefficientRegularization", this.registry);
    private final YoDouble rhoRateCoefficientRegularization = new YoDouble("rhoRateCoefficientRegularization", this.registry);
    private int problemSize;
    private boolean resetActiveSet = false;
    private boolean useWarmStart = false;
    private int maxNumberOfIterations = 100;
    private final LinearMPCIndexHandler indexHandler;
    private final MPCQPInputCalculator inputCalculator;
    protected final double dt;
    protected final double dt2;
    private final RowMajorNativeMatrixGrower nativeMatrixGrower = new RowMajorNativeMatrixGrower();

    public LinearMPCQPSolver(LinearMPCIndexHandler indexHandler, double dt, double gravityZ, YoRegistry parentRegistry) {
        this(indexHandler, dt, gravityZ, new BlockInverseCalculator(indexHandler, indexHandler::getComCoefficientStartIndex, i -> indexHandler.getRhoCoefficientsInSegment(i) + 6), parentRegistry);
    }

    public LinearMPCQPSolver(LinearMPCIndexHandler indexHandler, double dt, double gravityZ, InverseMatrixCalculator<NativeMatrix> inverseMatrixCalculator, YoRegistry parentRegistry) {
        this.indexHandler = indexHandler;
        this.dt = dt;
        this.dt2 = dt * dt;
        this.rhoCoefficientRegularization.set(1.0E-5);
        this.comCoefficientRegularization.set(1.0E-5);
        this.rhoRateCoefficientRegularization.set(1.0E-10);
        this.comRateCoefficientRegularization.set(1.0E-10);
        this.qpSolver = new MPCQPSolver();
        this.qpSolver.setConvergenceThreshold(5.0E-6);
        this.qpSolver.setConvergenceThresholdForLagrangeMultipliers(1.0E-4);
        if (inverseMatrixCalculator != null) {
            this.qpSolver.setInverseHessianCalculator(inverseMatrixCalculator);
        }
        this.qpSolver.setResetActiveSetOnSizeChange(false);
        this.inputCalculator = new MPCQPInputCalculator(indexHandler, gravityZ);
        int problemSize = 114;
        this.previousSolution = new NativeMatrix(0, 0);
        this.solverOutput = new NativeMatrix(problemSize, 1);
        parentRegistry.addChild(this.registry);
    }

    public void setComCoefficientRegularizationWeight(double weight) {
        this.comCoefficientRegularization.set(weight);
    }

    public void setRhoCoefficientRegularizationWeight(double weight) {
        this.rhoCoefficientRegularization.set(weight);
    }

    public void setComRateCoefficientRegularizationWeight(double weight) {
        this.comRateCoefficientRegularization.set(weight);
    }

    public void setRhoRateCoefficientRegularizationWeight(double weight) {
        this.rhoRateCoefficientRegularization.set(weight);
    }

    public void setUseWarmStart(boolean useWarmStart) {
        this.useWarmStart = useWarmStart;
    }

    public void setMaxNumberOfIterations(int maxNumberOfIterations) {
        this.maxNumberOfIterations = maxNumberOfIterations;
    }

    public void notifyResetActiveSet() {
        this.resetActiveSet = true;
    }

    public void setPreviousSolution(DMatrixRMaj previousSolution) {
        this.previousSolution.set(previousSolution);
        this.addRateRegularization.set(true);
    }

    private boolean pollResetActiveSet() {
        boolean ret = this.resetActiveSet;
        this.resetActiveSet = false;
        return ret;
    }

    public void initialize() {
        this.problemSize = this.indexHandler.getTotalProblemSize();
        this.qpInputTypeA.setNumberOfVariables(this.problemSize);
        this.qpInputTypeC.setNumberOfVariables(this.problemSize);
        this.solverOutput.reshape(this.problemSize, 1);
        this.resetRateRegularization();
        this.qpSolver.initialize(this.problemSize);
    }

    public void resetRateRegularization() {
        this.addRateRegularization.set(false);
    }

    private void addCoefficientRegularization() {
        this.addValueRegularization();
        if (this.addRateRegularization.getBooleanValue()) {
            this.addRateRegularization();
        }
    }

    public void addValueRegularization() {
        for (int segmentId = 0; segmentId < this.indexHandler.getNumberOfSegments(); ++segmentId) {
            int start = this.indexHandler.getComCoefficientStartIndex(segmentId);
            this.qpSolver.addRegularization(start, 6, this.comCoefficientRegularization.getValue());
            start = this.indexHandler.getRhoCoefficientStartIndex(segmentId);
            this.qpSolver.addRegularization(start, this.indexHandler.getRhoCoefficientsInSegment(segmentId), this.rhoCoefficientRegularization.getValue());
        }
    }

    public void addRateRegularization() {
        double comCoefficientFactor = this.comRateCoefficientRegularization.getDoubleValue() / this.dt2;
        double rhoCoefficientFactor = this.rhoRateCoefficientRegularization.getDoubleValue() / this.dt2;
        for (int segmentId = 0; segmentId < this.indexHandler.getNumberOfSegments(); ++segmentId) {
            int start = this.indexHandler.getComCoefficientStartIndex(segmentId);
            this.qpSolver.addRateRegularization(start, 6, comCoefficientFactor, (DMatrix)this.previousSolution);
            this.qpSolver.addRateRegularization(start += 6, this.indexHandler.getRhoCoefficientsInSegment(segmentId), rhoCoefficientFactor, (DMatrix)this.previousSolution);
        }
    }

    public void submitMPCCommandList(MPCCommandList commandList) {
        for (int i = 0; i < commandList.getNumberOfCommands(); ++i) {
            MPCCommand<?> command = commandList.getCommand(i);
            this.submitMPCCommand(command);
        }
    }

    public void submitMPCCommand(MPCCommand<?> command) {
        switch (command.getCommandType()) {
            case VALUE: {
                this.submitMPCValueObjective((MPCValueCommand)command);
                break;
            }
            case CONTINUITY: {
                this.submitContinuityObjective((MPCContinuityCommand)command);
                break;
            }
            case LIST: {
                this.submitMPCCommandList((MPCCommandList)command);
                break;
            }
            case RHO_VALUE: {
                this.submitRhoValueCommand((RhoObjectiveCommand)command);
                break;
            }
            case VRP_TRACKING: {
                this.submitVRPTrackingCommand((VRPTrackingCommand)command);
                break;
            }
            case RHO_BOUND: {
                this.submitRhoBoundCommand((RhoBoundCommand)command);
                break;
            }
            case NORMAL_FORCE_BOUND: {
                this.submitNormalForceBoundCommand((NormalForceBoundCommand)command);
                break;
            }
            case FORCE_VALUE: {
                this.submitForceValueCommand((ForceObjectiveCommand)command);
                break;
            }
            case FORCE_TRACKING: {
                this.submitForceTrackingCommand((ForceTrackingCommand)command);
                break;
            }
            case FORCE_RATE_TRACKING: {
                this.submitForceRateTrackingCommand((ForceRateTrackingCommand)command);
                break;
            }
            case RHO_TRACKING: {
                this.submitRhoTrackingCommand((RhoTrackingCommand)command);
                break;
            }
            case RHO_RATE_TRACKING: {
                this.submitRhoRateTrackingCommand((RhoRateTrackingCommand)command);
                break;
            }
            default: {
                throw new RuntimeException("The command type: " + (Object)((Object)command.getCommandType()) + " is not handled.");
            }
        }
    }

    public void submitRhoValueCommand(RhoObjectiveCommand command) {
        int offset = this.inputCalculator.calculateCompactRhoValueCommand(this.qpInputTypeA, command);
        if (offset != -1) {
            this.addInput(this.qpInputTypeA, offset);
        }
    }

    public void submitMPCValueObjective(MPCValueCommand command) {
        int offset = this.inputCalculator.calculateCompactValueObjective(this.qpInputTypeA, command);
        if (offset != -1) {
            this.addInput(this.qpInputTypeA, offset);
        }
    }

    public void submitContinuityObjective(MPCContinuityCommand command) {
        int offset = this.inputCalculator.calculateContinuityObjective(this.qpInputTypeA, command);
        if (offset != -1) {
            this.addInput(this.qpInputTypeA, offset);
        }
    }

    public void submitVRPTrackingCommand(VRPTrackingCommand command) {
        int offset = this.inputCalculator.calculateCompactVRPTrackingObjective(this.qpInputTypeC, command);
        if (offset != -1) {
            this.addInput(this.qpInputTypeC, offset);
        }
    }

    public void submitRhoBoundCommand(RhoBoundCommand command) {
        int offset = this.inputCalculator.calculateRhoBoundCommandCompact(this.qpInputTypeA, command);
        if (offset != -1) {
            this.addInput(this.qpInputTypeA, offset, command.getSlackVariableWeight());
        }
    }

    public void submitNormalForceBoundCommand(NormalForceBoundCommand command) {
        int offset = this.inputCalculator.calculateNormalForceBoundCommandCompact(this.qpInputTypeA, command);
        if (offset != -1) {
            this.addInput(this.qpInputTypeA, offset);
        }
    }

    public void submitForceValueCommand(ForceObjectiveCommand command) {
        boolean success = this.inputCalculator.calculateForceMinimizationObjective(this.qpInputTypeC, command);
        if (success) {
            this.addInput(this.qpInputTypeC);
        }
    }

    public void submitForceTrackingCommand(ForceTrackingCommand command) {
        int offset = this.inputCalculator.calculateForceTrackingObjective(this.qpInputTypeC, command);
        if (offset != -1) {
            this.addInput(this.qpInputTypeC);
        }
    }

    public void submitForceRateTrackingCommand(ForceRateTrackingCommand command) {
        int offset = this.inputCalculator.calculateForceRateTrackingObjective(this.qpInputTypeC, command);
        if (offset != -1) {
            this.addInput(this.qpInputTypeC);
        }
    }

    public void submitRhoTrackingCommand(RhoTrackingCommand command) {
        int offset = this.inputCalculator.calculateRhoTrackingObjective(this.qpInputTypeC, command);
        if (offset != -1) {
            this.addInput(this.qpInputTypeC, offset);
        }
    }

    public void submitRhoRateTrackingCommand(RhoRateTrackingCommand command) {
        int offset = this.inputCalculator.calculateRhoRateTrackingObjective(this.qpInputTypeC, command);
        if (offset != -1) {
            this.addInput(this.qpInputTypeC, offset);
        }
    }

    public void addInput(NativeQPInputTypeA input) {
        this.addInput(input, 0);
    }

    public void addInput(NativeQPInputTypeA input, int offset) {
        this.addInput(input, offset, Double.NaN);
    }

    public void addInput(NativeQPInputTypeA input, int offset, double slackVariableWeight) {
        switch (input.getConstraintType()) {
            case OBJECTIVE: {
                if (input.useWeightScalar()) {
                    this.qpSolver.addObjective(input.taskJacobian, input.taskObjective, input.getWeightScalar(), offset);
                    break;
                }
                this.qpSolver.addObjective(input.taskJacobian, input.taskObjective, input.getTaskWeightMatrix(), offset);
                break;
            }
            case EQUALITY: {
                this.qpSolver.addEqualityConstraint(input.taskJacobian, input.taskObjective, this.problemSize, offset);
                break;
            }
            case LEQ_INEQUALITY: {
                this.qpSolver.addMotionLesserOrEqualInequalityConstraint(input.taskJacobian, input.taskObjective, slackVariableWeight, this.problemSize, offset);
                break;
            }
            case GEQ_INEQUALITY: {
                this.qpSolver.addMotionGreaterOrEqualInequalityConstraint(input.taskJacobian, input.taskObjective, slackVariableWeight, this.problemSize, offset);
                break;
            }
            default: {
                throw new RuntimeException("Unexpected constraint type: " + (Object)((Object)input.getConstraintType()));
            }
        }
    }

    public void addInput(NativeQPInputTypeC input) {
        this.addInput(input, 0);
    }

    public void addInput(NativeQPInputTypeC input, int offset) {
        if (!input.useWeightScalar()) {
            throw new IllegalArgumentException("Not yet implemented.");
        }
        this.qpSolver.addDirectObjective(input.directCostHessian, input.directCostGradient, input.getWeightScalar(), offset);
    }

    public boolean solve() {
        this.addCoefficientRegularization();
        this.numberOfEqualityConstraints.set(this.qpSolver.getNumberOfEqualityConstraints());
        this.numberOfInequalityConstraints.set(this.qpSolver.getNumberOfInequalityConstraints());
        this.numberOfConstraints.set(this.numberOfEqualityConstraints.getIntegerValue() + this.numberOfInequalityConstraints.getIntegerValue());
        this.qpSolverTimer.startMeasurement();
        this.qpSolver.setUseWarmStart(this.useWarmStart);
        this.qpSolver.setMaxNumberOfIterations(this.maxNumberOfIterations);
        if (this.useWarmStart && this.pollResetActiveSet()) {
            this.qpSolver.resetActiveSet();
        }
        this.numberOfActiveVariables.set(this.problemSize);
        this.numberOfIterations.set(this.qpSolver.solve((DMatrix)this.solverOutput));
        this.qpSolverTimer.stopMeasurement();
        if (this.solverOutput.containsNaN()) {
            this.addRateRegularization.set(false);
            this.numberOfIterations.set(-1);
            this.foundSolution.set(false);
            return false;
        }
        this.foundSolution.set(true);
        this.addRateRegularization.set(true);
        return true;
    }

    public NativeMatrix getSolution() {
        return this.solverOutput;
    }

    public void setActiveInequalityIndices(TIntList activeInequalityIndices) {
        this.qpSolver.setActiveInequalityIndices(activeInequalityIndices);
    }

    public TIntList getActiveInequalityIndices() {
        return this.qpSolver.getActiveInequalityIndices();
    }
}

