#!/usr/bin/env python2

import math
import numpy
import sys
import os
import pickle

xmlstarlet='xml'

from math import sqrt

def get_cmap(struct):
  retval = numpy.zeros((len(struct),len(struct)), float)
  
  for i in range(len(struct)):
    for j in range(i+1,len(struct)):
        rij = 0.0
        for dim in range(3):
          rij += pow(struct[i][dim] - struct[j][dim],2)
        rij = sqrt(rij)
        if (rij <= lam*lam*sigma*sigma):
          retval[i][j] = 1
          retval[j][i] = 1
          
  return retval

def rmsd(cmap1, cmap2):
  norm1 = 0.0
  norm2 = 0.0
  retval = 0.0
  for i in range(len(cmap1)):
    for j in range(i+1+skip,len(cmap1)):
      norm1 += cmap1[i][j]*cmap1[i][j]
      norm2 += cmap2[i][j]*cmap2[i][j] 
      retval += cmap1[i][j]*cmap2[i][j]
  norm = sqrt(norm1)*sqrt(norm2)
  if (norm == 0):
    return 0.0
  return retval / norm

def get_structlist(outputfile):
  cmd = 'bzcat '+outputfile+' | '+xmlstarlet+' sel -t -m \'/OutputData/StructureImages\' -m \'Image\' -m \'Atom\' -v \'@x\' -o "," -v \'@y\' -o \',\' -v \'@z\' -o \':\' -b -n | gawk \'{if (NF) print $0}\''

  structlist = []
  for line in os.popen(cmd).readlines():
    line = line.rstrip(':')
    line = line.rstrip(':\n')
    structlist.append(numpy.array([[float(x) for x in atom.split(',')] for atom in line.split(':')]))

  print "Number of structures is "+str(len(structlist))
  if (len(structlist) > maxnumofstruct):
    print "\nTruncating to "+str(maxnumofstruct)
    if (len(structlist) == 500 and maxnumofstruct == 100):
      del structlist[0::5]
      del structlist[0::4]
      del structlist[0::3]
      del structlist[0::2]
    else:
      del structlist[100:]

  return structlist

def get_temperature(file):
    cmd = 'bzcat '+file+' | '+xmlstarlet+' sel -t -v \'/OutputData/KEnergy/T/@val\''
    return float(os.popen(cmd).read())  

from operator import add

def get_rmsd_data(structlist):
  arrayList = numpy.zeros( (len(structlist),len(structlist)), float)
  cmaplist = []
  
  for struct in structlist:
    cmaplist.append(get_cmap(struct))

  for i in range(len(structlist)):    
    for j in range(i+1,len(structlist)):
      arrayList[i][j] = rmsd(cmaplist[i], cmaplist[j])
                            #rmsd(structlist[i], (structlist[j])[::-1,::-1]))
      arrayList[j][i] = arrayList[i][j]

  minimum = 0
  minsum = 1e308
  for i in range(len(structlist)):  
    localsum = sum(arrayList[i])
    if (localsum < minsum):
      minimum = i
      minsum = localsum

  return minimum, minsum / len(structlist), arrayList, cmaplist

import copy

def cluster_analysis(arrayList, minval, threshold):
  localCluster = []
  for i in range(len(arrayList)):
    local = [i]
    count = 1
    for j in range(len(arrayList)):
      if (arrayList[i][j] > minval):
        if (i != j):
          local.append(j)
          count += 1
    localCluster.append([count, local, i])
  
  clusters = []
  maxval = copy.deepcopy(max(localCluster))
  while (maxval[0] > threshold):
    maxval[1].sort()
    clusters.append(maxval)
    for j in range(len(arrayList)):
      for id in maxval[1]:
        if (id in localCluster[j][1]):
          localCluster[j][1].remove(id)
          localCluster[j][0] -= 1
    maxval = copy.deepcopy(max(localCluster))
    
  return clusters

def print_structure(structure, filename):
  f = open(filename, 'w')
  g = open(filename+".txyz", 'w')
  try:
    length = str(len(structure))
    print >>f, "{VECT 1 "+length+" 0 "
    print >>f, length
    print >>f, "0"
    
    print >>g, str(len(structure)),"RMSDimage"

    for i in range(len(structure)):
      print >>f, structure[i][0], structure[i][1], structure[i][2]
      atom = "C"
      cmd="bzcat config.0.end.xml.bz2 | xml sel -t -v \"/DYNAMOconfig/Dynamics/Interactions/Interaction[@Type='SquareWellSeq']/Sequence/Element[@seqID='"+str(i)+"']/@Letter\""
      for line in os.popen(cmd).readlines():        
        if ("1" in line):
          atom = "O"
      p=i
      if (p==0): p = 2 
      print >>g, i+1, atom, ("%15.10f" % structure[i][0]), ("%15.10f" % structure[i][1]), ("%15.10f" % structure[i][2]), atom, p
    #print >>f,"1 1 1 1"
    print >>f, "}"
  finally:
    f.close()
    g.close()

