/**
 * @file instanta_layer2vi_example.c
 * @brief Provide layer2 sending and receiving cases.
 * @copyright Copyright (c) 2022 YUSUR Technology Co., Ltd. All Rights Reserved. Learn more at www.yusur.tech.
 * @author matianhao (matianhao@yusur.tech)
 * @date 2023-05-22 15:12:46
 * @last_author: matianhao (math@yusur.tech)
 * @last_edit_time: 2023-05-22 15:12:46
 */

#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <getopt.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <sys/time.h>
#include <netinet/ip.h>
#include <netinet/tcp.h>
#include <netinet/udp.h>
#include <errno.h>

#include "instanta_layer2vi.h"

#define USER_INPUT_MAX_LEN    64
#define SEND_BUF_LEN    1024
#define RECV_BUF_LEN    1514
#define SEND_BUF_MAX_LEN    4096
#define INFO_MAX 2048

#define BUFFER_NAME    "buffer1"
#define SEND_TIMES    100000 // 100,000
#define SEND_INTERVAL_USEC    100

#define IP_FILTER_SRC_ADDR    "50.50.50.50"
#define IP_FILTER_SRC_PORT    50000
#define IP_FILTER_DST_ADDR    "60.60.60.60"
#define IP_FILTER_DST_PORT    60000
#define IP_FILTER_PROTOCOL_TCP    6

#pragma pack(1)
typedef struct
{
    // MAC 包头
    uint8_t  dst_mac[6];
    uint8_t  src_mac[6];
    uint16_t ethertype;

    // IP 包头
    struct iphdr ip_hdr;

    // 传输层包头
    struct tcphdr tcp_hdr;
} common_packet_header_t;
#pragma pack()

struct pseudohdr
{
    u_int32_t saddr;
    u_int32_t daddr;
    u_int8_t  padding;
    u_int8_t  protocol;
    u_int16_t length;
};

#pragma pack(1)
struct pseudo_packet
{
    struct pseudohdr pseudo_hdr;
    struct tcphdr tcp_hdr;
    char info[INFO_MAX];
    uint32_t info_len;
};
#pragma pack()

layer2vi_filter_t g_mac_filter;
layer2vi_filter_t g_ip_filter;

static void init_filter()
{
    memset(&g_mac_filter, 0, sizeof(g_mac_filter));
    memset(&g_ip_filter, 0, sizeof(g_ip_filter));

    g_mac_filter.rule_type = LAYER2VI_FILTER_TYPE_MAC;
    memset(g_mac_filter.u.mac_filter.dst_mac, 0xff, 6);
    g_mac_filter.u.mac_filter.ethertype = htons(0x0800);
    g_mac_filter.u.mac_filter.vlan = 0;
    g_mac_filter.u.mac_filter.vlan_match_method = 0;

    g_ip_filter.rule_type = LAYER2VI_FILTER_TYPE_IP;
    g_ip_filter.u.ip_filter.dst_addr = inet_addr(IP_FILTER_DST_ADDR);
    g_ip_filter.u.ip_filter.dst_port = htons(IP_FILTER_DST_PORT);
    g_ip_filter.u.ip_filter.src_addr = inet_addr(IP_FILTER_SRC_ADDR);
    g_ip_filter.u.ip_filter.src_port = htons(IP_FILTER_SRC_PORT);
    g_ip_filter.u.ip_filter.protocol = IP_FILTER_PROTOCOL_TCP;
    g_ip_filter.u.ip_filter.attr.priority = ACCURATE_IP_FILTER;
}

unsigned short calculate_checksum(unsigned short *addr, int len)
{
    long sum = 0;

    while (len > 1)
    {
        sum += *(addr++);
        len -= 2;
    }

    if (len > 0)
    {
        sum += *addr;
    }

    while (sum >> 16)
    {
        sum = ((sum & 0xffff) + (sum >> 16));
    }

    return (u_short)(~sum);
}

