;;   Copyright (c) 7theta. All rights reserved.
;;   The use and distribution terms for this software are covered by the
;;   MIT License (https://opensource.org/licenses/MIT) which can also be
;;   found in the LICENSE file at the root of this distribution.
;;
;;   By using this software in any fashion, you are agreeing to be bound by
;;   the terms of this license.
;;   You must not remove this notice, or any others, from this software.

(ns vectio.netty.http1.websocket
  (:refer-clojure :exclude [send])
  (:require [vectio.netty :as n]
            [vectio.netty.websocket :as nws]
            [fluxus.flow :as f]
            [fluxus.promise :as p]
            [clojure.java.io :as io])
  (:import [io.netty.handler.codec.http2 Http2SecurityUtil]
           [io.netty.bootstrap Bootstrap]
           [io.netty.channel.socket SocketChannel]
           [io.netty.channel
            SimpleChannelInboundHandler
            ChannelInitializer
            ChannelDuplexHandler
            ChannelHandler
            ChannelHandlerContext
            ChannelPromise]
           [io.netty.handler.codec.http
            FullHttpResponse
            DefaultHttpHeaders
            HttpClientCodec
            HttpObjectAggregator]
           [io.netty.channel.nio NioEventLoopGroup]
           [io.netty.channel.socket.nio NioSocketChannel]
           [io.netty.handler.codec.http.websocketx
            WebSocketHandshakeException
            WebSocketClientHandshakerFactory
            WebSocketClientHandshaker
            WebSocketVersion]
           [io.netty.handler.ssl
            ClientAuth
            SslContextBuilder
            SslContext
            SslProvider
            SupportedCipherSuiteFilter
            ApplicationProtocolConfig
            ApplicationProtocolConfig$Protocol
            ApplicationProtocolConfig$SelectorFailureBehavior
            ApplicationProtocolConfig$SelectedListenerFailureBehavior
            ApplicationProtocolNames]
           [io.netty.handler.ssl.util
            InsecureTrustManagerFactory]
           [java.net URI]
           [java.io InputStream ByteArrayInputStream File Closeable]))

(declare websocket-client-handler channel-initializer)

(defn websocket-client
  [{:keys [host port tls ssl-context path query-string
           connect-timeout
           on-text-message
           on-binary-message
           on-ping-message
           on-pong-message
           on-close]
    :or {connect-timeout 10000}
    :as config}]
  (let [client (p/promise)
        worker-group (NioEventLoopGroup.)]
    (future
      (try
        (let [handshake-future (atom nil)
              url (str (if (or ssl-context (seq tls))
                         "wss"
                         "ws")
                       "://"
                       host
                       ":"
                       port
                       (when (not (re-find #"^/" (str path)))
                         "/")
                       path
                       (when query-string
                         (str "?" query-string)))
              ^ChannelHandler handler (websocket-client-handler
                                       (WebSocketClientHandshakerFactory/newHandshaker
                                        (URI. url)
                                        WebSocketVersion/V13
                                        nil
                                        false
                                        (DefaultHttpHeaders.))
                                       handshake-future
                                       {:on-text-message on-text-message
                                        :on-binary-message on-binary-message
                                        :on-ping-message on-ping-message
                                        :on-pong-message on-pong-message
                                        :on-close on-close})
              ssl-context (or ssl-context
                              (when (seq tls)
                                (let [coerce-is (fn [v]
                                                  (cond
                                                    (instance? InputStream v) v
                                                    (string? v) (ByteArrayInputStream. (.getBytes ^String v))
                                                    (instance? File v) (io/input-stream v)))]
                                  (-> (io.netty.handler.ssl.SslContextBuilder/forClient)
                                      (.trustManager ^InputStream (coerce-is (get-in tls [:server :ca])))
                                      (.keyManager ^InputStream (coerce-is (get-in tls [:client :cert]))
                                                   ^InputStream (coerce-is (get-in tls [:client :key])))
                                      (.build)))))
              b (doto (Bootstrap.)
                  (.group worker-group)
                  (.channel ^NioSocketChannel NioSocketChannel)
                  (.handler (channel-initializer host port ssl-context handler)))
              channel-future (.connect b (str host) (int port))
              _ (when (not (.await channel-future connect-timeout))
                  (throw (ex-info "Connection timed out"
                                  {:config config})))
              channel (.channel channel-future)
              max-flush-size (dec (int (Math/pow 2 16)))
              max-frame-size (dec (int (Math/pow 2 16)))]
          (when (not (.isActive channel))
            (throw (ex-info "Unable to establish connection to server"
                            {:config config
                             :channel channel})))
          (.sync ^ChannelPromise @handshake-future)
          (p/resolve! client
                      {:close (fn [] (.shutdownGracefully worker-group))
                       :send #(n/safe-execute
                               channel (fn []
                                         (->> %
                                              (nws/data->websocket-frames channel max-frame-size)
                                              (nws/send-websocket-frames channel max-flush-size))))}))
        (catch Exception e
          (try (.shutdownGracefully worker-group)
               (catch Exception _))
          (p/reject! client e))))
    client))

(defn websocket-client-stream
  [{:keys [host port ssl-context path] :as args}]
  (let [[client-stream internal] (f/entangled)
        message-stream (p/promise)
        client (websocket-client
                (assoc args
                       :on-close (fn [& _] (f/close! internal))
                       :on-text-message #(do @(f/put! internal %))
                       :on-binary-message #(do @(f/put! internal %))
                       :on-ping-message (fn [_] @(f/put! internal :ping))
                       :on-pong-message (fn [_] @(f/put! internal :pong))))]
    (-> client
        (p/then (fn [{:keys [send close]}]
                  (f/consume send internal)
                  (f/on-close client-stream (fn [_] (close)))
                  (p/resolve! message-stream client-stream)))
        (p/catch (fn [error]
                   (f/close! internal)
                   (p/reject! message-stream error))))
    message-stream))

(defn send
  [{:keys [send]} text-message]
  (send text-message))

(defn close
  [{:keys [close]}]
  (close))

(defn unsafe-self-signed-ssl-context
  []
  (-> (SslContextBuilder/forClient)
      (.sslProvider SslProvider/JDK)
      (.ciphers Http2SecurityUtil/CIPHERS
                SupportedCipherSuiteFilter/INSTANCE)
      (.trustManager InsecureTrustManagerFactory/INSTANCE)
      (.applicationProtocolConfig (ApplicationProtocolConfig.
                                   ApplicationProtocolConfig$Protocol/ALPN
                                   ApplicationProtocolConfig$SelectorFailureBehavior/NO_ADVERTISE
                                   ApplicationProtocolConfig$SelectedListenerFailureBehavior/ACCEPT
                                   ^"[Ljava.lang.String;" (into-array [ApplicationProtocolNames/HTTP_1_1])))
      (.build)))

;;; Private

(defn websocket-client-handler
  ^ChannelDuplexHandler
  [^WebSocketClientHandshaker handshaker
   handshake-future
   {:keys [on-text-message
           on-binary-message
           on-ping-message
           on-pong-message
           on-close]}]
  (let [frame-handler (nws/inbound-collector
                       {:on-text-message on-text-message
                        :on-binary-message on-binary-message
                        :on-ping-message on-ping-message
                        :on-pong-message on-pong-message
                        :on-close on-close})]
    (proxy [SimpleChannelInboundHandler] []
      (handlerAdded [^ChannelHandlerContext ctx]
        (reset! handshake-future (.newPromise ctx)))
      (channelActive [^ChannelHandlerContext ctx]
        (.handshake handshaker (.channel ctx)))
      (channelInactive [^ChannelHandlerContext ctx]
        ;; client disconnected
        (when on-close
          (on-close)))
      (channelRead0 [^ChannelHandlerContext ctx msg]
        (let [ch (.channel ctx)]
          (cond
            (not (.isHandshakeComplete handshaker))
            (try
              (.finishHandshake handshaker ch ^FullHttpResponse msg)
              (.setSuccess ^ChannelPromise @handshake-future)
              (catch WebSocketHandshakeException e
                (.setFailure ^ChannelPromise @handshake-future e)))

            (instance? FullHttpResponse msg)
            (throw (ex-info "Unexpected FullHttpResponse" {:msg msg}))

            :else (when-let [handler (frame-handler ctx msg)]
                    (handler)))))
      (exceptionCaught [^ChannelHandlerContext ctx ^Throwable cause]
        (when (not (.isDone ^ChannelPromise @handshake-future))
          (.setFailure ^ChannelPromise @handshake-future cause))
        (.close ctx)))))

(defn channel-initializer
  ^ChannelInitializer [^String host ^long port ^SslContext ssl-context ^ChannelHandler handler]
  (proxy [ChannelInitializer] []
    (initChannel [^SocketChannel ch]
      (let [p (.pipeline ch)]
        (when ssl-context
          (.addLast p "ssl-handler" (.newHandler ssl-context (.alloc ch) host (int port))))
        (.addLast p "client-codec" (HttpClientCodec.))
        (.addLast p "http-object-aggregator" (HttpObjectAggregator. 8192))
        (.addLast p "client-handler" handler)))))