def contact_hist(membersofcluster, cmaps, name):
  cmaphist = numpy.zeros( (len(cmaps[0]),len(cmaps[0])), float)
  for id in membersofcluster:
    cmaphist += cmaps[id]
        
  f = open('CMapHist.'+str(name)+'.dat', 'w')
  try:
    for i in range(len(cmaps[0])):
      for j in range(len(cmaps[0])):
        f.write(str(cmaphist[i][j] / float(len(membersofcluster))))
        f.write(' ')
      f.write("\n")

  finally:
    f.close()

def print_vmd_structure(structure, filename):
  f = open(filename, 'w')
  try:
    length = str(len(structure))
    print >>f, length
    print >>f, "RMSD VMD structure"
    for i in range(len(structure)):
      print >>f,"C", structure[i][0]*3.4, structure[i][1]*3.4, structure[i][2]*3.4
  finally:
    f.close()

import getopt, sys

try:
  optlist, filelist = getopt.getopt(sys.argv[1:], [], [ 'sigma=', 'skip=', 'cutoff=', 'threshold=', 'regen', 'maxstructs=' ])
except getopt.GetoptError, err:
  # print help information and exit:
  print str(err) # will print something like "option -a not recognized"
  sys.exit(2)

sigma=1.0
skip=0
cutoff=0.8
lam=1.5
threshold=5
maxnumofstruct=100
regen=False

for o, a in optlist:
  if o == "--sigma":
    sigma = float(a)
  elif o in ("--skip"):
    skip = int(a)
  elif o in ("--cutoff"):
    cutoff = float(a)
  elif o in ("--lambda"):
    lam = float(a)
  elif o in ("--threshold"):
    threshold = int(a)
  elif o in ("--regen"):
    regen = True
  elif o in ("--maxstructs"):
    maxnumofstruct= int(a)
  else:
    assert False, "unhandled option"

filedata = [ ]
for file in  filelist:
  print "Processing file "+file
  if ( (not regen) and os.path.exists(file+'.crmspickle3')):
    print "Found pickled data, skipping minimisation"
    f = open(file+'.crmspickle3', 'r')
    tempdata = pickle.load(f)
    tempdata.append(file)
    tempdata.append(get_structlist(file))
    filedata.append(tempdata)
    f.close()
  else:
    print "Could not find pickled minimisation data"
    tempdata = [ get_temperature(file), get_rmsd_data(get_structlist(file)) ]

    f = open(file+'.crmspickle3', 'w')
    try:
      pickle.dump(tempdata,f)
    finally:
      f.close()

    tempdata.append(file)
    tempdata.append(get_structlist(file))
    filedata.append(tempdata)


filedata.sort()

for data in filedata:
  f = open(data[2]+'.crmsdhist', 'w')
  g = open(data[2]+'.crmsdarray', 'w')
  try:
    arrayList = numpy.zeros(100, float)
    xmax = 1.0
    delx = xmax / 100.0
    count = 0
    for i in range(len(data[1][2])):
      for j in range(i+1,len(data[1][2])):
        coord = data[1][2][i][j]/delx
        if (coord < 100):
          arrayList[int(coord)] += 1
          count += 1

    for i in range(len(data[1][2])):
      for j in range(len(data[1][2])):
        print >>g, data[1][2][i][j],
      print >>g, "\n",

    for i in range(len(arrayList)):
      print >>f, (i+0.5)*delx, arrayList[i]/float(count)
  finally: 
    f.close()
    g.close()

f = open('crmsdhistsurface.dat', 'w')
try:
  for data in filedata:
    arrayList = numpy.zeros(101, float)
    xmax = 5.0
    delx = xmax / 100.0
    count = 0
    for i in range(len(data[1][2])):
      for j in range(i+1,len(data[1][2])):
        if (int(data[1][2][i][j]/delx) < 100):
          arrayList[int(data[1][2][i][j]/delx)] += 1
          count += 1
    
    for i in range(len(arrayList)):
      print >>f, data[0], (i+0.5)*delx, arrayList[i]/float(count)

    print >>f, " "
finally: 
  f.close()


#Outputs the min rmsd and the structure number
f = open('cclusters.dat', 'w')
g = open('cclusterfraction.dat', 'w')
h = open('cavgrmsd.dat', 'w')
try:
  for data in filedata:
    clusters = cluster_analysis(data[1][2], cutoff, threshold)
    if (len(filedata)==1):
      print clusters
      for cluster in clusters:
        print "Printing structure number "+str(cluster[2])
        print_structure(data[3][cluster[2]],"Cluster"+str(cluster[2])+".list")
        print "Cmap histogram printed"
        contact_hist(cluster[1],data[1][3], cluster[2])
        for cluster2 in clusters:
          print "  Struct "+str(cluster[2])+" rmsd vs struct "+str(cluster2[2])+" = "+str(data[1][2][cluster[2]][cluster2[2]])
    sum = 0
    for cluster in clusters:
      sum += cluster[0]
      
    print >>f, data[0], len(clusters)
    print >>g, data[0], float(sum) / float(len(data[1][2]))
    print >>h, data[0], data[1][1], data[1][0]
finally:
  f.close()
  g.close()
  h.close()
  
