#!/usr/bin/env python3
import json
import argparse
import importlib
from types import NoneType
from typing import List, Dict, Union, Tuple, Optional, Any

import rclpy
import setproctitle
from rclpy.executors import MultiThreadedExecutor
from rclpy.utilities import try_shutdown

from agents import config as all_configs
from agents import components as all_components
from agents import clients
from agents.clients.model_base import ModelClient
from agents.ros import (
    Topic,
    FixedInput,
    MapLayer,
    Route,
    QoSConfig,
    SupportedType,
    Event,
)


def _parse_args() -> Tuple[argparse.Namespace, List[str]]:
    """Parse arguments."""
    parser = argparse.ArgumentParser(description="Component Executable Config")
    parser.add_argument(
        "--config_type", type=str, help="Component configuration class name"
    )
    parser.add_argument("--component_type", type=str, help="Component class name")
    parser.add_argument(
        "--node_name",
        type=str,
        help="Component ROS2 node name",
    )
    parser.add_argument("--config", type=str, help="Component configuration object")
    parser.add_argument(
        "--inputs",
        type=str,
        help="Component input topics",
    )
    parser.add_argument(
        "--outputs",
        type=str,
        help="Component output topics",
    )
    parser.add_argument(
        "--routes",
        type=str,
        help="Semantic router routes",
    )
    parser.add_argument(
        "--layers",
        type=str,
        help="Map Encoding layers",
    )
    parser.add_argument(
        "--trigger",
        type=str,
        help="Component trigger",
    )
    parser.add_argument(
        "--model_client",
        type=str,
        help="Model Client",
    )
    parser.add_argument(
        "--db_client",
        type=str,
        help="DB Client",
    )
    parser.add_argument(
        "--additional_model_clients",
        type=str,
        help="Additional model clients",
    )
    parser.add_argument(
        "--config_file", type=str, help="Path to configuration YAML file"
    )
    parser.add_argument(
        "--events", type=str, help="Events to be monitored by the component"
    )
    parser.add_argument(
        "--actions", type=str, help="Actions associated with the component Events"
    )
    parser.add_argument(
        "--fallbacks", type=str, help="Fallbacks to be executed on component Failure"
    )
    parser.add_argument(
        "--external_processors",
        type=str,
        help="External processors associated with the component input and output topics",
    )
    parser.add_argument(
        "--additional_types",
        type=str,
        help="Additional type modules from derived packages",
    )

    return parser.parse_known_args()


def _parse_component_config(
    args: argparse.Namespace,
) -> all_configs.BaseComponentConfig:
    """Parse the component config object

    :param args: Command line arguments
    :type args: argparse.Namespace

    :return: Component config object
    :rtype: object
    """
    config_type = args.config_type or None
    if not config_type:
        raise ValueError("config_type must be provided")

    # Get config type and update from json arg
    config_class = getattr(all_configs, config_type)
    if not config_class:
        raise TypeError(
            f"Unknown config_type '{config_type}'. Known types are {all_configs.__all__}"
        )

    config = config_class(**json.loads(args.config))

    return config


def _parse_additional_types(value: str):
    """Get additional types"""
    serialized_types = json.loads(value)
    _additional_types = []
    for s_t in serialized_types:
        module_name, _, class_name = s_t.rpartition(".")
        if not module_name:
            continue
        module = importlib.import_module(module_name)
        new_type = getattr(module, class_name)
        if issubclass(new_type, SupportedType):
            _additional_types.append(new_type)
    return _additional_types


def _parse_trigger(
    trigger_str: str,
) -> Union[Topic, List[Topic], float, Event, NoneType]:
    """Parse component trigger json string

    :param trigger_str: Trigger JSON string
    :type trigger_str: str

    :return: Trigger topics or float
    :rtype: Topic | List[Topic] | float
    """
    # TODO: Handle additional types here

    # Deserialize main dict or float value
    trigger_deserialized = json.loads(trigger_str)
    if isinstance(trigger_deserialized, Dict):
        # Deserialize internal trigger content
        if trigger_deserialized["trigger_type"] == "List":
            # List always contains topics
            return [
                Topic(**json.loads(t))
                for t in json.loads(trigger_deserialized["trigger"])
            ]
        elif trigger_deserialized["trigger_type"] == "Topic":
            return Topic(**json.loads(trigger_deserialized["trigger"]))
        elif trigger_deserialized["trigger_type"] == "Event":
            return Event.from_json(trigger_deserialized["trigger"])
    else:
        # return float or None
        return trigger_deserialized


def _deserialize_topics(
    serialized_topics: str, additional_types: Optional[List] = None
) -> List[Dict]:
    list_of_str = json.loads(serialized_topics)
    topic_dicts = []
    for t in list_of_str:
        topic_dict = json.loads(t)
        topic_dict["qos_profile"] = QoSConfig(**topic_dict.get("qos_profile", {}))
        topic_dict["additional_types"] = (
            additional_types if additional_types else []
        )  # Add any additional types
        topic_dicts.append(topic_dict)
    return topic_dicts


def _load_primary_clients(
    args: argparse.Namespace,
) -> Tuple[Optional[Any], Optional[Any]]:
    """Instantiates Model and DB clients."""
    model_client = None
    db_client = None

    if args.model_client:
        mc_json = json.loads(args.model_client)
        model_client = getattr(clients, mc_json["client_type"])(**mc_json)

    if args.db_client:
        dbc_json = json.loads(args.db_client)
        db_client = getattr(clients, dbc_json["client_type"])(**dbc_json)

    return model_client, db_client


