#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Created on Mon May 22 16:59:17 2017

@author: tam10
"""

from params import params
from lepspoint import lepspoint
from lepnorm import lepnorm

import numpy as np
from numpy.linalg.linalg import LinAlgError

import matplotlib.pyplot as plt
import matplotlib.collections as mcoll
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D
import warnings

import tkinter as tk
import tkinter.messagebox as msgbox
from tkinter.filedialog import asksaveasfilename

class Interactive():
    
    def __init__(self): #Initialise Class
        
        ###Initialise tkinter###
        self.root = tk.Tk()
        self.root.title("LEPS GUI")
        self.root.resizable(0,0)
        
        ###Initialise defaults###
        
        self.atom_map = {
           #Atom: [Index, VdW radius, colour]
           #Index      - index in dropdown list in selection
           #VdW radius - used for animation
           #Colour     - used for animation
            "H" : [1, 1.20, '#eeeeee'],
            "F" : [2, 1.47, '#ffdd00'],
            "Cl": [3, 1.75, '#32d600'],
            "D" : [4, 1.20, '#c0c0c0'],
            "I" : [5, 1.98, '#de00a0'],
            "O" : [6, 1.52, '#ff0000']
        }
        
        self.dt = 0.005 #Time step in dynamics trajectory
        self.H = 0.424  #Surface parameter
        self.lim = 10   #Calculation will stop once this distance is exceeded

        self.Vmat = None       #Array where potential is stored for each gridpoint
        self.old_params = None #Variable used to prevent surface being recalculated
        
        self.entries  = {} #Dictionary of entries to be read on refresh (user input)
        self.defaults = {  #Defaults for each entry
           #Key        : Default value  , type , processing function
            "a"        : ["H"           , str  , lambda x: self.atom_map[x][0]],
            "b"        : ["H"           , str  , lambda x: self.atom_map[x][0]],
            "c"        : ["H"           , str  , lambda x: self.atom_map[x][0]],
            "xrabi"    : ["2.3"         , float, None                         ],
            "xrbci"    : ["0.74"        , float, None                         ],
            "prabi"    : ["-2.5"        , float, None                         ],
            "prbci"    : ["-1.5"        , float, None                         ],
            "steps"    : ["500"         , int  , lambda x: max(1, x)          ],
            "cutoff"   : ["-20"         , float, None                         ],
            "spacing"  : ["5"           , int  , None                         ],
            "calc_type": ["Dynamics"    , str  , None                         ],
            "theta"    : ["180"         , float, None                         ],
            "plot_type": ["Contour Plot", str  , None                         ]
        }
        
        #Store variable as class attributes
        for key, l in self.defaults.items():
            val, vtype, procfunc = l
            val = vtype(val)
            if procfunc: #Check whether processing is needed
                val = procfunc(val)
            setattr(self, key, val)
        
        #This is needed to allow surface to be calculated on the first run
        self._firstrun = True
        
        ###GUI###
        
        #Default frame format
        sunken = dict(height=2, bd=1, relief="sunken")
        
        #Atoms Selection Frame
        selection_frame = self._add_frame(self.root, "Atoms", sunken, dict(row=0, column=0, rowspan=2, sticky="news",padx=5,pady=5))
        
        self._add_label(selection_frame, {"text": "Atom A:"}, {"row":0, "column":0})
        self._add_label(selection_frame, {"text": "Atom B:"}, {"row":1, "column":0})
        self._add_label(selection_frame, {"text": "Atom C:"}, {"row":2, "column":0})
        
        self._add_optionmenu(selection_frame, "a", ["H", "F", "Cl", "D", "I", "O"],grid_kwargs={"row": 0, "column": 1})
        self._add_optionmenu(selection_frame, "b", ["H", "F", "Cl", "D", "I", "O"],grid_kwargs={"row": 1, "column": 1})
        self._add_optionmenu(selection_frame, "c", ["H", "F", "Cl", "D", "I", "O"],grid_kwargs={"row": 2, "column": 1})
        
        #Initial Conditions Frame
        values_frame = self._add_frame(self.root, "Initial Conditions", sunken, dict(row=2, column=0, rowspan=2, sticky="news",padx=5,pady=5))
        
        self._add_label(values_frame, {"text": "AB Distance:     "}, {"row":0, "column":0})
        self._add_label(values_frame, {"text": "BC Distance:     "}, {"row":1, "column":0})
        self._add_label(values_frame, {"text": "AB Momentum:   "  }, {"row":2, "column":0})
        self._add_label(values_frame, {"text": "BC Momentum:   "  }, {"row":3, "column":0})
        
        self._add_entry(values_frame, "xrabi", {}, {"row":0, "column":1}, {"width":10}, self.update_geometry_info)
        self._add_entry(values_frame, "xrbci", {}, {"row":1, "column":1}, {"width":10}, self.update_geometry_info)
        self._add_entry(values_frame, "prabi", {}, {"row":2, "column":1}, {"width":10}, self.update_geometry_info)
        self._add_entry(values_frame, "prbci", {}, {"row":3, "column":1}, {"width":10}, self.update_geometry_info)
        
        #Angle Frame
        angle_frame = self._add_frame(self.root, "Collision Angle", sunken, dict(row=4, column=0, sticky="news",padx=5,pady=5))
        
        self._add_scale(
            angle_frame, "theta",
            {"from_":0, "to":180, "orient":"horizontal"},
            {"row":0,"column":0,"sticky":"ew"},
            {"length":200}
        )
        
        #Update and Export
        update_frame = self._add_frame(self.root, None, sunken, dict(row=5, column=0, columnspan=3, sticky="news",padx=5,pady=5))
        self._add_button(
            update_frame, {"text": "Update Plot"},
            dict(row=0, column=0, sticky="news",padx=5,pady=5),
            {"<Button-1>": self.update_plot}
        )
        self._add_button(
            update_frame, {"text": "Get Last Geometry"},
            dict(row=0, column=1, sticky="news",padx=5,pady=5),
            {"<Button-1>": self.get_last_geo}
        )
        self._add_button(
            update_frame, {"text": "Export Data"},
            dict(row=0, column=2, sticky="news",padx=5,pady=5),
            {"<Button-1>": self.export}
        )
        
        #Calculation Type Frame
        calc_type_frame = self._add_frame(self.root, "Calculation Type", sunken, dict(row=0, column=1, sticky="news",padx=5,pady=5))
        
        self._add_optionmenu(
            calc_type_frame, "calc_type", [
                "Dynamics", 
                "MEP", 
                "Opt TS",
                "Opt Min"
            ], {}, {"row":1, "column":0}, {"width":20}
        )
        
        #Plot Type Frame
        type_frame = self._add_frame(self.root, "Plot Type", sunken, dict(row=1, column=1, sticky="news",padx=5,pady=5))
        
        self._add_optionmenu(
            type_frame, "plot_type", [
                "Contour Plot", 
                "Surface Plot",
                "Internuclear Distances vs Time", 
                "Internuclear Momenta vs Time", 
                "Energy vs Time",
                "p(AB) vs p(BC)",
                "v(AB) vs v(BC)",
                "Animation"
            ], {}, {"row":1, "column":0}, {"width":20}
        )
        
        #Steps Frame
        steps_frame = self._add_frame(self.root, "Steps", sunken, dict(row=2, column=1, sticky="news",padx=5,pady=5))
        self._add_entry(steps_frame, "steps", {}, {"row":0, "column":0}, {"width":6})
        
        #Cutoff Frame
        cutoff_frame = self._add_frame(self.root, "Cutoff (Kcal/mol)", sunken, dict(row=3, column=1, sticky="news",padx=5,pady=5))
        
        self._add_scale(
            cutoff_frame, "cutoff",
            {"from_":-100, "to":0, "orient":"horizontal"},
            {"row":0,"column":0,"sticky":"ew"},
            {"length":200}
        )
        
        #Contour Spacing Frame
        spacing_frame = self._add_frame(self.root, "Contour Spacing", sunken, dict(row=4, column=1, sticky="news",padx=5,pady=5))
        
        self._add_scale(
            spacing_frame, "spacing",
            {"from_":1, "to":10, "orient":"horizontal"},
            {"row":0,"column":0,"sticky":"ew"},
            {"length":200}
        )
        
        #Geometry Info Frame
        geometry_frame = self._add_frame(self.root, "Initial Geometry Information", sunken, dict(row=0, column=2, rowspan=5, sticky="news",padx=5,pady=5))
        
        self._add_button(
            geometry_frame, {"text": "Refresh"},
            dict(row=0, column=0, sticky="news",padx=5,pady=5),
            {"<Button-1>": self.update_geometry_info}
        )
        
        energy_frame = self._add_frame(geometry_frame, "Energy", sunken, dict(row=1, column=0, sticky="news",padx=5,pady=5))
        self._add_label(energy_frame, {"text": "Kinetic:   "}, {"row":0, "column":0})
        self._add_label(energy_frame, {"text": "Potential: "}, {"row":0, "column":1})
        self._add_label(energy_frame, {"text": "Total:     "}, {"row":0, "column":2})
        
        self.i_ke   = self._add_label(energy_frame, {"text": ""}, {"row":1, "column":0})
        self.i_pe   = self._add_label(energy_frame, {"text": ""}, {"row":1, "column":1})
        self.i_etot = self._add_label(energy_frame, {"text": ""}, {"row":1, "column":2})
        
        forces_frame = self._add_frame(geometry_frame, "Forces", sunken, dict(row=2, column=0, sticky="news",padx=5,pady=5))
        self._add_label(forces_frame, {"text": "AB:        "}, {"row":0, "column":0})
        self._add_label(forces_frame, {"text": "BC:        "}, {"row":0, "column":1})
        self._add_label(forces_frame, {"text": "Total:     "}, {"row":0, "column":2})
        
        self.i_fab  = self._add_label(forces_frame, {"text": ""}, {"row":1, "column":0})
        self.i_fbc  = self._add_label(forces_frame, {"text": ""}, {"row":1, "column":1})
        self.i_ftot = self._add_label(forces_frame, {"text": ""}, {"row":1, "column":2})
        
        hessian_frame = self._add_frame(geometry_frame, "Hessian", sunken, dict(row=3, column=0, sticky="news",padx=5,pady=5))
        self._add_label(hessian_frame, {"text": "1:         "}, {"row":0, "column":1})
        self._add_label(hessian_frame, {"text": "2:         "}, {"row":0, "column":2})
        self._add_label(hessian_frame, {"text": "Eigenvalue:"}, {"row":1, "column":0})
        self._add_label(hessian_frame, {"text": "AB Vector: "}, {"row":2, "column":0})
        self._add_label(hessian_frame, {"text": "BC Vector: "}, {"row":3, "column":0})
        
        self.i_eval1 = self._add_label(hessian_frame, {"text": ""}, {"row":1, "column":1})
        self.i_eval2 = self._add_label(hessian_frame, {"text": ""}, {"row":1, "column":2})
        
        self.i_evec11 = self._add_label(hessian_frame, {"text": ""}, {"row":2, "column":1})
        self.i_evec12 = self._add_label(hessian_frame, {"text": ""}, {"row":2, "column":2})
        self.i_evec21 = self._add_label(hessian_frame, {"text": ""}, {"row":3, "column":1})
        self.i_evec22 = self._add_label(hessian_frame, {"text": ""}, {"row":3, "column":2})
        
        ###First Run###
        
        # Initialise params and info
        self.get_params()
        self.update_geometry_info()
        
        #Plot
        warnings.filterwarnings("ignore")
        self.fig = plt.figure('Plot', figsize=(5,5))
        self.update_plot()
        
        #Make sure all plots are closed on exit
        def cl():            
            plt.close('all')
            self.root.destroy()
            
        self.root.protocol("WM_DELETE_WINDOW", cl)
        self.root.mainloop()
        
    def _read_entries(self): 
        """Read entries from GUI, process and set attributes"""
        for key, l in self.entries.items():
            entry, type, procfunc = l
            try:
                val = self._cast(entry, type)
                if procfunc:
                    val = procfunc(val)
                setattr(self, key, val)
            except:
                pass
            
    def _cast(self, entry, type): 
        """Read entry and cast to type"""
        val = type(entry.get())
        return val
            
    def _add_frame(self, parent, text=None, frame_kwargs={}, grid_kwargs={}):
        """Insert a frame (box) into parent.
        With text, a labelled frame is used"""
        if text:
            frame = tk.LabelFrame(parent, text = text, **frame_kwargs)
        else:
            frame = tk.Frame(parent, **frame_kwargs)
        frame.grid(**grid_kwargs)
        return frame
        
    def _add_label(self, frame, text_kwargs={}, grid_kwargs={}, config_kwargs={}):
        """Insert a label"""
        label = tk.Label(frame, **text_kwargs)
        label.grid(**grid_kwargs)
        label.config(**config_kwargs)
        return label
        
    def _add_scale(self, frame, key, scale_kwargs={}, grid_kwargs={}, config_kwargs={}):
        """Insert a scrollable bar"""
        val, vtype, procfunc = self.defaults[key]
        variable = tk.StringVar()
        variable.set(val)
        
        scale = tk.Scale(frame, **scale_kwargs)
        scale.set(variable.get())
        scale.grid(**grid_kwargs)
        scale.config(**config_kwargs)
        scale.grid_columnconfigure(0, weight = 1)
        
        self.entries[key] = [scale, vtype, procfunc]

    def _add_button(self, frame, button_kwargs={}, grid_kwargs={}, bind_kwargs={}, config_kwargs={}):
        "Insert a button"""
        button = tk.Button(frame, **button_kwargs)
        button.grid(**grid_kwargs)
        for k, v in bind_kwargs.items():
            button.bind(k, v)
        button.config(bg = "blue", **config_kwargs)
        
    def _add_entry(self, frame, key, entry_kwargs={}, grid_kwargs={}, config_kwargs={}, attach_func=None):
        """Add a text entry"""
        val, vtype, procfunc = self.defaults[key]
        variable = tk.StringVar()
        variable.set(val)
        if attach_func:
            variable.trace("w", attach_func)
        
        entry = tk.Entry(frame, textvariable=variable, **entry_kwargs)
        entry.grid(**grid_kwargs)
        entry.config(**config_kwargs)
        
        self.entries[key] = [entry, vtype, procfunc]
        
    def _add_optionmenu(self, frame, key, items, optionmenu_kwargs={}, grid_kwargs={}, config_kwargs={}):
        """Add a dropdown menu"""
        val, vtype, procfunc = self.defaults[key]
        variable = tk.StringVar()
        variable.set(val)
        
        optionmenu = tk.OptionMenu(frame, variable, *items, **optionmenu_kwargs)
        optionmenu.grid(**grid_kwargs)
        optionmenu.config(**config_kwargs)
        
        self.entries[key] = [variable, vtype, procfunc]
        
    def _add_radio(self, frame, key, radio_kwargs={}, grid_kwargs={}, config_kwargs={}, variable=None):
        """Add a radio button"""
        val, vtype, procfunc = self.defaults[key]
        if variable is None:
            variable = tk.StringVar()
            variable.set(val)
        
        radio  = tk.Radiobutton(frame, variable=variable, **radio_kwargs)
        radio.grid(**grid_kwargs)
        radio.config(**config_kwargs)
        
        self.entries[key] = [radio, vtype, procfunc]
            
    def get_params(self):
        """This gets parameters for a given set of atoms"""
        #Params
        try:
            ma,mb,mc,Drab,Drbc,Drac,lrab,lrbc,lrac,Brab,Brbc,Brac,mina,maxa,minb,maxb,cont = params(self.a,self.b,self.c)
        except Exception:
            msgbox.showerror("Error", "Parameters for this atom combination not available!")
            raise
        
        #Masses
        self.ma   = ma
        self.mb   = mb
        self.mc   = mc
        
        #LEPS Parameters
        self.Drab = Drab
        self.Drbc = Drbc
        self.Drac = Drac
        self.lrab = lrab
        self.lrbc = lrbc
        self.lrac = lrac
        self.Brab = Brab
        self.Brbc = Brbc
        self.Brac = Brac
        
        #Plot parameters
        self.mina = mina
        self.maxa = maxa
        self.minb = minb
        self.maxb = maxb
        self.cont = cont
        
        # Reduced masses.
        self.mab = (ma * mb) / (ma + mb)
        self.mbc = (mb * mc) / (mb + mc)
        self.mac = (ma * mc) / (ma + mc)
            
    def get_surface(self):
        """Get Vmat (potential) for a given set of parameters"""
        self.get_params()
        
        #Check if params have changed. If not, no need to recalculate
        new_params = [self.a, self.b, self.c, self.theta]
        if self.old_params == new_params and not self._firstrun:
            return

        resl = 0.02 #Resolution
        grad = 0    #Gradient calc type (0 = energy)
        self._firstrun = False
        
        #Get grid
        self.x = np.arange(self.mina,self.maxa,resl)
        self.y = np.arange(self.minb,self.maxb,resl)
        self.Vmat = np.zeros((len(self.y), len(self.x)))
        
        #Calculate potential for each gridpoint
        for drabcount, drab in enumerate(self.x):
            for drbccount, drbc in enumerate(self.y):
    
                V = lepspoint(
                    drab,
                    drbc,
                    np.deg2rad(self.theta),
                    self.Drab,
                    self.Drbc,
                    self.Drac,
                    self.Brab,
                    self.Brbc,
                    self.Brac,
                    self.lrab,
                    self.lrbc,
                    self.lrac,
                    self.H,
                    grad
                )
                self.Vmat[drbccount, drabcount] = V

        self.old_params = new_params
                        
    def get_trajectory(self):
        """Get dynamics, MEP or optimisation"""
        
        itlimit = self.steps #Max number of steps
        dt      = self.dt    #Time step
        ti      = 0          #Initial time variable
        tf      = 0          #Final time variable
        
        ma   = self.ma       #Mass of A
        mb   = self.mb       #Mass of B
        mc   = self.mc       #Mass of C
        
        mab  = self.mab      #Reduced mass of AB
        mbc  = self.mbc      #Reduced mass of BC
        
        Drab = self.Drab
        Drbc = self.Drbc
        Drac = self.Drac
        
        lrab = self.lrab
        lrbc = self.lrbc
        lrac = self.lrac
        
        Brab = self.Brab
        Brbc = self.Brbc
        Brac = self.Brac
        
        thetai = np.deg2rad(self.theta) #Collision Angle
        grad = 2 #Calculating gradients and Hessian
        
        xrabi = self.xrabi   #Initial AB separation
        xrbci = self.xrbci   #Initial BC separation
        xraci = ((xrabi ** 2) + (xrbci ** 2) - 2 * xrabi * xrbci * np.cos(thetai)) ** 0.5 #Initial AC separation
        prabi = self.prabi   #Initial AB momentum
        prbci = self.prbci   #Initial BC momentum
        
        vrabi = prabi / mab  #Initial AB Velocity
        vrbci = prbci / mbc  #Initial BC Velocity
        
        #Positions of A, B and C relative to B
        a = np.array([- xrabi, 0.])
        b = np.array([0., 0.])
        c = np.array([- np.cos(thetai) * xrbci, np.sin(thetai) * xrbci])
        
        #Get centre of mass
        com = (a * ma + b * mb + c * mc) / (ma + mb + mc)
        com = np.real(com)
        
        #Translate to centre of mass (for animation)
        a -= com
        b -= com
        c -= com
        
        self.ra    = [a]
        self.rb    = [b]
        self.rc    = [c]
        
        #Initial AC Velocity
        va = (a - self.ra[-1]) / dt
        vc = (c - self.rc[-1]) / dt
        vraci = np.linalg.norm(va + vc)
        
        #Initialise outputs
        self.xrab  = [xrabi]
        self.xrbc  = [xrbci]
        self.xrac  = [xraci]
        self.vrab  = [vrabi]
        self.vrbc  = [vrbci]
        self.vrac  = [vraci]
        self.t     = [ti]
    
        self.Vrint = Vrint = []
        self.Ktot  = Ktot  = []

        self.Frab  = Frab  = []
        self.Frbc  = Frbc  = []
        self.Frac  = Frac  = []
        self.arab  = arab  = []
        self.arbc  = arbc  = []
        self.arac  = arac  = []
        self.etot  = etot  = []
    
        self.hr1r1 = hr1r1 = []
        self.hr1r2 = hr1r2 = []
        self.hr1r3 = hr1r3 = []
        self.hr2r2 = hr2r2 = []
        self.hr2r3 = hr2r3 = []
        self.hr3r3 = hr3r3 = []
        
        #Flag to stop appending to output in case of a crash
        terminate = False        

        for itcounter in range(itlimit):
            if self.calc_type != "Dynamics":
                vrabi = 0
                vrbci = 0
                vraci = 0
            
            #Get current potential, forces, and Hessian
            Vrinti,Frabi,Frbci,Fraci,hr1r1i,hr1r2i,hr1r3i,hr2r2i,hr2r3i,hr3r3i = lepspoint(xrabi,xrbci,thetai,Drab,Drbc,Drac,Brab,Brbc,Brac,lrab,lrbc,lrac,self.H,grad)
            Vrint.append(Vrinti)
            Frab.append(Frabi)
            Frbc.append(Frbci)
            Frac.append(Fraci)
            hr1r1.append(hr1r1i)
            hr1r2.append(hr1r2i)
            hr1r3.append(hr1r3i)
            hr2r2.append(hr2r2i)
            hr2r3.append(hr2r3i)
            hr3r3.append(hr3r3i)
            
            if self.calc_type in ["Opt Min", "Opt TS"]: #Optimisation calculations
                
                #Diagonalise Hessian
                hessian = np.array([[hr1r1i, hr1r2i, hr1r3i], [hr1r2i, hr2r2i, hr2r3i], [hr1r3i, hr2r3i, hr3r3i]])
                eigenvalues, eigenvectors = np.linalg.eig(hessian)
                
                #Get forces for opt calculation
                forces = np.array([Frabi, Frbci, Fraci])
                
                #Eigenvalue test
                neg_eig_i = [i for i,eig in enumerate(eigenvalues) if eig < -0.01]
                if len(neg_eig_i) == 0 and self.calc_type == "Opt TS":
                    msgbox.showinfo("Eigenvalues Info", "No negative eigenvalues at this geometry")
                    terminate = True
                elif len(neg_eig_i) == 1 and self.calc_type == "Opt Min":
                    msgbox.showerror("Eigenvalues Error", "Too many negative eigenvalues at this geometry")
                    terminate = True                    
                elif len(neg_eig_i) > 1:
                    msgbox.showerror("Eigenvalues Error", "Too many negative eigenvalues at this geometry")
                    terminate = True
                
                #Optimiser
                disps = np.array([0.,0.,0.])
                for mode in range(len(eigenvalues)):
                    e_val = eigenvalues[mode]
                    e_vec = eigenvectors[mode]

                    disp = np.dot(np.dot((e_vec.T), forces), e_vec) / e_val
                    disps += disp
                    
                xrabf = xrabi + disps[0]
                xrbcf = xrbci + disps[1]
                xracf = ((xrabf ** 2) + (xrbcf ** 2) - 2 * xrabf * xrbcf * np.cos(thetai)) ** 0.5

                arabi  = 0
                arbci  = 0
                araci  = 0
                thetaf = thetai
                vrabf  = 0
                vrbcf  = 0
                vracf  = 0
                Ktoti  = 0
                tf    += dt
                
            else: #Dynamics/MEP
                try:
                    xrabf,xrbcf,xracf,thetaf,vrabf,vrbcf,vracf,tf,arabi,arbci,araci,Ktoti = lepnorm(xrabi,xrbci,thetai,Frabi,Frbci,Fraci,vrabi,vrbci,vraci,hr1r1i,hr1r2i,hr1r3i,hr2r2i,hr2r3i,hr3r3i,ma,mb,mc,ti,dt,self.calc_type == "MEP")
                except LinAlgError:
                    msgbox.showerror("Surface Error", "Energy could not be evaulated at step {}. Steps truncated".format(itcounter + 1))
                    terminate = True
                
            if xrabf > self.lim or xrbcf > self.lim: #Stop calc if distance lim is exceeded
                msgbox.showerror("Surface Error", "Surface Limits exceeded at step {}. Steps truncated".format(itcounter + 1))
                terminate = True
                    
            arab.append(arabi)
            arbc.append(arbci)
            arac.append(araci)
            Ktot.append(Ktoti)
            
            #Total energy
            etot.append(Vrint[itcounter] + Ktot[itcounter])
            
            if itcounter != itlimit - 1 and not terminate:
                
                #As above
                a = np.array([- xrabf, 0.])
                b = np.array([0., 0.])
                c = np.array([- np.cos(thetaf) * xrbcf, np.sin(thetaf) * xrbcf])
                com = (a * ma + b * mb + c * mc) / (ma + mb + mc)

                com = np.real(com)
                
                a -= com
                b -= com
                c -= com
                
                #Get A-C Velocity
                r0 = np.linalg.norm(self.rc[-1] - self.ra[-1])
                r1 = np.linalg.norm(c - a)
                vrac = (r1 - r0) / dt
                
                self.ra.append(a)                       #A  Pos
                self.rb.append(b)                       #B  Pos
                self.rc.append(c)                       #C  Pos
                self.xrab.append(xrabf)                 #A-B Distance
                self.xrbc.append(xrbcf)                 #B-C Distance
                self.xrac.append(xracf)                 #A-C Distance
                self.vrab.append(vrabf)                 #A-B Velocity
                self.vrbc.append(vrbcf)                 #B-C Velocity
                self.vrac.append(vrac)                  #A-C Velocity
                self.t.append(tf)                       #Time
                

            xrabi = xrabf
            xrbci = xrbcf
            xraci = xracf
            vrabi = vrabf
            vrbci = vrbcf
            vraci = vracf
            ti = tf
            
            if terminate:
                break
        
    def get_last_geo(self, *args):
        """Copy last geometry and momenta"""
        self.entries["xrabi"][0].delete(0, tk.END)
        self.entries["xrabi"][0].insert(0, self.xrab[-1])
        
        self.entries["xrbci"][0].delete(0, tk.END)
        self.entries["xrbci"][0].insert(0, self.xrbc[-1])
        
        self.entries["prabi"][0].delete(0, tk.END)
        self.entries["prabi"][0].insert(0, self.vrab[-1] * self.mab)
        
        self.entries["prbci"][0].delete(0, tk.END)
        self.entries["prbci"][0].insert(0, self.vrbc[-1] * self.mbc)
            
    def export(self, *args):
        """Run calculation and print output in CSV format"""
        self._read_entries()
        self.get_trajectory()
        
        filename = asksaveasfilename(defaultextension=".csv")
        if not filename:
            return
            
        sources = [
            ["Time",            self.t                           ],
            ["AB Distance",     self.xrab                        ],
            ["BC Distance",     self.xrbc                        ],
            ["AC Distance",     self.xrac                        ],
            ["AB Velocity",     self.vrab                        ],
            ["BC Velocity",     self.vrbc                        ],
            ["AC Velocity",     self.vrac                        ],
            ["AB Momentum",     [v * self.mab for v in self.vrab]],
            ["BC Momentum",     [v * self.mab for v in self.vrab]],
            ["AC Momentum",     [v * self.mab for v in self.vrab]],
            ["AB Force",        self.Frab                        ],
            ["BC Force",        self.Frbc                        ],
            ["AC Force",        self.Frac                        ],
            ["Total Potential", self.Vrint                       ],
            ["Total Kinetic",   self.Ktot                        ],
            ["Total Energy",    self.etot                        ],
            ["AB AB Hess Comp", self.hr1r1                       ],
            ["AB BC Hess Comp", self.hr1r2                       ],
            ["AB AC Hess Comp", self.hr1r3                       ],
            ["BC BC Hess Comp", self.hr2r2                       ],
            ["BC AC Hess Comp", self.hr2r3                       ],
            ["AC AC Hess Comp", self.hr3r3                       ]
        ]
        
        out = ",".join([t for t, s in sources]) + "\n"
        
        for step in range(len(self.t)):
            data = []
            for t, s in sources:
                try:
                    point = str(s[step])
                except:
                    point = ""
                data.append(point)
            out += ",".join(data) + "\n"
        
        with open(filename, "w") as f:
            f.write(out)
            
    def update_plot(self, *args):
        """Generate plot based on what type has been selected"""
        self._read_entries()
        self.get_surface()
        self.get_trajectory()
        
        if self.plot_type == "Contour Plot":
            self.plot_contour()
            self.plot_init_pos()
        elif self.plot_type == "Surface Plot":
            self.plot_surface()
            self.plot_init_pos()
        elif self.plot_type == "Internuclear Distances vs Time":
            self.plot_ind_vs_t()
        elif self.plot_type == "Internuclear Momenta vs Time":
            self.plot_inm_vs_t()
        elif self.plot_type == "Energy vs Time":
            self.plot_e_vs_t()
        elif self.plot_type == "p(AB) vs p(BC)":
            self.plot_momenta()
        elif self.plot_type == "v(AB) vs v(BC)":
            self.plot_velocities()
        elif self.plot_type == "Animation":
            self.animation()
            
    def plot_contour(self):    
        """Contour Plot"""
        plt.clf()
        ax = plt.gca()
        ax.get_xaxis().get_major_formatter().set_useOffset(False)
        ax.get_yaxis().get_major_formatter().set_useOffset(False)
        
        plt.xlabel("AB Distance")
        plt.ylabel("BC Distance")
        
        X, Y = np.meshgrid(self.x, self.y)
        
        levels = np.arange(np.min(self.Vmat) -1, float(self.cutoff), self.spacing)
        plt.contour(X, Y, self.Vmat, levels = levels)
        plt.xlim([min(self.x),max(self.x)])
        plt.ylim([min(self.y),max(self.y)])
        
        lc = colorline(self.xrab, self.xrbc, cmap = plt.get_cmap("jet"), linewidth=1)
        
        ax.add_collection(lc)
        plt.draw()
        plt.pause(0.0001) #This stops MPL from blocking
        
    def plot_surface(self):
        """3d Surface Plot"""
        
        plt.close('all') #New figure needed for 3D axes
        self.fig_3d = plt.figure('Surface Plot', figsize=(5,5))
        
        ax = Axes3D(self.fig_3d)
        
        plt.xlabel("AB Distance")
        plt.ylabel("BC Distance")
        
        X, Y = np.meshgrid(self.x, self.y)
        ax.set_xlim3d([min(self.x),max(self.x)])
        ax.set_ylim3d([min(self.y),max(self.y)])
        
        Z = np.clip(self.Vmat, -10000, self.cutoff)
        
        ax.plot_surface(X, Y, Z, rstride=self.spacing, cstride=self.spacing, cmap='jet', alpha=0.3, linewidth=0)
        ax.contour(X, Y, Z, zdir='z', cmap='jet', stride=self.spacing, offset=np.min(Z) - 10)
        ax.plot(self.xrab, self.xrbc, self.Vrint)
         
        plt.draw()
        plt.pause(0.0001)
        
    def plot_ind_vs_t(self):
        """Internuclear Distances VS Time"""
        plt.clf()
        ax = plt.gca()
        ax.get_xaxis().get_major_formatter().set_useOffset(False)
        ax.get_yaxis().get_major_formatter().set_useOffset(False)
        
        plt.xlabel("Time")
        plt.ylabel("Distance")
        
        ab, = plt.plot(self.t, self.xrab, label = "A-B")
        bc, = plt.plot(self.t, self.xrbc, label = "B-C")
        ac, = plt.plot(self.t, self.xrac, label = "A-C")
        
        plt.legend(handles=[ab, bc, ac])
        
        plt.draw()
        plt.pause(0.0001)
        
    def plot_inm_vs_t(self):
        """Internuclear Momenta VS Time"""
        plt.clf()
        ax = plt.gca()
        ax.get_xaxis().get_major_formatter().set_useOffset(False)
        ax.get_yaxis().get_major_formatter().set_useOffset(False)
        
        plt.xlabel("Time")
        plt.ylabel("Momentum")
        
        momab = [v * self.mab for v in self.vrab]
        mombc = [v * self.mbc for v in self.vrbc]
        momac = [v * self.mac for v in self.vrac]

        ab, = plt.plot(self.t, momab, label = "A-B")
        bc, = plt.plot(self.t, mombc, label = "B-C")
        ac, = plt.plot(self.t, momac, label = "A-C")
        
        plt.legend(handles=[ab, bc, ac])
        
        plt.draw()
        plt.pause(0.0001)      
        
    def plot_momenta(self):
        """AB Momentum VS BC Momentum"""
        plt.clf()
        ax = plt.gca()
        
        plt.xlabel("AB Momentum")
        plt.ylabel("BC Momentum")
        
        momab = [v * self.mab for v in self.vrab]
        mombc = [v * self.mbc for v in self.vrbc]

        lc = colorline(momab, mombc, cmap = plt.get_cmap("jet"), linewidth=1)
        
        ax.add_collection(lc)
        ax.autoscale()
        plt.draw()
        plt.pause(0.0001)
        
    def plot_velocities(self):
        """AB Velocity VS BC Velocity"""
        plt.clf()
        ax = plt.gca()
        
        plt.xlabel("AB Velocity")
        plt.ylabel("BC Velocity")
        
        lc = colorline(self.vrab, self.vrbc, cmap = plt.get_cmap("jet"), linewidth=1)
        
        ax.add_collection(lc)
        ax.autoscale()
        plt.draw()
        plt.pause(0.0001)
        
    def plot_e_vs_t(self):
        """Energy VS Time"""
        plt.clf()
        ax = plt.gca()
        ax.get_xaxis().get_major_formatter().set_useOffset(False)
        ax.get_yaxis().get_major_formatter().set_useOffset(False)
        
        plt.xlabel("Time")
        plt.ylabel("Energy")

        pot, = plt.plot(self.t, self.Vrint, label = "Potential Energy")
        kin, = plt.plot(self.t, self.Ktot,  label = "Kinetic Energy")
        
        plt.legend(handles=[pot, kin])
        
        plt.draw()
        plt.pause(0.0001)
        
    def animation(self):
        """Animation"""
        plt.close('all')
        self.ani_fig = plt.figure('Animation', figsize=(5,5))
        
        def init():
            ap, bp, cp = patches
            ax.add_patch(ap)
            ax.add_patch(bp)
            ax.add_patch(cp)
            return ap, bp, cp,
            
        def update(i):
            ap, bp, cp = patches
            ap.center = self.ra[i]
            bp.center = self.rb[i]
            cp.center = self.rc[i]
            return ap, bp, cp,
            
        ax = plt.axes(
        xlim = (min(self.ra, key=lambda x: x[0])[0] - 1, max(self.rc, key=lambda x: x[0])[0] + 1),
        ylim = (min(self.ra, key=lambda x: x[1])[1] - 1, max(self.rc, key=lambda x: x[1])[1] + 1)
        )
        ax.set_aspect('equal')
            
        patches = []
        
        for at_name in ["a", "b", "c"]:
            at = self.entries[at_name][0].get()
            index, vdw, c = self.atom_map[at]
            pos = getattr(self, "r" + at_name)[0]
            patch = plt.Circle(pos, vdw * 0.25, fc = c)
            patches.append(patch)
        
        self.anim = FuncAnimation(self.ani_fig, update, init_func=init, frames=len(self.ra), repeat=True, interval=20)
        
        try:
            plt.show()
        except:
            pass
        
    def plot_init_pos(self):
        """Cross representing initial geometry"""
        if not self.plot_type == "Contour Plot":
            return
            
        self.init_pos_plot, = plt.plot([self.xrabi], [self.xrbci], marker='x', markersize=6, color="red")
        plt.draw()
        plt.pause(0.0001)
        
    def get_first(self):
        """1 step of trajectory to get geometry properties"""
        self._read_entries()
        self.get_params()
        dt   = self.dt
        
        ma   = self.ma
        mb   = self.mb
        mc   = self.mc
        
        mab  = self.mab
        mbc  = self.mbc
        
        Drab = self.Drab
        Drbc = self.Drbc
        Drac = self.Drac
        
        lrab = self.lrab
        lrbc = self.lrbc
        lrac = self.lrac
        
        Brab = self.Brab
        Brbc = self.Brbc
        Brac = self.Brac
        
        xrabi = self.xrabi
        xrbci = self.xrbci
        prabi = self.prabi
        prbci = self.prbci
        
        thetai = np.deg2rad(self.theta)
        
        vrabi = prabi / mab
        vrbci = prbci / mbc
        vraci = 0
        
        grad = 2
        ti = 0
        
        Vrinti,Frabi,Frbci,Fraci,hr1r1i,hr1r2i,hr1r3i,hr2r2i,hr2r3i,hr3r3i = lepspoint(xrabi,xrbci,thetai,Drab,Drbc,Drac,Brab,Brbc,Brac,lrab,lrbc,lrac,self.H,grad)
        xrabf,xrbcf,xracf,thetaf,vrabf,vrbcf,vracf,tf,arabi,arbci,araci,Ktoti = lepnorm(xrabi,xrbci,thetai,Frabi,Frbci,Fraci,vrabi,vrbci,vraci,hr1r1i,hr1r2i,hr1r3i,hr2r2i,hr2r3i,hr3r3i,ma,mb,mc,ti,dt,False)
        
        return Vrinti,Frabi,Frbci,Fraci,hr1r1i,hr1r2i,hr1r3i,hr2r2i,hr2r3i,hr3r3i,xrabf,xrbcf,xracf,thetaf,vrabf,vrbcf,vracf,tf,arabi,arbci,araci,Ktoti
        
    def update_geometry_info(self, *args):
        """Updates the info pane"""
        try:
            Vrinti,Frabi,Frbci,Fraci,hr1r1,hr1r2,hr1r3,hr2r2,hr2r3,hr3r3,xrabf,xrbcf,xracf,thetaf,vrabf,vrbcf,vracf,tf,arabi,arbci,araci,Ktoti = self.get_first()
            hessian = np.array([[hr1r1, hr1r2, hr1r3], [hr1r2, hr2r2, hr2r3], [hr1r3, hr2r3, hr3r3]])
            eigenvalues, eigenvectors = np.linalg.eig(hessian)
            ke     = "{:+7.3f}".format(Ktoti)
            pe     = "{:+7.3f}".format(Vrinti)
            etot   = "{:+7.3f}".format(Vrinti + Ktoti)
            fab    = "{:+7.3f}".format(Frabi)
            fbc    = "{:+7.3f}".format(Frbci)
            
            eval1  = "{:+7.3f}".format(eigenvalues[0])
            eval2  = "{:+7.3f}".format(eigenvalues[1])
            
            evec11 = "{:+7.3f}".format(eigenvectors[0][0])
            evec12 = "{:+7.3f}".format(eigenvectors[0][1])
            evec21 = "{:+7.3f}".format(eigenvectors[1][0])
            evec22 = "{:+7.3f}".format(eigenvectors[1][1])
            
        except:
            ke     = "       "
            pe     = "       "
            etot   = "       "
            fab    = "       "
            fbc    = "       "
            eval1  = "       "
            eval2  = "       "
            evec11 = "       "
            evec12 = "       "
            evec21 = "       "
            evec22 = "       "
            
        self.i_ke["text"] = ke
        self.i_pe["text"] = pe
        self.i_etot["text"] = etot
        
        self.i_fab["text"] = fab
        self.i_fbc["text"] = fbc
        
        self.i_eval1["text"] = eval1
        self.i_eval2["text"] = eval2
        
        self.i_evec11["text"] = evec11
        self.i_evec12["text"] = evec12
        
        self.i_evec21["text"] = evec21
        self.i_evec22["text"] = evec22      
        
def colorline(
    x, y, z=None, cmap=plt.get_cmap('copper'), norm=plt.Normalize(0.0, 1.0),
        linewidth=3, alpha=1.0):
    """
    http://nbviewer.ipython.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
    http://matplotlib.org/examples/pylab_examples/multicolored_line.html
    Plot a colored line with coordinates x and y
    Optionally specify colors in the array z
    Optionally specify a colormap, a norm function and a line width
    """

    # Default colors equally spaced on [0,1]:
    if z is None:
        z = np.linspace(0.0, 1.0, len(x))

    # Special case if a single number:
    if not hasattr(z, "__iter__"):  # to check for numerical input -- this is a hack
        z = np.array([z])

    z = np.asarray(z)

    segments = make_segments(x, y)
    lc = mcoll.LineCollection(segments, array=z, cmap=cmap, norm=norm,
                              linewidth=linewidth, alpha=alpha)


    return lc


def make_segments(x, y):
    """
    Create list of line segments from x and y coordinates, in the correct format
    for LineCollection: an array of the form numlines x (points per line) x 2 (x
    and y) array
    """

    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    return segments
        
if __name__ == "__main__":
    
    interactive = Interactive()
    
