;;   Copyright (c) 7theta. All rights reserved.
;;   The use and distribution terms for this software are covered by the
;;   MIT License (https://opensource.org/licenses/MIT) which can also be
;;   found in the LICENSE file at the root of this distribution.
;;
;;   By using this software in any fashion, you are agreeing to be bound by
;;   the terms of this license.
;;   You must not remove this notice, or any others, from this software.

(ns via.netty.http2.handlers.websocket
  (:require [via.util.netty :as n])
  (:import [java.util.concurrent ExecutorService]
           [io.netty.buffer ByteBuf]
           [io.netty.util ReferenceCounted]
           [io.netty.channel
            ChannelHandlerContext
            ChannelDuplexHandler
            ChannelPromise]
           [io.netty.handler.codec.http2
            DefaultHttp2DataFrame
            Http2HeadersFrame
            Http2DataFrame]
           [io.netty.handler.codec.http.websocketx.extensions
            WebSocketServerExtension]
           [io.netty.handler.codec.http.websocketx
            TextWebSocketFrame
            ContinuationWebSocketFrame
            CloseWebSocketFrame
            WebSocket13FrameDecoder
            WebSocket13FrameEncoder
            WebSocketDecoderConfig]))

(defn websocket-request?
  [req-or-frame]
  (boolean
   (if (instance? Http2HeadersFrame req-or-frame)
     (.contains (.headers ^Http2HeadersFrame req-or-frame) ":protocol" "websocket" true)
     (and (= "websocket" (get-in req-or-frame [:headers ":protocol"]))
          (= "13" (get-in req-or-frame [:headers "sec-websocket-version"]))))))

(defn init-handler
  [^ChannelHandlerContext ctx connection-state ^Http2HeadersFrame msg response]
  (let [stream (.stream msg)
        stream-id (.id stream)
        pipeline (.pipeline ctx)
        {:keys [exec-service]} @connection-state
        ^ExecutorService exec-service exec-service
        {:keys [compression-extension handlers]} (meta response)
        {:keys [on-open on-close on-text-message]} handlers
        max-frame-size (get-in @connection-state [:settings :settings-max-frame-size])
        decoder-config (-> (WebSocketDecoderConfig/newBuilder)
                           (.allowExtensions true)
                           (.maxFramePayloadLength max-frame-size)
                           (.build))
        frame-decoder (WebSocket13FrameDecoder. decoder-config)
        frame-encoder (WebSocket13FrameEncoder. false)
        compression-encoder (when compression-extension
                              (.newExtensionEncoder ^WebSocketServerExtension compression-extension))
        compression-decoder (when compression-extension
                              (.newExtensionDecoder ^WebSocketServerExtension compression-extension))
        closed? (atom false)
        close-stream #(when-let [close (get-in @connection-state [:streams stream-id :close])]
                        (close))
        handlers (->> [["data-frame-outbound-handler" (proxy [ChannelDuplexHandler] []
                                                        (write [^ChannelHandlerContext ctx
                                                                ^ByteBuf msg
                                                                ^ChannelPromise _p]
                                                          (->> (-> msg
                                                                   (DefaultHttp2DataFrame. false)
                                                                   (.stream stream))
                                                               (n/invoke-write pipeline "frame-writer"))))]
                       ["data-frame-inbound-handler" (proxy [ChannelDuplexHandler] []
                                                       (channelRead [^ChannelHandlerContext ctx ^Http2DataFrame msg]
                                                         (n/acquire msg)
                                                         (let [^ChannelDuplexHandler this this]
                                                           (proxy-super channelRead ctx (.content msg)))))]
                       ["frame-decoder" frame-decoder]
                       ["compression-decoder" compression-decoder]
                       ["frame-encoder" frame-encoder]
                       ["compression-encoder" compression-encoder]
                       ["message-inbound-handler" (proxy [ChannelDuplexHandler] []
                                                    (channelRead [^ChannelHandlerContext ctx msg]
                                                      (try (cond
                                                             (instance? TextWebSocketFrame msg)
                                                             (let [text (.text ^TextWebSocketFrame msg)]
                                                               (when on-text-message
                                                                 (n/run exec-service
                                                                   (fn []
                                                                     (try
                                                                       (on-text-message text)
                                                                       (catch Exception e
                                                                         (.fireExceptionCaught ctx e)))))))

                                                             (instance? CloseWebSocketFrame msg)
                                                             (close-stream)

                                                             :else (throw (ex-info "Unhandled WebSocket frame" {:frame msg})))
                                                           (finally
                                                             (n/release ^ReferenceCounted msg)))))]
                       ["message-outbound-handler" (proxy [ChannelDuplexHandler] []
                                                     (write [^ChannelHandlerContext ctx
                                                             ^String msg
                                                             ^ChannelPromise p]
                                                       (let [^ByteBuf byte-buf (n/to-byte-buf ctx msg)
                                                             frames (->> max-frame-size
                                                                         (n/slice-byte-buf byte-buf)
                                                                         (mapv n/acquire)
                                                                         vec)]
                                                         (n/release byte-buf)
                                                         (let [byte-counter (volatile! 0)]
                                                           (dotimes [i (count frames)]
                                                             (let [^ByteBuf frame (nth frames i)
                                                                   final-fragment? (= i (dec (count frames)))
                                                                   msg (if (zero? i)
                                                                         (TextWebSocketFrame. final-fragment? 0 frame)
                                                                         (ContinuationWebSocketFrame. final-fragment? 0 frame))]
                                                               (.write ctx msg)
                                                               (vswap! byte-counter + (.capacity frame))
                                                               (when (> @byte-counter 1E5)
                                                                 (vreset! byte-counter 0)
                                                                 (.flush ctx))))
                                                           (.flush ctx)))))]]
                      (filter second)
                      (mapv (fn [[handler-name handler]]
                              [(str handler-name "-" stream-id) handler])))

        data-frame-inbound-handler-name (ffirst handlers)
        message-outbound-handler-name (first (last handlers))]
    (swap! (get-in @connection-state [:streams stream-id :cleanup-handlers]) conj
           #(do (reset! closed? true)
                (when on-close
                  (try (on-close)
                       (catch Exception e
                         (println "Exception occurred in on-close handler" e))))
                (doseq [[handler-name _] handlers]
                  (n/safe-remove-handler pipeline handler-name))))
    (doseq [[handler-name handler] handlers]
      (n/safe-add-handler pipeline handler-name handler))
    (when on-open
      (on-open
       {:send #(when (not @closed?)
                 (n/invoke-write pipeline message-outbound-handler-name ^String %))
        :close (fn []
                 (when (not @closed?)
                   (close-stream)))}))
    (fn [^Http2DataFrame data-frame]
      (n/invoke-channel-read pipeline data-frame-inbound-handler-name data-frame))))
