#!/usr/bin/env python3

import argparse
from collections import namedtuple
import ctypes
from io import BytesIO
import json
from math import floor
import logging
from pathlib import Path
import struct
import sys
import tarfile
from tarfile import BLOCKSIZE
from time import time
from typing import List, Tuple, Union, Set

# "<" prefix means little-endian and no alignment
# order is important! if uint64_t is not first, c++ will use padding bytes to unpack
INDEX_BIN_FORMAT = '<QLL'
INDEX_BIN_SIZE = struct.calcsize(INDEX_BIN_FORMAT)
INDEX_FILE = "index.bin"
# skip the first 40 bytes of the tile header
GRAPHTILE_SKIP_BYTES = struct.calcsize('<Q2f16cQ')
TRAFFIC_HEADER_SIZE = struct.calcsize('<2Q4I')
TRAFFIC_SPEED_SIZE = struct.calcsize('<Q')

Bbox = namedtuple("Bbox", "min_x min_y max_x max_y")
TILE_SIZES = {0: 4, 1: 1, 2: 0.25}

# hack so ArgumentParser can accept negative numbers
# see https://github.com/valhalla/valhalla/issues/3426
for i, arg in enumerate(sys.argv):
    if not len(arg) > 1:
        continue
    if (arg[0] == '-') and arg[1].isdigit():
        sys.argv[i] = ' ' + arg


class TileHeader(ctypes.Structure):
    """
    Resembles the uint64_t bit field at bytes 40 - 48 of the
    graphtileheader to get the directededgecount_.
    """

    _fields_ = [
        ("nodecount_", ctypes.c_ulonglong, 21),
        ("directededgecount_", ctypes.c_ulonglong, 21),
        ("predictedspeeds_count_", ctypes.c_ulonglong, 21),
        ("spare1_", ctypes.c_ulonglong, 1),
    ]


description = "Builds a tar extract from the tiles in mjolnir.tile_dir to the path specified in mjolnir.tile_extract."

parser = argparse.ArgumentParser(description=description)
parser.add_argument(
    "-c", "--config", help="Absolute or relative path to the Valhalla config JSON.", type=Path
)
parser.add_argument(
    "-i",
    "--inline-config",
    help="Inline JSON config, will override --config JSON if present",
    type=str,
    default='{}',
)
parser.add_argument(
    "-t", "--with-traffic", help="Flag to add a traffic.tar skeleton", action="store_true", default=False
)
geom_type = parser.add_mutually_exclusive_group()
geom_type.add_argument(
    "-b",
    "--bbox",
    help="If specified, will only archive tiles which are intersecting this bbox. In 'minx,miny,maxx,maxy' format.",
    type=str,
    default="",
)
geom_type.add_argument(
    "-g",
    "--geojson-dir",
    help="Absolute or relative path to directory with .geojson files used "
    "as input to tile intersection. Requires shapely.",
    type=Path,
)
parser.add_argument(
    "-v",
    "--verbosity",
    help="Accumulative verbosity flags; -v: INFO, -vv: DEBUG",
    action='count',
    default=0,
)

# set up the logger basics
LOGGER = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)5s: %(message)s"))
LOGGER.addHandler(handler)


def get_tile_bbox(tile_path_id: str) -> Tuple[float, float, float, float]:
    """Returns a tile's bounding box in minx,miny,maxx,maxy format"""
    try:
        level, tile_idx = [int(x.replace("/", "")) for x in get_tile_level_id(tile_path_id)]
    except ValueError:
        LOGGER.critical(f"Couldn't get level and tile ID for tile {tile_path_id}")
        sys.exit(1)

    tile_size = TILE_SIZES[level]
    row = floor(tile_idx / (360 / tile_size))
    col = tile_idx % (360 / tile_size)

    tile_base_y = (row * tile_size) - 90
    tile_base_x = (col * tile_size) - 180

    return tile_base_x, tile_base_y, tile_base_x + tile_size, tile_base_y + tile_size


