# -*- coding: utf-8 -*-
"""
A collection of functions that manipulate data from cube files

author: Oxana Andriuc

This code is under CC BY-NC-SA license: https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
"""

import sys
import numpy as np
import ReadCube as rc


def Derivative(fpath=None,x_vector=None, y_vector=None, z_vector=None, val=None, order=1, au=False, density=1, value_type='esp', units=None):
    """
    arguments: path for a cube file (str), increment vector in the x direction (list), increment vector in the y direction (list),
               increment vector in the z direction (list), matrix of values (list), order (int), atomic units (bool), density (float), value type (str), units(None/str)
    calls: ExtractData (from ReadCube if fpath is specified)

    returns: a list of two lists: the first one of three arrays corresponding to the derivatives of the value matrix with respect to x, y and z (each array has the same shape as the original matrix) and the second one of the dx, dy and dz values (i.e. the lengths of the x, y and z increment vectors)

    this function returns the Nth derivative of the value matrix from a cube file
    """
    if fpath:
        exdata = rc.ExtractData(fpath, au=au, value_type=value_type, density=density, units=units)
        [x_x, y_x, z_x] = exdata[4]
        [x_y, y_y, z_y] = exdata[6]
        [x_z, y_z, z_z] = exdata[8]
        val = exdata[11]
    else:
        [x_x, y_x, z_x] = x_vector
        [x_y, y_y, z_y] = y_vector
        [x_z, y_z, z_z] = z_vector
        
    dx = (x_x**2 + y_x**2 + z_x**2)**0.5
    dy = (x_y**2 + y_y**2 + z_y**2)**0.5
    dz = (x_z**2 + y_z**2 + z_z**2)**0.5

    der_x = np.gradient(val, dx, axis=0)
    der_y = np.gradient(val, dy, axis=1)
    der_z = np.gradient(val, dz, axis=2)

    for i in range(order-1):
        der_x = np.gradient(der_x, dx, axis=0)
        der_y = np.gradient(der_y, dy, axis=1)
        der_z = np.gradient(der_z, dz, axis=2)

    return [[der_x, der_y, der_z],[dx,dy,dz]]


