/*
 * Decompiled with CFR 0.152.
 */
package zeph.http;

import java.io.InputStream;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import zeph.buffer.BufferPool;
import zeph.http.HttpParser;
import zeph.http.HttpRequest;
import zeph.http.HttpResponse;
import zeph.http2.Http2FrameReader;
import zeph.http2.Http2ServerHandler;
import zeph.uring.IoUring;
import zeph.uring.Socket;

public class HttpServerMultiRing
implements AutoCloseable {
    private final int port;
    private final int ringCount;
    private final WorkerRing[] workers;
    private final AtomicBoolean running = new AtomicBoolean(true);
    private final AtomicInteger nextWorker = new AtomicInteger(0);
    private final Function<HttpRequest, HttpResponse> handler;
    private final IoUring acceptorRing;
    private final Arena acceptorArena;
    private final int serverFd;
    private final Thread acceptorThread;
    private static final long OP_ACCEPT = 0x100000000000000L;
    private static final long OP_MASK = -72057594037927936L;
    private static final long FD_MASK = 0xFFFFFFFFFFFFFFL;
    private final boolean sqPoll;
    private final int sqPollIdleMs;
    private final AtomicLong connectionIdGenerator = new AtomicLong(0L);
    private volatile int acceptorPendingSubmits = 0;

    public HttpServerMultiRing(String ip, int port, int ringCount, int ringSize, Function<HttpRequest, HttpResponse> handler) throws Exception {
        this(ip, port, ringCount, ringSize, handler, false, 1000);
    }

    public HttpServerMultiRing(String ip, int port, int ringCount, int ringSize, Function<HttpRequest, HttpResponse> handler, boolean sqPoll, int sqPollIdleMs) throws Exception {
        this.port = port;
        this.ringCount = ringCount;
        this.handler = handler;
        this.sqPoll = sqPoll;
        this.sqPollIdleMs = sqPollIdleMs;
        this.workers = new WorkerRing[ringCount];
        for (int i = 0; i < ringCount; ++i) {
            this.workers[i] = new WorkerRing(this, i, ringSize);
        }
        this.acceptorArena = Arena.ofShared();
        this.acceptorRing = new IoUring(ringSize, 0);
        this.serverFd = Socket.createServerSocket();
        try {
            Socket.setReuseAddr(this.serverFd, this.acceptorArena);
            Socket.setReusePort(this.serverFd, this.acceptorArena);
            Socket.bind(this.serverFd, ip, port, this.acceptorArena);
            Socket.listen(this.serverFd, 4096);
        }
        catch (Exception e) {
            Socket.close(this.serverFd);
            this.acceptorRing.close();
            this.acceptorArena.close();
            throw e;
        }
        this.acceptorThread = new Thread(this::acceptLoop, "zeph-acceptor");
        System.out.println("HTTP Server (multi-ring) listening on " + ip + ":" + port + " with " + ringCount + " workers");
    }

    private void submitAccept() {
        MemorySegment sqe = this.acceptorRing.getSqe();
        if (sqe != null) {
            this.acceptorRing.prepareAccept(sqe, this.serverFd, null, null, 526336, 0x100000000000000L | (long)this.serverFd);
            this.acceptorRing.submit();
            ++this.acceptorPendingSubmits;
        }
    }

    private void acceptLoop() {
        this.submitAccept();
        while (this.running.get()) {
            try {
                MemorySegment cqe;
                int toSubmit = this.acceptorPendingSubmits;
                this.acceptorPendingSubmits = 0;
                this.acceptorRing.enter(toSubmit, 1, 1);
                while ((cqe = this.acceptorRing.peekCqe()) != null) {
                    long userData = IoUring.getUserData(cqe);
                    int result = IoUring.getResult(cqe);
                    this.acceptorRing.advanceCq();
                    long op = userData & 0xFF00000000000000L;
                    if (op != 0x100000000000000L) continue;
                    if (result >= 0) {
                        int clientFd = result;
                        long connId = this.connectionIdGenerator.incrementAndGet();
                        int workerIdx = Math.abs(this.nextWorker.getAndIncrement() % this.ringCount);
                        this.workers[workerIdx].addConnection(clientFd, connId);
                    }
                    this.submitAccept();
                }
            }
            catch (Exception e) {
                if (!this.running.get()) continue;
                System.err.println("Acceptor error: " + e.getMessage());
                e.printStackTrace();
            }
        }
    }

    public void run() {
        for (WorkerRing worker : this.workers) {
            worker.start();
        }
        this.acceptorThread.start();
        System.out.println("HTTP Server (multi-ring) started with " + this.ringCount + " workers");
        try {
            this.acceptorThread.join();
        }
        catch (InterruptedException interruptedException) {
            // empty catch block
        }
    }

    public void stop() {
        this.running.set(false);
        for (WorkerRing worker : this.workers) {
            worker.stop();
        }
        try {
            MemorySegment sqe = this.acceptorRing.getSqe();
            if (sqe != null) {
                this.acceptorRing.prepareNop(sqe, 0L);
                this.acceptorRing.submit();
                this.acceptorRing.enter(1, 0, 0);
            }
        }
        catch (Exception exception) {
            // empty catch block
        }
    }

    @Override
    public void close() {
        this.stop();
        try {
            this.acceptorThread.join(2000L);
            for (WorkerRing worker : this.workers) {
                worker.thread.interrupt();
                worker.thread.join(1000L);
            }
        }
        catch (InterruptedException interruptedException) {
            // empty catch block
        }
        for (WorkerRing worker : this.workers) {
            worker.close();
        }
        try {
            Socket.close(this.serverFd);
        }
        catch (Exception exception) {
            // empty catch block
        }
        this.acceptorRing.close();
        this.acceptorArena.close();
    }

    public int getPort() {
        return this.port;
    }

    public static void main(String[] args) {
        int port = 8080;
        int workers = Runtime.getRuntime().availableProcessors();
        boolean sqPoll = false;
        if (args.length > 0) {
            port = Integer.parseInt(args[0]);
        }
        if (args.length > 1) {
            workers = Integer.parseInt(args[1]);
        }
        if (args.length > 2) {
            sqPoll = Boolean.parseBoolean(args[2]);
        }
        System.out.println("Starting Multi-Ring HTTP Server");
        System.out.println("Port: " + port + ", Workers: " + workers + ", SQ Poll: " + sqPoll);
        try (HttpServerMultiRing server = new HttpServerMultiRing("0.0.0.0", port, workers, 256, request -> {
            String path;
            return switch (path = request.getPath()) {
                case "/", "/hello" -> HttpResponse.ok("Hello from Zeph!");
                case "/json" -> HttpResponse.json("{\"message\": \"Hello!\"}");
                default -> HttpResponse.notFound();
            };
        }, sqPoll, 1000);){
            Runtime.getRuntime().addShutdownHook(new Thread(() -> {
                System.out.println("\nShutting down...");
                server.stop();
            }));
            server.run();
        }
        catch (Exception e) {
            System.err.println("Error: " + e.getMessage());
            e.printStackTrace();
        }
    }

    private class WorkerRing
    implements Runnable {
        final int id;
        final IoUring ring;
        final Arena arena;
        final BufferPool bufferPool;
        final ConcurrentHashMap<Integer, HttpConnection> connections;
        final ConcurrentLinkedQueue<NewConnection> pendingConnections;
        final Thread thread;
        volatile int pendingSubmits;
        final int eventFd;
        final MemorySegment eventFdBuffer;
        final MemorySegment eventFdWriteBuffer;
        private static final long OP_READ = 0x200000000000000L;
        private static final long OP_WRITE = 0x300000000000000L;
        private static final long OP_CLOSE = 0x400000000000000L;
        private static final long OP_EVENTFD = 0x500000000000000L;
        private static final int BUFFER_SIZE = 16384;
        private static final int POOL_SIZE = 256;
        final /* synthetic */ HttpServerMultiRing this$0;

        WorkerRing(HttpServerMultiRing httpServerMultiRing, int id, int ringSize) throws Exception {
            HttpServerMultiRing httpServerMultiRing2 = httpServerMultiRing;
            Objects.requireNonNull(httpServerMultiRing2);
            this.this$0 = httpServerMultiRing2;
            this.connections = new ConcurrentHashMap();
            this.pendingConnections = new ConcurrentLinkedQueue();
            this.pendingSubmits = 0;
            this.id = id;
            this.arena = Arena.ofShared();
            if (httpServerMultiRing.sqPoll) {
                int flags = 2;
                this.ring = new IoUring(ringSize, flags, id, httpServerMultiRing.sqPollIdleMs);
            } else {
                this.ring = new IoUring(ringSize, 0);
            }
            this.bufferPool = new BufferPool(16384, 256);
            this.thread = new Thread((Runnable)this, "zeph-worker-" + id);
            this.eventFd = IoUring.createEventFd(0, 0);
            this.eventFdBuffer = this.arena.allocate(8L, 8L);
            this.eventFdWriteBuffer = this.arena.allocate(8L, 8L);
            this.eventFdWriteBuffer.set(ValueLayout.JAVA_LONG, 0L, 1L);
        }

        void start() {
            this.thread.start();
        }

        void addConnection(int fd, long connId) {
            this.pendingConnections.offer(new NewConnection(fd, connId));
            this.wakeup();
        }

        private void wakeup() {
            try {
                IoUring.writeEventFd(this.eventFd, this.eventFdWriteBuffer);
            }
            catch (Exception exception) {
                // empty catch block
            }
        }

        private void submitEventFdRead() {
            MemorySegment sqe = this.ring.getSqe();
            if (sqe != null) {
                this.ring.prepareRead(sqe, this.eventFd, this.eventFdBuffer, 8, 0L, 0x500000000000000L);
                this.ring.submit();
                ++this.pendingSubmits;
            }
        }

        private void processPendingConnections() {
            NewConnection nc;
            while ((nc = this.pendingConnections.poll()) != null) {
                BufferPool.PooledBuffer readBuf = this.bufferPool.acquireOrAllocate();
                BufferPool.PooledBuffer writeBuf = this.bufferPool.acquireOrAllocate();
                HttpConnection conn = new HttpConnection(nc.fd, nc.id, readBuf, writeBuf);
                this.connections.put(nc.fd, conn);
                this.submitRead(conn);
            }
        }

        private void submitRead(HttpConnection conn) {
            MemorySegment sqe = this.ring.getSqe();
            if (sqe != null) {
                this.ring.prepareRecv(sqe, conn.fd, conn.readBuffer, 16384, 0, 0x200000000000000L | (long)conn.fd);
                this.ring.submit();
                ++this.pendingSubmits;
            }
        }

        private void submitWrite(HttpConnection conn) {
            if (conn.pendingWrite == null || conn.writeOffset >= conn.pendingWrite.length) {
                return;
            }
            int remaining = conn.pendingWrite.length - conn.writeOffset;
            int toWrite = Math.min(remaining, 16384);
            conn.writeBuffer.asSlice(0L, toWrite).asByteBuffer().put(conn.pendingWrite, conn.writeOffset, toWrite);
            MemorySegment sqe = this.ring.getSqe();
            if (sqe != null) {
                this.ring.prepareSend(sqe, conn.fd, conn.writeBuffer, toWrite, 0, 0x300000000000000L | (long)conn.fd);
                this.ring.submit();
                ++this.pendingSubmits;
            }
        }

        private void submitClose(int fd) {
            MemorySegment sqe = this.ring.getSqe();
            if (sqe != null) {
                this.ring.prepareClose(sqe, fd, 0x400000000000000L | (long)fd);
                this.ring.submit();
                ++this.pendingSubmits;
            }
        }

        @Override
        public void run() {
            this.submitEventFdRead();
            while (this.this$0.running.get()) {
                try {
                    MemorySegment cqe;
                    int toSubmit = this.pendingSubmits;
                    this.pendingSubmits = 0;
                    int enterFlags = 1;
                    if (this.ring.isSqPollMode()) {
                        if (this.ring.needsWakeup()) {
                            enterFlags |= 2;
                        }
                        toSubmit = 0;
                    }
                    this.ring.enter(toSubmit, 1, enterFlags);
                    while ((cqe = this.ring.peekCqe()) != null) {
                        long userData = IoUring.getUserData(cqe);
                        int result = IoUring.getResult(cqe);
                        this.ring.advanceCq();
                        long op = userData & 0xFF00000000000000L;
                        int fd = (int)(userData & 0xFFFFFFFFFFFFFFL);
                        if (op == 0x500000000000000L) {
                            this.submitEventFdRead();
                            this.processPendingConnections();
                            continue;
                        }
                        if (op == 0x200000000000000L) {
                            this.handleRead(fd, result);
                            continue;
                        }
                        if (op == 0x300000000000000L) {
                            this.handleWrite(fd, result);
                            continue;
                        }
                        if (op != 0x400000000000000L) continue;
                        this.handleClose(fd);
                    }
                }
                catch (Exception e) {
                    if (!this.this$0.running.get()) continue;
                    System.err.println("Worker " + this.id + " error: " + e.getMessage());
                }
            }
        }

        private void handleRead(int fd, int result) {
            HttpConnection conn = this.connections.get(fd);
            if (conn == null) {
                return;
            }
            if (result > 0) {
                byte[] data = new byte[result];
                conn.readBuffer.asSlice(0L, result).asByteBuffer().get(data);
                if (conn.isHttp2) {
                    this.handleHttp2Read(conn, data);
                    return;
                }
                HttpParser.Result parseResult = conn.parser.parse(data, 0, result);
                if (parseResult == HttpParser.Result.COMPLETE) {
                    HttpResponse response;
                    HttpRequest request = conn.parser.getRequest();
                    request.setServerPort(this.this$0.port);
                    if (this.isH2cUpgradeRequest(request)) {
                        this.handleH2cUpgrade(conn, request);
                        return;
                    }
                    conn.keepAlive = request.isKeepAlive();
                    try {
                        response = this.this$0.handler.apply(request);
                    }
                    catch (Exception e) {
                        response = HttpResponse.serverError();
                    }
                    if (conn.keepAlive) {
                        response.setHeader("Connection", "keep-alive");
                    } else {
                        response.setHeader("Connection", "close");
                        conn.closeAfterWrite = true;
                    }
                    if (response.isStreaming()) {
                        this.startStreamingResponse(conn, response);
                    } else {
                        conn.pendingWrite = response.encode();
                        conn.writeOffset = 0;
                        this.submitWrite(conn);
                    }
                    conn.parser.reset();
                } else if (parseResult == HttpParser.Result.ERROR) {
                    HttpResponse response = HttpResponse.badRequest();
                    response.setHeader("Connection", "close");
                    conn.pendingWrite = response.encode();
                    conn.writeOffset = 0;
                    conn.closeAfterWrite = true;
                    this.submitWrite(conn);
                } else {
                    this.submitRead(conn);
                }
            } else if (result == 0) {
                this.closeConnection(fd);
            } else {
                int err = -result;
                if (err != 11 && err != 35) {
                    this.closeConnection(fd);
                } else {
                    this.submitRead(conn);
                }
            }
        }

        private boolean isH2cUpgradeRequest(HttpRequest request) {
            String upgrade = request.getHeader("upgrade");
            String connection = request.getHeader("connection");
            String http2Settings = request.getHeader("http2-settings");
            return upgrade != null && upgrade.equalsIgnoreCase("h2c") && connection != null && connection.toLowerCase().contains("upgrade") && http2Settings != null;
        }

        private void handleH2cUpgrade(HttpConnection conn, HttpRequest request) {
            conn.http2Handler = new Http2ServerHandler(this.this$0.handler, this.this$0.port, false);
            conn.isHttp2 = true;
            conn.keepAlive = true;
            conn.http2Handler.setupUpgradeStream(request);
            StringBuilder response = new StringBuilder();
            response.append("HTTP/1.1 101 Switching Protocols\r\n");
            response.append("Connection: Upgrade\r\n");
            response.append("Upgrade: h2c\r\n");
            response.append("\r\n");
            conn.pendingWrite = response.toString().getBytes(StandardCharsets.US_ASCII);
            conn.writeOffset = 0;
            conn.parser.reset();
            conn.pendingHttp2Data = this.buildServerPreface(conn);
            this.submitWrite(conn);
        }

        private byte[] buildServerPreface(HttpConnection conn) {
            try {
                ByteBuffer output = ByteBuffer.allocate(1024);
                conn.http2Handler.getConnection().writeServerPreface(output);
                output.flip();
                byte[] result = new byte[output.remaining()];
                output.get(result);
                return result;
            }
            catch (Exception e) {
                return new byte[0];
            }
        }

        private void handleHttp2Read(HttpConnection conn, byte[] data) {
            try {
                byte[] response = conn.http2Handler.processData(data);
                if (response != null && response.length > 0) {
                    conn.pendingWrite = response;
                    conn.writeOffset = 0;
                    this.submitWrite(conn);
                } else if (conn.http2Handler.isOpen()) {
                    this.submitRead(conn);
                } else {
                    this.closeConnection(conn.fd);
                }
            }
            catch (Http2FrameReader.Http2Exception e) {
                this.closeConnection(conn.fd);
            }
        }

        private void startStreamingResponse(HttpConnection conn, HttpResponse response) {
            InputStream stream = response.getBodyStream();
            if (stream == null) {
                conn.pendingWrite = response.encode();
                conn.writeOffset = 0;
                this.submitWrite(conn);
                return;
            }
            conn.streamingBody = stream;
            conn.useChunkedEncoding = response.getContentLength() < 0L;
            conn.streamingBuffer = new byte[16384];
            conn.pendingWrite = response.encodeStreamingHeaders();
            conn.writeOffset = 0;
            this.submitWrite(conn);
        }

        private boolean continueStreamingResponse(HttpConnection conn) {
            if (conn.streamingBody == null) {
                return false;
            }
            try {
                int n = conn.streamingBody.read(conn.streamingBuffer);
                if (n <= 0) {
                    conn.pendingWrite = (byte[])(conn.useChunkedEncoding ? "0\r\n\r\n".getBytes(StandardCharsets.US_ASCII) : null);
                    conn.streamingBody.close();
                    conn.streamingBody = null;
                    conn.streamingBuffer = null;
                    if (conn.pendingWrite != null && conn.pendingWrite.length > 0) {
                        conn.writeOffset = 0;
                        return true;
                    }
                    return false;
                }
                if (conn.useChunkedEncoding) {
                    String sizeHex = Integer.toHexString(n);
                    byte[] header = (sizeHex + "\r\n").getBytes(StandardCharsets.US_ASCII);
                    byte[] footer = "\r\n".getBytes(StandardCharsets.US_ASCII);
                    conn.pendingWrite = new byte[header.length + n + footer.length];
                    System.arraycopy(header, 0, conn.pendingWrite, 0, header.length);
                    System.arraycopy(conn.streamingBuffer, 0, conn.pendingWrite, header.length, n);
                    System.arraycopy(footer, 0, conn.pendingWrite, header.length + n, footer.length);
                } else {
                    conn.pendingWrite = new byte[n];
                    System.arraycopy(conn.streamingBuffer, 0, conn.pendingWrite, 0, n);
                }
                conn.writeOffset = 0;
                return true;
            }
            catch (Exception e) {
                try {
                    conn.streamingBody.close();
                }
                catch (Exception exception) {
                    // empty catch block
                }
                conn.streamingBody = null;
                conn.streamingBuffer = null;
                return false;
            }
        }

        private void handleWrite(int fd, int result) {
            HttpConnection conn = this.connections.get(fd);
            if (conn == null) {
                return;
            }
            if (result > 0) {
                conn.writeOffset += result;
                if (conn.writeOffset < conn.pendingWrite.length) {
                    this.submitWrite(conn);
                } else {
                    conn.pendingWrite = null;
                    conn.writeOffset = 0;
                    if (conn.streamingBody != null) {
                        if (this.continueStreamingResponse(conn)) {
                            this.submitWrite(conn);
                        } else if (conn.closeAfterWrite) {
                            this.closeConnection(fd);
                        } else if (conn.keepAlive) {
                            this.submitRead(conn);
                        } else {
                            this.closeConnection(fd);
                        }
                        return;
                    }
                    if (conn.isHttp2 && conn.http2Handler != null && conn.http2Handler.hasPendingStreaming()) {
                        byte[] nextChunk = conn.http2Handler.continueStreaming();
                        if (nextChunk != null && nextChunk.length > 0) {
                            conn.pendingWrite = nextChunk;
                            conn.writeOffset = 0;
                            this.submitWrite(conn);
                        } else if (conn.http2Handler.isOpen()) {
                            this.submitRead(conn);
                        } else {
                            this.closeConnection(fd);
                        }
                        return;
                    }
                    if (conn.isHttp2 && conn.pendingHttp2Data != null && conn.pendingHttp2Data.length > 0) {
                        conn.pendingWrite = conn.pendingHttp2Data;
                        conn.pendingHttp2Data = null;
                        this.submitWrite(conn);
                    } else if (conn.closeAfterWrite) {
                        this.closeConnection(fd);
                    } else if (conn.keepAlive || conn.isHttp2) {
                        this.submitRead(conn);
                    } else {
                        this.closeConnection(fd);
                    }
                }
            } else if (result < 0) {
                int err = -result;
                if (err != 11 && err != 35) {
                    this.closeConnection(fd);
                } else {
                    this.submitWrite(conn);
                }
            }
        }

        private void handleClose(int fd) {
            HttpConnection conn = this.connections.remove(fd);
            if (conn != null) {
                conn.close();
            }
        }

        private void closeConnection(int fd) {
            if (this.connections.containsKey(fd)) {
                this.submitClose(fd);
            }
        }

        void stop() {
            this.wakeup();
        }

        void close() {
            for (HttpConnection conn : this.connections.values()) {
                try {
                    Socket.close(conn.fd);
                }
                catch (Exception exception) {
                    // empty catch block
                }
                conn.close();
            }
            this.connections.clear();
            IoUring.closeEventFd(this.eventFd);
            this.ring.close();
            this.bufferPool.close();
            this.arena.close();
        }

        record NewConnection(int fd, long id) {
        }
    }

    private static class HttpConnection {
        final int fd;
        final long id;
        final BufferPool.PooledBuffer readPooledBuffer;
        final BufferPool.PooledBuffer writePooledBuffer;
        final MemorySegment readBuffer;
        final MemorySegment writeBuffer;
        final HttpParser parser;
        byte[] pendingWrite;
        int writeOffset;
        boolean keepAlive = true;
        boolean closeAfterWrite = false;
        boolean isHttp2 = false;
        Http2ServerHandler http2Handler;
        byte[] pendingHttp2Data;
        InputStream streamingBody;
        boolean useChunkedEncoding;
        byte[] streamingBuffer;
        static final int STREAMING_BUFFER_SIZE = 16384;

        HttpConnection(int fd, long id, BufferPool.PooledBuffer readBuf, BufferPool.PooledBuffer writeBuf) {
            this.fd = fd;
            this.id = id;
            this.readPooledBuffer = readBuf;
            this.writePooledBuffer = writeBuf;
            this.readBuffer = readBuf.segment();
            this.writeBuffer = writeBuf.segment();
            this.parser = new HttpParser();
        }

        void close() {
            this.readPooledBuffer.release();
            this.writePooledBuffer.release();
            if (this.streamingBody != null) {
                try {
                    this.streamingBody.close();
                }
                catch (Exception exception) {
                    // empty catch block
                }
                this.streamingBody = null;
            }
        }
    }
}

