(ns bionlp.biobert
  "A libpython-clj interface to running onnx transformers models (e.g. BERT models) in a batched mode.
   It can particularly take a long text and run the BERT classification for NER in a batched mode.
   Examples:
   ```
   user> (require '[bionlp.biobert :as biobert])

   ;; First instantiate a transformers nlp pipeline using a modelname or path to model directory
   ;; we can use any NER models from the huggingface model repository i.e https://huggingface.co/models
   user> (def nlp-pipe (biobert/nlp-pipeline \"alvaroalon2/biobert_diseases_ner\"))

   ;; Run NER
   user> (ner-tag nlp-pipe \"The objective of this study was to provide more accurate frequency estimates 
          of breast cancer susceptibility gene 1 (BRCA1) germline alterations in the ovarian cancer population.\")
   ```"
  (:require [clojure.string :as str]
            [bionlp.util :as util]
            [libpython-clj2.require :refer [require-python]]
            [libpython-clj2.python :as py :refer [py. py.. initialize!
                                                  run-simple-string
                                                  import-module]]))

(initialize!)

(require-python '[onnx_transformers :bind-ns])

(defn nlp-pipeline
  "Returns a transformers ner pipeline based on supplied model-name or path"
  [model]
  (py. onnx_transformers "pipeline" "ner" :model model :onnx true))

(defn ner-tag
  "Runs nlp-pipeline on text segment.
   `nlp-pipeline` - transformers pipeline instance to use for NER
   `txt` - text segment"
  [nlp-pipeline txt]
  (let [ner-result (try
                     (nlp-pipeline txt)
                     (catch Exception e# '()) ; Ignore exception (for now)
                     )]
    ;; aggregate tagged tokens into single token based on IOB tags
    ;; index and tag type can be used to determine consecutive tokens
    (->> ner-result
         (reduce (fn [arr res]
                   (let [word (get res "word")
                         last-index (:index (last arr) 0)
                         token (str/trim (str/replace (util/str-escape-regex-chars word) #"#+" "(\\\\w+)?"))
                         ;; determine if token should be attached to previous token
                         attached? (or (str/index-of word "#")
                                       (str/index-of word "'")
                                       (= word "-")
                                       (contains? #{\- \'} (last (:token (last arr))))
                                       ((fnil str/index-of "") (:token (last arr)) "#"))]
                     (if (and (> (- (get res "index" 0) last-index) 1)
                              (not attached?))
                       (conj arr {:token-regex token :score (get res "score") :index (get res "index")})
                       (conj (into [] (butlast arr))
                             {:token-regex (str/trim (str/join (if attached? "" " ") (list (:token-regex (last arr)) token)))
                              :index (get res "index")
                              :score (/ (+ (:score (last arr) 0) (get res "score")) 2)}))))
                 [])
        (map (fn [tok]
               (let [pat (re-pattern (str "(?i)" (:token-regex tok) "(\\w+)?"))
                     match (re-find pat txt)]
                 {:token (cond
                           (vector? match)
                           (first (filter #(not (nil? %)) match))

                           :else
                           match)
                  :index (:index tok)
                  :score (:score tok)}))))))

(defn batched-ner
  "Splits a long text into smaller batches and run NER tag.
   `nlp-pipeline` - transformers pipeline instance to use for NER
   `txt` - text to tag"
  [nlp-pipeline txt]
  (let [batches (partition 200 195 [" "] (str/split txt #"\s+"))]
    (apply concat
           (map #(let [txt (str/join " " %)]
                   (ner-tag nlp-pipeline txt))
                batches))))
