from DecisionTree import DecisionTree

import re
import random
import operator
from functools import reduce
import string
import sys

def convert(value):
    try:
        answer = float(value)
        return answer
    except:
        return value

def sample_index(sample_name):
    '''
    When the training data is read from a CSV file, we assume that the first column
    of each data record contains a unique integer identifier for the record in that
    row. This training data is stored in a dictionary whose keys are the prefix
    'sample_' followed by the identifying integers.  The purpose of this function is
    to return the identifying integer associated with a data record.
    '''
    m = re.search('_(.+)$', sample_name)
    return int(m.group(1))

def cleanup_csv(line):
    line = line.translate(bytes.maketrans(b":?/()[]{}'",b"          ")) \
           if sys.version_info[0] == 3 else line.translate(string.maketrans(":?/()[]{}'","          "))
    double_quoted = re.findall(r'"[^\"]+"', line[line.find(',') : ])
    for item in double_quoted:
        clean = re.sub(r',', r'', item[1:-1].strip())
        parts = re.split(r'\s+', clean.strip())
        line = str.replace(line, item, '_'.join(parts))
    white_spaced = re.findall(r',(\s*[^,]+)(?=,|$)', line)
    for item in white_spaced:
        litem = item
        litem = re.sub(r'\s+', '_', litem)
        litem = re.sub(r'^\s*_|_\s*$', '', litem) 
        line = str.replace(line, "," + item, "," + litem) if line.endswith(item) else str.replace(line, "," + item + ",", "," + litem + ",") 
    fields = re.split(r',', line)
    newfields = []
    for field in fields:
        newfield = field.strip()
        if newfield == '':
            newfields.append('NA')
        else:
            newfields.append(newfield)
    line = ','.join(newfields)
    return line
    
