#!/apps/python/3.4.2//bin/python3


##!/usr/bin/env python3
# This code decides whether range sep parameter is converged or not when tuning a RSH, and if not sets up the next batch of input files.
# Only should require base_filename as input, due to the regular nature of all other files names.
#Note that mu is what i call the parameter im minimising in this code
#from init_OTRSH import write_RSH
#from init_OTRSH import mod_es_input
import re
import sys
import random
####TESTING ONLY. ALL THESE WILL BE MADE INTO CMD LINE ARGS#######
#base_fname = r'/Users/rf614/Work/Random_Coding/tune_RSH_testing/test_files/H2O_SPE_6311plgdp_GP_B3LYP_builtin.com'
####


#conver_threshold_ehomo = 0.1 / 27.212	#Write in eV, convert to au str8 away. Diff between EHOMO/
grad_thres_mu = 0.05	#mu range at point to start gradient search (higher values mean gradient search quicker
###TESTING ONLY. WILL BE MADE INTO CMD LINE ARGS####

#GLOBAL CONSTANTS
#extensions applied in renaming the base file (use base_fname.replace('.com',output_ext) for example to get the .out file)
output_ext = '_OTRSH.out' 
is_ext_low = '_ionstate_mulow.com'	#ionised state (low mu)
is_ext_high = '_ionstate_muhigh.com' #ionised state (high mu)
gs_ext_low = '_gsstate_mulow.com'	#neutral state (low mu)
gs_ext_high = '_gsstate_muhigh.com'	#neutral state (high mu)
init_step_grad = 0.01 #Change in mu for first step of the gradient search method
max_step_grad = 0.05 #Maximum step size in the gradient method
min_step_grad = 0.001 #Minimum step size the gradient optimiser can take
#max_cycles = 20 #Maximum times this function can be called.(actually just max for the grad-optimiser part)
#grad_conver_thres = 0.5/27.212	#Acceptable gradient (\delta(IP-E_HOMO)/\delta(step_size)) at point of convergence. Value in eV, *27.212 is the conversion factor
IP_conver_thres = 0.01/27.212  #Acceptable deviation of the ionisation potential from the HOMO energt eigenvalue at convergence
# Pulls out all the info we care about for a single mu value (range sep param) from the input log files (one is for ionised, one not ionised)

base_fname = sys.argv[1]        #0th arg is the name of this script
max_cycles =int(sys.argv[2])
min_step_grad = float(sys.argv[3])
max_step_grad = float(sys.argv[4])


def write_RSH(filename,xc_funct,P5_val,mu_range,alpha_val,beta_val,new_filename):
        #****SECTION 1****#
        #Getting all the correct IOp commands in individual strings 
        p1_p2_com = 'IOp(3/76=1000010000)'
        p3_p4_com = 'IOp(3/77=1000010000)'
        p5_p6_com = 'IOp(3/78=' + format(round(P5_val*10000),'05.0f') + '10000)'
        beta_com = 'IOp(3/119=' + format(round(beta_val*10000),'05.0f') + '00000) IOp(3/120=' + format(round(beta_val*10000),'05.0f') +'00000)'
        alpha_com  = 'IOp(3/130=' + format(round(alpha_val*10000),'05.0f') + ') IOp(3/131=' + format(round(alpha_val*10000),'05.0f') + ')'
        mu_com = 'IOp(3/107=' + format(round(mu_range*10000),'05.0f') + '00000) IOp(3/108=' + format(round(mu_range*10000),'05.0f') + '00000)'
        all_coms = xc_funct + ' ' + p1_p2_com + ' ' + p3_p4_com + ' ' + p5_p6_com + ' ' + alpha_com + ' ' + beta_com + ' ' + mu_com   
        #Creating the modified input file
        with open(filename,'r') as f:
                inp_file=[line.replace('#','# ' + all_coms) for line in f]      #Each line of input file is an entry in the list. Also replaced route section of input file


	#Adding the .chk to the file (aparently sed cant do this ffs)