def get_tiles_with_geojson(all_tile_paths: List[Path], geojson_dir: Path, tiles_dir_: Path) -> Set[Path]:
    """Returns all tile paths intersecting with the GeoJSON (multi)polygons"""
    try:
        from shapely.geometry import Polygon, box
    except ImportError:
        LOGGER.critical(
            "Could not import shapely. Please install shapely or use another download method instead."
        )
        sys.exit(1)

    if not geojson_dir.is_dir() or not len(list(geojson_dir.glob('*.geojson'))) > 0:
        LOGGER.critical(
            f"Geojson directory does not exist or contains no GeoJSON files: {geojson_dir.resolve()}"
        )
        sys.exit(1)

    def get_outer_rings(input_dir: Path) -> List[Polygon]:
        polygons = list()
        for file in input_dir.glob("*.geojson"):
            with open(file) as gj_f:
                geojson = json.load(gj_f)
            for feature in geojson["features"]:
                if feature["geometry"]["type"] == "Polygon":
                    polygons.append(Polygon(feature["geometry"]["coordinates"][0]))
                if feature["geometry"]["type"] == "MultiPolygon":
                    for single_polygon in feature["geometry"]["coordinates"]:
                        polygons.append(Polygon(single_polygon[0]))
        return polygons

    tile_paths_ = set()
    geojson_polys = get_outer_rings(geojson_dir)
    for tile_path in all_tile_paths:
        tile_path_id = str(tile_path.relative_to(tiles_dir_))
        tile_bbox = box(*get_tile_bbox(tile_path_id))
        for poly in geojson_polys:
            if poly.intersects(tile_bbox):
                tile_paths_.add(tile_path)

    return tile_paths_


def get_tiles_with_bbox(all_tile_paths: List[Path], bbox_str: str, tiles_dir_: Path) -> Set[Path]:
    """Returns all tile paths intersecting with the bbox"""
    try:
        bbox = Bbox(*[float(x) for x in bbox_str.split(",")])
    except ValueError:
        LOGGER.critical(f"BBOX {bbox_str} is not a comma-separated string of coordinates.")
        sys.exit(1)

    # validate bbox
    if (bbox.min_x >= bbox.max_x or bbox.min_x < -180 or bbox.max_x > 180) or (
        bbox.min_y >= bbox.max_y or bbox.min_y < -90 or bbox.max_y > 90
    ):
        LOGGER.critical(f"Bbox invalid: {list(bbox)}")
        sys.exit(1)

    tile_paths_ = set()
    for tile_path in all_tile_paths:
        tile_path_id = str(tile_path.relative_to(tiles_dir_))
        tile_bbox = Bbox(*get_tile_bbox(tile_path_id))
        # check if tile_bbox is outside of bbox
        if not any(
            [
                tile_bbox.min_x < bbox.min_x and tile_bbox.max_x < bbox.min_x,  # left of bbox
                tile_bbox.min_y < bbox.min_y and tile_bbox.max_y < bbox.min_y,  # below bbox
                tile_bbox.min_x > bbox.max_x and tile_bbox.max_x > bbox.max_x,  # right of bbox
                tile_bbox.min_y > bbox.max_y and tile_bbox.max_y > bbox.max_y,  # above bbox
            ]
        ):
            tile_paths_.add(tile_path)

    return tile_paths_


def get_tile_level_id(path: str) -> List[str]:
    """Returns both level and tile ID"""
    return path[:-4].split('/', 1)


def get_tile_id(path: str) -> int:
    """Turns a tile path into a numeric GraphId, including the level"""
    level, idx = get_tile_level_id(path)

    return int(level) | (int(idx.replace('/', '')) << 3)


def get_tar_info(name: str, size: int) -> tarfile.TarInfo:
    """Creates and returns a tarinfo object"""
    tarinfo = tarfile.TarInfo(name)
    tarinfo.size = size
    tarinfo.mtime = int(time())
    tarinfo.type = tarfile.REGTYPE

    return tarinfo


def write_index_to_tar(tar_fp_: Path):
    """Loop through all tiles and write the correct index.bin file to the tar"""
    # get the offset and size from the tarred tile members
    index: List[Tuple[int, int, int]] = list()
    with tarfile.open(tar_fp_, 'r|') as tar:
        for member in tar.getmembers():
            if member.name.endswith('.gph'):
                LOGGER.debug(
                    f"Tile {member.name} with offset: {member.offset_data}, size: {member.size}"
                )

                index.append((member.offset_data, get_tile_id(member.name), member.size))

    # write back the actual index info
    with open(tar_fp_, 'r+b') as tar:
        # jump to the data block, index.bin is the first file
        tar.seek(BLOCKSIZE)
        for entry in index:
            tar.write(struct.pack(INDEX_BIN_FORMAT, *entry))