class DecisionTreeWithBagging(object):
    def __init__(self, *args, **kwargs ):
        if kwargs and args:
            raise SyntaxError(  
                   '''DecisionTreeWithBagging constructor can only be called with keyword arguments for
                      the following keywords: training_datafile, entropy_threshold,
                      max_depth_desired,csv_class_column_index,symbolic_to_numeric_cardinality_threshold,
                      number_of_histogram_bins,csv_columns_for_features,csv_cleanup_needed,
                      number_of_histogram_bins,how_many_bags,bag_overlap_fraction,debug1''') 
        allowed_keys = 'training_datafile','entropy_threshold','max_depth_desired','csv_class_column_index',\
                       'symbolic_to_numeric_cardinality_threshold','csv_columns_for_features',\
                       'number_of_histogram_bins','csv_cleanup_needed','how_many_bags','bag_overlap_fraction','debug1'
        keywords_used = kwargs.keys()
        for keyword in keywords_used:
            if keyword not in allowed_keys:
                raise SyntaxError(keyword + ":  Wrong keyword used --- check spelling") 
        training_datafile=entropy_threshold=max_depth_desired=csv_class_column_index=number_of_histogram_bins=None
        symbolic_to_numeric_cardinality_threshold=csv_columns_for_features=how_many_bags=csv_cleanup_needed=None
        bag_overlap_fraction=debug1=None
        if kwargs and not args:
            if 'how_many_bags' in kwargs : how_many_bags = kwargs.pop('how_many_bags')
            if 'bag_overlap_fraction' in kwargs : bag_overlap_fraction = kwargs.pop('bag_overlap_fraction')
            if 'training_datafile' in kwargs : training_datafile = kwargs['training_datafile']
            if 'csv_class_column_index' in kwargs: csv_class_column_index = kwargs.pop('csv_class_column_index')
            if 'csv_columns_for_features' in kwargs: \
                                  csv_columns_for_features = kwargs.pop('csv_columns_for_features')
            if 'csv_cleanup_needed' in kwargs: csv_cleanup_needed = kwargs.pop('csv_cleanup_needed')
            if 'debug1' in kwargs  :  debug1 = kwargs.pop('debug1')
        if not args and training_datafile:
            self._training_datafile = training_datafile
        elif not args and not training_datafile:
                raise Exception('''You must specify a training datafile''')
        else:
            if args[0] != 'evalmode':
                raise Exception("""When supplying non-keyword arg, it can only be 'evalmode'""")
        if csv_class_column_index:
            self._csv_class_column_index        =      csv_class_column_index
        else:
            self._csv_class_column_index        =      None
        if csv_columns_for_features:
            self._csv_columns_for_features      =      csv_columns_for_features
        else: 
            self._csv_columns_for_features      =      None            
        self._number_of_training_samples        =      None
        self._how_many_bags                     =      how_many_bags
        if csv_cleanup_needed:
            self._csv_cleanup_needed            =      csv_cleanup_needed
        else:
            self._csv_cleanup_needed            =      0
        self._segmented_training_data           =      {}
        self._all_trees                         =      {i:DecisionTree(**kwargs) for i in range(how_many_bags)}
        self._root_nodes                        =      []
        self._classifications                   =      None
        if bag_overlap_fraction is not None: 
            self._bag_overlap_fraction          =      bag_overlap_fraction 
        else:
            self._bag_overlap_fraction          =      0.20            
        self._bag_sizes                         =      []
        if debug1:
            self._debug1                        =      debug1
        else:
            self._debug1                        =      0

    def get_training_data_for_bagging(self):
        if not self._training_datafile.endswith('.csv'): 
            TypeError("Aborted. get_training_data_from_csv() is only for CSV files")
        class_names = []
        all_record_ids_with_class_labels = {}
        firstline = None
        data_dict = {}
        with open(self._training_datafile) as f:
            for i,line in enumerate(f):
                record = cleanup_csv(line) if self._csv_cleanup_needed else line
                if i == 0:
                    firstline = record
                    continue
                parts = record.rstrip().split(r',')
                data_dict[parts[0].strip('"')] = parts[1:]
                class_names.append(parts[self._csv_class_column_index])
                all_record_ids_with_class_labels[parts[0].strip('"')] = parts[self._csv_class_column_index]
                if i%10000 == 0:
                    print('.'),
                    sys.stdout.flush()
                sys.stdout = sys.__stdout__
            f.close() 
        self._how_many_total_training_samples = i   # i is less by 1 from total num of records; but that's okay
        unique_class_names = list(set(class_names))
        if self._debug1:
            print("\n\nTotal number of training samples: %d\n" % self._how_many_total_training_samples)
        self._number_of_training_samples = len(data_dict)
        all_feature_names = firstline.rstrip().split(',')[1:]
        class_column_heading = all_feature_names[self._csv_class_column_index - 1]        
        feature_names = [all_feature_names[i-1] for i in self._csv_columns_for_features]
        class_for_sample_dict = { "sample_" + key : 
               class_column_heading + "=" + data_dict[key][self._csv_class_column_index - 1] for key in data_dict}
        sample_names = ["sample_" + key for key in data_dict]
        random.shuffle(sample_names) 
        bag_size = int(len(sample_names) / self._how_many_bags)      
        def bags(l,n):
            for i in range(0,len(l),n):
                yield l[i:i+n]
        data_sample_bags = list(bags(sample_names, bag_size))[0:self._how_many_bags]
        if (len(sample_names) %  bag_size) > 0:
            data_sample_bags[-1] += sample_names[ self._how_many_bags * bag_size : ]
        self._bag_sizes = [ len(data_sample_bags[i]) for i in range(self._how_many_bags) ]
        if self._bag_overlap_fraction is not None:
            number_of_samples_needed_from_other_bags = int( len(data_sample_bags[0]) * self._bag_overlap_fraction )
            for i in range(self._how_many_bags): 
                samples_in_other_bags = reduce( lambda x,y: x+y, [data_sample_bags[x]
                                                                  for x in range(self._how_many_bags) if x != i])
                new_samples_to_be_added = random.sample(samples_in_other_bags, number_of_samples_needed_from_other_bags)
                data_sample_bags[i] += new_samples_to_be_added
            self._bag_sizes = [ len(data_sample_bags[i]) for i in range(self._how_many_bags) ]
        class_for_sample_dict_bags = { i : {sample_name :  class_for_sample_dict[sample_name]
                                 for sample_name in data_sample_bags[i] } for i in range(self._how_many_bags) }
        feature_values_for_samples_dict = {"sample_" + key :         
                  list(map(operator.add, list(map(operator.add, feature_names, "=" * len(feature_names))), 
           [str(convert(data_dict[key][i-1])) for i in self._csv_columns_for_features])) for key in data_dict}
        features_and_values_dict = {all_feature_names[i-1] :
            [convert(data_dict[key][i-1]) for key in data_dict] for i in self._csv_columns_for_features}
        all_class_names = sorted(list(set(class_for_sample_dict.values())))
        if self._debug1: print("\n All class names: "+ str(all_class_names))
        numeric_features_valuerange_dict = {}
        feature_values_how_many_uniques_dict = {}
        features_and_unique_values_dict = {}
        feature_values_for_samples_dict = {"sample_" + key :         
                  list(map(operator.add, list(map(operator.add, feature_names, "=" * len(feature_names))), 
           [str(convert(data_dict[key][i-1])) for i in self._csv_columns_for_features])) 
                           for key in data_dict}
        feature_values_for_samples_dict_bags =  { b : {sample_name :  feature_values_for_samples_dict[sample_name]
                                 for sample_name in data_sample_bags[b] } for b in range(self._how_many_bags) }
        features_and_values_dict = {all_feature_names[i-1] :
            [convert(data_dict[key][i-1]) for key in data_dict] for i in self._csv_columns_for_features}
        all_class_names = sorted(list(set(class_for_sample_dict.values())))
        if self._debug1: print("\n All class names: "+ str(all_class_names))
        features_and_values_dict_bags = { b :  { all_feature_names[i-1] :
          [convert(data_dict[key][i-1]) for  key in data_dict
                                                   if "sample_" + key in data_sample_bags[b] ]
                            for i in self._csv_columns_for_features } for b in range(self._how_many_bags) }
        numeric_features_valuerange_dict_bags = {b : {} for b in range(self._how_many_bags)}        
        feature_values_how_many_uniques_dict_bags = {b : {} for b in range(self._how_many_bags)}
        features_and_unique_values_dict_bags = {b : {} for b in range(self._how_many_bags)}
        for i in range(self._how_many_bags):
            for feature in features_and_values_dict_bags[i]:
                unique_values_for_feature = list(set(features_and_values_dict_bags[i][feature]))
                unique_values_for_feature = sorted(list(filter(lambda x: x != 'NA', unique_values_for_feature)))
                feature_values_how_many_uniques_dict_bags[i][feature] = len(unique_values_for_feature)
                if all(isinstance(x,float) for x in unique_values_for_feature):
                    numeric_features_valuerange_dict_bags[i][feature] = \
                                  [min(unique_values_for_feature), max(unique_values_for_feature)]
                    unique_values_for_feature.sort(key=float)
                features_and_unique_values_dict_bags[i][feature] = sorted(unique_values_for_feature)
        for i in range(self._how_many_bags):                
            self._all_trees[i]._class_names = all_class_names
            self._all_trees[i]._feature_names = feature_names
            self._all_trees[i]._samples_class_label_dict = class_for_sample_dict_bags[i]
            self._all_trees[i]._training_data_dict  =  feature_values_for_samples_dict_bags[i]
            self._all_trees[i]._features_and_values_dict    =  features_and_values_dict_bags[i]
            self._all_trees[i]._features_and_unique_values_dict    =  features_and_unique_values_dict_bags[i]
            self._all_trees[i]._numeric_features_valuerange_dict = numeric_features_valuerange_dict_bags[i]
            self._all_trees[i]._feature_values_how_many_uniques_dict = feature_values_how_many_uniques_dict_bags[i]
        if self._debug1:
            for i in range(self._how_many_bags):            
                print("\n\n=============================   For bag %d   ==================================\n" % i)
                print("\nAll class names: " + str(self._all_trees[i]._class_names))
                print("\nEach sample data record:")
                for item in sorted(self._all_trees[i]._training_data_dict.items(), key = lambda x: sample_index(x[0]) ):
                    print(item[0]  + "  =>  "  + str(item[1]))
                print("\nclass label for each data sample:")
                for item in sorted(self._all_trees[i]._samples_class_label_dict.items(), key=lambda x: sample_index(x[0])):
                    print(item[0]  + "  =>  "  + str(item[1]))
                print("\nfeatures and the values taken by them:")
                for item in sorted(self._all_trees[i]._features_and_values_dict.items()):
                    print(item[0]  + "  =>  "  + str(item[1]))
                print("\nnumeric features and their ranges:")
                for item in sorted(self._all_trees[i]._numeric_features_valuerange_dict.items()):
                    print(item[0]  + "  =>  "  + str(item[1]))
                print("\nunique values for the features:")
                for item in sorted(self._all_trees[i]._features_and_unique_values_dict.items()):
                    print(item[0]  + "  =>  "  + str(item[1]))
                print("\nnumber of unique values in each feature:")
                for item in sorted(self._all_trees[i]._feature_values_how_many_uniques_dict.items()):
                    print(item[0]  + "  =>  "  + str(item[1]))

    def get_number_of_training_samples(self):
        return self._number_of_training_samples

    def show_training_data_in_bags(self):
        for i in range(self._how_many_bags):
            print("\n\n=============================   For bag %d   ==================================\n" % i)
            self._all_trees[i].show_training_data()            

    def calculate_first_order_probabilities(self):            
        list(map(lambda x: self._all_trees[x].calculate_first_order_probabilities(), range(self._how_many_bags)))

    def calculate_class_priors(self):            
        list(map(lambda x: self._all_trees[x].calculate_class_priors(), range(self._how_many_bags)))
        
    def construct_decision_trees_for_bags(self):            
        self._root_nodes = \
             list(map(lambda x: self._all_trees[x].construct_decision_tree_classifier(), range(self._how_many_bags)))

    def display_decision_trees_for_bags(self):
        for i in range(self._how_many_bags):
            print("\n\n=============================   For bag %d   ==================================\n" % i)
            self._root_nodes[i].display_decision_tree("     ")

    def classify_with_bagging(self, test_sample):
        self._classifications = list(map(lambda x: self._all_trees[x].classify(self._root_nodes[x], test_sample),
                                   range(self._how_many_bags)))

    def display_classification_results_for_each_bag(self):         
        classifications = self._classifications
        if classifications is None:
            raise Exception('''You must first call "classify_with_bagging()" before invoking "display_classification_results_for_each_bag()" ''')
        solution_paths = list(map(lambda x: x['solution_path'], classifications))
        for i in range(self._how_many_bags):
            print("\n\n=============================   For bag %d   ==================================\n" % i)
            print("\nbag size: %d\n" % self._bag_sizes[i])
            classification = classifications[i]
            del classification['solution_path']
            which_classes = list( classification.keys() )
            which_classes = sorted(which_classes, key=lambda x: classification[x], reverse=True)
            print("\nClassification:\n")
            print("     "  + str.ljust("class name", 30) + "probability")
            print("     ----------                    -----------")
            for which_class in which_classes:
                if which_class is not 'solution_path':
                    print("     "  + str.ljust(which_class, 30) +  str(classification[which_class]))
            print("\nSolution path in the decision tree: " + str(solution_paths[i]))
            print("\nNumber of nodes created: " + str(self._root_nodes[i].how_many_nodes()))

    def get_majority_vote_classification(self):
        classifications = self._classifications
        if classifications is None:
            raise Exception('''You must first call "classify_with_bagging()" before invoking "get_majority_vote_classification()" ''')
        decision_classes = {class_label : 0 for class_label in self._all_trees[0]._class_names}
        for i in range(self._how_many_bags):
            classification = classifications[i]
            if 'solution_path' in classification:
                del classification['solution_path']
            sorted_classes = sorted(list(classification.keys()), key=lambda x: classification[x], reverse=True)  
            decision_classes[sorted_classes[0]] += 1
        sorted_by_votes_decision_classes = \
                         sorted(list(decision_classes.keys()), key=lambda x: decision_classes[x], reverse=True)
        return sorted_by_votes_decision_classes[0]

    def get_all_class_names(self):
        return self._all_trees[0]._class_names