(ns circle-util.ssh
  (:require [clojure.core.memoize :as memo]
            [clojure.core.typed :as t]
            [clojure.java.io :as io]
            [clojure.string :as str]
            [clojure.tools.logging :refer (infof errorf)]
            [slingshot.slingshot :refer (throw+)]
            [clj-time.core :as time]
            [clj-ssh.ssh :as ssh]
            [circle-util.sh :as sh]
            [circle-util.middleware :as middleware]
            [circle-util.except :refer (throwf throw-if-not)]
            [circle-util.retry :refer (wait-for)]
            [circle-util.core :refer (apply-map)]
            [circle-util.fs :refer (tmp-spit with-temp-file with-delete)]
            [circle-util.time :refer (from-now to-millis)]
            [circle-util.sed :as sed]
            [circle-util.ssh-keys :as ssh-keys]
            [circle-util.datadog :refer (with-timing-metric)]
            [clj-statsd :as statsd])
  (:import (com.jcraft.jsch Session
                            Channel
                            JSch)
           (java.net InetAddress)))

(t/warn-on-unannotated-vars)

(t/defalias Node (t/HMap :mandatory {:username String
                                     :ip-addr String
                                     :private-key String}
                         :optional {:public-key String
                                    :port t/Int}))

(t/ann clj-ssh.ssh/ssh-agent [(t/HMap :optional {:use-system-ssh-agent Boolean
                                                 :known-hosts-path (t/Option String)})
                              -> JSch])

(t/ann clj-ssh.ssh/add-identity [JSch (t/HMap :mandatory {:private-key String}
                                              :optional {:public-key (t/Option String)})
                                 -> t/Any])


;; want session to be an HMap, but also allow arbitrary strings. Doesn't seem to be possible atm.

;; (HMap :mandatory {:username String}
;;                                               :optional {:port t/Int
;;                                                          :password String})
;;                             (t/Map String String)

(t/ann clj-ssh.ssh/session [JSch String (t/Map (t/U t/Keyword String) t/Any) -> Session])

