;;   Copyright (c) 7theta. All rights reserved.
;;   The use and distribution terms for this software are covered by the
;;   MIT License (https://opensource.org/licenses/MIT) which can also be
;;   found in the LICENSE file at the root of this distribution.
;;
;;   By using this software in any fashion, you are agreeing to be bound by
;;   the terms of this license.
;;   You must not remove this notice, or any others, from this software.

(ns via.netty.http2.server
  (:require [utilis.map :refer [map-keys map-vals compact]]
            [byte-streams.core :as bs]
            [manifold.deferred :as d]
            [manifold.executor :as e]
            [manifold.stream :as s]
            [manifold.stream.core :as manifold]
            [clojure.string :as st]
            [com.brunobonacci.mulog :as u])
  (:import [java.io IOException]
           [io.netty.bootstrap Bootstrap ServerBootstrap]
           [io.netty.buffer ByteBuf Unpooled]
           [io.netty.channel
            ChannelPromise
            ChannelDuplexHandler
            Channel ChannelFuture ChannelOption
            ChannelPipeline EventLoopGroup
            DefaultChannelPipeline
            ChannelHandler FileRegion
            ChannelInboundHandler
            ChannelOutboundHandler
            ChannelHandlerContext
            ChannelInitializer]
           [java.util.concurrent
            ConcurrentHashMap
            CancellationException
            ScheduledFuture
            TimeUnit
            ThreadFactory
            Executors
            ExecutorService]
           [io.netty.channel.epoll Epoll EpollEventLoopGroup
            EpollServerSocketChannel
            EpollSocketChannel
            EpollDatagramChannel]
           [io.netty.util
            Attribute
            AttributeKey
            ReferenceCounted]
           [io.netty.handler.codec Headers]
           [io.netty.handler.codec.http
            HttpResponseStatus
            HttpRequest
            HttpServerCodec]
           [io.netty.handler.codec.http.websocketx
            TextWebSocketFrame
            WebSocketFrame
            CloseWebSocketFrame
            WebSocket13FrameDecoder
            WebSocket13FrameEncoder
            WebSocketDecoderConfig]
           [io.netty.handler.codec.http.websocketx.extensions
            WebSocketServerExtension
            WebSocketExtensionData
            WebSocketExtensionDecoder
            WebSocketExtensionEncoder]
           [io.netty.handler.codec.http.websocketx.extensions.compression
            PerMessageDeflateServerExtensionHandshaker]
           [io.netty.handler.codec.http2
            DefaultHttp2DataFrame
            DefaultHttp2Headers
            DefaultHttp2HeadersFrame
            EmptyHttp2Headers
            Http2FrameCodecBuilder
            Http2HeadersFrame
            Http2DataFrame
            Http2FrameStream
            Http2Headers
            Http2Settings]
           [io.netty.channel.nio NioEventLoopGroup]
           [io.netty.channel.socket
            ServerSocketChannel
            SocketChannel]
           [io.netty.channel.socket.nio
            NioServerSocketChannel
            NioSocketChannel
            NioDatagramChannel]
           [io.netty.handler.ssl
            SslContext
            SslContextBuilder
            SslHandler
            SslHandshakeCompletionEvent
            ApplicationProtocolNegotiationHandler
            ApplicationProtocolNames]
           [java.net URI SocketAddress InetSocketAddress]
           [java.nio ByteBuffer]))

(def ByteArray (type (byte-array [])))
(def SETTINGS_ENABLE_CONNECT_PROTOCOL (char 8))

(defn websocket-request?
  [req]
  (boolean (and (= "websocket" (get-in req [:headers ":protocol"]))
                (= "13" (get-in req [:headers "sec-websocket-version"])))))

(defn add-all-headers
  [^Http2Headers headers req headers-map]
  (doseq [[k v] (map-keys #(st/lower-case
                            (str (if (keyword? %)
                                   (name %)
                                   %)))
                          headers-map)]
    (.add headers k (str v)))
  headers)

(defn headers
  [^Http2HeadersFrame msg]
  (->> (.headers msg)
       .iterator
       iterator-seq
       (map (fn [[k v]]
              [(str k) (str v)]))
       (into {})))