static int build_send_buf(char *send_buf, int send_buf_len)
{
    if (NULL == send_buf)
    {
        printf("send_buf is NULL.\n");
        return -1;
    }

    char send_buf_temporary[SEND_BUF_MAX_LEN] = {0};
    common_packet_header_t send_pkt_header;

    memset(send_buf_temporary, 0x88, SEND_BUF_MAX_LEN);
    memset(&send_pkt_header, 0, sizeof(send_pkt_header));

    uint32_t info_len = send_buf_len - sizeof(common_packet_header_t);//过短时需要处理一下

    // MAC包头
    memset(send_pkt_header.src_mac, 0, 6);
    memcpy(send_pkt_header.dst_mac, g_mac_filter.u.mac_filter.dst_mac, 6);
    send_pkt_header.ethertype = g_mac_filter.u.mac_filter.ethertype;
    // IP包头
    send_pkt_header.ip_hdr.ihl = 5;
    send_pkt_header.ip_hdr.version = 4;
    send_pkt_header.ip_hdr.tos = 0;
    send_pkt_header.ip_hdr.tot_len = htons(sizeof(struct iphdr) + sizeof(struct udphdr) + info_len);
    send_pkt_header.ip_hdr.id = htons(1);
    send_pkt_header.ip_hdr.frag_off = htons(0x4000);
    send_pkt_header.ip_hdr.ttl = 64;
    send_pkt_header.ip_hdr.protocol = g_ip_filter.u.ip_filter.protocol;
    send_pkt_header.ip_hdr.check = 0;
    send_pkt_header.ip_hdr.saddr = g_ip_filter.u.ip_filter.src_addr;
    send_pkt_header.ip_hdr.daddr = g_ip_filter.u.ip_filter.dst_addr;
    // Set TCP header of send_pkt_header.
    send_pkt_header.tcp_hdr.source = g_ip_filter.u.ip_filter.src_port;
    send_pkt_header.tcp_hdr.dest = g_ip_filter.u.ip_filter.dst_port;
    send_pkt_header.tcp_hdr.seq = htonl(1);
    send_pkt_header.tcp_hdr.ack_seq = htonl(2);
    send_pkt_header.tcp_hdr.res1 = 0;
    send_pkt_header.tcp_hdr.doff = 5;
    send_pkt_header.tcp_hdr.fin = 0;
    send_pkt_header.tcp_hdr.syn = 0;
    send_pkt_header.tcp_hdr.rst = 0;
    send_pkt_header.tcp_hdr.psh = 1;
    send_pkt_header.tcp_hdr.ack = 1;
    send_pkt_header.tcp_hdr.urg = 0;
    send_pkt_header.tcp_hdr.res2 = 0;
    send_pkt_header.tcp_hdr.window = htons(10000);
    send_pkt_header.tcp_hdr.check = 0;
    send_pkt_header.tcp_hdr.urg_ptr = htons(0);
    // Create pseudo packet of ACK to calculate checksum.
    struct pseudo_packet pseudo_send_packet;
    memset(pseudo_send_packet.info, 0x88, info_len);
    pseudo_send_packet.info_len = info_len;
    // // Set pseudo header of ACK of send_pkt_header.
    pseudo_send_packet.pseudo_hdr.saddr = send_pkt_header.ip_hdr.saddr;
    pseudo_send_packet.pseudo_hdr.daddr = send_pkt_header.ip_hdr.daddr;
    pseudo_send_packet.pseudo_hdr.padding = 0;
    pseudo_send_packet.pseudo_hdr.protocol = send_pkt_header.ip_hdr.protocol;
    pseudo_send_packet.pseudo_hdr.length = htons(sizeof(struct tcphdr) + info_len);
    // Set TCP header of pseudo_send_packet.
    pseudo_send_packet.tcp_hdr = send_pkt_header.tcp_hdr;
    // Calculate the checksum.
    send_pkt_header.ip_hdr.check = calculate_checksum((unsigned short *)&send_pkt_header.ip_hdr, sizeof(struct iphdr));
    send_pkt_header.tcp_hdr.check = calculate_checksum(
        (unsigned short *)&pseudo_send_packet,
        sizeof(struct pseudohdr) + sizeof(struct tcphdr) + pseudo_send_packet.info_len);

    memcpy(send_buf_temporary, &send_pkt_header, sizeof(send_pkt_header));
    memcpy(send_buf, send_buf_temporary, send_buf_len);

    return 0;
}