def Laplacian(fpath=None, x_vector=None, y_vector=None, z_vector=None, val=None,  au=False, density=1, value_type='esp', units=None):
    """
    arguments: path for a cube file (str), increment vector in the x direction (list), increment vector in the y direction (list),
               increment vector in the z direction (list), matrix of values (list), atomic units (bool), density (float), value type (str), units(None/str)

    calls: Derivative

    returns: a matrix of the Laplacian for every data point from a cube file

    the Laplacian is calculated by summing the second order derivative wrt x, y and z as obtained using Derivative for each data point
    """
    if fpath:
        [d2x, d2y, d2z] = Derivative(fpath=fpath, order=2, au=au, value_type=value_type, density=density, units=units)[0]
    else:
        [d2x, d2y, d2z] = Derivative(x_vector=x_vector, y_vector=y_vector, z_vector=z_vector, val=val, order=2)[0]

    n_x = len(d2x)
    n_y = len(d2x[0])
    n_z = len(d2x[0][0])

    lapl = [[[0 for m in range(n_z)] for l in range(n_y)] for k in range(n_x)]
    print('\n\n '+'-'*74+'\n|'+' Manipulating data from the cube file '.center(74)+'|\n '+'-'*74)
    for k in range(n_x):
        for l in range(n_y):
            for m in range(n_z):
                current_id=m+l*n_z+k*n_z*n_y+1
                if (n_x*n_y*n_z)>=50 and current_id%((n_x*n_y*n_z)//50)==0 or (n_x*n_y*n_z)<50:
                    print('Computing Laplacian: '.ljust(21)+'#'*(current_id*50//(n_x*n_y*n_z))+' '*(54-current_id*50//(n_x*n_y*n_z)-len(str(current_id*100//(n_x*n_y*n_z))))+str(current_id*100//(n_x*n_y*n_z))+'%',end='\r')
                lapl[k][l][m] = d2x[k][l][m] + d2y[k][l][m] + d2z[k][l][m]
    print('Computing Laplacian: '.ljust(21)+'#'*50+' 100%')
    return lapl


def VdWLaplacian(fpath=None,axes_list=None, gridpts=None, atoms_list=None, x_vector=None, y_vector=None, z_vector=None, val=None, factor=1,vdw_ind=None,vdw_pts=None, au=False, density=1, value_type='esp', units=None, thickness=0.3):
    """
    arguments: path for a cube file (str), axes list (list), grid points (dict), atoms list (list), increment vector in the x direction (list), increment vector in the y direction (list),
               increment vector in the z direction (list), matrix of values (list), factor (float), vdw points with indices (dict), vdw points with coordinates (dict), atomic units (bool), density (float), value type (str), units(None/str)

    calls: ExtractData (if fpath is specified), Axes, ValuesAsDictionary, GetVdWPoints (from ReadCube; if they are not passed as arguments), Laplacian

    returns: two lists: one containing coordinates + Laplacian for each data point on the van der Waals surface and one containing indices + Laplacian for each data point on the van der Waals surface
    """
    if fpath:
        exdata = rc.ExtractData(fpath, au=au, value_type=value_type, density=density, units=units)
        x_vector=exdata[4]
        y_vector=exdata[6]
        z_vector=exdata[8]
        val=exdata[11]
        
        if vdw_ind is None or vdw_pts is None:
            origin=exdata[1]
            n_x = exdata[3]
            n_y=exdata[5]
            n_z=exdata[7]
            atoms_list=exdata[9]
            vals=exdata[12]
            axes_list=rc.Axes(origin=origin, n_x = n_x, x_vector=x_vector, n_y=n_y, y_vector=y_vector,n_z=n_z, z_vector=z_vector)
            gridpts=rc.ValuesAsDictionary(vals=vals,axes_list=axes_list, origin=origin)[0]
            vdw = rc.GetVdWPoints(axes_list=axes_list,gridpts=gridpts,atoms_list=atoms_list, x_vector=x_vector, y_vector=y_vector, z_vector=z_vector, origin=origin, factor=factor, au=au, thickness=thickness)
            if vdw_ind is None:
                vdw_ind = vdw[1]
            if vdw_pts is None:
                vdw_pts = vdw[0]
                
        l = Laplacian(x_vector=x_vector, y_vector=y_vector, z_vector=z_vector, val=val)
    else:
        if vdw_ind is None or vdw_pts is None:
            vdw = rc.GetVdWPoints(axes_list=axes_list, gridpts=gridpts, atoms_list=atoms_list, x_vector=x_vector, y_vector=y_vector, z_vector=z_vector, origin=origin, factor=factor, au=au, thickness=thickness)
            if vdw_ind is None:
                vdw_ind = vdw[1]
            if vdw_pts is None:
                vdw_pts = vdw[0]
        l = Laplacian(x_vector=x_vector, y_vector=y_vector, z_vector=z_vector, val=val)

    vdw_l_pts = []
    vdw_l_ind = []

    for i in range(len(vdw_ind)):
        vdw_l_pts.append([vdw_pts[i][0], vdw_pts[i][1], vdw_pts[i][2], l[vdw_ind[i][0]][vdw_ind[i][1]][vdw_ind[i][2]]])
        vdw_l_ind.append([vdw_ind[i][0], vdw_ind[i][1], vdw_ind[i][2], l[vdw_ind[i][0]][vdw_ind[i][1]][vdw_ind[i][2]]]) 

    return vdw_l_pts, vdw_l_ind


def Hessian(fpath=None,x_vector=None, y_vector=None, z_vector=None, val=None, au=False, density=1, value_type='esp', units=None):
    """
    arguments: path for a cube file (str), increment vector in the x direction (list), increment vector in the y direction (list),
               increment vector in the z direction (list), matrix of values (list), atomic units (bool), density (float), value type (str), units(None/str)

    calls: Derivative

    returns: a matrix of the Hessian for every data point from a cube file
    """
    if fpath:
        [[Dx, Dy, Dz],[dx, dy, dz]] = Derivative(fpath=fpath, order=1, au=au, value_type=value_type, density=density, units=units)
    else:
        [[Dx, Dy, Dz],[dx, dy, dz]] = Derivative(x_vector=x_vector, y_vector=y_vector, z_vector=z_vector, val=val, order=1)

    n_x = len(Dx)
    n_y = len(Dx[0])
    n_z = len(Dx[0][0])
    
    DxDx = np.gradient(Dx, dx, axis=0)
    DxDy = np.gradient(Dx, dy, axis=1)
    DxDz = np.gradient(Dx, dz, axis=2)
    
    DyDx = np.gradient(Dy, dx, axis=0)
    DyDy = np.gradient(Dy, dy, axis=1)
    DyDz = np.gradient(Dy, dz, axis=2)
    
    DzDx = np.gradient(Dz, dx, axis=0)
    DzDy = np.gradient(Dz, dy, axis=1)
    DzDz = np.gradient(Dz, dz, axis=2)

    hess = [[[0 for m in range(n_z)] for l in range(n_y)] for k in range(n_x)]

    for k in range(n_x):
        for l in range(n_y):
            for m in range(n_z):
                hess[k][l][m] = [[DxDx[k][l][m], DxDy[k][l][m], DxDz[k][l][m]],
                                 [DyDx[k][l][m], DyDy[k][l][m], DyDz[k][l][m]],
                                 [DzDx[k][l][m], DzDy[k][l][m], DzDz[k][l][m]]]

    return hess

def VdWHessian(fpath=None,axes_list=None, gridpts=None, atoms_list=None, x_vector=None, y_vector=None, z_vector=None, val=None, vdw_ind=None, vdw_pts=None, factor=1, au=False, density=1, value_type='esp', units=None, thickness=0.3):
    """
    arguments: path for a cube file (str), axes list (list), grid points (dict), atoms list (list), increment vector in the x direction (list), increment vector in the y direction (list),
               increment vector in the z direction (list), matrix of values (list), factor (float), vdw points with indices (dict), vdw points with coordinates (dict), atomic units (bool), density (float), value type (str), units(None/str)

    calls: ExtractData (if fpath is specified), Axes, ValuesAsDictionary, GetVdWPoints (from ReadCube; if they are not passed as arguments), Hessian

    returns: two lists: one containing coordinates + Hessian matrix for each data point on the van der Waals surface and one containing indices + Hessian matrix for each data point on the van der Waals surface
    """
    if fpath:
        exdata = rc.ExtractData(fpath, au=au, value_type=value_type, density=density, units=units)
        x_vector=exdata[4]
        y_vector=exdata[6]
        z_vector=exdata[8]
        val=exdata[11]
        
        if vdw_ind is None or vdw_pts is None:
            origin=exdata[1]
            n_x = exdata[3]
            n_y=exdata[5]
            n_z=exdata[7]
            atoms_list=exdata[9]
            vals=exdata[12]
            axes_list=rc.Axes(origin=origin, n_x = n_x, x_vector=x_vector, n_y=n_y, y_vector=y_vector,n_z=n_z, z_vector=z_vector)
            gridpts=rc.ValuesAsDictionary(vals=vals,axes_list=axes_list, origin=origin)[0]
            vdw = rc.GetVdWPoints(axes_list=axes_list,gridpts=gridpts,atoms_list=atoms_list, x_vector=x_vector, y_vector=y_vector, z_vector=z_vector, origin=origin, factor=factor, au=au, thickness=thickness)
            if vdw_ind is None:
                vdw_ind = vdw[1]
            if vdw_pts is None:
                vdw_pts = vdw[0]
        
        
        h=Hessian(x_vector=x_vector, y_vector=y_vector, z_vector=z_vector, val=val)
    else:
        if vdw_ind is None or vdw_pts is None:
            vdw = rc.GetVdWPoints(axes_list=axes_list, gridpts=gridpts, atoms_list=atoms_list, x_vector=x_vector, y_vector=y_vector, z_vector=z_vector, origin=origin, factor=factor, au=au, thickness=thickness)
            if vdw_ind is None:
                vdw_ind = vdw[1]
            if vdw_pts is None:
                vdw_pts = vdw[0]
        
        h=Hessian(x_vector=x_vector, y_vector=y_vector, z_vector=z_vector, val=val)
    
    vdw_h_pts = []
    vdw_h_ind = []

    for i in range(len(vdw_ind)):
        vdw_h_pts.append([vdw_pts[i][0], vdw_pts[i][1], vdw_pts[i][2], h[vdw_ind[i][0]][vdw_ind[i][1]][vdw_ind[i][2]]])
        vdw_h_ind.append([vdw_ind[i][0], vdw_ind[i][1], vdw_ind[i][2], h[vdw_ind[i][0]][vdw_ind[i][1]][vdw_ind[i][2]]]) 

    return vdw_h_pts, vdw_h_ind

if __name__ == '__main__':
    # Map command line arguments to function arguments.
    
    arg_str=""
    
    for i in range(len(sys.argv)-1):
        arg_str=arg_str+sys.argv[i]+" "
    arg_str=arg_str+sys.argv[-1]
    arg_str=arg_str.replace(" =","=").replace("= ","=")
    args=arg_str.split()
    
    fct = args[1]
    comp_args = []    # compulsory arguments
    arg_dict = {}     # optional arguments
    for argument in args[2:]:
        if "=" not in str(argument):
            comp_args.append(argument)
        else:
            val=str(argument).split('=')[1]
            if val.lower()=='true':
                val='True'
            elif val.lower()=='false':
                val='False'
            elif val.lower=='none':
                val='None'
            try:
                arg_dict[str(argument).split('=')[0]] = eval(val)   # eval turns a string into the right format (list, float, boolean etc.)
            except NameError:
                arg_dict[str(argument).split('=')[0]] = val
                
    globals()[fct](*comp_args, **arg_dict)