(defn non-blocking-slurp
  "given an input stream, read as much as possible and return it"
  [^java.io.InputStream stream #^bytes buffer]
  (when (pos? (.available stream))
    (let [num-read (.read stream buffer 0 (count buffer))]
      (String. buffer 0 num-read "UTF-8"))))

(defn- process-exec
  "Takes the opened channel and streams and processes them. Periodically calls 'abort?' fn
  whose result, if truthy, is thrown."
  [channel stdout-stream stderr-stream abort? handle-out handle-err]
  (let [last-output (atom (time/now))
        reset-last-output (fn [] (reset! last-output (time/now)))
        stdout (atom nil)
        stderr (atom nil)
        buffer-size ssh/*piped-stream-buffer-size*
        buffer (byte-array buffer-size)
        slurp-stream (fn [stream output output-handler]
                       (let [s (non-blocking-slurp stream buffer)]
                         (when (seq s)
                           (reset-last-output)
                           (swap! output output-handler s)
                           true)))
        slurp-streams (fn slurp-streams []
                        (let [out? (slurp-stream stdout-stream stdout handle-out)
                              err? (slurp-stream stderr-stream stderr handle-err)]
                          (when (or out? err?)
                            (recur))))]
    (while (ssh/connected-channel? channel)
      (when-let [exception (abort? @last-output)]
        (throw+ exception))
      (slurp-streams)
      (Thread/sleep 100))
    ;; wait 100 ms for final output to show up, then leave. Don't use
    ;; slurp here, because it may hang forever(?!). Don't know why,
    ;; probably has to do w/ JSch being crap.
    (Thread/sleep 100)
    (slurp-streams)
    {:exit (ssh/exit-status channel) ;; n.b. can be -1
     :out (str @stdout)
     :err (str @stderr)}))

(t/ann session [Node -> Session])
(defn ^Session session
  "Creates an SSH session on an arbitrary box.
  All of username and private-key are required, and port defaults to
  22. Either of dns-name or ip-addr are required, dns-name is
  preferred."
  [{:keys [username ip-addr dns-name public-key private-key port]
    :or {port 22}}]
  {:pre [username (or dns-name ip-addr) private-key]}
  (let [ssh-agent (ssh/ssh-agent {:use-system-ssh-agent false
                                  :known-hosts-path nil})]
    (ssh/add-identity ssh-agent
                      (merge {:private-key private-key}
                             (when public-key
                               {:public-key public-key})))
    (ssh/session ssh-agent
                 (or dns-name ip-addr)
                 {:username username
                  :port port
                  "GSSAPIAuthentication" "no"
                  "StrictHostKeyChecking" "no"})))

(defn- sanitise-node [node]
  (-> node
      ;; the encrypted key is noisy and confuses users
      (update-in [:vm :host :encrypted-ssh-keys] dissoc :encrypted-private-key)
      ;; actual private keys in the node object
      (update-in [:owner :ssh-node] dissoc :private-key)
      (dissoc :private-key)))

(defn retry-connect
  "Connect to the server, retrying on failure. Returns an ssh session"
  [{:keys [dns-name ip-addr] :as node}]
  (try
    ;; TODO: rename this metric?
    (with-timing-metric :circle.ssh.connect-session []
      (wait-for
       {:sleep (time/seconds 1)
        :timeout (time/seconds 60)
        :catch [com.jcraft.jsch.JSchException]
        :error-hook (fn [e]
                      (let [name-or-ip (or dns-name ip-addr)]
                        (infof "caught %s trying to connect to %s %s"
                               e name-or-ip (sanitise-node node))
                        (infof "%s resolves to %s"
                               name-or-ip (-> name-or-ip (InetAddress/getByName) (.getHostAddress)))))}
       #(let [s (session node)]
          ;; This means: send a "server alive?" packet across the session every 5000ms, and
          ;; consider the connection dead if 600 of these packets in a row don't get a response
          ;; n.b. jsch's implementation of keepalives is INCREDIBLY BAD; we really can't be
          ;; aggressive with this!
          (.setServerAliveInterval s 5000)
          (.setServerAliveCountMax s 600)
          (ssh/connect s 60000) ;; TODO: fine-tune this timeout
          s)))
    (catch com.jcraft.jsch.JSchException e
      (throw (com.jcraft.jsch.JSchException. (format "SSH Connecting to %s:%s" (:master node) (:name node)) e)))))

(defn reusable-session? [node & _]
  (and (-> node (contains? :session))
       (ssh/connected? (-> node :session))))

(defmulti with-session reusable-session?)

(defmethod with-session true
  [node f]
  (f (-> node :session)))

(defmethod with-session false
  [node f]
  (let [session (retry-connect node)]
    (ssh/with-connection session
      (f session))))

(t/ann default-output-handler [(t/Option StringBuilder) String -> StringBuilder])
(defn default-output-handler
  "The default output handler for remote-exec-session: uses a stringbuilder to retain
  the full output, one squirt at a time."
  [maybe-stringbuilder s]
  (let [^StringBuilder stringbuilder (if (nil? maybe-stringbuilder)
                                       (StringBuilder.)
                                       maybe-stringbuilder)]
    (doto stringbuilder (.append s))))

(defn abort-relative-timeout-middleware
  [relative-timeout]
  [:wrap-relative-timeout
   (fn [abort?]
     (fn [last-output-time]
       (or (abort? last-output-time)
           (when (time/after? (time/now) (time/plus last-output-time relative-timeout))
             {:type :relative-timeout}))))])

(defn abort-absolute-timeout-middleware
  [absolute-timeout]
  [:wrap-absolute-timeout
   (let [end-time (time/plus (time/now) absolute-timeout)]
     (fn [abort?]
       (fn [last-output-time]
         (or (abort? last-output-time)
             (when (time/after? (time/now) end-time)
               {:type :absolute-timeout})))))])

(defn abort-canceled-middleware
  [build]
  [:wrap-canceled
   (fn [abort?]
     (fn [last-output-time]
       (or (abort? last-output-time)
           (when (:canceled @build)
             {:type :build-canceled}))))])

(defn- exec-binary-cmd*
  "Executes cmd in the given session and writes all output to the given stream"
  [session cmd output-stream]
  (binding [clj-ssh.ssh/*piped-stream-buffer-size* (* 1024 100)]
    (let [ssh-opts {"StrictHostKeyChecking" "no"
                    "GSSAPIAuthentication" "no"}
          {:keys [channel out-stream]} (ssh/ssh-exec session
                                                     (sh/emit-form cmd)
                                                     "" :stream ssh-opts)]
      (try
        (io/copy out-stream output-stream)
        {:exit (ssh/exit-status channel)}
        (finally
          (ssh/disconnect-channel channel))))))

(defn- log-jcsh-results
  "Record system metrics about the operation of JSch.
  JSch has three common failure modes that we witness:
  - Null pointer exceptions
  - JSCH exceptions
  - Failures with a negative exit status
  Detect these three cases and log service metrics recording them.
  We also want to increment a metric on each call to JSch so that we can
  monitor the percentage of errors occurring so that we can alert properly.

  This function does not alter the operation of our calls to JSch, it only
  observes and logs the calls and results."
  [f]
  (fn [& args]
    (try
      (statsd/increment :circle.backend.jsch.call)
      (let [suspicious? (comp (partial = -1) :exit)
            result (apply f args)]
        (when (suspicious? result)
          (statsd/increment :circle.backend.jsch.error))
        result)
      (catch com.jcraft.jsch.JSchException e
        (statsd/increment :circle.backend.jsch.jsch-exception)
        (throw e))
      (catch Throwable t
        (statsd/increment :circle.backend.jsch.exception)
        (throw t)))))

(defn- remote-exec-session* [session cmd {:keys [in abort-middleware handle-out handle-err]
                                         :or {
                                              ;; Note that "" has very different semantics
                                              ;; from nil here: empty string will provide
                                              ;; an empty stdin, whereas nil will (IIUC)
                                              ;; attempt to connect the jvm's stdin to
                                              ;; the ssh command.
                                              in ""
                                              abort-middleware [(abort-relative-timeout-middleware (time/minutes 10))]
                                              handle-out default-output-handler
                                              handle-err default-output-handler}
                                         :as opts}]
  (let [cmd (if (list? cmd)
              (sh/emit-form cmd)
              cmd)
        abort? (middleware/compile abort-middleware (constantly false))
        ;; For now, we will make the assumption that if :in is provided, we don't want a pty,
        ;; but otherwise we do.
        ssh-opts (merge {"StrictHostKeyChecking" "no"
                         "GSSAPIAuthentication" "no"}
                        (when-not (contains? opts :in) {:pty true}))]
    (binding [clj-ssh.ssh/*piped-stream-buffer-size* (* 1024 100)]
      (let [{:keys [channel out-stream err-stream]} (ssh/ssh-exec session
                                                                  cmd
                                                                  in
                                                                  :stream
                                                                  ssh-opts)]
        (try
          (process-exec channel out-stream err-stream abort? handle-out handle-err)
          (finally
            (ssh/disconnect-channel channel)))))))

(def exec-binary-cmd     (log-jcsh-results exec-binary-cmd*))
(def remote-exec-session (log-jcsh-results remote-exec-session*))

(t/ann ^:no-check remote-exec [Node (t/U sh/SteveExpr sh/SteveOneLiner) -> sh/ShMap])
(defn remote-exec
  "Node is a map containing the keys required by with-session"
  [node ^String cmd & {:as opts}]
  (if (:oom? node)
    {:exit -1 :out "" :err "container suspended: it's out of memory"}
    (try
      (with-session node
        #(remote-exec-session % cmd (or opts {})))
      (catch com.jcraft.jsch.JSchException ex
        (let [dest (if (:master node)
                     (str (:master node) ":" (:name node))
                     (:ip-addr node))
              message (format "Exception %s when calling a command on node %s" ex dest)]
          (errorf ex message)
          (throw (com.jcraft.jsch.JSchException. message ex)))))))

(defn remote-exec!
  "Same as remote-exec!, throws on non-zero exit"
  [node cmd & opts]
  (let [resp (apply remote-exec node cmd opts)]
    (when (-> resp :exit (not= 0))
      (let [cleaned-node (select-keys node [:name :tags :type :master :ip-addr])]
        (errorf "Failure %s running a command on node %s" resp cleaned-node)
        (throw+ (assoc resp :node cleaned-node))))
    resp))

(defn compute-remote-home [node]
  (->> (remote-exec! node "echo REMOTE_HOME=$HOME")
       :out
       (re-find #"REMOTE_HOME=(.*)")
       (second)
       (str/trim)))

(defn remote-home [node]
  (:remote-home node))

(t/ann ^:no-check scp (t/IFn [Node & :mandatory {:direction (t/U (t/Value :to-remote)
                                                                 (t/Value :to-local))
                                                 :local-path (t/U String (t/Seq String))
                                                 :remote-path (t/U String (t/Seq String))}
                              -> sh/ShMap]))
(defn scp
  "Scp one or more files. Direction is a keyword, either :to-remote
  or :to-local.

  The 'source' side of the connection may be a seq of
  strings. i.e. when transferring to-remote, local-path can be a
  seq. When transferring to-local, remote-path can be a seq of paths"

  [node & {:keys [local-path remote-path direction]}]
  (if (:oom? node)
    (throw+ {:message "container suspended: it's out of memory"
             :node (dissoc node :private-key :public-key :encrypted-keypair)})
    (with-session node
      (fn [ssh-session]
        (cond
          (= :to-remote direction) (ssh/sftp ssh-session {} :put local-path remote-path)
          (= :to-local direction) (ssh/sftp ssh-session {} :get remote-path local-path)
          :else (throwf "direction must be :to-local or :to-remote"))))))

(defn rsync
  [{:keys [username ip-addr public-key private-key port] :or {port 22}}
   & {:keys [local-path remote-path direction]}]
  (let [remote (sh/q-arg (format "%s@%s:%s" username ip-addr remote-path))
        local (sh/q-arg local-path)
        from (if (= :to-local direction) remote local)
        to (if (= :to-local direction) local remote)]
    (with-temp-file private-key-file
      (spit private-key-file (str/trim private-key))
      (sh/sh!
       (sh/q
        (chmod 600 ~private-key-file)
        (rsync -aqz
               --delete
               -e (quoted (ssh -p ~port -i ~private-key-file
                               -o "'StrictHostKeyChecking no'"
                               -o "'GSSAPIAuthentication no'"
                               -o "'PasswordAuthentication no'"))
               ~from
               ~to))
       :timeout (time/minutes 10)))))

(defn remote-tempfile
  "Creates a temp file on the box, returns the path"
  [node]
  (-> (remote-exec! node "mktemp -t circle-XXXX")
      :out
      (str/trim)))

(defn remote-mkdir [node path]
  (-> (remote-exec! node (sh/emit-form (sh/q (mkdir -p ~path))))
      :out
      (str/trim)))

(t/ann remote-spit [Node String String -> sh/ShMap])
(defn remote-spit [node content path]
  (let [local-path (tmp-spit content)]
    (with-delete local-path
      (scp node
           :direction :to-remote
           :local-path local-path
           :remote-path path))))

(defn remote-slurp [node path]
  (let [dest-path (fs/tempfile)]
    (with-delete dest-path
      (scp node
           :direction :to-local
           :local-path dest-path
           :remote-path path)
      (slurp dest-path))))

(defn remote-find
  "Returns a seq of files, or nil.

   dir-or-dirs may be either a single path to search or a list/vector/etc of
   paths which will all be searched."
  [node dir-or-dirs & args]
  (let [dirs (str/join \newline (remove empty? (flatten [dir-or-dirs])))
        find-one (str/join " " (concat ["find" "\"$D\""] args ["-print"]))
        find-each (sh/emit-form
                    (sh/q (while "read -r D"
                            (when "[ -e \"$D\" ]"
                              ~find-one))))]
    (some-> (remote-exec node find-each :in (str dirs "\n"))
            ((fn [resp] (if (some-> resp :exit zero?) resp nil)))
            :out
            (str/split #"\r?\n")
            (->> (remove empty?)))))

(defn install-private-key
  "Sets the default ssh key (id_rsa)"
  [node priv-key]
  (let [remote-path (remote-tempfile node)]
    (throw-if-not priv-key "ssh private key is required")
    (remote-spit node priv-key remote-path)
    (remote-exec node (sh/emit-form (sh/q (mkdir -p "~/.ssh")
                                          (mv ~remote-path "~/.ssh/id_rsa")
                                          (chmod 400 ~remote-path))))))

(defn add-private-key-command
  [home & {:keys [hostname keyname private-key public-key gh-http-host ignore-host-key-checking]
           :or {gh-http-host "github.com"}}]
  (throw-if-not private-key "ssh private key is required")
  (throw-if-not keyname "ssh private key is required")
  (let [public-key (some-> public-key str/trim)
        private-key (some-> private-key str/trim)
        remote-dest-private (fs/join home (format ".ssh/%s" keyname))
        remote-dest-public (fs/join home (format ".ssh/%s.pub" keyname))
        config (str/join \newline
                         [(format "Host %s" (or hostname
                                                (format "!%s *" gh-http-host)))
                          ;; when the hostname is specified, don't use ssh-agent identities
                          (format "IdentitiesOnly %s" (if hostname "yes" "no"))
                          (format "IdentityFile %s" remote-dest-private)
                          (if ignore-host-key-checking "StrictHostKeyChecking no" "")])]
    (concat (sh/q (mkdir -p "~/.ssh"))
            (sh/q (~(format "echo '%s' >> ~/.ssh/config" config)))
            (when public-key
              (sh/q (~(sed/overwrite-file remote-dest-public public-key))))
            (when public-key
              (sh/q (chmod 600 ~remote-dest-public)))
            (sh/q (~(sed/overwrite-file remote-dest-private private-key)))
            (sh/q (chmod 600 ~remote-dest-private)))))

(defn add-private-key
  "Installs an SSH key to be tried when logging in

   - keyname, a filename, w/o path, such as 'id_github'. Should be unique.
   - hostname, if present, only use this key when connecting to a specific host"
  [node & {:as args}]
  (let [home (remote-home node)
        command (apply-map add-private-key-command home args)]
    (remote-exec node (sh/emit-form command))))

(defn setup-ssh-agent-command []
  (sh/q (pipe (echo "$(ssh-agent)")
              (cut -d "\\;" -f "1,2,3,4" >> "~/.circlerc"))))

(defn add-key-to-ssh-agent-command [home keyname]
  (let [remote-keypath (fs/join home (format ".ssh/%s" keyname))]
    (sh/q (source "~/.circlerc")
          (ssh-add ~remote-keypath))))

(def authorized-keys-path "~/.ssh/authorized_keys")

(defn append-authorized-keys
  "Append content to authorized-keys with proper perms"
  [node content]
  (remote-exec node (sh/emit-form (sh/q (mkdir -p "~/.ssh")
                                        (echo "\n" >> ~authorized-keys-path)
                                        ~(sed/append-file authorized-keys-path content)
                                        (chmod -R go-r "~/.ssh")))))

(defn authorize-keys
  "Add new public keys to the authorized keys"
  [node pub-keys]
  (let [content (str/join "\n" pub-keys)]
    (append-authorized-keys node content)))

(defn authorize-key
  "Add a new public key to the authorized keys"
  [node pub-key]
  (append-authorized-keys node pub-key))

(defn remote-host-fingerprint
  "Output host fingerprint"
  [node]
  (-> (remote-exec node (sh/emit-form
                         (sh/q (ssh-keygen -l -f "/etc/ssh/ssh_host_rsa_key.pub"))))
      :out
      (str/split #" ")
      second))

(defn anon-key-name
  "Come up with a filename for an (anonymous) key, given only its pubkey "
  [public-key]
  (-> public-key ssh-keys/fingerprint (str/replace #":" "")))

(defn init []
  ;; clj-ssh doesn't expose the channel connection timeout, so force it.
  ;; TODO: we could add retry logic here, if the occasional transient channel
  ;; connection failures persist.
  (intern 'clj-ssh.ssh 'connect-channel (fn [^Channel channel]
                                          (with-timing-metric :circle.ssh.connect-channel []
                                            (.connect channel 60000)))))
