package org.logolith.kzgo;

import java.math.BigInteger;

/**
 * Performs polynomial interpolation over a finite field using BigIntegers.
 */
public class PolynomialInterpolator {

    /**
     * Computes Lagrange basis polynomial L_j(x) = product_{i=0, i!=j}^{k} (x - x_i) / (x_j - x_i)
     * specifically for evaluation points x_i = 0, 1, ..., degree.
     * Returns the coefficients of L_j(x).
     *
     * This is computationally intensive.
     */
    private static BigInteger[] computeLagrangeBasisCoeffs(int j, int degree, BigInteger modulus) {
        if (j < 0 || j > degree) {
            throw new IllegalArgumentException("j out of range");
        }

        // Calculate denominator: product_{i=0, i!=j}^{degree} (j - i)
        BigInteger denominator = BigInteger.ONE;
        BigInteger bigJ = BigInteger.valueOf(j);
        for (int i = 0; i <= degree; i++) {
            if (i == j) continue;
            denominator = denominator.multiply(bigJ.subtract(BigInteger.valueOf(i))).mod(modulus);
        }
        BigInteger invDenominator = denominator.modInverse(modulus);

        // Calculate numerator polynomial coefficients: product_{i=0, i!=j}^{degree} (x - i)
        // Start with polynomial P(x) = 1
        BigInteger[] numeratorCoeffs = new BigInteger[degree + 1];
        for(int i=0; i<=degree; i++) numeratorCoeffs[i] = BigInteger.ZERO;
        numeratorCoeffs[0] = BigInteger.ONE; // P(x) = 1 initially (degree 0)
        int currentDegree = 0;

        for (int i = 0; i <= degree; i++) {
            if (i == j) continue;
            // Multiply current poly by (x - i)
            // (a_k*x^k + ... + a_0) * (x - i)
            // = a_k*x^(k+1) + (a_{k-1} - i*a_k)*x^k + ... + (a_0 - i*a_1)*x - i*a_0
            BigInteger bigI = BigInteger.valueOf(i);
            BigInteger prevCoeff = BigInteger.ZERO;
            for (int k = 0; k <= currentDegree; k++) {
                BigInteger currentCoeff = numeratorCoeffs[k];
                // New coefficient for x^k is currentCoeff_{k-1} - i * currentCoeff_k
                numeratorCoeffs[k] = prevCoeff.subtract(currentCoeff.multiply(bigI)).mod(modulus);
                prevCoeff = currentCoeff;
            }
            // Coefficient for x^(k+1) is prevCoeff (which was coeff_k)
            numeratorCoeffs[currentDegree + 1] = prevCoeff;
            currentDegree++;
        }

        // Multiply all coefficients by invDenominator
        for (int i = 0; i <= degree; i++) {
            numeratorCoeffs[i] = numeratorCoeffs[i].multiply(invDenominator).mod(modulus);
        }

        return numeratorCoeffs; // These are coefficients of L_j(x)
    }

    /**
     * Interpolates a polynomial P(x) given its evaluations P(0), P(1), ..., P(degree).
     * Uses Lagrange interpolation: P(x) = sum_{j=0}^{degree} evaluations[j] * L_j(x)
     *
     * @param evaluations Array of evaluations P(0), P(1), ... P(degree).
     * @param modulus The field modulus.
     * @return The coefficients of the interpolated polynomial [c0, c1, ..., c_degree].
     */
    public static BigInteger[] interpolate(BigInteger[] evaluations, BigInteger modulus) {
        int n = evaluations.length;
        if (n == 0) {
            return new BigInteger[0];
        }
        int degree = n - 1;

        BigInteger[] resultCoeffs = new BigInteger[n];
        for(int i=0; i<n; i++) resultCoeffs[i] = BigInteger.ZERO;

        // This is O(N^3) - potentially slow for large N. FFT/Newton is O(N^2) or O(N log N).
        // Consider optimizing if N > ~100.
        for (int j = 0; j < n; j++) {
            if (evaluations[j].equals(BigInteger.ZERO)) continue; // Skip if y_j is zero
            
            // Calculate coefficients of L_j(x)
            BigInteger[] ljCoeffs = computeLagrangeBasisCoeffs(j, degree, modulus);

            // Add evaluations[j] * L_j(x) to the result polynomial
            for (int i = 0; i < n; i++) {
                resultCoeffs[i] = resultCoeffs[i].add(evaluations[j].multiply(ljCoeffs[i]))
                                                .mod(modulus);
            }
        }

        return resultCoeffs;
    }

} 