#	chk_string = new_filename.split('/').replace('.com','.chk')
        for line in range(0,len(inp_file)):
                if inp_file[line].strip().startswith('%chk') and inp_file[line].strip().endswith('.chk'):
                        inp_file[line] = '%chk=' + new_filename.split('/')[-1].replace('.com','.chk') + '\n'
        print(''.join(inp_file))
        with open(new_filename,'w') as f:
                f.write(''.join(inp_file))


#Modifies the charge and multiplicity for the ionised calculation. filename full path to the file to modify
def mod_es_input(filename,es_mult):
        with open(filename,'r') as f: #Sorting out the multiplicity
                inp_file=[line for line in f]
        route_line = 0
        blanks=0
        for line in range(len(inp_file)):
                if inp_file[line].find('#')!=-1:
                        route_line = route_line + 1
                if inp_file[line].strip() == '':        #Looking for the 2nd blank line after #
                        blanks = blanks + 1
                if route_line == 1 and blanks==2:
                        inp_file[line+1] = str(int(inp_file[line+1].strip().replace(' ',',').split(',')[0])+1) + ',' + str(es_mult) + '\n'
                        break
        with open(filename,'w') as f:
                f.write("".join(inp_file))












# INPUT:
# neut_file = full path to the .log file for the neutral SPE
# ionised_file = full path to the .log file for the SPE on the ionised species
# ionised_multiplicity = the multiplicity of the file for the ionised species
# OUTPUT (all values in eV):
# E_neutral = scf energy for the neut_file 
# E_ion = scf energy for the ionised_file
# HOMO_E = HOMO energy (-ve value for bound states)
# IP = E_ion - E_neutral (+ve value)
# diff_IP_E_HOMO = IP - (-HOMO_E)
def extract_info_log_files(neut_file,ionised_file,ionised_multiplicity):
	#Need to first work out if we need to remove an alpha or beta electron to ionise to the ionised_file_state
	with open(neut_file,'r') as f:
		for line in f:
			if line.find('alpha electrons')!=-1:
				numb_alpha = int(line.strip().split()[0])	#Number of alpha electrons in ground state system
				numb_beta = int(line.strip().split()[3])	#Number of beta electrons in ground state system
				gs_mult= abs(numb_alpha-numb_beta)+1
				break

	#Now working out what eigenvalues to search for in output file
	if gs_mult==1:
		eigen_str = 'occ. eigenvalues --' #Gets both alpha/beta eignvals (though only aalpha will appear)
	elif numb_alpha > numb_beta and ionised_multiplicity > gs_mult:
		eigen_str = 'Beta  occ. eigenvalues'
	elif numb_alpha > numb_beta and ionised_multiplicity < gs_mult:
		eigen_str = 'Alpha  occ. eigenvalues'
	elif numb_alpha < numb_beta and ionised_multiplicity > gs_mult:
		eigen_str = 'Alpha  occ. eigenvalues'
	elif numb_alpha < numb_beta and ionised_multiplicity < gs_mult:
		eigen_str = 'Beta  occ. eigenvalues'

	#Getting relevant info from neut_file:
	all_eigens = []	#All orbital energies for occupied MOs in the neutral species	
	with open(neut_file,'r') as f:	
		for line in f:
			if line.find('SCF Done:') != -1:
				E_neutral = float(re.findall('-\d+\.\d+',line)[0])  
			if line.find(eigen_str) != -1:
				curr_eigens = [float(eigen) for eigen in re.findall('.\d+\.\d+',line)]
				all_eigens.extend(curr_eigens)
	#Relevant info from excited state file
	with open(ionised_file,'r') as f:
		for line in f:
			if line.find('SCF Done:') != -1:
				E_ion = float(re.findall('-\d+\.\d+',line)[0])	
	#Calculating remaning quantities + returning them
	print(all_eigens)
	HOMO_E = max(all_eigens)
	IP = E_ion - E_neutral
	diff_IP_E_HOMO = IP - (HOMO_E*-1)
	return (E_neutral,E_ion,HOMO_E,IP,diff_IP_E_HOMO)	