static int send_test(char *netdev_name)
{
    if (NULL == netdev_name)
    {
        printf("netdev_name is NULL.\n");
        return -1;
    }

    char *buffer_name = BUFFER_NAME;
    char send_buf[SEND_BUF_LEN] = {0};
    int send_times = 0;

    // 创建设备
    LAYER2VI layer2vi = layer2vi_create(netdev_name, buffer_name);
    if (NULL == layer2vi)
    {
        printf("Create layer2vi fail!\n");
        return -1;
    }

    if (build_send_buf(send_buf, SEND_BUF_LEN) < 0)
    {
        printf("build_send_buf error.\n");
        layer2vi_destroy(layer2vi);
        return -1;
    }

    // 发包
    while (send_times < SEND_TIMES)
    {
        if (layer2vi_transmit_frame(layer2vi, send_buf, SEND_BUF_LEN) < 0)
        {
            printf("send failed.\n");
        }
        send_times++;

        if (send_times % 1000 == 0)
        {
            printf("Send %d times\n", send_times);
        }

        usleep(SEND_INTERVAL_USEC);
    }

    printf("Send %d times\n", send_times);

    layer2vi_destroy(layer2vi);

    return 0;
}

static int recv_test(char *netdev_name)
{
    if (NULL == netdev_name)
    {
        printf("netdev_name is NULL.\n");
        return -1;
    }

    char *buffer_name = BUFFER_NAME;
    char recv_buf[RECV_BUF_LEN] = {0};
    int recv_pkts = 0;

    // 创建设备
    LAYER2VI layer2vi = layer2vi_create(netdev_name, buffer_name);
    if (NULL == layer2vi)
    {
        printf("Create layer2vi fail!\n");
        return -1;
    }

    //添加filter
    if (layer2vi_add_filter(layer2vi, g_mac_filter) < 0)
    {
        printf("add mac filter failed.\n");
        layer2vi_destroy(layer2vi);
        return -1;
    }
    if (layer2vi_add_filter(layer2vi, g_ip_filter) < 0)
    {
        printf("add ip filter failed.\n");
        layer2vi_destroy(layer2vi);
        return -1;
    }

    // 收包
    while (1)
    {
        if (layer2vi_receive_frame(layer2vi, recv_buf, RECV_BUF_LEN) < 0)
        {
            printf("recv failed.\n");
        }

        recv_pkts++;

        if (recv_pkts % 1000 == 0)
        {
            printf("recv %d packets\n", recv_pkts);
        }
    }

    printf("recv %d packets\n", recv_pkts);

    layer2vi_destroy(layer2vi);

    return 0;
}

static void help_info()
{
    printf("\nexample:\n");
    printf("  instanta_layer2vi_example -i swift1f0 -m send\n");
    printf("  instanta_layer2vi_example -i swift1f0 -m recv\n");
    printf("\noptions:\n");
    printf("  -i <netdev_interface>    - netdev interface\n");
    printf("  -m <test_mode>           - test_mode, send or recv\n");
    printf("\n");
}

int main(int argc, char *argv[])
{
    char *netdev_name = NULL;
    char *test_mode = NULL;

    static struct option long_opts[] = {
        {"interface", required_argument, NULL, 'i'},
        {"mode", required_argument, NULL, 'm'},
        {0, 0, 0, 0}};

    int c = -1;
    while ((c = getopt_long(argc, argv, "i:m:", long_opts, NULL)) != -1)
    {
        switch (c)
        {
        case 'i':
            netdev_name = optarg;
            break;

        case 'm':
            test_mode = optarg;
            break;

        default:
            printf("\nPlease check your options\n\n");
            help_info();
            return -1;
        }
    }

    if (argc < 5)
    {
        help_info();
        return -1;
    }

    if (0 == strcmp(test_mode, "send"))
    {
        init_filter();
        if (send_test(netdev_name) < 0)
        {
            printf("send test failed.\n");
            return -1;
        }
    }
    else if (0 == strcmp(test_mode, "recv"))
    {
        init_filter();
        if (recv_test(netdev_name) < 0)
        {
            printf("recv test failed.\n");
            return -1;
        }
    }
    else
    {
        help_info();
        return -1;
    }

    return 0;
}
