package com.whimsy.map.base;


import java.io.InputStream;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Scanner;
import java.util.Set;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.whimsy.map.algo.AStar;

import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.Stopwatch;

import static com.whimsy.map.base.Utils.*;

/**
 * Created by whimsy on 5/27/15.
 */
public class Graph {

    Logger LOG = LoggerFactory.getLogger(Graph.class);


    // raw date
    boolean isNodeLoaded;
    public List<Node> nodes;

    boolean isEdgeLoaded;
    public List<Edge> edges;



    // grid index

    boolean isGridIndexBuilded = false;




    Area area;
    Set<Edge>[] grid;


    AStar shortestPathAlgo;

    public Graph() {
    }

    public void loadNode(InputStream nodeFile) {

        LOG.info("Start Loading Node File");

        Stopwatch stopwatch = new Stopwatch();

        Scanner in = new Scanner(nodeFile);

        nodes = Lists.newArrayList();

        while (in.hasNext()) {
            int id = in.nextInt();
            double lat = in.nextDouble();
            double lon = in.nextDouble();

            nodes.add(new Node(lat, lon));
        }

        in.close();

        isNodeLoaded = true;
        LOG.info("Loaded Node File {} sec", stopwatch.elapsedTime());

    }

    public void loadEdge(InputStream edgeFile) {

        LOG.info("Start Loading Edge file");
        Stopwatch stopwatch = new Stopwatch();

        Scanner in = new Scanner(edgeFile);


        edges = Lists.newArrayList();
        while (in.hasNext()) {

            Edge edge = new Edge();

            int edgeId = in.nextInt();

            int sId = in.nextInt();
            int eId = in.nextInt();

            edge.id = edgeId;
            edge.sId = sId;
            edge.eId = eId;

            int num = in.nextInt();
            for (int i = 0; i < num; ++i) {
                double lat = in.nextDouble();
                double lon = in.nextDouble();

                Edge.Figure figure = new Edge.Figure();
                figure.lat = lat;
                figure.lon = lon;

                edge.figures.add(figure);

            }

            edges.add(edge);
        }

        isEdgeLoaded = true;
        LOG.info("Loaded Edge file {} sec", stopwatch.elapsedTime());
    }

    public void buildShortestPathAlgorithm() {
        LOG.debug("Start Build Shortest Path Algorithm");

        Stopwatch stopwatch = new Stopwatch();
        shortestPathAlgo = new AStar(this.nodes, this.edges);
        LOG.debug("Build Shortest Path, Used Time {} sec", stopwatch.elapsedTime());

    }

    void buildArea(int partition) {

        double minlat = Long.MAX_VALUE;
        double minlon = Long.MAX_VALUE;
        double maxlat = Long.MIN_VALUE;
        double maxlon = Long.MIN_VALUE;

        for (Edge edge : edges) {
            for (Edge.Figure figure : edge.figures) {
                minlat = Math.min(figure.lat, minlat);
                maxlat = Math.max(figure.lat, maxlat);
                minlon = Math.min(figure.lon, minlon);
                maxlon = Math.max(figure.lon, maxlon);
            }
        }

        area = new Area();
        area.minlat = minlat;
        area.maxlat = maxlat;
        area.minlon = minlon;
        area.maxlon = maxlon;

        area.partition = partition;

        area.refinement();

        LOG.trace("Bound Calced : minlat = {}, minlon = {}, maxlat = {}, maxlon = {}", minlat, minlon, maxlat, maxlon);
    }


    public int gridId(Edge.Figure figure, int size) {
        return 0;
    }