#This checks whether the bracket search section has been converged, if so just outputs unchanged out_file and conver=1 (so grad search can start)
#If not converged then this adds a single new mu value to out_file which is exactly between the previous two
#out_file = list,the entire contents of the .out file as a list (each line = 1 entry). All table rows must be complete
#conver_threshold = float, max distance between previous two mu values to run another bracket search. If 
def bracket_search(out_file,conver_threshold):
	#see init_grad_search for meanings of these nasty out_file.strip.split things
	mu_diff = abs(float(out_file[-1].strip().split(',')[0]) - float(out_file[-2].strip().split(',')[0]))
	if mu_diff < conver_threshold:
		conver=1
		return out_file,conver
	else:
		conver=0
		new_mu = (float(out_file[-1].strip().split(',')[0]) + float(out_file[-2].strip().split(',')[0])) / 2	#Average of the two mu		
		#Putting the old minimum mu (filled in column) at bottom of table. So in next iteration it can be called.
		if abs(float(out_file[-1].strip().split(',')[-1])) > abs(float(out_file[-2].strip().split(',')[-1])):
			out_file.append(out_file[-2])
		else:
			out_file.append(out_file[-1])
		#Putting the new mu value at bottom of out_file
		out_file.append(str(new_mu))
		return out_file,conver

#Appends the term "START GRAD SEARCH" to out_file, and below that has the two relevant mu values for the 1st step of the gradient search.
#1st mu value is that which has given minimum absolute (IP-E_HOMO) so far.
#2nd mu value is empty row, and is 1st + init_step_grad
#INPUT:
#init_step_grad is the value to add to minimum mu at 1st step of gradient search (whther added or subbed depends on direction bracket search predicts minmum to be).
def init_grad_search(out_file,init_step_grad):
	#Finding mu value with lowest abs(IP-E_{HOMO}) and generating value to mod it with
	first_mu = float(out_file[-1].strip().split(',')[0])	#value of 1st mu
	second_mu = float(out_file[-2].strip().split(',')[0])
	first_mu_IPdiff = abs(float(out_file[-1].strip().split(',')[-1]))		#value of abs(IP-E_HOMO) for first mu
	second_mu_IPdiff = abs(float(out_file[-2].strip().split(',')[-1]))
	
	if (first_mu_IPdiff > second_mu_IPdiff)  and (first_mu > second_mu):
		old_mu_line = str(out_file[-2])
		new_mu = second_mu + init_step_grad
	elif (first_mu_IPdiff > second_mu_IPdiff)  and (first_mu < second_mu):
		old_mu_line = str(out_file[-2])
		new_mu = second_mu - init_step_grad
	elif (first_mu_IPdiff < second_mu_IPdiff)  and (first_mu > second_mu):
		old_mu_line = str(out_file[-1])	#Full table line for the best mu value so far
		new_mu = first_mu - init_step_grad 
	elif (first_mu_IPdiff < second_mu_IPdiff)  and (first_mu < second_mu):
		old_mu_line = str(out_file[-1])	#Full table line for the best mu value so far
		new_mu = first_mu + init_step_grad 

	#Modifying the output_file accordingly:
	out_file.append('START GRAD SEARCH\n')
	out_file.append(old_mu_line)
	out_file.append(str(new_mu))
	return out_file


#INPUT:
#out_file =list, obvious i hope (basically the full .out file)
#min_step = float, the minimum step the gradient code will do
#max_step = float, the maximum step the gradient code will do
#grad_conver_thres = the minimum gradient acceptable for convergence to be reached (regardless of IP_conver_thres)
#IP_conver_thres = minumum deviation of HOMO energy and ionisation potential for convergence to be reached (regardless of gradient)
def grad_search(out_file,min_step,max_step,IP_conver_thres):
	#first step is to calc the gradient: (step n) - (step n-1)
	step_n_val = abs(float(out_file[-1].strip().split(',')[-1])) 	# abs(IP-E_{HOMO}) for last step
	step_n_mu = float(out_file[-1].strip().split(',')[0])	#mu val (range sep param) for last step
	step_nm1_val =	abs(float(out_file[-2].strip().split(',')[-1])) # abs(IP-E_{HOMO}) for last step - 1 step
	step_nm1_mu = float(out_file[-2].strip().split(',')[0])
	prev_step = step_n_mu - step_nm1_mu	#The step taken last time (direction/magnitude)
	grad = (step_n_val - step_nm1_val) / (step_n_mu - step_nm1_mu)
