;;   Copyright (c) 7theta. All rights reserved.
;;   The use and distribution terms for this software are covered by the
;;   MIT License (https://opensource.org/licenses/MIT) which can also be
;;   found in the LICENSE file at the root of this distribution.
;;
;;   By using this software in any fashion, you are agreeing to be bound by
;;   the terms of this license.
;;   You must not remove this notice, or any others, from this software.

(ns via.netty.http2.server
  (:require [utilis.map :refer [map-keys]]
            [utilis.types.number :refer [string->long]]
            [byte-streams.core :as bs]
            [clojure.string :as st]
            [clojure.java.io :as io])
  (:import [org.apache.commons.io IOUtils]
           [java.io
            PipedOutputStream
            PipedInputStream
            FileInputStream
            InputStream
            File]
           [io.netty.buffer ByteBuf]
           [io.netty.channel
            DefaultChannelPromise
            ChannelPromise
            ChannelDuplexHandler
            ChannelPipeline
            ChannelHandlerContext
            ChannelOutboundHandler
            ChannelInboundHandler]
           [java.util.concurrent ExecutorService]
           [io.netty.util ReferenceCounted]
           [io.netty.handler.codec.http.websocketx
            TextWebSocketFrame
            CloseWebSocketFrame
            WebSocket13FrameDecoder
            WebSocket13FrameEncoder
            WebSocketDecoderConfig]
           [io.netty.handler.codec.http2
            Http2Error
            DefaultHttp2DataFrame
            DefaultHttp2Headers
            DefaultHttp2HeadersFrame
            DefaultHttp2ResetFrame
            Http2StreamFrame
            DefaultHttp2WindowUpdateFrame
            Http2ResetFrame
            Http2FrameCodecBuilder
            Http2HeadersFrame
            Http2DataFrame
            Http2Headers
            Http2Settings]))

(def ByteArray (type (byte-array [])))

(def SETTINGS_ENABLE_CONNECT_PROTOCOL (char 8))

(def DEFAULT_INITIAL_WINDOW_SIZE (long (dec (Math/pow 2 16))))
(def DEFAULT_MAX_FRAME_SIZE (long (dec (Math/pow 2 16))))

(defrecord ResponseEvent [headers stream body status sent])

(defn with-exec
  [^ExecutorService exec-service handler & args]
  (.submit exec-service
           (reify Runnable
             (run [_]
               (apply handler args)))))

(defn write-handler
  [^ChannelPipeline pipeline ^String handler-name msg]
  (let [p (DefaultChannelPromise. (.channel pipeline))]
    (when-let [writer (.get pipeline handler-name)]
      (.write ^ChannelOutboundHandler writer (.context pipeline handler-name) msg p))))

(defn read-handler
  [^ChannelPipeline pipeline ^String handler-name msg]
  (when-let [reader (.get pipeline handler-name)]
    (.channelRead ^ChannelInboundHandler reader (.context pipeline handler-name) msg)))

(defn websocket-request?
  [req-or-frame]
  (boolean
   (if (instance? Http2HeadersFrame req-or-frame)
     (.contains (.headers ^Http2HeadersFrame req-or-frame) ":protocol" "websocket" true)
     (and (= "websocket" (get-in req-or-frame [:headers ":protocol"]))
          (= "13" (get-in req-or-frame [:headers "sec-websocket-version"]))))))

(defn add-all-headers
  [^Http2Headers headers headers-map]
  (doseq [[k v] (map-keys #(st/lower-case
                            (str (if (keyword? %)
                                   (name %)
                                   %)))
                          headers-map)]
    (.add headers k (str v)))
  headers)

(defn safe-add-handler
  [^ChannelPipeline pipeline handler-name handler]
  (if (.get pipeline handler-name)
    (.replace pipeline handler-name handler-name handler)
    (.addLast pipeline handler-name handler)))

(defn headers
  [^Http2HeadersFrame msg]
  (->> (.headers msg)
       .iterator
       iterator-seq
       (map (fn [[k v]]
              [(str k) (str v)]))
       (into {})))

