#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
A collection of functions that save the data from cube files in convenient formats

author: Oxana Andriuc

This code is under CC BY-NC-SA license: https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
"""

import pandas as pd
import sys
import ReadCube as rc
from decimal import Decimal
import time


def SaveAll(fpath, outfile=None, au=False, density=1, value_type='esp', units=None):
    """
    VERY SLOW

    arguments: path for a cube file (str), output file name (with extension; str), au (bool, opt), density (float), value type (str), units(None/str)

    calls: ExtractData, Axes, ValuesAsDictionary (from ReadCube)

    returns: -

    this function writes the x, y, z coordinates and the values from a cube file to a new file (4 columns)
    """
    
    if outfile is None:
        outfile=fpath[:-4]+'csv'
    exdata = rc.ExtractData(fpath, au=au, value_type=value_type, density=density, units=units)
    origin=exdata[1]
    n_x = exdata[3]
    x_vector=exdata[4]
    n_y=exdata[5]
    y_vector=exdata[6]
    n_z=exdata[7]
    z_vector=exdata[8]
    vals=exdata[12]
    val_units=exdata[13]
    axes_list=rc.Axes(origin=origin, n_x = n_x, x_vector=x_vector, n_y=n_y, y_vector=y_vector,n_z=n_z, z_vector=z_vector)
    gridpts=rc.ValuesAsDictionary(vals=vals,axes_list=axes_list, origin=origin)[0]

    l = list(gridpts.items())
    d = 1
    proceed = '2'
    print(n_x,n_y,n_z,len(l))
    while proceed == '2':
        fsize = 67*(1+len(l)//d)

        if fsize < 1000:
            fsize = round(fsize)
            sizeunit = 'B'
        elif fsize < 1000000:
            fsize = round(fsize/1000)
            sizeunit = 'KB'
        elif fsize < 1000000000:
            fsize = round(fsize/1000000)
            sizeunit = 'MB'

        proceed = input('\nYour file size will be '+str(fsize)+sizeunit+'. Options: 1. Proceed 2. Reduce the number of points then proceed 3. Abort. Pick an option [1/2/3]: ')

        if proceed == '2':
            d = int(input('Please input an integer number n so that only every n value is saved to the file: '))

    if proceed == '1' or proceed == '2':
        if au:
            l_units = 'bohr'
        else:
            l_units = 'A'

        col_names = [('X coord/'+l_units).rjust(15), ('Y coord/'+l_units).rjust(15), ('Z coord/'+l_units).rjust(15), ('Value/'+val_units).rjust(18)]

        df = pd.DataFrame(columns=col_names)

        print('\nProgress: ', end="\r")

        n = len(l)//d

        for i in range(len(l)):

            if i%d == 0:

                df.loc[i] = [('%.6f' % l[i][0][0]).rjust(15), ('%.6f' % l[i][0][1]).rjust(15), ('%.6f' % l[i][0][2]).rjust(15), ('%.5E' % Decimal(str(l[i][1]))).rjust(18)]

                j = i//d

                # print progress
                if n >= 100 and j % (n//100) == 0 or n < 100:
                    print('Progress: '+(j*100//(n))*'|'+(104-len(str(j*100//n))-j*100//n)*' '+str(j*100//n)+'%', end="\r")

        df.to_csv(outfile, sep='\t', index=False)     # write dataframe to file

        print('Progress: '+100*'|'+' 100%\n')


def SaveVdW(fpath, outfile=None, factor=1, au=False,  density=1, value_type='esp', thickness=0.3, units=None):
    """
    arguments: path for a cube file (str), output file name (with extension; str), factor (float, opt), au (bool, opt), density (float), value type (str), units(None/str)

    calls: ExtractData, Axes, ValuesAsDictionary, GetVdWPoints (from ReadCube)

    returns: -

    this function writes the x, y, z coordinates and the values of all the points that lie on a surface defined by the van der Waals radii of the atoms to a new file (4 columns)
    """
    if outfile is None:
        outfile=fpath[:-4]+'csv'
    
    exdata = rc.ExtractData(fpath, au=au, value_type=value_type, density=density, units=units)
    origin=exdata[1]
    n_x = exdata[3]
    x_vector=exdata[4]
    n_y=exdata[5]
    y_vector=exdata[6]
    n_z=exdata[7]
    z_vector=exdata[8]
    atoms_list=exdata[9]
    vals=exdata[12]
    val_units=exdata[13]
    axes_list=rc.Axes(origin=origin, n_x = n_x, x_vector=x_vector, n_y=n_y, y_vector=y_vector,n_z=n_z, z_vector=z_vector)
    gridpts=rc.ValuesAsDictionary(vals=vals,axes_list=axes_list, origin=origin)[0]

    l = rc.GetVdWPoints(axes_list=axes_list,gridpts=gridpts,atoms_list=atoms_list, x_vector=x_vector, y_vector=y_vector, z_vector=z_vector, origin=origin, factor=factor, au=au, thickness=thickness)[0]


    d = 1
    proceed = '2'

    while proceed == '2':
        fsize = 67*(1+len(l)//d)

        if fsize < 1000:
            fsize = round(fsize)
            sizeunit = 'B'
        elif fsize < 1000000:
            fsize = round(fsize/1000)
            sizeunit = 'KB'
        elif fsize < 1000000000:
            fsize = round(fsize/1000000)
            sizeunit = 'MB'

        proceed = input('\nYour file size will be '+str(fsize)+sizeunit+'. Options: 1. Proceed 2. Reduce the number of points then proceed 3. Abort. Pick an option [1/2/3]: ')

        if proceed == '2':
            d = int(input('Please input an integer number n so that only every n value is saved to the file: '))

#    start_time=time.time()

    if proceed == '1' or proceed == '2':
        if au:
            l_units = 'bohr'
        else:
            l_units = 'A'

        col_names = [('X coord/'+l_units).rjust(15), ('Y coord/'+l_units).rjust(15), ('Z coord/'+l_units).rjust(15), ('Value/'+val_units).rjust(18)]

        df=pd.DataFrame(columns=col_names)

        print('\nProgress: ', end="\r")

        n = len(l)//d

        for i in range(len(l)):

            if i % d == 0:

                df.loc[i] = [('%.6f' % l[i][0]).rjust(15), ('%.6f' % l[i][1]).rjust(15), ('%.6f' % l[i][2]).rjust(15), ('%.5E' % Decimal(str(l[i][3]))).rjust(18)]

                j = i//d

                # print progress
                if n >= 100 and j % (n//100) == 0 or n < 100:
                    print('Progress: '+(j*100//(n))*'|'+(104-len(str(j*100//n))-j*100//n)*' '+str(j*100//n)+'%', end="\r")

        df.to_csv(outfile, sep='\t', index=False)     # write dataframe to file

        print('Progress: '+100*'|'+' 100%\n')
#    print("--- %s seconds ---" % (time.time() - start_time))


if __name__ == '__main__':
    # Map command line arguments to function arguments.
    
    arg_str=""
    
    for i in range(len(sys.argv)-1):
        arg_str=arg_str+sys.argv[i]+" "
    arg_str=arg_str+sys.argv[-1]
    arg_str=arg_str.replace(" =","=").replace("= ","=")
    args=arg_str.split()
    
    fct = args[1]
    comp_args = []    # compulsory arguments
    arg_dict = {}     # optional arguments
    for argument in args[2:]:
        if "=" not in str(argument):
            comp_args.append(argument)
        else:
            val=str(argument).split('=')[1]
            if val.lower()=='true':
                val='True'
            elif val.lower()=='false':
                val='False'
            elif val.lower=='none':
                val='None'
            try:
                arg_dict[str(argument).split('=')[0]] = eval(val)   # eval turns a string into the right format (list, float, boolean etc.)
            except NameError:
                arg_dict[str(argument).split('=')[0]] = val
                
    globals()[fct](*comp_args, **arg_dict)