(defn request
  [^Http2HeadersFrame msg]
  (let [headers (headers msg)
        path (get headers ":path")
        [uri query-string] (st/split path #"\?")]
    {:headers headers
     :protocol-version "http/2"
     :uri uri
     :path path
     :scheme (keyword (get headers ":scheme"))
     :query-string query-string
     :request-method (-> headers
                         (get ":method")
                         st/lower-case
                         keyword)
     :query-params (when query-string
                     (->> (st/split query-string #"&")
                          (map #(st/split % #"\="))
                          (into {})))}))

(defn handle-http2-headers-frame
  [^ChannelHandlerContext ctx ^Http2HeadersFrame headers-frame ^ExecutorService exec-service handler]
  (let [websocket? (.contains (.headers headers-frame) ":protocol" "websocket" true)]
    (when (or websocket? (.isEndStream headers-frame))
      (let [req (request headers-frame)]
        (.submit exec-service
                 (reify Runnable
                   (run [_]
                     (try
                       (.fireUserEventTriggered
                        ctx {:req req
                             :res (handler req)
                             :headers-frame headers-frame})
                       (catch Exception e
                         (.fireUserEventTriggered
                          ctx {:req req
                               :res {:status 500
                                     :body "Internal Error Occurred"
                                     :headers {}}
                               :headers-frame headers-frame})
                         (u/log ::http2-headers-frame :exception e))))))
        nil))))

(defn safe-remove
  [^ChannelPipeline pipeline ^String handler-name]
  (when ((set (.names pipeline)) handler-name)
    (.remove pipeline handler-name)))

(defn configure-websocket-inbound-handlers
  [^ChannelPipeline pipeline
   ^WebSocket13FrameDecoder frame-decoder
   ^WebSocketExtensionDecoder compression-decoder
   {:keys [on-text-message on-close]}
   ^ExecutorService exec-service]
  (doto pipeline
    (safe-remove "frame-decoder")
    (safe-remove "frame-compression-decoder")
    (safe-remove "frame-handler"))
  (doto pipeline
    (.addLast "frame-decoder"
              (proxy [ChannelDuplexHandler] []
                (channelRead [^ChannelHandlerContext ctx msg]
                  (cond
                    (instance? Http2DataFrame msg)
                    (.channelRead frame-decoder ctx (.content ^Http2DataFrame msg))

                    (instance? ReferenceCounted msg)
                    (.release ^ReferenceCounted msg)))))
    (.addLast "frame-compression-decoder"
              (proxy [ChannelDuplexHandler] []
                (channelRead [^ChannelHandlerContext ctx msg]
                  (cond
                    (instance? TextWebSocketFrame msg)
                    (.channelRead compression-decoder ctx msg)

                    (instance? ReferenceCounted msg)
                    (.release ^ReferenceCounted msg)))))
    (.addLast "frame-handler"
              (proxy [ChannelDuplexHandler] []
                (channelRead [^ChannelHandlerContext ctx msg]
                  (let [handler (cond
                                  (instance? TextWebSocketFrame msg)
                                  (let [text (.text ^TextWebSocketFrame msg)]
                                    #(on-text-message text))

                                  (instance? CloseWebSocketFrame msg)
                                  on-close

                                  :else nil)]
                    (try (when handler
                           (.submit exec-service
                                    (reify Runnable
                                      (run [_]
                                        (try
                                          (handler)
                                          (catch Exception e
                                            (u/log ::websocket-inbound-handler :exception e)))))))
                         (finally
                           (when (instance? ReferenceCounted msg)
                             (.release ^ReferenceCounted msg))))))))))

(defn configure-websocket-outbound-handlers
  [^ChannelPipeline pipeline
   ^WebSocket13FrameDecoder frame-encoder
   ^WebSocketExtensionDecoder compression-encoder
   ^Http2FrameStream stream]
  (doto pipeline
    (safe-remove "frame-writer")
    (safe-remove "frame-encoder")
    (safe-remove "frame-compression-encoder"))
  (doto pipeline
    (.addLast "frame-writer"
              (proxy [ChannelDuplexHandler] []
                (write [^ChannelHandlerContext ctx msg ^ChannelPromise p]
                  (.writeAndFlush ctx (.stream (DefaultHttp2DataFrame. msg false) stream)))))
    (.addLast "frame-encoder"
              (proxy [ChannelDuplexHandler] []
                (write [^ChannelHandlerContext ctx msg ^ChannelPromise p]
                  (when (instance? TextWebSocketFrame msg)
                    (.write frame-encoder ctx msg p)))))
    (.addLast "frame-compression-encoder"
              (proxy [ChannelDuplexHandler] []
                (write [^ChannelHandlerContext ctx msg ^ChannelPromise p]
                  (when (instance? TextWebSocketFrame msg)
                    (.write compression-encoder ctx msg p)))))))

(defn configure-websocket
  [^ChannelHandlerContext ctx ^Http2HeadersFrame msg req {:keys [headers] :as res} ^ExecutorService exec-service]
  (try (let [pl (.pipeline ctx)]
         (when (not (get (set (.names (.pipeline ctx))) "websocket-handler"))
           (let [decoder-config (-> (WebSocketDecoderConfig/newBuilder)
                                    (.allowExtensions true)
                                    (.build))
                 {:keys [compression-extension handlers]} (meta res)
                 compression-encoder (when compression-extension (.newExtensionEncoder compression-extension))
                 compression-decoder (when compression-extension (.newExtensionDecoder compression-extension))
                 headers (doto (DefaultHttp2Headers.)
                           (.status (str 200))
                           (add-all-headers req (or (not-empty headers) {})))
                 frame-decoder (WebSocket13FrameDecoder. decoder-config)
                 frame-encoder (WebSocket13FrameEncoder. false)
                 http2-stream (.stream msg)
                 {:keys [on-open]} handlers]
             (.write ctx (.stream (DefaultHttp2HeadersFrame. headers) http2-stream))
             (.flush ctx)
             (doto (.pipeline ctx)
               (configure-websocket-inbound-handlers frame-decoder compression-decoder handlers exec-service)
               (configure-websocket-outbound-handlers frame-encoder compression-encoder http2-stream))
             (on-open {:send #(.writeAndFlush (.channel ctx) (TextWebSocketFrame. %))
                       :close (fn [] (.close (.channel ctx)))})
             nil)))
       (catch Exception e
         (u/log ::websocket-configuration :exception e))))

(defn ^ChannelHandler http2-response-handler
  [handler ^ExecutorService exec-service]
  (proxy [ChannelDuplexHandler] []
    (userEventTriggered [^ChannelHandlerContext ctx event]
      (when (and (map? event) (:req event) (:res event) (:headers-frame event))
        (try (let [{:keys [res req headers-frame]} event
                   headers-frame ^Http2HeadersFrame headers-frame
                   {:keys [status body headers]} res]
               (if (and (websocket-request? req)
                        (= 200 (:status res)))
                 (configure-websocket ctx headers-frame req res exec-service)
                 (let [headers (doto (DefaultHttp2Headers.)
                                 (.status (str status))
                                 (add-all-headers req headers))]
                   (->> (.stream headers-frame)
                        (.stream (DefaultHttp2HeadersFrame. headers))
                        (.write ctx))
                   (when body
                     (let [^ByteBuf content (-> ctx .alloc .buffer)
                           body-bytes (bs/convert body ByteArray)]
                       (.writeBytes content body-bytes)
                       (when (instance? java.io.Closeable body)
                         (.close ^java.io.Closeable body))
                       (->> (.stream headers-frame)
                            (.stream (DefaultHttp2DataFrame. content true))
                            (.write ctx))))
                   (.flush ctx))))
             (catch Exception e
               (u/log ::http2-response :exception e)))))))

(defn ^ChannelHandler http2-request-handler
  [^ExecutorService exec-service handler]
  (proxy [ChannelDuplexHandler] []
    (channelRead [^ChannelHandlerContext ctx msg]
      (if (instance? Http2HeadersFrame msg)
        (handle-http2-headers-frame ctx ^Http2HeadersFrame msg exec-service handler)
        (proxy-super channelRead ctx msg)))))

(defn configure-http2-request-handler
  [^ChannelPipeline pipeline ^ExecutorService exec-service handler]
  (doto pipeline
    (.addLast "request-handler" (http2-request-handler exec-service handler))
    (.addLast "response-handler" (http2-response-handler handler exec-service))))

(defn configure-http2-frame-codec-builder
  [^ChannelPipeline pipeline]
  (doto pipeline
    (.addLast "frame-codec"
              (-> (Http2FrameCodecBuilder/forServer)
                  (.initialSettings
                   (doto (Http2Settings/defaultSettings)
                     (.put SETTINGS_ENABLE_CONNECT_PROTOCOL (long 1))))
                  .build))))

(defn configure-http2-pipeline
  [^ChannelPipeline pipeline ^ExecutorService exec-service handler]
  (doto pipeline
    (configure-http2-frame-codec-builder)
    (configure-http2-request-handler exec-service handler)))