(defn request
  [^Http2HeadersFrame msg]
  (let [headers (headers msg)
        path (get headers ":path")
        [uri query-string] (st/split path #"\?")
        content-length (get-in request [:headers "content-length"])
        content-length (when (string? content-length)
                         (string->long content-length))
        request-method (-> headers
                           (get ":method")
                           st/lower-case
                           keyword)]
    (merge (when (and (not (websocket-request? msg))
                      (or (#{:put :post} request-method)
                          (and content-length
                               (pos? content-length))))
             (let [in (PipedInputStream. DEFAULT_MAX_FRAME_SIZE)]
               {:body-input-stream in
                :body-output-stream (PipedOutputStream. in)}))
           {:headers headers
            :protocol-version "http/2"
            :uri uri
            :path path
            :scheme (keyword (get headers ":scheme"))
            :query-string query-string
            :request-method request-method
            :query-params (when query-string
                            (->> (st/split query-string #"&")
                                 (map #(st/split % #"\="))
                                 (into {})))})))

(defn safe-remove
  [^ChannelPipeline pipeline ^String handler-name]
  (when ((set (.names pipeline)) handler-name)
    (.remove pipeline handler-name)))

(defn configure-http2-frame-codec-builder
  [^ChannelPipeline pipeline]
  (safe-add-handler
   pipeline "frame-codec"
   (-> (Http2FrameCodecBuilder/forServer)
       (.autoAckPingFrame true)
       (.initialSettings
        (doto (Http2Settings/defaultSettings)
          (.put SETTINGS_ENABLE_CONNECT_PROTOCOL (long 1))
          (.initialWindowSize (long DEFAULT_INITIAL_WINDOW_SIZE))
          (.maxFrameSize DEFAULT_MAX_FRAME_SIZE)))
       .build)))

(defn websocket-response
  [{:keys [headers] :as response}]
  (let [{:keys [compression-extension]} (meta response)
        decoder-config (-> (WebSocketDecoderConfig/newBuilder)
                           (.allowExtensions true)
                           (.build))]
    (merge (meta response)
           response
           {:status 200
            :headers (or (not-empty headers) {})
            :frame-decoder (WebSocket13FrameDecoder. decoder-config)
            :frame-encoder (WebSocket13FrameEncoder. false)}
           (when compression-extension
             {:compression-encoder (.newExtensionEncoder compression-extension)})
           (when compression-extension
             {:compression-decoder (.newExtensionDecoder compression-extension)}))))

(defn init-websocket!
  [^ChannelPipeline pipeline streams ^ExecutorService exec-service stream
   {:keys [handlers
           frame-encoder
           frame-decoder
           compression-encoder
           compression-decoder]
    :as websocket-response}]
  (let [stream-id (.id stream)
        executor (.executor (.firstContext pipeline))
        run-event (fn [f]
                    (.execute executor
                              (reify Runnable
                                (run [_]
                                  (f)))))
        close-stream (fn []
                       (->> (-> Http2Error/STREAM_CLOSED
                                (DefaultHttp2ResetFrame.)
                                (.stream stream))
                            (write-handler pipeline "frame-writer")))
        {:keys [on-open on-close on-text-message]} handlers
        cleanup (atom nil)
        ws-handlers (->> [["websocket-data-frame-writer" (proxy [ChannelDuplexHandler] []
                                                           (write [^ChannelHandlerContext ctx msg ^ChannelPromise p]
                                                             (if (instance? ByteBuf msg)
                                                               (->> (-> msg
                                                                        (DefaultHttp2DataFrame. false)
                                                                        (.stream stream))
                                                                    (write-handler pipeline "frame-writer"))
                                                               (write-handler pipeline "frame-writer" msg))))]
                          ["data-frame-content" (proxy [ChannelDuplexHandler] []
                                                  (channelRead [^ChannelHandlerContext ctx msg]
                                                    (when (instance? Http2DataFrame msg)
                                                      (let [msg ^Http2DataFrame msg
                                                            ^ByteBuf content (.content msg)]
                                                        (.fireChannelRead ctx content)))))]
                          ["frame-decoder" frame-decoder]
                          ["compression-decoder" compression-decoder]
                          ["frame-encoder" frame-encoder]
                          ["compression-encoder" compression-encoder]
                          ["message-handler" (proxy [ChannelDuplexHandler] []
                                               (channelRead [^ChannelHandlerContext ctx msg]
                                                 (try (cond
                                                        (instance? TextWebSocketFrame msg)
                                                        (let [text (.text ^TextWebSocketFrame msg)]
                                                          (when on-text-message
                                                            (with-exec exec-service
                                                              #(on-text-message text))))

                                                        (instance? CloseWebSocketFrame msg)
                                                        (with-exec exec-service
                                                          #(do (@cleanup)
                                                               (when on-close (on-close)))))
                                                      (finally
                                                        (when (instance? ReferenceCounted msg)
                                                          (.release ^ReferenceCounted msg))))))]
                          ["websocket-text-frame-writer" (proxy [ChannelDuplexHandler] []
                                                           (write [^ChannelHandlerContext ctx msg ^ChannelPromise p]
                                                             (when (instance? TextWebSocketFrame msg)
                                                               (let [^TextWebSocketFrame msg msg]
                                                                 (.writeAndFlush ctx msg)))))]]
                         (filter second)
                         (mapv (fn [[handler-name handler]]
                                 [(str handler-name "-" stream-id) handler])))]
    (reset! cleanup
            #(do (swap! streams dissoc (.id stream))
                 (doseq [[handler-name _] ws-handlers]
                   (safe-remove pipeline handler-name))))
    (swap! streams update (.id stream) merge
           websocket-response
           {:handle-frame (fn [^Http2DataFrame frame]
                            (read-handler pipeline (first (second ws-handlers)) frame))
            :on-reset (fn [] (@cleanup))})
    (doseq [[handler-name handler] ws-handlers]
      (safe-add-handler pipeline handler-name handler))
    (when (fn? on-open)
      (let [text-frame-writer-name (first (last ws-handlers))
            write-text-frame #(write-handler pipeline text-frame-writer-name (TextWebSocketFrame. %))]
        (on-open {:send #(run-event (fn [] (write-text-frame %)))
                  :close (fn []
                           (@cleanup)
                           (run-event #(close-stream)))})))

    nil))

(defn close-stream
  [pipeline stream]
  (let [executor (.executor (.firstContext pipeline))
        run-event (fn [f]
                    (.execute executor
                              (reify Runnable
                                (run [_]
                                  (f)))))
        close-stream #(->> (-> Http2Error/STREAM_CLOSED
                               (DefaultHttp2ResetFrame.)
                               (.stream stream))
                           (write-handler pipeline "frame-writer"))]
    (run-event #(close-stream))))

(defn safe-close
  [closeable]
  (when closeable
    (.close closeable)))

(defn configure-http2-inbound-frame-handler
  [^ChannelPipeline pipeline streams ^ExecutorService exec-service handler]
  (safe-add-handler
   pipeline "inbound-frame-handler"
   (proxy [ChannelDuplexHandler] []
     (channelRead [^ChannelHandlerContext ctx msg]
       (cond
         (instance? Http2ResetFrame msg)
         (let [^Http2ResetFrame msg msg
               stream-id (.id (.stream msg))
               {:keys [on-reset]} (get @streams stream-id)]
           (when on-reset (on-reset))
           (swap! streams dissoc stream-id))

         (instance? Http2HeadersFrame msg)
         (let [websocket? (websocket-request? msg)
               ^Http2HeadersFrame msg msg
               stream (.stream msg)
               stream-id (.id stream)
               end-stream? (.isEndStream msg)
               request (request msg)]
           (swap! streams assoc stream-id
                  (merge {:headers-frame msg
                          :stream stream
                          :request request}
                         (when websocket?
                           {:websocket true})))
           (with-exec exec-service
             (fn []
               (try (let [response (handler (cond-> request
                                              true (dissoc :body-input-stream :body-output-stream)
                                              (:body-input-stream request) (assoc :body (:body-input-stream request))))
                          response (if websocket?
                                     (websocket-response response)
                                     response)
                          response (-> response
                                       (assoc :stream stream
                                              :sent (atom false))
                                       map->ResponseEvent
                                       (update :body (fn [body]
                                                       (when body
                                                         (let [^ByteBuf content (-> ctx .alloc .buffer)
                                                               ^bytes result (cond
                                                                               (instance? File body)
                                                                               (with-open [fis (FileInputStream. ^File body)]
                                                                                 (IOUtils/toByteArray fis))

                                                                               (instance? InputStream body)
                                                                               (IOUtils/toByteArray ^InputStream body)

                                                                               (instance? String body)
                                                                               (.getBytes ^String body)

                                                                               :else (do (prn :type (type body))
                                                                                         (bs/convert body ByteArray)))]
                                                           (when (instance? java.io.Closeable body)
                                                             (.close ^java.io.Closeable body))

                                                           (.writeBytes content result)

                                                           (-> content
                                                               (DefaultHttp2DataFrame. true)
                                                               (.stream stream))))))
                                       (assoc :headers (let [{:keys [headers status body]} response
                                                             websocket? (boolean (:websocket (get @streams (.id stream))))
                                                             header-close-stream? (boolean
                                                                                   (and (not websocket?)
                                                                                        (not body)))
                                                             headers (doto (DefaultHttp2Headers.)
                                                                       (.status (str status))
                                                                       (add-all-headers headers))]
                                                         (.stream (DefaultHttp2HeadersFrame. headers header-close-stream?) stream))))]
                      (cond
                        end-stream?
                        (do (safe-close (:body-output-stream request))
                            (.fireUserEventTriggered pipeline response))

                        (<= 200 (:status response) 299)
                        (if websocket?
                          (.fireUserEventTriggered pipeline response)
                          (do (swap! streams update (.id stream)
                                     assoc :response response)
                              (.fireUserEventTriggered pipeline response)))

                        :else (do (close-stream pipeline stream) ;; sending 404 here doesn't seem to work
                                  (safe-close (:body-output-stream request)))))
                    (catch Exception e
                      (locking Object
                        (println e)))))))

         (instance? Http2StreamFrame msg)
         (let [^Http2StreamFrame msg msg
               {:keys [handle-frame request]} (get @streams (.id (.stream msg)))]
           (when handle-frame (handle-frame msg))
           (when (instance? Http2DataFrame msg)
             (let [{:keys [body-output-stream]} request
                   ^Http2DataFrame msg msg]
               (when body-output-stream
                 (let [^ByteBuf buf (.content msg)
                       bytes (byte-array (.readableBytes buf))]
                   (.readBytes buf bytes)
                   (.write body-output-stream bytes 0 (count bytes))))
               (when (.isEndStream msg)
                 (when body-output-stream
                   (safe-close (:body-output-stream request)))
                 (when-let [response (get-in @streams [(.id (.stream msg)) :response])]
                   (.fireUserEventTriggered pipeline response)))
               (->> (-> (.initialFlowControlledBytes msg)
                        (DefaultHttp2WindowUpdateFrame.)
                        (.stream (.stream msg)))
                    (write-handler pipeline "frame-writer")))))

         :else (when (instance? ReferenceCounted msg)
                 (.release msg)))))))

(defn configure-http2-frame-writer
  [^ChannelPipeline pipeline streams]
  (safe-add-handler
   pipeline "frame-writer"
   (proxy [ChannelDuplexHandler] []
     (write [^ChannelHandlerContext ctx msg ^ChannelPromise p]
       (when (and (instance? Http2StreamFrame msg)
                  (.stream msg))
         (.write ctx msg)
         (when (instance? Http2DataFrame msg)
           (let [^Http2DataFrame msg msg
                 stream-id (.id (.stream msg))]
             (when (.isEndStream msg)
               (swap! streams dissoc stream-id))))
         (.flush ctx))))))

(defn configure-http2-response-writer
  [^ChannelPipeline pipeline streams]
  (safe-add-handler
   pipeline "response-writer"
   (proxy [ChannelDuplexHandler] []
     (write [^ChannelHandlerContext ctx response ^ChannelPromise p]
       (let [{:keys [headers body]} response]
         (write-handler pipeline "frame-writer" headers)
         (when body
           (write-handler pipeline "frame-writer" ^Http2DataFrame body))
         (.flush ctx))))))

(defn configure-http2-user-event-handler
  [^ChannelPipeline pipeline streams ^ExecutorService exec-service]
  (safe-add-handler
   pipeline "user-event-handler"
   (proxy [ChannelDuplexHandler] []
     (userEventTriggered [^ChannelHandlerContext ctx response]
       (cond
         (instance? ResponseEvent response)
         (let [{:keys [stream sent]} response]
           (when (not @sent)
             (when (:websocket (get @streams (.id stream)))
               (init-websocket! pipeline streams exec-service stream response))
             (write-handler pipeline "response-writer" response)
             (reset! sent true)))

         :else nil)))))

(defn configure-http2-pipeline
  [^ChannelPipeline pipeline ^ExecutorService exec-service handler]
  (let [streams (atom {})]
    (doto pipeline
      (configure-http2-frame-codec-builder)
      (configure-http2-inbound-frame-handler streams exec-service handler)
      (configure-http2-frame-writer streams)
      (configure-http2-response-writer streams)
      (configure-http2-user-event-handler streams exec-service))))