    // size = 10
    @SuppressWarnings("unchecked")
    public void buildGridIndex(int partition) {

        Stopwatch stopwatch = new Stopwatch();

        LOG.info("Start buildGridIndex partition = {}", partition);

        buildArea(partition);

        // flatten grid

        grid = new HashSet[partition * partition];

        for (int i = 0; i < partition * partition; ++i) {
            grid[i] = new HashSet<Edge>();
        }

        for (Edge edge : edges) {


            int x = 2;

            for (int i = 0; i < edge.figures.size(); ++i) {
                Edge.Figure fig1 = edge.figures.get(i);
                int partId = area.getPartId(fig1.lat, fig1.lon);
                grid[partId].add(edge);

            }

            for (int i = 0; i < edge.figures.size() - 1; ++i) {
                Edge.Figure fig1 = edge.figures.get(i);
                Edge.Figure fig2 = edge.figures.get(i + 1);


                int xBound1 = area.getPartId(fig1.lat, fig1.lon) / area.partition;
                int xBound2 = area.getPartId(fig2.lat, fig2.lon) / area.partition;
                int yBound1 = area.getPartId(fig1.lat, fig1.lon) % area.partition;
                int yBound2 = area.getPartId(fig2.lat, fig2.lon) % area.partition;

                if (xBound2 < xBound1) {
                    int t = xBound2;
                    xBound2 = xBound1;
                    xBound1 = t;
                }

                if (yBound1 > yBound2) {
                    int t = yBound1;
                    yBound1 = yBound2;
                    yBound2 = t;
                }


                if (Math.abs(fig2.lon - fig1.lon) > Math.abs(fig2.lat - fig1.lat)) {
                    int l = area.getPartId(fig1.lat, fig1.lon) % area.partition;
                    int r = area.getPartId(fig2.lat, fig2.lon) % area.partition;

                    if (r < l) {
                        int t = l;
                        l = r;
                        r = t;
                    }

                    for (int cur = l; cur <= r; ++cur) {
                        double y1 = cur * area.lonStep + area.minlon;

                        Point2D lPoint = intersectPoint(new Point2D(fig1.lat, fig1.lon),
                                                           new Point2D(fig2.lat, fig2.lon),
                                                           new Point2D(area.minlat - 100, y1),
                                                           new Point2D(area.maxlat + 100, y1));

                        double y2 = (cur + 1) * area.lonStep + area.minlon;
                        Point2D rPoint = intersectPoint(new Point2D(fig1.lat, fig1.lon),
                                                           new Point2D(fig2.lat, fig2.lon),
                                                           new Point2D(area.minlat - 100, y2),
                                                           new Point2D(area.maxlat + 100, y2));

                        int gridX1 = area.getPartId(lPoint.x(), lPoint.y()) / area.partition;
                        int gridX2 = area.getPartId(rPoint.x(), rPoint.y()) / area.partition;

                        // the calc grid Id can beyond scope, we use loose bound to ensure correctness.
                        gridX1 = checkAndRefine(xBound1, xBound2, gridX1);
                        gridX2 = checkAndRefine(xBound1, xBound2, gridX2);

                        if (gridX1 > gridX2) {
                            int t = gridX1;
                            gridX1 = gridX2;
                            gridX2 = t;
                        }

                        for (int k = gridX1; k <= gridX2; ++k) {
                            grid[k * area.partition + cur].add(edge);
                        }
                    }


                } else {

                    int l = area.getPartId(fig1.lat, fig1.lon) / area.partition;
                    int r = area.getPartId(fig2.lat, fig2.lon) / area.partition;


                    if (r < l) {
                        int t = l;
                        l = r;
                        r = t;
                    }

                    for (int cur = l; cur <= r; ++cur) {
                        double x1 = cur * area.latStep + area.minlat;

                        Point2D lPoint = intersectPoint(new Point2D(fig1.lat, fig1.lon),
                                                           new Point2D(fig2.lat, fig2.lon),
                                                           new Point2D(x1, area.minlon - 100),
                                                           new Point2D(x1, area.minlon + 100));

                        double x2 = (cur + 1) * area.latStep + area.minlat;
                        Point2D rPoint = intersectPoint(new Point2D(fig1.lat, fig1.lon),
                                                           new Point2D(fig2.lat, fig2.lon),
                                                           new Point2D(x2, area.minlat - 100),
                                                           new Point2D(x2, area.maxlat + 100));

                        int gridY1 = area.getPartId(lPoint.x(), lPoint.y()) % area.partition;
                        int gridY2 = area.getPartId(rPoint.x(), rPoint.y()) % area.partition;

                        // the calc grid Id can beyond scope, we use loose bound to ensure correctness.
                        gridY1 = checkAndRefine(yBound1, yBound2, gridY1);
                        gridY2 = checkAndRefine(yBound1, yBound2, gridY2);


                        if (gridY2 < gridY1) {
                            int t = gridY1;
                            gridY2 = gridY1;
                            gridY1 = t;
                        }

                        for (int k = gridY1; k <= gridY2; ++k) {
                            grid[cur * area.partition + k].add(edge);
                        }
                    }
                }
            }

        }


        isGridIndexBuilded = true;

        LOG.info("Grid Index Builded in {} sec", stopwatch.elapsedTime());
    }

