#!/usr/bin/python

import pylab
import xsfutil

'''
#usage:

import xsfutil
import arpes3D
from os import *
chdir('/home/prokop/Desktop/Dorje/SiAg/4x4-3D-4.0-SiProject')
xsfutil.loadXSF('ARPES_3D.xsf')
arpes3D.plotCut(F1, cut=1.0, Efermi = -2.28233751996685, ylim=(-3,3), vmax=1.0)


getcwd()
'''

   
def plotCut(xsf, cut=0.0, Efermi = 0.0,xlim=None, ylim=None, vmax=1.0, interpolation='bicubic', cmap='Greys', figsize=None):
    F    = xsf[0]
    lvec = xsf[1] 
    dims = xsf[2] 
    #print "lvec = ", lvec
    #print "dims = ", dims
    klim = (lvec[0,0],lvec[0,0]+lvec[1,0])
    elim = (lvec[0,1],lvec[0,1]+lvec[2,1])
    #print "klim",klim
    #print "elim",elim
    #icut = round( pylab.clip( dims[0]*(cut-lvec[0,2])/lvec[3,2] , 0, dims[0]-1 ) )
    icut = round( dims[0]*(cut-lvec[0,2])/lvec[3,2] )
    #print "xcut",cut,icut,lvec[0,2],lvec[3,2]
    if (figsize != None):
        pylab.figure(figsize=figsize)
    pylab.imshow(F[icut], cmap=cmap, extent=[klim[0],klim[1],elim[0]-Efermi,elim[1]-Efermi],  interpolation=interpolation , origin='lower', vmax=vmax, aspect = 'auto' )
    if (xlim != None):
        pylab.xlim (xlim[0],xlim[1])
    if (ylim != None):
        pylab.ylim (ylim[0],ylim[1])
    pylab.title(" ky = "+str(cut))
    pylab.tight_layout()
    pylab.ion()
    pylab.show()

def plotHex( r, a0 ):
    ax = pylab.gca()
    angles =pylab.arange(0,2,0.33333)*pylab.pi+a0
    pts = pylab.transpose(pylab.array([ pylab.cos(angles), pylab.sin(angles) ]) *r )
    ax.add_patch( pylab.Polygon( pts, closed=True, fill=False, hatch=None, ec='r' ) )
    
def plotHexs(alat, a0, clr, Range, lw=1.0):
    r = 2*pylab.pi/alat  /0.86602540378
    n = int(Range/r)
    ax = pylab.gca()
    a      = r*pylab.array([pylab.cos(a0),pylab.sin(a0)])
    b      = r*pylab.array([pylab.cos(a0+pylab.pi*0.333333),pylab.sin(a0+pylab.pi*0.333333)])
    angles = (pylab.arange(0,2,0.333333)+0.166666)*pylab.pi+a0
    pts = pylab.transpose(pylab.array([ pylab.cos(angles), pylab.sin(angles) ]) *0.5*r/0.86602540378 )
    for ia in range(-n,n+1):
        for ib in range(-n,n+1):
            ax.add_patch( pylab.Polygon( pts+a*ia+b*ib, closed=True, fill=False, hatch=None, ec=clr, lw=lw ) )

def plotMultiHexs( hexs, extent=[-2,2,-2,2], lw=1 ):
    xsz = extent[0]-extent[1]
    ysz = extent[2]-extent[3]
    aspect = ysz/xsz				
    pylab.figure(figsize=(6,6*aspect))
    kR = pylab.sqrt( xsz**2+ysz**2)
    for hx in hexs:
        plotHexs( hx[0],hx[1]*pylab.pi/180,hx[2], kR, lw=hx[3] )
    pylab.axis('equal')
    pylab.ylim(extent[2],extent[3])
    pylab.xlim(extent[0],extent[1])								

            
'''
Example:
  SiAg  
    arpes3D.plotCutE( SiAg4 , cut=0,Efermi=-2.39416747588695, hexs=[[11.6106933437,0.0,'r'],[11.6106933437/3,0.0,'g'],[11.6106933437/4,0.0,'b']])
  SiPt
    PtR19Si=[[12.35,6.6,'c'],[12.35/4.35889894354,30.0,'g'],[12.35/3,6.6,'m']]
    arpes3D.plotCutE(PtMartin,  Efermi=-2.82019285836379,vmax=None,cut=-0.5, hexs=PtR19Si)
'''

