#!/usr/bin/python

from pylab import *
from matplotlib.collections import LineCollection
from xsfutil import *
from basUtils import *
import Element 
import sys

from colorsys import hls_to_rgb


# ======== functions =============

def slice2scatter(Xs,Ys, lvec):
	nx = len(Xs)
	ny = len(Xs[0])
	Ps = mgrid[0:nx,0:ny].astype(float)
	Ps[0] *= (lvec[1,0]/float(nx))
	Ps[1] *= (lvec[2,1]/float(ny))
	Ps[0] += Ys + lvec[0,0]
	Ps[1] += Xs + lvec[0,1]
	return Ps


def plotBonds( xyz, bonds ):
	for b in bonds:
		i=b[0]; j=b[1]
		arrow(xyz[1][i], xyz[2][i], xyz[1][j]-xyz[1][i], xyz[2][j]-xyz[2][i], head_width=0.0, head_length=0.0,  fc='k', ec='k', lw= 1.0,ls='solid' )

def colorize_XY2HL(Xs, Ys, s=1.0 ):
	z = Xs + 1j*Ys
	r = abs(z)
	vmax = r.max()
	arg = angle(z) 
	h = (arg + pi)  / (2 * pi) + 0.5
	l = r
	c = vectorize(hls_to_rgb) (h,l,s) # --> tuple
	c = array(c)  # -->  array of (3,n,m) shape, but need (n,m,3)
	c = c.swapaxes(0,2) 
	c = c.swapaxes(0,1) 
	return c, vmax

def colorize_XY2HS(Xs, Ys, l=0.5 ):
	z = Xs + 1j*Ys
	r = abs(z)
	vmax = r.max()
	arg = angle(z) 
	h = (arg + pi)  / (2 * pi) + 0.5
	s = r/r.max()
	LL = s.copy()*0 + l
	c = vectorize(hls_to_rgb) (h,LL,s) # --> tuple
	c = array(c)  # -->  array of (3,n,m) shape, but need (n,m,3)
	c = c.swapaxes(0,2) 
	c = c.swapaxes(0,1) 
	return c, l

def colorize_XY2RG(Xs, Ys, l=0.5 ):
	r = sqrt(Xs**2 + Ys**2)
	vmax = r[10:-10,10:-10].max()
	Red   = 0.5*Xs/vmax + 0.5
	Green = 0.5*Ys/vmax + 0.5
	c = array( (Red, Green, zeros(shape(Red)) )  )  # -->  array of (3,n,m) shape, but need (n,m,3)
	c = c.swapaxes(0,2) 
	c = c.swapaxes(0,1) 
	return c, vmax

	
# ======== code ==================

dz = 0.1

#ilist = [35,38,40,43,45,48,50,53,55,58]
#ilist = [40,45,50,55]
#ilist = [45]

#ilist = [0,3,5,7, 10,13,15,17, 20,23,25, 27, 30, 33, 35,37, 40,43,45,47, 50,53,55,57]
#ilist = [ 20,23,25, 27, 30, 33, 35,37, 40,43,45,47, 50,53,55,57]
#ilist = [ 20,25, 30, 35, 40,45,50,55]
#ilist = [ 20,25, 30, 35, 40,]
#ilist = [ 20,23,25, 27, 30, 33, 35,37, 40,43,45,47, 50,53,55,57]



ilist = arange( 20,58,2 )

#bd    =  0; shift = -10
bd    =  30; shift = -7


Xs,lvec, nDim, head = loadXSF('OutX.xsf')
Ys,lvec, nDim, head = loadXSF('OutY.xsf')
print " lvec ",lvec
print " nDim ",nDim
#dz = lvec[3,2] / nDim[0]
#print " dz ", dz

# if there is not x0 in xsf




lvec[0,0] += -10
lvec[0,1] += -10
extent=( shift, -shift ,   shift, -shift  )
#extent=( lvec[0,0], lvec[1,0]+lvec[0,0] ,   lvec[0,1],    lvec[2,1]+lvec[0,1]  )


n=len(ilist)

xyz = loadBas('../surf.bas')[0]
#xyz = loadBas('surf.bas')[0]
bonds = findBondsSimple(xyz, 1.7 ) 
#print bonds


