/*
 * Decompiled with CFR 0.152.
 */
package pattern.model.glm;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import javax.xml.xpath.XPathConstants;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import pattern.PMML;
import pattern.PatternException;
import pattern.model.Model;
import pattern.model.glm.LinkFunction;
import pattern.model.glm.PCell;
import pattern.model.glm.PPCell;
import pattern.model.glm.PPMatrix;
import pattern.model.glm.ParamMatrix;
import storm.trident.tuple.TridentTuple;

public class GeneralizedRegressionModel
extends Model
implements Serializable {
    private static final Logger LOG = LoggerFactory.getLogger(GeneralizedRegressionModel.class);
    PPMatrix ppmatrix = new PPMatrix();
    ParamMatrix paramMatrix = new ParamMatrix();
    HashSet<String> covariate = new HashSet();
    HashSet<String> factors = new HashSet();
    HashSet<String> parameterList = new HashSet();
    LinkFunction linkFunction;

    public GeneralizedRegressionModel(PMML pmml) throws PatternException {
        this.schema = pmml.getSchema();
        this.schema.parseMiningSchema(pmml.getNodeList("/PMML/GeneralRegressionModel/MiningSchema/MiningField"));
        this.ppmatrix.parsePPCell(pmml.getNodeList("/PMML/GeneralRegressionModel/PPMatrix/PPCell"));
        LOG.debug(this.ppmatrix.toString());
        this.paramMatrix.parsePCell(pmml.getNodeList("/PMML/GeneralRegressionModel/ParamMatrix/PCell"));
        LOG.debug(this.paramMatrix.toString());
        String node_expr = "/PMML/GeneralRegressionModel/ParameterList/Parameter";
        NodeList child_nodes = pmml.getNodeList(node_expr);
        int i = 0;
        while (i < child_nodes.getLength()) {
            Node child = child_nodes.item(i);
            if (child.getNodeType() == 1) {
                String name = ((Element)child).getAttribute("name");
                this.parameterList.add(name);
            }
            ++i;
        }
        String node_expr_covariate = "/PMML/GeneralRegressionModel/CovariateList/Predictor";
        NodeList child_nodes_covariate = pmml.getNodeList(node_expr_covariate);
        int i2 = 0;
        while (i2 < child_nodes_covariate.getLength()) {
            Node child = child_nodes_covariate.item(i2);
            if (child.getNodeType() == 1) {
                String name = ((Element)child).getAttribute("name");
                this.covariate.add(name);
            }
            ++i2;
        }
        String node_expr_factors = "/PMML/GeneralRegressionModel/FactorList/Predictor";
        NodeList child_nodes_factors = pmml.getNodeList(node_expr_factors);
        int i3 = 0;
        while (i3 < child_nodes_factors.getLength()) {
            Node child = child_nodes_factors.item(i3);
            if (child.getNodeType() == 1) {
                String name = ((Element)child).getAttribute("name");
                this.factors.add(name);
            }
            ++i3;
        }
        String node = "/PMML/GeneralRegressionModel/@linkFunction";
        String linkFunctionStr = pmml.getReader().read(node, XPathConstants.STRING).toString();
        this.linkFunction = LinkFunction.getFunction(linkFunctionStr);
    }

    @Override
    public void prepare() {
    }

    @Override
    public String classifyTuple(TridentTuple values) throws PatternException {
        double result = 0.0;
        for (String param : this.paramMatrix.keySet()) {
            PCell pCell;
            ArrayList pCells;
            if (this.ppmatrix.containsKey(param)) {
                pCells = (ArrayList)this.paramMatrix.get(param);
                pCell = (PCell)pCells.get(0);
                Double beta = Double.parseDouble(pCell.getBeta());
                ArrayList ppCells = (ArrayList)this.ppmatrix.get(param);
                double paramResult = 1.0;
                for (PPCell pc : ppCells) {
                    int power = Integer.parseInt(pc.getValue());
                    String data = values.getStringByField(pc.getPredictorName());
                    if (data != null) {
                        if (this.factors.contains(param)) {
                            if (pc.getValue().equals(data)) {
                                paramResult *= 1.0;
                                continue;
                            }
                            paramResult *= 0.0;
                            continue;
                        }
                        paramResult *= Math.pow(Double.parseDouble(data), power);
                        continue;
                    }
                    throw new PatternException("XML and tuple fields mismatch");
                }
                result += paramResult * beta;
                continue;
            }
            pCells = (ArrayList)this.paramMatrix.get(param);
            pCell = (PCell)pCells.get(0);
            result += Double.parseDouble(pCell.getBeta());
        }
        String linkResult = this.linkFunction.calc(result);
        LOG.debug("result: " + linkResult);
        return linkResult;
    }

    public String toString() {
        StringBuilder buf = new StringBuilder();
        buf.append("GLM");
        return buf.toString();
    }
}