#	prev_step_dir = -1 if (step_n_mu - step_nm1_mu)<0 else 1	#Whether we're going forwards or backwardson the PES (1st case is backwards). Used to get direction of next step
	print('Gradient is equal to ' + str(grad))
	#Check if convergence has been reached
	if step_n_val<=IP_conver_thres:
		out_file.append('Final abs((E_IP - E_HOMO)) / eV = ' + str(step_n_val*27.212) )
		out_file.append('Final mu(range sep) parameter * 10000 = ' + format(round(step_n_mu*10000),'05.0f') )
		out_file.append('GRADIENT OPTIMISATION COMPLETE')
		return out_file
	else:
		r_fact = random.random()
		r_fact = r_fact +0.9 if r_fact < 0.1 else r_fact # random number between 0.1 and 1. Weighted twoards >0.9 slightly. Purpose is to get us out of oscilations
		new_step_size = abs((1/grad) * step_n_val * r_fact)
		if new_step_size < min_step:
			new_step_size = min_step
		if new_step_size > max_step:
			new_step_size = max_step
		new_mu_step = new_step_size if grad < 0 else new_step_size*-1
#		new_mu_step = new_step_size if (grad*prev_step_dir) <0 else new_step_size*-1 
		new_mu_val =abs(new_mu_step + step_n_mu)
		out_file.append(str(new_mu_val))
	#generate next step mu val if convergence hasnt been reached
		

#------------------------>EXECUTED CODE STARTS HERE<-----------------

#****SECTION 1****
#Getting initial parameters and Filling in the table in .out
with open(base_fname.replace('.com',output_ext),'r') as f:
	out_file = [line for line in f]
xc_funct,P5_val,alpha_val,beta_val,ionised_multiplicity = out_file[2].strip().split(',')
P5_val,alpha_val,beta_val,ionised_multiplicity = float(P5_val),float(alpha_val),float(beta_val),int(ionised_multiplicity)
# Filling in missing rows of the optimisation table
if len(out_file[-2].split(","))==2:	#Only import .log if data is missing in the optimisation table
	if out_file[-2].strip().split(",")[-1] == 'low':
		is_file = base_fname.replace('.com',is_ext_low.replace('.com','.log'))
		gs_file = base_fname.replace('.com',gs_ext_low.replace('.com','.log'))
		E_neutral,E_ion,HOMO_E,IP,diff_IP_E_HOMO = extract_info_log_files(gs_file,is_file,ionised_multiplicity)	#All parameters in atomic units
	if out_file[-2].strip().split(",")[-1] == 'high':
		is_file = base_fname.replace('.com',is_ext_high.replace('.com','.log'))
		gs_file = base_fname.replace('.com',gs_ext_high.replace('.com','.log'))		
		E_neutral,E_ion,HOMO_E,IP,diff_IP_E_HOMO = extract_info_log_files(gs_file,is_file,ionised_multiplicity)
	out_file[-2] = out_file[-2].strip() + ',' + ",".join([str(x) for x in [E_neutral,E_ion,HOMO_E,IP,diff_IP_E_HOMO]]) + '\n'	#Modifying table in our output


# Getting info from last calculation
#print(out_file)
#print('out_file[-1].strip().split(",")[-1]=' + out_file[-1].strip().split(",")[-1])
if out_file[-1].strip().split(",")[-1] == 'low':
	print('MISSING DATA ON 2nd last line')
	is_file = base_fname.replace('.com',is_ext_low.replace('.com','.log'))
	gs_file = base_fname.replace('.com',gs_ext_low.replace('.com','.log'))
	E_neutral,E_ion,HOMO_E,IP,diff_IP_E_HOMO = extract_info_log_files(gs_file,is_file,ionised_multiplicity)	#All parameters in atomic units
