;;   Copyright (c) 7theta. All rights reserved.
;;   The use and distribution terms for this software are covered by the
;;   MIT License (https://opensource.org/licenses/MIT) which can also be
;;   found in the LICENSE file at the root of this distribution.
;;
;;   By using this software in any fashion, you are agreeing to be bound by
;;   the terms of this license.
;;   You must not remove this notice, or any others, from this software.

(ns vectio.netty.server
  (:require [vectio.netty.h2.server :as h2]
            [vectio.netty.h2.handlers.websocket :as ws]
            [vectio.netty.http1.server :as http1]
            [vectio.netty :as n]
            [spectator.log :as log]
            [utilis.map :refer [map-vals compact]]
            [fluxus.flow :as f]
            [clojure.string :as st])
  (:import [io.netty.bootstrap Bootstrap ServerBootstrap]
           [io.netty.channel
            ChannelDuplexHandler
            ChannelOption
            EventLoopGroup
            ChannelHandlerContext
            ChannelInitializer]
           [java.util.concurrent
            ThreadFactory
            Executors
            ExecutorService]
           [io.netty.handler.codec DecoderException]
           [io.netty.handler.codec.http.websocketx.extensions WebSocketExtensionData]
           [io.netty.handler.codec.http.websocketx.extensions.compression
            PerMessageDeflateServerExtensionHandshaker]
           [io.netty.channel.nio NioEventLoopGroup]
           [io.netty.channel.socket
            ServerSocketChannel
            SocketChannel
            DatagramChannel]
           [io.netty.channel.socket.nio
            NioServerSocketChannel
            NioDatagramChannel]
           [io.netty.handler.ssl
            SslContext
            ApplicationProtocolNegotiationHandler
            ApplicationProtocolNames
            NotSslRecordException]
           [java.net InetSocketAddress]))

(declare enumerating-thread-factory init-channel)

(defn start-udp-server
  [{:keys [socket-address
           leak-detector-level
           num-threads
           init-channel
           server-prefix]
    :or {leak-detector-level :disabled
         num-threads 128
         server-prefix "VectioUDPServer"}}]
  (when (not= leak-detector-level :disabled)
    (log/debug [::start-server :leak-detector-level leak-detector-level]))
  (n/leak-detector-level! leak-detector-level)
  (let [closed?        (atom false)
        transport      :nio
        ^Class channel NioDatagramChannel
        num-cores      (.availableProcessors (Runtime/getRuntime))
        ^EventLoopGroup group (NioEventLoopGroup.
                               (long (* 2 num-cores))
                               ^ThreadFactory (enumerating-thread-factory "vectio-netty-nio-event-loop-group-pool" false))
        handler-exec-service (Executors/newFixedThreadPool
                              num-threads
                              (enumerating-thread-factory "vectio-netty-handler-exec-service-pool" false))]
    (try
      (let [b (doto (Bootstrap.)
                (.group group)
                (.channel channel)
                (.handler (proxy [ChannelInitializer] []
                            (initChannel [^DatagramChannel ch]
                              (init-channel handler-exec-service ch)))))
            ^DatagramChannel ch (-> b
                                    (.bind ^InetSocketAddress socket-address)
                                    .sync
                                    .channel)]
        (reify
          java.io.Closeable
          (close [_]
            (when (compare-and-set! closed? false true)
              (-> ch .close .sync)
              (-> group .shutdownGracefully)
              (.shutdown handler-exec-service)))
          Object
          (toString [_]
            (format "%s[channel:%s, transport:%s]" server-prefix ch transport))))
      (catch Exception e
        @(.shutdownGracefully group)
        (.shutdown handler-exec-service)
        (throw e)))))

(defn start-tcp-server
  [{:keys [socket-address
           leak-detector-level
           num-threads
           init-channel
           server-prefix]
    :or {leak-detector-level :disabled
         num-threads 128
         server-prefix "VectioTCPServer"}}]
  (when (not= leak-detector-level :disabled)
    (log/debug [::start-server :leak-detector-level leak-detector-level]))
  (n/leak-detector-level! leak-detector-level)
  (let [closed?        (atom false)
        transport      :nio
        ^Class channel NioServerSocketChannel
        num-cores      (.availableProcessors (Runtime/getRuntime))
        ^EventLoopGroup group (NioEventLoopGroup.
                               (long (* 2 num-cores))
                               ^ThreadFactory (enumerating-thread-factory "vectio-netty-nio-event-loop-group-pool" false))
        handler-exec-service (Executors/newFixedThreadPool
                              num-threads
                              (enumerating-thread-factory "vectio-netty-handler-exec-service-pool" false))]
    (try
      (let [b (doto (ServerBootstrap.)
                (.option ChannelOption/SO_BACKLOG (int 1024))
                (.option ChannelOption/SO_REUSEADDR true)
                (.option ChannelOption/MAX_MESSAGES_PER_READ Integer/MAX_VALUE)
                (.group group)
                (.channel channel)
                (.childHandler (proxy [ChannelInitializer] []
                                 (initChannel [^SocketChannel ch]
                                   (init-channel handler-exec-service ch))))
                (.childOption ChannelOption/SO_REUSEADDR true)
                (.childOption ChannelOption/MAX_MESSAGES_PER_READ Integer/MAX_VALUE))
            ^ServerSocketChannel ch (-> b
                                        (.bind ^InetSocketAddress socket-address)
                                        .sync
                                        .channel)]
        (reify
          java.io.Closeable
          (close [_]
            (when (compare-and-set! closed? false true)
              (-> ch .close .sync)
              (-> group .shutdownGracefully)
              (.shutdown handler-exec-service)))
          Object
          (toString [_]
            (format "%s[channel:%s, transport:%s]" server-prefix ch transport))))
      (catch Exception e
        @(.shutdownGracefully group)
        (.shutdown handler-exec-service)
        (throw e)))))

(defn start-http-server
  [{:keys [ssl-context socket-address handler leak-detector-level protocols
           num-threads]
    :or {leak-detector-level :disabled
         protocols #{:http1 :h2}
         num-threads 128}
    :as args}]
  (when (and (not ssl-context)
             (get protocols :h2))
    (throw (ex-info "An :h2 server requires an :ssl-context"
                    {:protocols protocols
                     :ssl-context ssl-context})))
  (start-tcp-server
   {:leak-detector-level leak-detector-level
    :num-threads num-threads
    :socket-address socket-address
    :server-prefix "VectioHttpServer"
    :init-channel (fn [^ExecutorService exec-service ^SocketChannel ch]
                    (init-channel protocols ssl-context exec-service ch
                                  (select-keys args
                                               [:default-outbound-max-frame-size
                                                :max-flush-size
                                                :initial-window-size
                                                :websocket-max-frame-size
                                                :max-frame-size
                                                :max-concurrent-streams
                                                :max-header-list-size
                                                :push-enabled])
                                  handler))}))

(defn websocket-request?
  [req]
  (boolean
   (condp = (:protocol-version req)
     "http/2" (ws/websocket-request? req)

     "http/1.1" (http1/websocket-request? req)

     false)))

(defn websocket-accept-response
  [req handlers]
  (when (websocket-request? req)
    (condp = (:protocol-version req)
      "http/2"
      (let [_version (get-in req [:headers "sec-websocket-version"])
            _subprotocols (get-in req [:headers "sec-websocket-protocol"])
            extensions (when-let [extensions (not-empty (get-in req [:headers "sec-websocket-extensions"]))]
                         (->> (st/split extensions #",")
                              (map #(st/split % #";"))
                              (into {})
                              (map-vals #(->> (st/split (st/trim %) #";")
                                              (map (fn [parameter]
                                                     (let [[k v] (st/split parameter #"\=")]
                                                       [k v])))
                                              (into {})))))
            compression-extension (when-let [parameters (get extensions "permessage-deflate")]
                                    (let [compression-handshaker (PerMessageDeflateServerExtensionHandshaker.)]
                                      (->> parameters
                                           (WebSocketExtensionData. "permessage-deflate")
                                           (.handshakeExtension compression-handshaker))))]
        (with-meta (compact
                    {:status 200
                     :headers (when compression-extension
                                (let [response-data (.newReponseData compression-extension)
                                      param-string (when-let [param-string (->> (.parameters response-data)
                                                                                (map (fn [[k v]] (str k "=" v)))
                                                                                (st/join "; ")
                                                                                not-empty)]
                                                     (str "; " param-string))]
                                  {"sec-websocket-extensions" (str "permessage-deflate" param-string)}))})
          {:compression-extension compression-extension
           :handlers handlers
           :initial-request req}))

      "http/1.1"
      (with-meta {:status 101}
        {:handlers handlers
         :initial-request req})

      nil)))

(defn websocket-stream-response
  [request]
  (let [[message-stream internal] (f/entangled)]
    (if-let [response (websocket-accept-response
                       request
                       {:on-open (fn [{:keys [send close]}]
                                   (f/consume (partial send) internal)
                                   (f/on-close internal (fn [_] (close))))
                        :on-close (fn [] (f/close! internal))
                        :on-text-message (fn [message] @(f/put! internal message))
                        :on-binary-message (fn [message] @(f/put! internal message))
                        :on-ping-message (fn [& _] @(f/put! internal :ping))
                        :on-pong-message (fn [& _] @(f/put! internal :pong))})]
      {:stream message-stream
       :response response}
      (do (f/close! internal) nil))))


;;; Private

(defn apn-handler
  ^ApplicationProtocolNegotiationHandler
  [protocols ^ExecutorService exec-service settings handler]
  (proxy [ApplicationProtocolNegotiationHandler] [ApplicationProtocolNames/HTTP_1_1]
    (configurePipeline [^ChannelHandlerContext ctx ^String protocol]
      (cond
        (and (.equals ApplicationProtocolNames/HTTP_2 protocol)
             (get protocols :h2))
        (do (log/debug ["Configuring new :h2 stream" settings])
            (h2/configure-http2-pipeline (.pipeline ctx) exec-service settings handler))

        (and (.equals ApplicationProtocolNames/HTTP_1_1 protocol)
             (get protocols :http1))
        (do (log/debug ["Configuring new :http1 stream"])
            (http1/configure-http1-pipeline (.pipeline ctx) exec-service handler))

        :else
        (do (.close (.channel ctx))
            (throw (IllegalStateException. (str "Protocol: " protocol " not supported."))))))))

(defn enumerating-thread-factory
  ^ThreadFactory [prefix daemon?]
  (let [num-threads (atom 0)]
    (reify ThreadFactory
      (newThread [_ runnable]
        (let [name (str prefix "-" (swap! num-threads inc))
              thread (Thread. nil #(.run ^Runnable runnable) name 0)]
          (.setDaemon thread daemon?)
          thread)))))

(defn ssl-exception-handler
  ^ChannelDuplexHandler []
  (proxy [ChannelDuplexHandler] []
    (exceptionCaught [^ChannelHandlerContext ctx ^Throwable cause]
      (cond
        (and (instance? DecoderException cause)
             (instance? NotSslRecordException (.getCause cause)))
        (.close ctx)

        (and (instance? java.net.SocketException cause)
             (re-find #"Connection reset" (.getMessage cause)))
        (log/debug cause)

        :else (throw cause)))))

(defn- init-channel
  [protocols
   ^SslContext ssl-context
   ^ExecutorService exec-service
   ^SocketChannel ch
   settings
   ^clojure.lang.IFn handler]
  (if ssl-context
    (doto (.pipeline ch)
      (.addLast "ssl-handler" (.newHandler ssl-context (.alloc ch)))
      (.addLast "ssl-exception-handler" (ssl-exception-handler))
      (.addLast "apn-handler" (apn-handler protocols exec-service settings handler)))
    (http1/configure-http1-pipeline (.pipeline ch) exec-service handler)))