# ========== plot radial distortion
figure( figsize=(3*n,0.5+3 ) );
for i in range(n):
	ii = ilist[i]
	subplot(1,n,i+1);
	title( "z=%2.2f$\AA$" %(dz*ii)  );
	HSBs,vmax = colorize_XY2RG(Xs[ii],Ys[ii])
	print i,ii,vmax
	imshow( HSBs[bd:-bd,bd:-bd],  extent=extent, vmin=0, vmax=vmax,  origin='image' ) 
	scatter( xyz[1], xyz[2], s=pylab.sqrt(xyz[0])*10,   c='#FFFFFF'  )
	xlim(extent[0],extent[1])
	ylim(extent[2],extent[3])
savefig( 'dcomplex_RG_all.png', bbox_inches='tight', pad_inches=0)


'''
# ========== plot radial distortion
figure( figsize=(3*n,0.5+3 ) );
for i in range(n):
	ii = ilist[i]
	print i,ii
	subplot(1,n,i+1);
	title( "z=%2.2f$\AA$" %(dz*ii)  );
	HSBs,vmax = colorize_XY2HL(Xs[ii],Ys[ii])
	imshow( HSBs,  extent=extent, vmin=0, vmax=vmax,  origin='image' ) 
	scatter( xyz[1], xyz[2], s=pylab.sqrt(xyz[0])*10,   c='#FFFFFF'  )
	xlim(extent[0],extent[1])
	ylim(extent[2],extent[3])
savefig( 'dcomplex_HL_all.png', bbox_inches='tight', pad_inches=0)

# ========== plot radial distortion
figure( figsize=(3*n,0.5+3 ) );
for i in range(n):
	ii = ilist[i]
	print i,ii
	subplot(1,n,i+1);
	title( "z=%2.2f$\AA$" %(dz*ii)  );
	HSBs,vmax = colorize_XY2HS(Xs[ii],Ys[ii])
	imshow( HSBs,  extent=extent, vmin=0, vmax=vmax,  origin='image' ) 
	scatter( xyz[1], xyz[2], s=pylab.sqrt(xyz[0])*10,   c='#FFFFFF'  )
	xlim(extent[0],extent[1])
	ylim(extent[2],extent[3])
savefig( 'dcomplex_HS_all.png', bbox_inches='tight', pad_inches=0)
'''

# =========== plot points position
figure( figsize=(3*n,0.5+3 ) );
for i in range(n):
	ii = ilist[i]
	print i,ii
	subplot(1,n,i+1);
	title( "z=%2.2f$\AA$" %(dz*ii)  );
	imshow( Xs[ii]*0,  extent=extent,  origin='image', cmap = 'binary' ) 
	Ps = slice2scatter(Xs[ii],Ys[ii], lvec)
	scatter( Ps[1].flat, Ps[0].flat, s=1,   c='#FF0000', edgecolors='none', alpha=0.25  )
	#scatter( Ps[1].flat, Ps[0].flat, s=1,   c='#FF0000', edgecolors='none', alpha=1.00   )
	scatter( xyz[1], xyz[2], s=pylab.sqrt(xyz[0])*10,   c='#FFFFFF'  )
	plotBonds( xyz, bonds )
	xlim(extent[0],extent[1])
	ylim(extent[2],extent[3])
savefig( 'points2_all.png', bbox_inches='tight', pad_inches=0)


'''
# ========== plot radial distortion
figure( figsize=(3*n,0.5+3 ) );
for i in range(n):
	ii = ilist[i]
	print i,ii
	subplot(1,n,i+1);
	title( "z=%2.2f$\AA$" %(dz*ii)  );
	imshow( -(Xs[ii]**2+Ys[ii]**2),  extent=extent,  origin='image', cmap = 'binary' ) 
	scatter( xyz[1], xyz[2], s=pylab.sqrt(xyz[0])*10,   c='#FFFFFF'  )
	#plotBonds( xyz, bonds )
	xlim(extent[0],extent[1])
	ylim(extent[2],extent[3])
savefig( 'dr2_all.png', bbox_inches='tight', pad_inches=0)

# ========== plot radial distortion
figure( figsize=(3*n,0.5+3 ) );
for i in range(n):
	ii = ilist[i]
	print i,ii
	subplot(1,n,i+1);
	title( "z=%2.2f$\AA$" %(dz*ii)  );
	imshow( -sqrt(Xs[ii]**2+Ys[ii]**2) ,  extent=extent,  origin='image', cmap = 'binary' ) 
	scatter( xyz[1], xyz[2], s=pylab.sqrt(xyz[0])*10,   c='#FFFFFF'  )
	#plotBonds( xyz, bonds )
	xlim(extent[0],extent[1])
	ylim(extent[2],extent[3])
savefig( 'dr_all.png', bbox_inches='tight', pad_inches=0)
'''

#show();