def create_extracts(config_: dict, do_traffic: bool, tile_paths_: Union[Set[Path], List[Path]]):
    """Actually creates the tar ball. Break out of main function for testability."""
    tiles_fp: Path = Path(config_["mjolnir"].get("tile_dir", '/dev/null'))
    extract_fp: Path = Path(
        config_["mjolnir"].get("tile_extract") or tiles_fp.parent.joinpath('tiles.tar')
    )
    traffic_fp: Path = Path(
        config_["mjolnir"].get("traffic_extract") or tiles_fp.parent.joinpath('traffic.tar')
    )

    tiles_count = len(tile_paths_)
    if not tiles_count:
        LOGGER.critical(f"Directory {tiles_fp.resolve()} does not contain any usable graph tiles.")
        sys.exit(1)

    # write the in-memory index file
    index_size = INDEX_BIN_SIZE * tiles_count
    index_fd = BytesIO(b'0' * index_size)
    index_fd.seek(0)

    # first add the index file, then the sorted tiles to the tarfile
    # TODO: come up with a smarter strategy to cluster the tiles in the tar
    with tarfile.open(extract_fp, 'w') as tar:
        tar.addfile(get_tar_info(INDEX_FILE, index_size), index_fd)
        for t in tile_paths_:
            rel_path = str(t.relative_to(tiles_fp))
            LOGGER.debug(f"Adding tile {rel_path} to the path")
            tar.add(str(t.resolve()), arcname=rel_path)

    write_index_to_tar(extract_fp)

    LOGGER.info(f"Finished tarring {tiles_count} tiles to {extract_fp}")

    # exit if no traffic extract wanted
    if not do_traffic:
        index_fd.close()
        sys.exit(0)

    LOGGER.info("Start creating traffic extract...")

    # we already have the right size of the index file, simply reset it
    index_fd.seek(0)
    with tarfile.open(extract_fp) as tar_in, tarfile.open(traffic_fp, 'w') as tar_traffic:
        # this will let us do seeks
        in_fileobj = tar_in.fileobj

        # add the index file as first data
        tar_traffic.addfile(get_tar_info(INDEX_FILE, index_size), index_fd)
        index_fd.close()

        # loop over all routing tiles and create fixed-size traffic tiles
        # based on the directed edge count
        for tile_in in tar_in.getmembers():
            if not tile_in.name.endswith('.gph'):
                continue
            # jump to the data's offset and skip the uninteresting bytes
            in_fileobj.seek(tile_in.offset_data + GRAPHTILE_SKIP_BYTES)

            # read the appropriate size of bytes from the tar into the TileHeader struct
            tile_header = TileHeader()
            b = BytesIO(in_fileobj.read(ctypes.sizeof(TileHeader)))
            b.readinto(tile_header)
            b.close()

            # create the traffic tile
            traffic_size = TRAFFIC_HEADER_SIZE + TRAFFIC_SPEED_SIZE * tile_header.directededgecount_
            tar_traffic.addfile(get_tar_info(tile_in.name, traffic_size), BytesIO(b'\0' * traffic_size))

            LOGGER.debug(f"Tile {tile_in.name} has {tile_header.directededgecount_} directed edges")

    write_index_to_tar(traffic_fp)

    LOGGER.info(f"Finished creating the traffic extract at {traffic_fp}")


if __name__ == '__main__':
    args = parser.parse_args()

    if not args.config and not args.inline_config:
        LOGGER.critical("No valid config file or inline config used.")
        sys.exit(1)

    config = dict()
    try:
        with open(args.config) as f:
            config = json.load(f)
    except TypeError:
        LOGGER.warning("Only inline-config will be used.")

    # override with inline-config
    config.update(**json.loads(args.inline_config))

    # get and validate the tiles directory
    tiles_dir: Path = Path(config["mjolnir"].get("tile_dir", '/dev/null'))
    if not tiles_dir.is_dir():
        LOGGER.critical(
            f"Directory 'mjolnir.tile_dir': {tiles_dir.resolve()} was not found on the filesystem."
        )
        sys.exit(1)

    # get the tile paths which intersect with the geom, if any
    tile_paths: Union[Set[Path], List[Path]] = sorted(tiles_dir.rglob('*.gph'))
    if args.bbox:
        tile_paths = get_tiles_with_bbox(tile_paths, args.bbox, tiles_dir)
    elif args.geojson_dir:
        tile_paths = get_tiles_with_geojson(tile_paths, args.geojson_dir, tiles_dir)

    # set the right logger level
    if args.verbosity == 0:
        LOGGER.setLevel(logging.CRITICAL)
    elif args.verbosity == 1:
        LOGGER.setLevel(logging.INFO)
    elif args.verbosity >= 2:
        LOGGER.setLevel(logging.DEBUG)

    create_extracts(config, args.with_traffic, tile_paths)
