#include "manager.h"
#include "graph_executor.h"
#include <iostream>
#include <vector>
#include <string>
#include <stdexcept>
#include <iomanip> // For std::setw
#include <chrono>  // For high-resolution timing
#include <cmath>   // For pow

long long get_time_ms(const std::chrono::high_resolution_clock::time_point& start,
                      const std::chrono::high_resolution_clock::time_point& end) {
    return std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
}
long long get_time_us(const std::chrono::high_resolution_clock::time_point& start,
                      const std::chrono::high_resolution_clock::time_point& end) {
    return std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
}

/**
 * @brief Creates a simple graph of N 'add' operations.
 */
std::vector<GraphInstruction> create_simple_graph(int num_ops) {
    std::vector<GraphInstruction> graph;
    graph.reserve(num_ops);
    for (int i = 0; i < num_ops; ++i) {
        graph.push_back(GraphInstruction{
            "add", {"A", "B"}, {TensorSpec{"A", {1}}}, {}
        });
    }
    return graph;
}


int main() {
    try {
        std::cout << "--- CudaManager & GraphExecutor Scaling Benchmark ---" << std::endl;

        // 1. Initialize
        CudaManager manager;
        GraphExecutor executor(manager);

        // 2. Allocate persistent tensors for benchmarks
        manager.allocate("A", {1});
        manager.allocate("B", {1});

        // 3. Warm-up: Run a small operation to initialize CUDA context
        std::cout << "1. Performing CUDA warm-up..." << std::endl;
        manager.copyHostToDevice("A", (new float[1]{1.0f}));
        manager.copyHostToDevice("B", (new float[1]{2.0f}));
        executor.load_graph(create_simple_graph(1));
        executor.run(1);
        std::cout << "   Warm-up complete." << std::endl;


        // ---
        // Test 1: Scaling Model Size (Graph Depth)
        // ---
        std::cout << "\n--- Test 1: Scaling Model Size (Fixed Data Size) ---" << std::endl;
        std::cout << "Running 1000 samples against an increasingly deep graph..." << std::endl;
        std::cout << std::setw(12) << "Graph Ops" << " | "
                  << std::setw(15) << "Total Time (ms)" << " | "
                  << std::setw(20) << "Avg Time per Op (us)" << std::endl;
        std::cout << "------------------------------------------------------------------" << std::endl;

        int num_samples_fixed = 1000;
        int num_ops = 10;
        while (num_ops <= 12800) { // Test up to a very deep graph
            // 1. Create the graph
            std::vector<GraphInstruction> graph = create_simple_graph(num_ops);
            executor.load_graph(graph);

            // 2. Time the run
            auto start = std::chrono::high_resolution_clock::now();
            executor.run(num_samples_fixed);
            auto stop = std::chrono::high_resolution_clock::now();

            // 3. Report
            long long total_us = get_time_us(start, stop);
            long long total_ops_executed = (long long)num_samples_fixed * num_ops;
            double time_per_op_us = static_cast<double>(total_us) / total_ops_executed;

            std::cout << std::setw(12) << num_ops << " | "
                      << std::setw(15) << (total_us / 1000.0) << " | "
                      << std::setw(20) << std::fixed << std::setprecision(2) << time_per_op_us << std::endl;

            num_ops *= 2;
        }


        // ---
        // Test 2: Scaling Data Size (Batch Iterations)
        // ---
        std::cout << "\n--- Test 2: Scaling Data Size (Fixed Model Size) ---" << std::endl;
        std::cout << "Running an increasingly large dataset against a fixed-size graph (100 ops)..." << std::endl;
        std::cout << std::setw(12) << "Num Samples" << " | "
                  << std::setw(15) << "Total Time (ms)" << " | "
                  << std::setw(22) << "Avg Time per Sample (us)" << std::endl;
        std::cout << "----------------------------------------------------------------------" << std::endl;

        int num_ops_fixed = 100;
        std::vector<GraphInstruction> fixed_graph = create_simple_graph(num_ops_fixed);
        executor.load_graph(fixed_graph);

        int num_samples = 1000;
        while (num_samples <= 2048000) { // Test up to ~2 million samples
            // 1. Time the run
            auto start = std::chrono::high_resolution_clock::now();
            executor.run(num_samples);
            auto stop = std::chrono::high_resolution_clock::now();

            // 2. Report
            long long total_us = get_time_us(start, stop);
            double time_per_sample_us = static_cast<double>(total_us) / num_samples;

            std::cout << std::setw(12) << num_samples << " | "
                      << std::setw(15) << (total_us / 1000.0) << " | "
                      << std::setw(22) << std::fixed << std::setprecision(2) << time_per_sample_us << std::endl;

            num_samples *= 2;
        }

        std::cout << "\n--- Benchmark Complete ---" << std::endl;

        // Cleanup
        manager.free("A");
        manager.free("B");

    } catch (const std::exception& e) {
        std::cerr << "An error occurred: " << e.what() << std::endl;
        return 1;
    }
    return 0;
}