def plotCutE(xsf, cut=0.0, Efermi = 0.0,xlim=None, ylim=None, vmax=1.0, interpolation='bicubic', cmap='Greys', hexs=None, lw=1.0):
    F    = xsf[0]
    lvec = xsf[1] 
    dims = xsf[2] 
    #print "lvec = ", lvec
    #print "dims = ", dims
    kxlim = (lvec[0,0],lvec[0,0]+lvec[1,0])
    kylim = (lvec[0,2],lvec[0,2]+lvec[3,2])
    #print "kxlim",kxlim
    #print "kylim",kylim
    #icut = round( pylab.clip( dims[0]*(cut-lvec[0,2])/lvec[3,2] , 0, dims[0]-1 ) )
    icut = round( dims[1]*(cut+Efermi-lvec[0,1])/lvec[2,1] )
    #print "xcut",cut,icut,lvec[0,1],lvec[2,1]
    pylab.imshow(F[:,icut,:], cmap=cmap, extent=[kxlim[0],kxlim[1],kylim[0],kylim[1]],  interpolation=interpolation , origin='lower', vmax=vmax, aspect = 'equal' )
    if (xlim != None):
        pylab.xlim (xlim[0],xlim[1])
    if (ylim != None):
        pylab.ylim (ylim[0],ylim[1])
    kR = pylab.sqrt( (kxlim[1]-kxlim[0])**2+(kylim[1]-kylim[0])**2)
    if hexs != None:
        for hx in hexs:
             plotHexs( hx[0],hx[1]*pylab.pi/180,hx[2], kR, lw=lw )
    pylab.title(" E = "+str(cut))
    pylab.tight_layout()
    pylab.ion()
    pylab.show()
    
    
def plotCutESym(xsf, cut=0.0, Efermi = 0.0,xlim=None, ylim=None, vmax=1.0, interpolation='bicubic', cmap='Greys', hexs=None, lw=1.0):
    F    = xsf[0]    
    lvec = xsf[1] 
    dims = xsf[2] 
    #FF   = pylab.zeros( (dims[0]*2,dims[1],dims[2]) )
    #FF[:,:,:dims[2]] = F[:,:,:]
    #FF[:,:,dims[2]:] = F[::-1,:,::-1]
    FF   = pylab.zeros( (dims[0]*2,dims[1],dims[2]) )
    FF[:dims[0],:,:] = F[::-1,:,::-1]
    FF[dims[0]:,:,:] = F[:,:,:]
    #print "lvec = ", lvec
    #print "dims = ", dims
    kxlim = (lvec[0,0],lvec[0,0]+lvec[1,0])
    kylim = (lvec[0,2]-lvec[3,2],lvec[0,2]+lvec[3,2])
    #print "kxlim",kxlim
    #print "kylim",kylim
    #icut = round( pylab.clip( dims[0]*(cut-lvec[0,2])/lvec[3,2] , 0, dims[0]-1 ) )
    icut = round( dims[1]*(cut+Efermi-lvec[0,1])/lvec[2,1] )
    #print "xcut",cut,icut,lvec[0,1],lvec[2,1]
    pylab.imshow(FF[:,icut,:], cmap=cmap, extent=[kxlim[0],kxlim[1],kylim[0],kylim[1]],  interpolation=interpolation , origin='lower', vmax=vmax, aspect = 'equal' )
    if (xlim != None):
        pylab.xlim (xlim[0],xlim[1])
    if (ylim != None):
        pylab.ylim (ylim[0],ylim[1])
    kR = pylab.sqrt( (kxlim[1]-kxlim[0])**2+(kylim[1]-kylim[0])**2)
    if hexs != None:
        for hx in hexs:
             plotHexs( hx[0],hx[1]*pylab.pi/180,hx[2], kR, lw=lw )
    pylab.title(" E = "+str(cut))
    pylab.tight_layout()
    pylab.ion()
    pylab.show()

    

def plotCutCompare(xsfs, cut=0.0, Efermi = 0.0,xlim=None, ylim=None, vmax=1.0, interpolation='bicubic', cmap='Greys'):
    nplots = len(xsfs)
    for iplot in range(nplots):
        xsf = xsfs[iplot]
        F    = xsf[0]
        lvec = xsf[1] 
        dims = xsf[2] 
        klim = (lvec[0,0],lvec[0,0]+lvec[1,0])
        elim = (lvec[0,1],lvec[0,1]+lvec[2,1])
        #icut = round( pylab.clip( dims[0]*(cut-lvec[0,2])/lvec[3,2] , 0, dims[0]-1 ) )
        icut = round( dims[0]*(cut-lvec[0,2])/lvec[3,2] )
        pylab.subplot(1,nplots,iplot)
        pylab.imshow(F[icut], cmap=cmap, extent=[klim[0],klim[1],elim[0]-Efermi,elim[1]-Efermi],  interpolation=interpolation , origin='lower', vmax=vmax, aspect = 'auto' )
        if (xlim != None):
            pylab.xlim (xlim[0],xlim[1])
        if (ylim != None):
            pylab.ylim (ylim[0],ylim[1])
    pylab.show()
    