if out_file[-1].strip().split(",")[-1] == 'high':
	is_file = base_fname.replace('.com',is_ext_high.replace('.com','.log'))
	gs_file = base_fname.replace('.com',gs_ext_high.replace('.com','.log'))		
	E_neutral,E_ion,HOMO_E,IP,diff_IP_E_HOMO = extract_info_log_files(gs_file,is_file,ionised_multiplicity)

out_file[-1] = out_file[-1].strip() + ',' + ",".join([str(x) for x in [E_neutral,E_ion,HOMO_E,IP,diff_IP_E_HOMO]]) + '\n'	#Modifying table in our output



#*****SECTION 2*****
#Checking convergence + working out the next value to test

#Putting the next mu values into the out_file list
#This will run the gradient search if bracket search is finished
if any([bool(a.find('START GRAD SEARCH')+1) for a in out_file]):
	#Updating the step number for grad_search (both in .calc file and current code)
	with open(base_fname.replace('.com','_OTRSH.calc'),'r') as f:
		calc_file = [line for line in f]
	for line in range(0,len(calc_file)):
		if calc_file[line].find('curr_grad_step') != -1:
			curr_grad_step = int(calc_file[line].strip().replace('curr_grad_step=',''))
			calc_file[line] = 'curr_grad_step=' + str(curr_grad_step+1)
		with open(base_fname.replace('.com','_OTRSH.calc'),'w') as f:
			f.write("\n".join([x.strip() for x in calc_file]))
	
	if curr_grad_step < max_cycles:
		grad_search(out_file,min_step_grad,max_step_grad,IP_conver_thres)
	else: 
		out_file.append('GRADIENT OPTIMISATION REACHED MAX CYCLES')
#This runs the bracket_search (Will make else statement in the bit above)
if not(any([bool(a.find('START GRAD SEARCH')+1) for a in out_file])):
	print('Bracket search start')
	out_file,conver = bracket_search(out_file,grad_thres_mu)	#creates any required files, and checks if bracket search part converged. if bracket_search part converged then conver ==1
	print(out_file)
	if conver == 1:
		out_file = init_grad_search(out_file,init_step_grad)	#creates relevant files for first step of the gradient search + mods the .out file acordingly

#options: 1) bracket search @ start of run always
#	  2) initialisation of gradient search (when final line is 'START GRAD SEARCH'
#	  3) general gradient search (when 'START GRAD SEARCH' present but not final line


#****SECTION X*****
#Creating .com file for new mu value (max=1) at bottom of out_file. Only creates for those with empty columns

#First checking whether it should be "high" or "low"
if len(out_file[-1].split(","))==1 and out_file[-1].find('GRADIENT OPTIMISATION COMPLETE')==-1 and out_file[-1].find('GRADIENT OPTIMISATION REACHED MAX CYCLES')==-1:
	out_mu = float(out_file[-1])
	write_RSH(base_fname,xc_funct,P5_val,out_mu,alpha_val,beta_val,base_fname.replace('.com','_gsstate_mulow.com')) #ground state calc (low mu)
	write_RSH(base_fname,xc_funct,P5_val,out_mu,alpha_val,beta_val,base_fname.replace('.com','_ionstate_mulow.com')) #ionised state calc (low mu) 
	mod_es_input(base_fname.replace('.com','_ionstate_mulow.com'),ionised_multiplicity)
	out_file[-1] = out_file[-1] + ',' + 'low'





#****SECTION X****
#Generating all the output files
#if not(any([bool(a.find('GRADIENT OPTIMISATION COMPLETE')+1) for a in out_file])) and not(any([bool(a.find('GRADIENT OPTIMISATION REACHED MAX CYCLES')+1) for a in out_file])):

with open(base_fname.replace('.com',output_ext),'w') as f:
	f.write("\n".join([x.strip() for x in out_file]))
#	print(out_file)

#****SECTION X****
# Generate  .calc file, telling bash which files to run next.

