__version__   = '1.0.2'
__author__    = "Avinash Kak (kak@purdue.edu)"
__date__      = '2020-January-12'   
__url__       = 'https://engineering.purdue.edu/kak/distCGP/ComputationalGraphPrimer-1.0.2.html'
__copyright__ = "(C) 2020 Avinash Kak. Python Software Foundation."



import sys,os,os.path
import numpy as np
import re
import math
import random
import copy
import matplotlib.pyplot as plt
import networkx as nx

#______________________________  ComputationalGraphPrimer Class Definition  ________________________________

class ComputationalGraphPrimer(object):

    def __init__(self, *args, **kwargs ):
        if args:
            raise ValueError(  
                   '''ComputationalGraphPrimer constructor can only be called with keyword arguments for 
                      the following keywords: expressions, output_vars, dataset_size, grad_delta,
                      learning_rate, display_vals_how_often, and debug''')
        expressions = output_vars = dataset_size = grad_delta = display_vals_how_often = learning_rate = debug  = None
        if 'expressions' in kwargs                   :   expressions = kwargs.pop('expressions')
        if 'output_vars' in kwargs                   :   output_vars = kwargs.pop('output_vars')
        if 'dataset_size' in kwargs                  :   dataset_size = kwargs.pop('dataset_size')
        if 'learning_rate' in kwargs                 :   learning_rate = kwargs.pop('learning_rate')
        if 'grad_delta' in kwargs                    :   grad_delta = kwargs.pop('grad_delta')
        if 'display_vals_how_often' in kwargs        :   display_vals_how_often = kwargs.pop('display_vals_how_often')
        if 'debug' in kwargs                         :   debug = kwargs.pop('debug') 
        if len(kwargs) != 0: raise ValueError('''You have provided unrecognizable keyword args''')
        if expressions:
            self.expressions = expressions
        else:
            sys.exit("you need to supply a list of expressions")
        if output_vars:
            self.output_vars = output_vars
        if dataset_size:
            self.dataset_size = dataset_size
        if learning_rate:
            self.learning_rate = learning_rate
        else:
            self.learning_rate = 1e-6
        if grad_delta:
            self.grad_delta = grad_delta
        else:
            self.grad_delta = 1e-4
        if display_vals_how_often:
            self.display_vals_how_often = display_vals_how_often
        self.dataset_input_samples  = {i : None for i in range(dataset_size)}
        self.true_output_vals       = {i : None for i in range(dataset_size)}
        self.vals_for_learnable_params = None
        if debug:                             
            self.debug = debug
        else:
            self.debug = 0
        self.independent_vars = None
        self.gradient_of_loss = None
        self.gradients_for_learnable_params = None
        self.expressions_dict = {}
        self.LOSS = []                               ##  loss values for all iterations of training
        self.all_vars = set()
        self.independent_vars = set()
        self.dependent_vars = {}
        self.learnable_params = set()
        self.depends_on = {}                         ##  See Introduction for the meaning of this 
        self.leads_to = {}                           ##  See Introduction for the meaning of this 
    
    def parse_expressions(self):
        ''' 
        This method creates a DAG from a set of expressions that involve variables and learnable
        parameters. The expressions are based on the assumption that a symbolic name that starts
        with the letter 'x' is a variable, with all other symbolic names being learnable parameters.
        The computational graph is represented by two dictionaries, 'depends_on' and 'leads_to'.
        To illustrate the meaning of the dictionaries, something like "depends_on['xz']" would be
        set to a list of all other variables whose outgoing arcs end in the node 'xz'.  So 
        something like "depends_on['xz']" is best read as "node 'xz' depends on ...." where the
        dots stand for the array of nodes that is the value of "depends_on['xz']".  On the other
        hand, the 'leads_to' dictionary has the opposite meaning.  That is, something like
        "leads_to['xz']" is set to the array of nodes at the ends of all the arcs that emanate
        from 'xz'.
        '''
        for exp in self.expressions:
            left,right = exp.split('=')
            self.all_vars.add(left)
            self.expressions_dict[left] = right
            self.depends_on[left] = []
            parts = re.findall('([a-zA-Z]+)', right)
            for part in parts:
                if part.startswith('x'):
                    self.all_vars.add(part)
                    self.depends_on[left].append(part)
                else:
                    self.learnable_params.add(part)
        if self.debug:
            print("\n\nall variables: %s" % str(self.all_vars))
            print("\n\nlearnable params: %s" % str(self.learnable_params))
            print("\n\ndependencies: %s" % str(self.depends_on))
            print("\n\nexpressions dict: %s" % str(self.expressions_dict))
        for var in self.all_vars:
            if var not in self.depends_on:              # that is, var is not a key in the depends_on dict
                self.independent_vars.add(var)
        if self.debug:
            print("\n\nindependent vars: %s" % str(self.independent_vars))
        self.dependent_vars = [var for var in self.all_vars if var not in self.independent_vars]
        self.leads_to = {var : set() for var in self.all_vars}
        for k,v in self.depends_on.items():
            for var in v:
                self.leads_to[var].add(k)    

    def display_network1(self):
        G = nx.DiGraph()
        G.add_nodes_from(self.all_vars)
        edges = []
        for ver1 in self.leads_to:
            for ver2 in self.leads_to[ver1]:
                edges.append( (ver1,ver2) )
        G.add_edges_from( edges )
        nx.draw(G, with_labels=True, font_weight='bold')
        plt.show()

    def display_network2(self):
        '''
        Provides a fancier display of the network graph
        '''
        G = nx.DiGraph()
        G.add_nodes_from(self.all_vars)
        edges = []
        for ver1 in self.leads_to:
            for ver2 in self.leads_to[ver1]:
                edges.append( (ver1,ver2) )
        G.add_edges_from( edges )
        pos = nx.circular_layout(G)    
        nx.draw(G, pos, with_labels = True, edge_color='b', node_color='lightgray', 
                          arrowsize=20, arrowstyle='fancy', node_size=1200, font_size=20, 
                          font_color='black')
        plt.title("Computational graph for the expressions")
        plt.show()


    def train_on_all_data(self):
        '''
        The purpose of this method is to call forward_propagate_one_input_sample_with_partial_deriv_calc()
        repeatedly on all input/output ground-truth training data pairs generated by the method 
        gen_gt_dataset().  The call to the forward_propagate...() method returns the predicted value
        at the output nodes from the supplied values at the input nodes.  The "train_on_all_data()"
        method calculates the error associated with the predicted value.  The call to
        forward_propagate...() also returns the partial derivatives estimated by using the finite
        difference method in the computational graph.  Using the partial derivatives, the 
        "train_on_all_data()" backpropagates the loss to the interior nodes in the computational graph
        and updates the values for the learnable parameters.
        '''
        self.vals_for_learnable_params = {var: random.uniform(0,1) for var in self.learnable_params}
        print("\n\n\nvalues for all learnable parameters: %s" % str(self.vals_for_learnable_params))
        for sample_index in range(self.dataset_size):
            if sample_index % self.display_vals_how_often == 0:
                print("\n\n\n=========  [Forward Propagation] Training with sample indexed: %d ============" % sample_index)
            input_vals_for_ind_vars = {var: self.dataset_input_samples[sample_index][var] for var in self.independent_vars}
            predicted_output_vals, partial_var_to_param, partial_var_to_var = \
         self.forward_propagate_one_input_sample_with_partial_deriv_calc(sample_index, input_vals_for_ind_vars)
            error = [self.true_output_vals[sample_index][var] - predicted_output_vals[var] for var in self.output_vars]
            loss = np.linalg.norm(error)
            if self.debug:
                print("\n\n\nloss for training sample indexed %d: %s" % (sample_index, str(loss)))
            self.LOSS.append(loss)
            if sample_index % self.display_vals_how_often == 0:
                print("\n\n\nestimated partial derivatives of vars wrt learnable parameters:")
                for k,v in partial_var_to_param.items():
                    print("\nk=%s     v=%s" % (k, str(v)))
                print("\n\n\nestimated partial derivatives of vars wrt other vars:")
                for k,v in partial_var_to_var.items():
                    print("\nk=%s     v=%s" % (k, str(v)))
            paths = {param : [] for param in self.learnable_params}
            for var1 in partial_var_to_param:
                for var2 in partial_var_to_param[var1]:
                    for param in self.learnable_params:
                        if partial_var_to_param[var1][var2][param] is not None:
                            paths[param] += [var1,var2,param]
            for param in paths:
                node = paths[param][0]
                if node in self.output_vars: 
                    continue
                for var_out in self.output_vars:        
                    if node in self.depends_on[var_out]:
                        paths[param].insert(0,var_out) 
                    else:
                        for node2 in self.depends_on[var_out]:
                            if node in self.depends_on[node2]:
                                paths[param].insert(0,node2)
                                paths[param].insert(0,var_out)
            for param in self.learnable_params:
                product_of_partials = 1.0
                for i in range(len(paths[param]) - 2):
                    var1 = paths[param][i]
                    var2 = paths[param][i+1]
                    product_of_partials *= partial_var_to_var[var1][var2]
                if self.debug:
                    print("\n\nfor param=%s, product of partials: %s" % str(product_of_partials))
                product_of_partials *=  partial_var_to_param[var1][var2][param]
                self.vals_for_learnable_params[param] -=  self.learning_rate * product_of_partials
            if sample_index % self.display_vals_how_often == 0:
                    print("\n\n\nat sample index: %d, vals for learnable parameters: %s" % (sample_index, str(self.vals_for_learnable_params)))


    def forward_propagate_one_input_sample_with_partial_deriv_calc(self, sample_index, input_vals_for_ind_vars):
        '''
        If you want to look at how the information flows in the DAG when you don't have to worry about
        estimating the partial derivatives, see the method gen_gt_dataset().  As you will notice in the
        implementation code for that method, there is nothing much to pushing the input values through
        the nodes and the arcs of a computational graph if we are not concerned about estimating the
        partial derivatives.

        On the other hand, if you want to see how one might also estimate the partial derivatives as
        during the forward flow of information in a computational graph, the forward_propagate...()
        presented here is the method to examine.  We first split the expression that the node 
        variable depends on into its constituent parts on the basis of '+' and '-' operators and
        subsequently, for each part, we estimate the partial of the node variable with respect
        to the variables and the learnable parameters in that part.
        '''
        predicted_output_vals = {var : None for var in self.output_vars}
        vals_for_dependent_vars = {var: None for var in self.all_vars if var not in self.independent_vars}
        partials_var_to_param = {var : {var : {ele: None for ele in self.learnable_params} for var in self.all_vars} for var in self.all_vars}
        partials_var_to_var =  {var : {var : None for var in self.all_vars} for var in self.all_vars}       
        while any(v is None for v in [vals_for_dependent_vars[x] for x in vals_for_dependent_vars]):
            for var1 in self.all_vars:
                if var1 in self.dependent_vars and vals_for_dependent_vars[var1] is None: continue
                for var2 in self.leads_to[var1]:
                    if any([vals_for_dependent_vars[var] is None for var in self.depends_on[var2] if var not in self.independent_vars]): continue
                    exp = self.expressions_dict[var2]
                    learnable_params_in_exp = [ele for ele in self.learnable_params if ele in exp]
                    ##  in order to calculate the partials of the node (each node stands for a variable)
                    ##  values with respect to the learnable params, and, then, with respect to the 
                    ##  source vars, we must break the exp at '+' and '-' operators:
                    parts =  re.split(r'\+|-', exp)
                    if self.debug:
                        print("\n\n\n\n  ====for var2=%s =================   for exp=%s     parts: %s" % (var2, str(exp), str(parts)))
                    vals_for_parts = []
                    for part in parts:
                        splits_at_arith = re.split(r'\*|/', part)
                        if len(splits_at_arith) > 1:
                            operand1 = splits_at_arith[0]
                            operand2 = splits_at_arith[1]
                            if '^' in operand1:
                                operand1 = operand1[:operand1.find('^')]
                            if '^' in operand2:
                                operand2 = operand2[:operand2.find('^')]
                            if operand1.startswith('x'):
                                var_in_part = operand1
                                param_in_part = operand2
                            elif operand2.startswith('x'):
                                var_in_part = operand2
                                param_in_part = operand1
                            else:
                                sys.exit("you are not following the convention -- aborting")
                        else:
                            if '^' in part:
                                ele_in_part = part[:part.find('^')]
                                if ele_in_part.startswith('x'):
                                    var_in_part = ele_in_part
                                    param_in_part = ""
                                else:
                                    param_in_part = ele_in_part
                                    var_in_part = ""
                            else:
                                if part.startswith('x'):
                                    var_in_part = part
                                    param_in_part = ""
                                else:
                                    param_in_part = part
                                    var_in_part = ""
                        if self.debug:
                            print("\n\n\nvar_in_part: %s    para_in_part=%s" % (var_in_part, param_in_part))
                        part_for_partial_var2var = copy.deepcopy(part)
                        part_for_partial_var2param = copy.deepcopy(part)
                        if self.debug:
                            print("\n\nSTEP1a: part: %s  of   exp: %s" % (part, exp))
                            print("STEP1b: part_for_partial_var2var: %s  of   exp: %s" % (part_for_partial_var2var, exp))
                            print("STEP1c: part_for_partial_var2param: %s  of   exp: %s" % (part_for_partial_var2param, exp))
                        if var_in_part in self.independent_vars:
                            part = part.replace(var_in_part, str(input_vals_for_ind_vars[var_in_part]))
                            part_for_partial_var2var  = part_for_partial_var2var.replace(var_in_part, str(input_vals_for_ind_vars[var_in_part] + self.grad_delta))
                            part_for_partial_var2param = part_for_partial_var2param.replace(var_in_part, str(input_vals_for_ind_vars[var_in_part]))
                            if self.debug:
                                print("\n\nSTEP2a: part: %s   of   exp=%s" % (part, exp))
                                print("STEP2b: part_for_partial_var2var: %s   of   exp=%s" % (part_for_partial_var2var, exp))
                                print("STEP2c: part_for_partial_var2param: %s   of   exp=%s" % (part_for_partial_var2param, exp))
                        if var_in_part in self.dependent_vars:
                            if vals_for_dependent_vars[var_in_part] is not None:
                                part = part.replace(var_in_part, str(vals_for_dependent_vars[var_in_part]))
                                part_for_partial_var2var  = part_for_partial_var2var.replace(var_in_part, str(vals_for_dependent_vars[var_in_part] + self.grad_delta))
                                part_for_partial_var2param = part_for_partial_var2param.replace(var_in_part, str(vals_for_dependent_vars[var_in_part]))
                            if self.debug:
                                print("\n\nSTEP3a: part=%s   of   exp: %s" % (part, exp))
                                print("STEP3b: part_for_partial_var2var=%s   of   exp: %s" % (part_for_partial_var2var, exp))
                                print("STEP3c: part_for_partial_var2param: %s   of   exp=%s" % (part_for_partial_var2param, exp))
                        ##  now we do the same thing wrt the learnable parameters:
                        if param_in_part is not "" and param_in_part in self.learnable_params:
                            if self.vals_for_learnable_params[param_in_part] is not None:
                                part = part.replace(param_in_part, str(self.vals_for_learnable_params[param_in_part]))
                                part_for_partial_var2var  = part_for_partial_var2var.replace(param_in_part, str(self.vals_for_learnable_params[param_in_part]))
                                part_for_partial_var2param  = part_for_partial_var2param.replace(param_in_part, str(self.vals_for_learnable_params[param_in_part] + self.grad_delta))
                                if self.debug:
                                    print("\n\nSTEP4a: part: %s  of  exp: %s" % (part, exp))
                                    print("STEP4b: part_for_partial_var2var=%s   of   exp: %s" % (part_for_partial_var2var, exp))
                                    print("STEP4c: part_for_partial_var2param=%s   of   exp: %s" % (part_for_partial_var2param, exp))
                        ###  Now evaluate the part for each of three cases:
                        evaled_part = eval( part.replace('^', '**') )
                        vals_for_parts.append(evaled_part)
                        evaled_partial_var2var = eval( part_for_partial_var2var.replace('^', '**') )
                        if param_in_part is not "":
                            evaled_partial_var2param = eval( part_for_partial_var2param.replace('^', '**') )
                        partials_var_to_var[var2][var_in_part] = (evaled_partial_var2var - evaled_part) / self.grad_delta
                        if param_in_part is not "":
                            partials_var_to_param[var2][var_in_part][param_in_part] = (evaled_partial_var2param - evaled_part) / self.grad_delta
                    vals_for_dependent_vars[var2] = sum(vals_for_parts)
        predicted_output_vals = {var : vals_for_dependent_vars[var] for var in self.output_vars}
        return predicted_output_vals, partials_var_to_param, partials_var_to_var



    def calculate_loss(self, predicted_val, true_val):
        error = true_val - predicted_val
        loss = np.linalg.norm(error)
        return loss

    def plot_loss(self):
        plt.figure()
        plt.plot(self.LOSS)
        plt.show()

    def gen_gt_dataset(self, vals_for_learnable_params={}):
        '''
        This method illustrates that it is trivial to forward-propagate the information through
        the computational graph if you are not concerned about estimating the partial derivatives
        at the same time.  This method is used to generate 'dataset_size' number of input/output
        values for the computational graph for given values for the learnable parameters.
        '''
        N = self.dataset_size
        for i in range(N):
            if self.debug:
                print("\n\n\n================== Gen GT: iteration %d ============================\n" % i)
            vals_for_ind_vars = {var: random.uniform(0,1) for var in self.independent_vars}
            self.dataset_input_samples[i] = vals_for_ind_vars    
            vals_for_dependent_vars = {var: None for var in self.all_vars if var not in self.independent_vars}
            while True:
                if not any(v is None for v in [vals_for_dependent_vars[x] for x in vals_for_dependent_vars]):
                    break
                for var1 in self.all_vars:
                    for var2 in self.leads_to[var1]:
                        if vals_for_dependent_vars[var2] is not None: continue
                        predecessor_vars = self.depends_on[var2]
                        predecessor_vars_without_inds = [x for x in predecessor_vars if x not in self.independent_vars]
                        if any(vals_for_dependent_vars[vart] is None for vart in predecessor_vars_without_inds): continue
                        exp = self.expressions_dict[var2]
                        if self.debug:
                            print("\n\nSTEP1: exp: %s" % exp)
                        for var in self.independent_vars:
                            exp = exp.replace(var, str(vals_for_ind_vars[var]))
                        if self.debug:
                            print("\n\nSTEP2: exp: %s" % exp)
                        for var in self.dependent_vars:
                            if vals_for_dependent_vars[var] is not None:
                                exp = exp.replace(var, str(vals_for_dependent_vars[var]))
                        if self.debug:
                            print("\n\nSTEP3: exp: %s" % exp)
                        for ele in self.learnable_params:
                            exp = exp.replace(ele, str(vals_for_learnable_params[ele]))
                        if self.debug:                     
                            print("\n\nSTEP5: exp: %s" % exp)
                        vals_for_dependent_vars[var2] = eval( exp.replace('^', '**') )
            self.true_output_vals[i] = {ovar : vals_for_dependent_vars[ovar] for ovar in self.output_vars}
        if self.debug:
            print("\n\n\ninput samples: %s" % str(self.dataset_input_samples))
            print("\n\n\noutput vals: %s" % str(self.true_output_vals))

#_________________________  End of ComputationalGraphPrimer Class Definition ___________________________


#______________________________    Test code follows    _________________________________

if __name__ == '__main__': 
    pass