    private int checkAndRefine(int b1, int b2, int v) {
        if (v < b1) return b1;
        if (v > b2) return b2;
        return v;
    }





    // shortest path;

    // should memorize
    public Map<Pair<Integer, Integer>, Double> shorestPairCache = Maps.newHashMap();

    public double shortestPathLength(int sId, int tId) {

        Double dist = shorestPairCache.get(new Pair<Integer, Integer>(sId, tId));

        if (dist != null) {
            return dist;
        } else {
            dist = shortestPathAlgo.query(sId, tId);

            shorestPairCache.put(new Pair<Integer, Integer>(sId, tId), dist);
        }


        LOG.info("Query shortestPath bewteen {} to {}, length = {}", sId, tId, dist);

        return dist;
    }

    private static double sqr(double x) {
        return x * x;
    }



    public List<Edge> getNearEdges(final double lat, final double lon, int numOfCandidateEdges) {

        PriorityQueue<Edge> heap = new PriorityQueue<Edge>(numOfCandidateEdges * 2, new Comparator<Edge>() {
            @Override
            public int compare(Edge o1, Edge o2) {
                double dist1 = distM(lat, lon, o1);
                double dist2 = distM(lat, lon, o2);

                return (dist1 < dist2 - Constant.EPS) ? -1 :  ((dist1 - Constant.EPS > dist2) ? 1 : 0);
            }
        });

        Set<Edge> hasAdded = new HashSet<Edge>();


        int curGridId = area.getPartId(lat, lon);

        int gx = curGridId / area.partition;
        int gy = curGridId % area.partition;

        int step = 0;
        do {

            Set<Edge> edgeToAdd = new HashSet<Edge>();

            if (step >= area.partition) {
                break;
            }

            for (int i = -step; i <= step; ++i) {
                if (area.isLegal(i + gx) && area.isLegal(gy - step)) {
                    edgeToAdd.addAll(grid[area.index(i + gx,gy - step)]);
                }
                if (area.isLegal(i + gx) && area.isLegal(gy + step)) {
                    edgeToAdd.addAll(grid[area.index(i + gx, gy + step)]);
                }
            }

            for (int i = -step + 1; i <= step - 1; ++i) {
                if (area.isLegal(gx - step) && area.isLegal(i + gy)) {
                    edgeToAdd.addAll(grid[area.index(gx - step, i + gy)]);
                }
                if (area.isLegal(gx + step) && area.isLegal(i + gy)) {
                    edgeToAdd.addAll(grid[area.index(gx + step, i + gy)]);
                }
            }

            for (Edge e : edgeToAdd) {
                if (!hasAdded.contains(e)) {
                    heap.add(e);
                    hasAdded.add(e);
                }
            }

            ++step;

            // step > 1 , prevent naive wrong case.  see more explanation from thesis
        } while (step == 1 || (heap.size() < numOfCandidateEdges && step < area.partition));


        if (heap.size() < numOfCandidateEdges) {
            LOG.warn("candidate too few! expected = {}, got {}", numOfCandidateEdges, heap.size());
        }

        List<Edge> resEdge = Lists.newArrayList();
        int cnt = 0;
        while (!heap.isEmpty() && cnt < numOfCandidateEdges) {
            resEdge.add(heap.poll());
            ++cnt;
        }

        return resEdge;
    }



    static class Area {

        static Logger LOG = LoggerFactory.getLogger(Area.class);

        public double minlon;
        public double maxlon;
        public double minlat;
        public double maxlat;

        public int partition;

        public double latStep;
        public double lonStep;

        public void refinement() {
            minlat -= Constant.EPS;
            maxlat += Constant.EPS;
            minlon -= Constant.EPS;
            maxlon += Constant.EPS;

            latStep = (maxlat - minlat) / partition;
            lonStep = (maxlon - minlon) / partition;

            LOG.trace("Area has been refined");
        }

        // 左开右闭
        public int getPartId(double lat, double lon) {

            return (int) ((lat - minlat - Constant.EPS) / latStep) * partition +
                       (int) ((lon - minlon - Constant.EPS) / lonStep);

        }

        boolean isLegal(int x) {
            return x >= 0 && x < this.partition;
        }

        int index(int x, int y) {
            return x * partition + y;
        }


    }


    public void inspectGrid() {


        if (isGridIndexBuilded) {
            for (int i = 0; i < grid.length; ++i) {
                LOG.debug(i + " " + grid[i].toString());
            }
        } else {
            LOG.error("Gird Index Haven't build");
        }
    }



}