def _load_additional_model_clients(
    additional_clients_json: str,
) -> Dict[str, ModelClient]:
    """Initialize additional model clients"""
    _additional_clients = json.loads(additional_clients_json)
    for k, v in _additional_clients.items():
        _additional_clients[k] = getattr(clients, v["client_type"])(**v)
    return _additional_clients


def _parse_ros_args(args_names: List[str]) -> List[str]:
    """Parse ROS arguments from command line arguments

    :param args_names: List of all parsed arguments
    :type args_names: list[str]

    :return: List ROS parsed arguments
    :rtype: list[str]
    """
    # Look for --ros-args in ros_args
    ros_args_start = None
    if "--ros-args" in args_names:
        ros_args_start = args_names.index("--ros-args")

    if ros_args_start is not None:
        ros_specific_args = args_names[ros_args_start:]
    else:
        ros_specific_args = []
    return ros_specific_args


def _setup_component_post_init(component: Any, args: argparse.Namespace) -> None:
    """Perform post-initialization setup on the component instance.

    :param component: The instantiated component object
    :param args: Parsed command line arguments
    """
    # Init the node with rclpy
    component.rclpy_init_node()

    # Set events/actions
    if args.events and args.actions:
        component._events_json = args.events
        component._actions_json = args.actions

    # Set fallbacks
    if fallbacks_json := args.fallbacks:
        component._fallbacks_json = fallbacks_json

    # Set external processors
    if args.external_processors:
        component._external_processors_json = args.external_processors

    # Set additional model clients if any
    component.additional_model_clients = (
        _load_additional_model_clients(args.additional_model_clients)
        if args.additional_model_clients
        else None
    )


def main():
    """Executable main function to run a component as a ROS2 node in a new process.
    Used to start a node using Sugarcoat Launcher. Extends functionality from ROS Sugar

    :param list_of_components: List of all known Component classes in the package
    :type list_of_components: List[Type]
    :param list_of_configs: List of all known ComponentConfig classes in the package
    :type list_of_configs: List[Type]
    :raises ValueError: If component or component config are unknown classes
    :raises ValueError: If component cannot be started with provided arguments
    """
    args, args_names = _parse_args()

    # Initialize rclpy with the ros-specific arguments
    rclpy.init(args=_parse_ros_args(args_names))

    component_type = args.component_type or None

    if not component_type:
        raise ValueError("Cannot launch without providing a component_type")

    comp_class = getattr(all_components, component_type)

    if not comp_class:
        raise ValueError(
            f"Cannot launch unknown component type '{component_type}'. Known types are: '{all_components.__all__}'"
        )

    # Get name
    component_name = args.node_name or None

    if not component_name:
        raise ValueError("Cannot launch component without specifying a name")

    # SET PROCESS NAME
    setproctitle.setproctitle(component_name)

    config = _parse_component_config(args)

    # Get Yaml config file if provided
    config_file = args.config_file or None

    additional_types = (
        _parse_additional_types(args.additional_types)
        if args.additional_types
        else None
    )
    # Get inputs/outputs/layers/routes
    inputs = (
        [
            FixedInput(**i) if i.get("fixed") else Topic(**i)
            for i in _deserialize_topics(args.inputs, additional_types)
        ]
        if args.inputs
        else None
    )
    outputs = (
        [Topic(**o) for o in _deserialize_topics(args.outputs, additional_types)]
        if args.outputs
        else None
    )

    # TODO: Handle additional types and qos in deserialization
    layers = (
        [MapLayer(**json.loads(i)) for i in json.loads(args.layers)]
        if args.layers
        else None
    )

    # TODO: Handle additional types and qos in deserialization
    routes = (
        [Route(**json.loads(r)) for r in json.loads(args.routes)]
        if args.routes
        else None
    )

    # Get triggers
    trigger = _parse_trigger(args.trigger)

    # Initialize clients
    model_client, db_client = _load_primary_clients(args)

    # Init the component
    # Semantic Router Component
    if component_type == all_components.SemanticRouter.__name__:
        if not model_client and not db_client:
            raise RuntimeError(
                "The router component expects at least one client (model or vectorDB)"
            )
        component = comp_class(
            inputs=inputs,
            routes=routes,
            db_client=db_client,
            model_client=model_client,
            config=config,
            component_name=component_name,
            config_file=config_file,
        )  # we dont pass default route here as its already part of the config
    # and will be set from there
    # Map Encoding Component
    elif component_type == all_components.MapEncoding.__name__:
        db_client_json = json.loads(args.db_client)
        db_client = getattr(clients, db_client_json["client_type"])(**db_client_json)
        if not db_client:
            raise RuntimeError("The map encoding component expects a vectorDB client")
        component = comp_class(
            layers=layers,
            position=config._position,
            map_topic=config._map_topic,
            db_client=db_client,
            config=config,
            trigger=trigger,
            component_name=component_name,
            config_file=config_file,
        )

    # All other components
    else:
        component = comp_class(
            inputs=inputs,
            outputs=outputs,
            model_client=model_client,
            db_client=db_client,
            trigger=trigger,
            config=config,
            component_name=component_name,
            config_file=config_file,
        )

    # Run Post-Init Setup
    _setup_component_post_init(component, args)

    executor = MultiThreadedExecutor()

    executor.add_node(component)

    try:
        executor.spin()

    except KeyboardInterrupt:
        pass

    finally:
        executor.remove_node(component)
        try_shutdown()


if __name__ == "__main__":
    main()
