#!/usr/bin/env python

__version__ = '1.7'
__author__  = "Avinash Kak (kak@purdue.edu)"
__date__    = '2012-July-29'
__url__     = 'https://engineering.purdue.edu/kak/distDT/DecisionTree-1.7.html'
__copyright__ = "(C) 2012 Avinash Kak. Python Software Foundation."


import math
import re
import sys
import functools 


#-------------------------  Utility Functions ---------------------------

def sample_index(sample_name):
    '''
    We assume that every record in the training datafile begins with
    the name of the record. The only requirement is that these names
    end in a suffix like '_23', as in 'sample_23' for the 23rd training
    record.  This function returns the integer in the suffix.
    '''
    m = re.search('_(.+)$', sample_name)
    return int(m.group(1))

# Meant only for an array of strings (no nesting):
def deep_copy_array(array_in):
    array_out = []
    for i in range(0, len(array_in)):
        array_out.append( array_in[i] )
    return array_out

# Returns simultaneously the minimum value and its positional index in an
# array. [Could also have used min() and index() defined for Python's
# sequence types.]
def minimum(arr):
    min,index = None,None
    for i in range(0, len(arr)):  
        if min is None or arr[i] < min:
            index = i
            min = arr[i]
    return min,index


#---------------------- DecisionTree Class Definition ---------------------

class DecisionTree(object):

    def __init__(self, *args, **kwargs ):
        if args:
            raise ValueError(  
                   '''DecisionTree constructor can only be called
                      with keyword arguments for the following
                      keywords: training_datafile, entropy_threshold,
                      max_depth_desired, debug1, and debug2''') 

        allowed_keys = 'training_datafile','entropy_threshold', \
                       'max_depth_desired','debug1','debug2'
        keywords_used = kwargs.keys()
        for keyword in keywords_used:
            if keyword not in allowed_keys:
                raise ValueError("Wrong keyword used --- check spelling") 

        training_datafile = entropy_threshold = max_depth_desired = None
        debug1 = debug2 = None

        if 'training_datafile' in kwargs : \
                           training_datafile = kwargs.pop('training_datafile')
        if 'entropy_threshold' in kwargs : \
                           entropy_threshold = kwargs.pop('entropy_threshold')
        if 'max_depth_desired' in kwargs : \
                           max_depth_desired = kwargs.pop('max_depth_desired')
        if 'debug1' in kwargs  :  debug1 = kwargs.pop('debug1')
        if 'debug2' in kwargs  :  debug2 = kwargs.pop('debug2')

        if training_datafile:
            self._training_datafile = training_datafile
        else:
            raise ValueError('''You must specify a training datafile''')
        if entropy_threshold: 
            self._entropy_threshold =  entropy_threshold
        else:
            self._entropy_threshold =  0.001        
        if max_depth_desired:
            self._max_depth_desired = max_depth_desired 
        else:
            self._max_depth_desired = None
        if debug1:
            self._debug1 = debug1
        else:
            self._debug1 = 0
        if debug2:
            self._debug2 = debug2
        else:
            self._debug2 = 0
        self._root_node = None
        self._probability_cache           = {}
        self._entropy_cache               = {}
        self._training_data_dict          = {}
        self._features_and_values_dict    = {}
        self._samples_class_label_dict    = {}
        self._class_names                 = []
        self._class_priors                = []
        self._feature_names               = []

    def get_training_data(self):
        recording_features_flag = 0
        recording_training_data = 0
        table_header = None
        column_labels_dict = {}
        FILE = None
        try:
            FILE = open( self._training_datafile )
        except IOError:
            print("unable to open %s" % self._training_datafile)
            sys.exit(1)
        for line in FILE:
            line = line.rstrip()
            lineskip = r'^[\s=#]*$'
            if re.search(lineskip, line): 
                continue
            elif re.search(r'\s*class', line, re.IGNORECASE) \
                       and not recording_training_data \
                       and not recording_features_flag:
                classpattern = r'^\s*class names:\s*([ \S]+)\s*'
                m = re.search(classpattern, line, re.IGNORECASE)
                if not m: 
                    raise ValueError('''No class names in training file''')
                self._class_names = m.group(1).split()
                continue
            elif re.search(r'\s*feature names and their values', \
                               line, re.IGNORECASE):
                recording_features_flag = 1
                continue
            elif re.search(r'training data', line, re.IGNORECASE):
                recording_training_data = 1
                recording_features_flag = 0
                continue
            elif not recording_training_data and recording_features_flag:
                feature_name_value_pattern = r'^\s*(\S+)\s*=>\s*(.+)'
                m = re.search(feature_name_value_pattern, line, re.IGNORECASE)
                feature_name = m.group(1)
                feature_values = m.group(2).split()
                self._features_and_values_dict[feature_name]  = feature_values
            elif recording_training_data:
                if not table_header:
                    table_header = line.split()
                    for i in range(2, len(table_header)):
                        column_labels_dict[i] = table_header[i]
                    continue
                record = line.split()
                if record[1] not in self._class_names:
                    sys.exit('''The class name in a row of training data does '''
                             '''not match the class names extracted earlier from '''
                             '''the file. You may have used commas or some other '''
                             '''punctuation to separate out the class names '''
                             '''earlier''' )                
                self._samples_class_label_dict[record[0]] = record[1]
                self._training_data_dict[record[0]] = []
                for i in range(2, len(record)):
                    feature_name_for_i = column_labels_dict[i]
                    if record[i] not in self._features_and_values_dict[feature_name_for_i]:
                        sys.exit('''The feature value for a row of training data does '''
                                 '''not correspond to the different possible values '''
                                 '''declared at the top of the training file. You may '''
                                 '''have used commas or other punctuation marks to '''
                                 '''separate out the feature values ''' )
                    self._training_data_dict[record[0]].append(
                          column_labels_dict[i] + "=>" + record[i] )
        FILE.close()                        
        self._feature_names = self._features_and_values_dict.keys()
        empty_classes = []
        for classname in self._class_names:        
            if classname not in self._samples_class_label_dict.values():
                empty_classes.append( classname )
        if empty_classes and self._debug1:
            num_empty_classes = len(empty_classes)
            print("\nDid you know you have %d empty classes?  The decision module can ignore these classes for you." % (num_empty_classes))
            print("EMPTY CLASSES: " , empty_classes) 
            ans = None
            if sys.version_info[0] == 3:
                ans = input("\nDo you wish to continue? Enter 'y' if yes:  ")
            else:
                ans = raw_input("\nDo you wish to continue? Enter 'y' if yes:  ")
            ans = ans.strip()
            if ans != 'y':
                sys.exit(0)
        for classname in empty_classes:
            self._class_names.remove(classname)
        if self._debug1:
            print("Class names: ", self._class_names)
            print( "Feature names: ", self._feature_names)
            print("Features and values: ", self._features_and_values_dict.items())
        for feature in self._feature_names:
            values_for_feature = self._features_and_values_dict[feature]
            for value in values_for_feature:
                feature_and_value = "".join([feature, "=>", value])
                self._probability_cache[feature_and_value] = \
                       self.probability_for_feature_value(feature, value)

    def show_training_data(self):
        print("Class names: ", self._class_names)
        print("\n\nFeatures and Their Possible Values:\n\n")
        features = self._features_and_values_dict.keys()
        for feature in sorted(features):
            print("%s ---> %s" \
                  % (feature, self._features_and_values_dict[feature]))
        print("\n\nSamples vs. Class Labels:\n\n")
        for item in sorted(self._samples_class_label_dict.items(), \
                key = lambda x: sample_index(x[0]) ):
            print(item)
        print("\n\nTraining Samples:\n\n")
        for item in sorted(self._training_data_dict.items(), \
                key = lambda x: sample_index(x[0]) ):
            print(item)


#------------------    Classify with Decision Tree  -----------------------

    def classify(self, root_node, features_and_values):
        if not self.check_names_used(features_and_values):
            raise ValueError("Error in the names you have used for features and/or values") 
        classification = self.recursive_descent_for_classification( \
                                    root_node, features_and_values )
        if self._debug2:
            print("\nThe classification:")
            for class_name in self._class_names:
                print("    " + class_name + " with probability " + \
                                        str(classification[class_name]))
        return classification

    def recursive_descent_for_classification(self, node, feature_and_values):
        feature_test_at_node = node.get_feature()
        value_for_feature = None
        remaining_features_and_values = []
        for feature_and_value in feature_and_values:
            pattern = r'(.+)=>(.+)'
            m = re.search(pattern, feature_and_value)
            feature,value = m.group(1),m.group(2)
            if feature == feature_test_at_node:
                value_for_feature = value
            else:
                remaining_features_and_values.append(feature_and_value)
        if feature_test_at_node:
            feature_value_combo = \
                    "".join([feature_test_at_node,"=>",value_for_feature])
        children = node.get_children()
        answer = {}
        if len(children) == 0:
            leaf_node_class_probabilities = node.get_class_probabilities()
            for i in range(0, len(self._class_names)):
                answer[self._class_names[i]] = leaf_node_class_probabilities[i]
            return answer
        for child in children:
            branch_features_and_values = child.get_branch_features_and_values()
            last_feature_and_value_on_branch = branch_features_and_values[-1] 
            if last_feature_and_value_on_branch == feature_value_combo:
                answer = self.recursive_descent_for_classification(child, \
                                        remaining_features_and_values)
                break
        return answer

    def classify_by_asking_questions(self, root_node):
        classification = self.interactive_recursive_descent_for_classification(root_node)
        return classification

    def interactive_recursive_descent_for_classification(self, node):
        feature_test_at_node = node.get_feature()
        possible_values_for_feature = \
                   self._features_and_values_dict[feature_test_at_node]
        value_for_feature = None
        while 1:
            value_for_feature = None
            if sys.version_info[0] == 3:
                value_for_feature = \
                   input( "\nWhat is the value for the feature '" + \
               feature_test_at_node + "'?" + "\n" +    \
               "Enter one of: " + str(possible_values_for_feature) + " => " )
            else:
                value_for_feature = \
                   raw_input( "\nWhat is the value for the feature '" + \
               feature_test_at_node + "'?" + "\n" +    \
               "Enter one of: " + str(possible_values_for_feature) + " => " )
            value_for_feature = value_for_feature.strip()
            answer_found = 0
            for value in possible_values_for_feature:
                if value == value_for_feature: 
                    answer_found = 1
                    break
            if answer_found == 1: break
            print("\n")
            print("You entered illegal value. Let's try again")
            print("\n")
        feature_value_combo = \
                "".join([feature_test_at_node,"=>",value_for_feature])
        children = node.get_children()
        answer = {}
        if len(children) == 0:
            leaf_node_class_probabilities = node.get_class_probabilities()
            for i in range(0, len(self._class_names)):
                answer[self._class_names[i]] = leaf_node_class_probabilities[i]
            return answer
        for child in children:
            branch_features_and_values = child.get_branch_features_and_values()
            last_feature_and_value_on_branch = branch_features_and_values[-1] 
            if last_feature_and_value_on_branch == feature_value_combo:
                answer = self.interactive_recursive_descent_for_classification(child)
                break
        return answer

#----------------------  Construct Decision Tree  -------------------------- 

    def construct_decision_tree_classifier(self):
        if self._debug2:        
            self.determine_data_condition() 
            print("\nStarting construction of the decision tree:\n") 
        class_probabilities = \
          list(map(lambda x: self.prior_probability_for_class(x), \
                                                   self._class_names))
        entropy = self.class_entropy_on_priors()
        root_node = Node(None, entropy, class_probabilities, [])
        self._root_node = root_node
        self.recursive_descent(root_node)
        return root_node        

    def recursive_descent(self, node):
        features_and_values_on_branch = node.get_branch_features_and_values()
        best_feature, best_feature_entropy =  \
         self.best_feature_calculator(features_and_values_on_branch)
        node.set_feature(best_feature)
        if self._debug2: node.display_node() 
        if self._max_depth_desired is not None and \
         len(features_and_values_on_branch) >= self._max_depth_desired:
            return
        if best_feature is None: return
        if best_feature_entropy \
                   < node.get_entropy() - self._entropy_threshold:
            values_for_feature = \
                  self._features_and_values_dict[best_feature]
            feature_value_combos = \
              map(lambda x: "".join([best_feature,"=>",x]), values_for_feature)
            for feature_and_value in feature_value_combos:
                extended_branch_features_and_values = None
                if features_and_values_on_branch is None:
                    extended_branch_features_and_values = feature_and_value
                else:
                    extended_branch_features_and_values = \
                        deep_copy_array( features_and_values_on_branch )
                    extended_branch_features_and_values.append(\
                                                      feature_and_value)
                class_probabilities = list(map(lambda x: \
         self.probability_for_a_class_given_sequence_of_features_and_values(\
                x, extended_branch_features_and_values), self._class_names))
                child_node = Node(None, best_feature_entropy, \
                     class_probabilities, extended_branch_features_and_values)
                node.add_child_link( child_node )
                self.recursive_descent(child_node)

    def best_feature_calculator(self, features_and_values_on_branch):
        features_already_used = []
        for feature_and_value in features_and_values_on_branch:
            pattern = r'(.+)=>(.+)'
            m = re.search(pattern, feature_and_value)
            feature = m.group(1)
            features_already_used.append(feature)
        feature_tests_not_yet_used = []
        for feature in self._feature_names:
            if (feature not in features_already_used):
                feature_tests_not_yet_used.append(feature)
        if len(feature_tests_not_yet_used) == 0: return None, None
        array_of_entropy_values_for_different_features = []
        for i in range(0, len(feature_tests_not_yet_used)):
            values = \
             self._features_and_values_dict[feature_tests_not_yet_used[i]]
            entropy_for_new_feature = None
            for value in values:
                feature_and_value_string = \
                   "".join([feature_tests_not_yet_used[i], "=>", value]) 
                extended_features_and_values_on_branch = None
                if len(features_and_values_on_branch) > 0:
                    extended_features_and_values_on_branch =  \
                          deep_copy_array(features_and_values_on_branch)
                    extended_features_and_values_on_branch.append(  \
                                              feature_and_value_string) 
                else:
                    extended_features_and_values_on_branch  =    \
                        [feature_and_value_string]
                if entropy_for_new_feature is None:
                    entropy_for_new_feature =  \
                   self.class_entropy_for_a_given_sequence_of_features_values(\
                             extended_features_and_values_on_branch) \
                     * \
                     self.probability_of_a_sequence_of_features_and_values( \
                         extended_features_and_values_on_branch)
                    continue
                else:
                    entropy_for_new_feature += \
                  self.class_entropy_for_a_given_sequence_of_features_values(\
                         extended_features_and_values_on_branch) \
                     *  \
                     self.probability_of_a_sequence_of_features_and_values( \
                         extended_features_and_values_on_branch)
            array_of_entropy_values_for_different_features.append(\
                                         entropy_for_new_feature)
        min,index = minimum(array_of_entropy_values_for_different_features)
        return feature_tests_not_yet_used[index], min


#--------------------------  Entropy Calculators  --------------------------

    def class_entropy_on_priors(self):
        if 'priors' in self._entropy_cache:
            return self._entropy_cache['priors']
        entropy = None
        for class_name in self._class_names:
            prob = self.prior_probability_for_class(class_name)
            if (prob >= 0.0001) and (prob <= 0.999):
                log_prob = math.log(prob,2)
            if prob < 0.0001:
                log_prob = 0 
            if prob > 0.999:
                log_prob = 0 
            if entropy is None:
                entropy = -1.0 * prob * log_prob
                continue
            entropy += -1.0 * prob * log_prob
        self._entropy_cache['priors'] = entropy
        return entropy

    def class_entropy_for_a_given_sequence_of_features_values(self, \
                                       array_of_features_and_values):
        sequence = ":".join(array_of_features_and_values)
        if sequence in self._entropy_cache:
            return self._entropy_cache[sequence]
        entropy = None    
        for class_name in self._class_names:
            prob = \
           self.probability_for_a_class_given_sequence_of_features_and_values(\
                 class_name, array_of_features_and_values)
            if prob == 0:
                prob = 1.0/len(self._class_names)
            if (prob >= 0.0001) and (prob <= 0.999):
                log_prob = math.log(prob,2)
            if prob < 0.0001:
                log_prob = 0 
            if prob > 0.999:
                log_prob = 0 
            if entropy is None:
                entropy = -1.0 * prob * log_prob
                continue
            entropy += -1.0 * prob * log_prob
        self._entropy_cache[sequence] = entropy
        return entropy


#-------------------------  Probability Calculators ------------------------

    def prior_probability_for_class(self, class_name):
        class_name_in_cache = "".join(["prior::", class_name])
        if class_name_in_cache in self._probability_cache:
            return self._probability_cache[class_name_in_cache]
        total_num_of_samples = len( self._samples_class_label_dict )
        all_values = self._samples_class_label_dict.values()
        for this_class_name in self._class_names:
            trues = list(filter( lambda x: x == this_class_name, all_values ))
            prior_for_this_class = (1.0 * len(trues)) / total_num_of_samples
            this_class_name_in_cache = "".join(["prior::", this_class_name])
            self._probability_cache[this_class_name_in_cache] = \
                                                    prior_for_this_class
        return self._probability_cache[class_name_in_cache]

    def probability_for_feature_value(self, feature, value):
        feature_and_value = "".join([feature, "=>", value])
        if feature_and_value in self._probability_cache:
            return self._probability_cache[feature_and_value]
        values_for_feature = self._features_and_values_dict[feature]
        values_for_feature = list(map(lambda x: feature + "=>" + x, \
                                                   values_for_feature))
        value_counts = [0] * len(values_for_feature)
        for sample in sorted(self._training_data_dict.keys(), \
                key = lambda x: sample_index(x) ):
            features_and_values = self._training_data_dict[sample]
            for i in range(0, len(values_for_feature)):
                for current_value in features_and_values:
                    if values_for_feature[i] == current_value:
                        value_counts[i] += 1 
        for i in range(0, len(values_for_feature)):
            self._probability_cache[values_for_feature[i]] = \
                      value_counts[i] / (1.0 * len(self._training_data_dict))
        if feature_and_value in self._probability_cache:
            return self._probability_cache[feature_and_value]
        else:
            return 0

    def probability_for_feature_value_given_class(self, feature_name, \
                                        feature_value, class_name):
        feature_value_class = \
             "".join([feature_name,"=>",feature_value,"::",class_name])
        if feature_value_class in self._probability_cache:
            return self._probability_cache[feature_value_class]
        samples_for_class = []
        for sample_name in self._samples_class_label_dict.keys():
            if self._samples_class_label_dict[sample_name] == class_name:
                samples_for_class.append(sample_name) 
        values_for_feature = self._features_and_values_dict[feature_name]
        values_for_feature = \
        list(map(lambda x: "".join([feature_name,"=>",x]), values_for_feature))
        value_counts = [0] * len(values_for_feature)
        for sample in samples_for_class:
            features_and_values = self._training_data_dict[sample]
            for i in range(0, len(values_for_feature)):
                for current_value in (features_and_values):
                    if values_for_feature[i] == current_value:
                        value_counts[i] += 1 
        total_count = functools.reduce(lambda x,y:x+y, value_counts)
        for i in range(0, len(values_for_feature)):
            feature_and_value_for_class = \
                     "".join([values_for_feature[i],"::",class_name])
            self._probability_cache[feature_and_value_for_class] = \
                                       value_counts[i] / (1.0 * total_count)
        feature_and_value_and_class = \
              "".join([feature_name, "=>", feature_value,"::",class_name])
        if feature_and_value_and_class in self._probability_cache:
            return self._probability_cache[feature_and_value_and_class]
        else:
            return 0

    def probability_of_a_sequence_of_features_and_values(self, \
                                        array_of_features_and_values):
        sequence = ":".join(array_of_features_and_values)
        if sequence in self._probability_cache:
            return self._probability_cache[sequence]
        probability = None
        for feature_and_value in array_of_features_and_values:
            pattern = r'(.+)=>(.+)'
            m = re.search(pattern, feature_and_value)
            feature,value = m.group(1),m.group(2)
            if probability is None:
                probability = \
                   self.probability_for_feature_value(feature, value)
                continue
            else:
                probability *= \
                  self.probability_for_feature_value(feature, value)
        self._probability_cache[sequence] = probability
        return probability

    def probability_for_sequence_of_features_and_values_given_class(self, \
                            array_of_features_and_values, class_name):
        sequence = ":".join(array_of_features_and_values)
        sequence_with_class = "".join([sequence, "::", class_name])
        if sequence_with_class in self._probability_cache:
            return self._probability_cache[sequence_with_class]
        probability = None
        for feature_and_value in array_of_features_and_values:
            pattern = r'(.+)=>(.+)'
            m = re.search(pattern, feature_and_value)
            feature,value = m.group(1),m.group(2)
            if probability is None:
                probability = self.probability_for_feature_value_given_class(\
                                                 feature, value, class_name)
                continue
            else:
                probability *= self.probability_for_feature_value_given_class(\
                                           feature, value, class_name)
        self._probability_cache[sequence_with_class] = probability
        return probability 

    def probability_for_a_class_given_feature_value(self, class_name, \
                                              feature_name, feature_value):
        prob = self.probability_for_feature_value_given_class( \
                                 feature_name, feature_value, class_name)
        answer = (prob * self.prior_probability_for_class(class_name)) \
                 /                                                     \
                 self.probability_for_feature_value(feature_name,feature_value)
        return answer

    def probability_for_a_class_given_sequence_of_features_and_values(self, \
                    class_name, array_of_features_and_values):
        sequence = ":".join(array_of_features_and_values)
        class_and_sequence = "".join([class_name, "::", sequence])
        if class_and_sequence in self._probability_cache:
            return self._probability_cache[class_and_sequence]
        array_of_class_probabilities = [0] * len(self._class_names)
        for i in range(0, len(self._class_names)):
            prob = \
            self.probability_for_sequence_of_features_and_values_given_class(\
                   array_of_features_and_values, self._class_names[i]) 
            if prob == 0:
                array_of_class_probabilities[i] = 0 
                continue
            prob_of_feature_sequence = \
                self.probability_of_a_sequence_of_features_and_values(  \
                                              array_of_features_and_values)
            prior = self.prior_probability_for_class(self._class_names[i])
            array_of_class_probabilities[i] =   \
               prob * self.prior_probability_for_class(self._class_names[i]) \
                 / prob_of_feature_sequence
        sum_probability = \
          functools.reduce(lambda x,y:x+y, array_of_class_probabilities)
        if sum_probability == 0:
            array_of_class_probabilities = [1.0/len(self._class_names)] \
                                              * len(self._class_names)
        else:
            array_of_class_probabilities = \
                     list(map(lambda x: x / sum_probability,\
                               array_of_class_probabilities))
        for i in range(0, len(self._class_names)):
            this_class_and_sequence = \
                     "".join([self._class_names[i], "::", sequence])
            self._probability_cache[this_class_and_sequence] = \
                                     array_of_class_probabilities[i]
        return self._probability_cache[class_and_sequence]


#---------------------  Class Based Utilities  ---------------------

    def determine_data_condition(self):
        num_of_features = len(self._feature_names)
        values = list(self._features_and_values_dict.values())
        print("Number of features: ", num_of_features)
        max_num_values = 0
        for i in range(0, len(values)):
            if ((not max_num_values) or (len(values[i]) > max_num_values)):
                max_num_values = len(values[i])
        print("Largest number of feature values is: ", max_num_values)
        estimated_number_of_nodes = max_num_values ** num_of_features
        print("\nWORST CASE SCENARIO WITHOUT TAKING INTO ACCOUNT YOUR SETTING FOR \
ENTROPY_THRESHOD: The decision tree COULD have as many as %d nodes. The exact number of \
nodes created depends critically on the entropy_threshold used for node expansion. \
(The default for this threshold is 0.001.)" % (estimated_number_of_nodes))
        if estimated_number_of_nodes > 10000:
            print("\nTHIS IS WAY TOO MANY NODES. Consider using a relatively large \
value for entropy_threshold to reduce the number of nodes created.\n")
            ans = None
            if sys.version_info[0] == 3:
                ans = input("\nDo you wish to continue? Enter 'y' if yes:  ")
            else:
                ans = raw_input("\nDo you wish to continue? Enter 'y' if yes:  ")
            ans = ans.strip()
            if ans != 'y':
                sys.exit(0)
        print("\nHere are the probabilities of feature-value pairs in your data:\n\n")
        for feature in self._feature_names:
            values_for_feature = self._features_and_values_dict[feature]
            for value in values_for_feature:
                prob = self.probability_for_feature_value(feature,value) 
                print("Probability of feature-value pair (%s,%s): %.3f" % \
                                                (feature,value,prob)) 

    def check_names_used(self, features_and_values_test_data):
        for feature_and_value in features_and_values_test_data:
            pattern = r'(.+)=>(.+)'
            m = re.search(pattern, feature_and_value)
            feature,value = m.group(1),m.group(2)
            if feature is None or value is None:
                raise ValueError("Your test data has formatting error")
            if feature not in self._feature_names:
                return 0
            if value not in self._features_and_values_dict[feature]:
                return 0
        return 1

    def get_class_names(self):
        return self._class_names


#----------------  Generate Your Own Training Data  ----------------

class TrainingDataGenerator(object):
    def __init__(self, *args, **kwargs ):
        if args:
            raise ValueError(  
                   '''TrainingDataGenerator can only be called
                      with keyword arguments for the following
                      keywords: output_datafile, parameter_file,
                      number_of_training_samples, write_to_file,
                      debug1, and debug2''') 
        allowed_keys = 'output_datafile','parameter_file', \
                       'number_of_training_samples', 'write_to_file', \
                       'debug1','debug2'
        keywords_used = kwargs.keys()
        for keyword in keywords_used:
            if keyword not in allowed_keys:
                raise ValueError("Wrong keyword used --- check spelling") 

        output_datafile = parameter_file = number_of_training_samples = None
        write_to_file = debug1 = debug2 = None

        if 'output_datafile' in kwargs : \
                           output_datafile = kwargs.pop('output_datafile')
        if 'parameter_file' in kwargs : \
                           parameter_file = kwargs.pop('parameter_file')
        if 'number_of_training_samples' in kwargs : \
          number_of_training_samples = kwargs.pop('number_of_training_samples')
        if 'write_to_file' in kwargs : \
                                   write_to_file = kwargs.pop('write_to_file')
        if 'debug1' in kwargs  :  debug1 = kwargs.pop('debug1')
        if 'debug2' in kwargs  :  debug2 = kwargs.pop('debug2')

        if output_datafile:
            self._output_datafile = output_datafile
        else:
            raise ValueError('''You must specify an output datafile''')
        if parameter_file: 
            self._parameter_file =  parameter_file
        else:
            raise ValueError('''You must specify a parameter file''')
        if number_of_training_samples:
            self._number_of_training_samples = number_of_training_samples
        else:
            raise ValueError('''You forgot to specify the number of training samples needed''')
        if write_to_file:
            self._write_to_file = write_to_file
        else:
            self._write_to_file = 0          
        if debug1:
            self._debug1 = debug1
        else:
            self._debug1 = 0
        if debug2:
            self._debug2 = debug2
        else:
            self._debug2 = 0
        self._training_sample_records     = {}
        self._features_and_values_dict    = {}
        self._bias_dict                   = {}
        self._class_names                 = []
        self._class_priors                = []

    # Read the parameter for generating the TRAINING data
    def read_parameter_file( self ):
        debug1 = self._debug1
        debug2 = self._debug2
        write_to_file = self._write_to_file
        number_of_training_samples = self._number_of_training_samples
        input_parameter_file = self._parameter_file
        all_params = []
        param_string = ''
        try:
            FILE = open(input_parameter_file, 'r')
        except IOError:
            print("unable to open %s" % input_parameter_file)
            sys.exit(1)
        all_params = FILE.read()
        all_params = re.split(r'\n', all_params)
        FILE.close()
        pattern = r'^(?![ ]*#)'
        try:
            regex = re.compile( pattern )
        except:
            print("error in your pattern")
            sys.exit(1)
        all_params = list( filter( regex.search, all_params ) )
        all_params = list( filter( None, all_params ) )
        all_params = [x.rstrip('\n') for x in all_params]
        param_string = ' '.join( all_params )
        pattern = '^\s*class names:(.*?)\s*class priors:(.*?)(feature: .*)'
        m = re.search( pattern, param_string )
        rest_params = m.group(3)
        self._class_names = list( filter(None, re.split(r'\s+', m.group(1))) )
        self._class_priors = list( filter(None, re.split(r'\s+', m.group(2))) )
        pattern = r'(feature:.*?) (bias:.*)'
        m = re.search( pattern, rest_params  )
        feature_string = m.group(1)
        bias_string = m.group(2)
        features_and_values_dict = {}
        features = list( filter( None, re.split( r'(feature[:])', feature_string ) ) )
        for item in features:
            if re.match(r'feature', item): continue
            splits = list( filter(None, re.split(r' ', item)) )
            for i in range(0, len(splits)):
                if i == 0: features_and_values_dict[splits[0]] = []
                else:
                    if re.match( r'values', splits[i] ): continue
                    features_and_values_dict[splits[0]].append( splits[i] )
        self._features_and_values_dict = features_and_values_dict
        bias_dict = {}
        biases = list( filter(None, re.split(r'(bias[:]\s*class[:])', bias_string )) )
        for item in biases:
            if re.match(r'bias', item): continue
            splits = list( filter(None, re.split(r' ', item)) )
            feature_name = ''
            for i in range(0, len(splits)):
                if i == 0:
                    bias_dict[splits[0]] = {}
                elif ( re.search( r'(^.+)[:]$', splits[i] ) ):
                    m = re.search(  r'(^.+)[:]$', splits[i] )
                    feature_name = m.group(1)
                    bias_dict[splits[0]][feature_name] = []
                else:
                    if not feature_name: continue
                    bias_dict[splits[0]][feature_name].append( splits[i] )
        self._bias_dict = bias_dict
        if self._debug1:
            print("\n\n") 
            print("Class names: " + str(self._class_names))
            print("\n") 
            num_of_classes = len(self._class_names)
            print("Number of classes: " + str(num_of_classes))
            print("\n")
            print("Class priors: " + str(self._class_priors))
            print("\n\n")
            print("Here are the features and their possible valuesn")
            print("\n")
            items = self._features_and_values_dict.items()
            for item in items:
                print(item[0] + " ===> " + str(item[1]))
            print("\n")
            print("Here is the biasing for each class:")
            print("\n")          
            items = self._bias_dict.items()
            for item in items:
                print("\n")
                print(item[0])
                items2 = list( item[1].items() )
                for i in range(0, len(items2)):
                    print( items2[i])

    def gen_training_data( self ):
        class_names = self._class_names
        class_priors = self._class_priors
        training_sample_records = {}
        features_and_values_dict = self._features_and_values_dict
        bias_dict  = self._bias_dict
        how_many_training_samples = self._number_of_training_samples
        class_priors_to_unit_interval_map = {}
        accumulated_interval = 0
        for i in range(0, len(class_names)):
            class_priors_to_unit_interval_map[class_names[i]] = \
            (accumulated_interval, accumulated_interval+float(class_priors[i]))
            accumulated_interval += float(class_priors[i])
        if self._debug1:
            print("Mapping of class priors to unit interval:")
            print("\n")
            items = class_priors_to_unit_interval_map.items()
            for item in items:
                print(item[0] + " ===> " + str(item[1]))
        class_and_feature_based_value_priors_to_unit_interval_map = {}
        for class_name  in class_names:
            class_and_feature_based_value_priors_to_unit_interval_map[class_name] = {}
            for feature in features_and_values_dict.keys():
                class_and_feature_based_value_priors_to_unit_interval_map[class_name][feature] = {}
        for class_name  in class_names:
            for feature in features_and_values_dict.keys():
                values = features_and_values_dict[feature]
                if len(bias_dict[class_name][feature]) > 0:
                    bias_string = bias_dict[class_name][feature][0]
                else:
                    no_bias = 1.0 / len(values)
                    bias_string = values[0] +  "=" + str(no_bias)
                value_priors_to_unit_interval_map = {}
                splits = list( filter( None, re.split(r'\s*=\s*', bias_string) ) )
                chosen_for_bias_value = splits[0]
                chosen_bias = splits[1]
                remaining_bias = 1 - float(chosen_bias)
                remaining_portion_bias = remaining_bias / (len(values) -1)
                accumulated = 0;
                for i in range(0, len(values)):
                    if (values[i] == chosen_for_bias_value):
                        value_priors_to_unit_interval_map[values[i]] = \
                          [accumulated, accumulated + float(chosen_bias)]
                        accumulated += float(chosen_bias)
                    else:
                        value_priors_to_unit_interval_map[values[i]] = \
                          [accumulated, accumulated + remaining_portion_bias]
                        accumulated += remaining_portion_bias
                class_and_feature_based_value_priors_to_unit_interval_map[class_name][feature] = value_priors_to_unit_interval_map
                if self._debug2:
                    print("\n")
                    print( "For class " + class_name + \
                       ": Mapping feature value priors for feature '" + \
                       feature + "' to unit interval: ")
                    print("\n")
                    items = value_priors_to_unit_interval_map.items()
                    for item in items:
                        print("    " + item[0] + " ===> " + str(item[1]))
        ele_index = 0
        while (ele_index < how_many_training_samples):
            sample_name = "sample" + "_" + str(ele_index)
            training_sample_records[sample_name] = []
            # Generate class label for this training sample:                
            import random
            ran = random.Random()
            roll_the_dice  = ran.randint(0,1000) / 1000.0
            class_label = ''
            for class_name  in class_priors_to_unit_interval_map.keys():
                v = class_priors_to_unit_interval_map[class_name]
                if ( (roll_the_dice >= v[0]) and (roll_the_dice <= v[1]) ):
                    training_sample_records[sample_name].append( 
                                             "class=" + class_name )
                    class_label = class_name
                    break
            for feature in sorted(list(features_and_values_dict.keys())):
                roll_the_dice  = ran.randint(0,1000) / 1000.0
                value_label = ''
                value_priors_to_unit_interval_map = \
                  class_and_feature_based_value_priors_to_unit_interval_map[class_label][feature]
                for value_name in value_priors_to_unit_interval_map.keys():
                    v = value_priors_to_unit_interval_map[value_name]
                    if ( (roll_the_dice >= v[0]) and (roll_the_dice <= v[1]) ):
                        training_sample_records[sample_name].append( \
                                            feature + "=" + value_name )
                        value_label = value_name;
                        break
            ele_index += 1
        self._training_sample_records = training_sample_records
        if self._debug2:
            print("\n\n")
            print("TERMINAL DISPLAY OF TRAINING RECORDS:")
            print("\n\n")
            sample_names = training_sample_records.keys()
            sample_names = sorted( sample_names, key=lambda x: int(x.lstrip('sample_')) )
            for sample_name in sample_names:
                print(sample_name + " => " + \
                             str(training_sample_records[sample_name]))

    def find_longest_feature_or_value(self):
        features_and_values_dict = self._features_and_values_dict
        max_length = 0
        for feature in features_and_values_dict.keys():
            if not max_length:
                max_length = len(str(feature))
            if len(str(feature)) > max_length:
                max_length = len(str(feature)) 
            values = features_and_values_dict[feature]
            for value in values:
                if len(str(value)) > max_length:
                    max_length = len(str(value)) 
        return max_length

    def write_training_data_to_file( self ):
        features_and_values_dict = self._features_and_values_dict
        class_names = self._class_names
        output_file = self._output_datafile
        training_sample_records = self._training_sample_records
        try:
            FILE = open(self._output_datafile, 'w') 
        except IOError:
            print("Unable to open file: " + self._output_datafile)
            sys.exit(1)
        class_names_string = ''
        for aname in class_names:
            class_names_string += aname + " "
        class_names_string.rstrip()
        FILE.write("Class names: %s\n\n" % class_names_string ) 
        FILE.write("Feature names and their values:\n")
        features = list( features_and_values_dict.keys() )
        if len(features) == 0:
            print("You probably forgot to call gen_training_data() before " + \
                          "calling write_training_data_to_file()") 
            sys.exit(1)
        for i in range(0, len(features)):
            values = features_and_values_dict[features[i]]
            values_string = ''
            for aname in values:
                values_string += aname + " "
            values_string.rstrip()
            FILE.write("     %(s1)s => %(s2)s\n" % {'s1':features[i], 's2':values_string} )
        FILE.write("\n\nTraining Data:\n\n")
        num_of_columns = len(features) + 2
        field_width = self.find_longest_feature_or_value() + 2
        if field_width < 12: field_width = 12
        title_string = str.ljust( "sample", field_width ) + \
                       str.ljust( "class", field_width )
        features.sort()
        for feature_name in features:
            title_string += str.ljust( str(feature_name), field_width )
        FILE.write(title_string + "\n")
        separator = '=' * len(title_string)
        FILE.write(separator + "\n")
        sample_names = list( training_sample_records.keys() )
        sample_names = sorted( sample_names, key=lambda x: int(x.lstrip('sample_')) )
        record_string = ''
        for sample_name in sample_names:
            sample_name_string = str.ljust(sample_name, field_width)
            record_string += sample_name_string
            record = training_sample_records[sample_name]
            item_parts_dict = {}
            for item in record:
                splits = list( filter(None, re.split(r'=', item)) )
                item_parts_dict[splits[0]] = splits[1]
            record_string += str.ljust(item_parts_dict["class"], field_width)
            del item_parts_dict["class"]
            kees = list(item_parts_dict.keys())
            kees.sort()
            for kee in kees:
                record_string += str.ljust(item_parts_dict[kee], field_width)
            FILE.write(record_string + "\n")
            record_string = ''
        FILE.close()


#------------------------  Generate Your Own Test Data ---------------------

class TestDataGenerator(object):
    def __init__(self, *args, **kwargs ):
        if args:
            raise ValueError(  
                   '''TestDataGenerator can only be called
                      with keyword arguments for the following
                      keywords: parameter_file, output_test_datafile,
                      output_class_labels_file, number_of_test_samples, 
                      write_to_file, debug1, and debug2''') 
        allowed_keys = 'output_test_datafile','parameter_file', \
                       'number_of_test_samples', 'write_to_file', \
                       'output_class_labels_file', 'debug1', 'debug2'
        keywords_used = kwargs.keys()
        for keyword in keywords_used:
            if keyword not in allowed_keys:
                raise ValueError("Wrong keyword used --- check spelling") 
        output_test_datafile = parameter_file = number_of_test_samples = None
        write_to_file = debug1 = debug2 = None
        if 'output_test_datafile' in kwargs : \
                    output_test_datafile = kwargs.pop('output_test_datafile')
        if 'output_class_labels_file' in kwargs : \
             output_class_labels_file =  kwargs.pop('output_class_labels_file')
        if 'parameter_file' in kwargs : \
                           parameter_file = kwargs.pop('parameter_file')
        if 'number_of_test_samples' in kwargs : \
          number_of_test_samples = kwargs.pop('number_of_test_samples')
        if 'write_to_file' in kwargs : \
                                   write_to_file = kwargs.pop('write_to_file')
        if 'debug1' in kwargs  :  debug1 = kwargs.pop('debug1')
        if 'debug2' in kwargs  :  debug2 = kwargs.pop('debug2')
        if output_test_datafile:
            self._output_test_datafile = output_test_datafile
        else:
            raise ValueError('''You must specify an output test datafile''')
        if output_class_labels_file:
            self._output_class_labels_file = output_class_labels_file
        else:
            raise ValueError('''You must specify an output file for class labels''')
        if parameter_file: 
            self._parameter_file =  parameter_file
        else:
            raise ValueError('''You must specify a parameter file''')
        if number_of_test_samples:
            self._number_of_test_samples = number_of_test_samples
        else:
            raise ValueError('''You forgot to specify the number of test samples needed''')
        if write_to_file:
            self._write_to_file = write_to_file
        else:
            self._write_to_file = 0          
        if debug1: self._debug1 = debug1
        else: self._debug1 = 0
        if debug2: self._debug2 = debug2
        else: self._debug2 = 0
        self._test_sample_records         = {}
        self._features_and_values_dict    = {}
        self._bias_dict                   = {}
        self._class_names                 = []
        self._class_priors                = []

    # Read the parameter file for generating the TEST data
    def read_parameter_file( self ):
        debug1 = self._debug1
        debug2 = self._debug2
        write_to_file = self._write_to_file
        number_of_test_samples = self._number_of_test_samples
        input_parameter_file = self._parameter_file
        all_params = []
        param_string = ''
        try:
            FILE = open(input_parameter_file, 'r') 
        except IOError:
            print("Unable to open file: " + input_parameter_file)
            sys.exit(1)
        all_params = FILE.read()

        all_params = re.split(r'\n', all_params)
        FILE.close()
        pattern = r'^(?![ ]*#)'
        try:
            regex = re.compile( pattern )
        except:
            print("error in your pattern")
            sys.exit(1)
        all_params = list( filter( regex.search, all_params ) )
        all_params = list( filter( None, all_params ) )
        all_params = [x.rstrip('\n') for x in all_params]
        param_string = ' '.join( all_params )
        pattern = '^\s*class names:(.*?)\s*class priors:(.*?)(feature: .*)'
        m = re.search( pattern, param_string )
        rest_params = m.group(3)
        self._class_names = list( filter(None, re.split(r'\s+', m.group(1))) )
        self._class_priors = list( filter(None, re.split(r'\s+', m.group(2))) )
        pattern = r'(feature:.*?) (bias:.*)'
        m = re.search( pattern, rest_params  )
        feature_string = m.group(1)
        bias_string = m.group(2)
        features_and_values_dict = {}
        features = list( filter( None, re.split( r'(feature[:])', feature_string ) ) )
        for item in features:
            if re.match(r'feature', item): continue
            splits = list( filter(None, re.split(r' ', item)) )
            for i in range(0, len(splits)):
                if i == 0: features_and_values_dict[splits[0]] = []
                else:
                    if re.match( r'values', splits[i] ): continue
                    features_and_values_dict[splits[0]].append( splits[i] )
        self._features_and_values_dict = features_and_values_dict
        bias_dict = {}
        biases = list( filter(None, re.split(r'(bias[:]\s*class[:])', bias_string )) )
        for item in biases:
            if re.match(r'bias', item): continue
            splits = list( filter(None, re.split(r' ', item)) )
            feature_name = ''
            for i in range(0, len(splits)):
                if i == 0:
                    bias_dict[splits[0]] = {}
                elif ( re.search( r'(^.+)[:]$', splits[i] ) ):
                    m = re.search(  r'(^.+)[:]$', splits[i] )
                    feature_name = m.group(1)
                    bias_dict[splits[0]][feature_name] = []
                else:
                    if not feature_name: continue
                    bias_dict[splits[0]][feature_name].append( splits[i] )
        self._bias_dict = bias_dict
        if self._debug1:
            print("\n\n")
            print("Class names: " + str(self._class_names))
            print("\n")
            num_of_classes = len(self._class_names)
            print("Number of classes: " + str(num_of_classes))
            print("\n")
            print("Class priors: " + str(self._class_priors))
            print("\n\n")
            print("Here are the features and their possible values:")
            print("\n")
            items = self._features_and_values_dict.items()
            for item in items:
                print(item[0] + " ===> " + str(item[1]))
            print("\n")
            print("Here is the biasing for each class:")
            print("\n")            
            items = self._bias_dict.items()
            for item in items:
                print("\n")
                print(item[0])
                items2 = list( item[1].items() )
                for i in range(0, len(items2)):
                    print( items2[i])

    def gen_test_data( self ):
        class_names = self._class_names
        class_priors = self._class_priors
        test_sample_records = {}
        features_and_values_dict = self._features_and_values_dict
        bias_dict  = self._bias_dict
        how_many_test_samples = self._number_of_test_samples
        file_for_class_labels = self._output_class_labels_file
        class_priors_to_unit_interval_map = {}
        accumulated_interval = 0
        for i in range(0, len(class_names)):
            class_priors_to_unit_interval_map[class_names[i]] = \
            (accumulated_interval, accumulated_interval+float(class_priors[i]))
            accumulated_interval += float(class_priors[i])
        if self._debug1:
            print("Mapping of class priors to unit interval:")
            print("\n")
            items = class_priors_to_unit_interval_map.items()
            for item in items:
                print(item[0] + " ===> " + str(item[1]))
        class_and_feature_based_value_priors_to_unit_interval_map = {}
        for class_name  in class_names:
            class_and_feature_based_value_priors_to_unit_interval_map[class_name] = {}
            for feature in features_and_values_dict.keys():
                class_and_feature_based_value_priors_to_unit_interval_map[class_name][feature] = {}
        for class_name  in class_names:
            for feature in features_and_values_dict.keys():
                values = features_and_values_dict[feature]
                if len(bias_dict[class_name][feature]) > 0:
                    bias_string = bias_dict[class_name][feature][0]
                else:
                    no_bias = 1.0 / len(values)
                    bias_string = values[0] +  "=" + str(no_bias)
                value_priors_to_unit_interval_map = {}
                splits = list( filter( None, re.split(r'\s*=\s*', bias_string) ) )
                chosen_for_bias_value = splits[0]
                chosen_bias = splits[1]
                remaining_bias = 1 - float(chosen_bias)
                remaining_portion_bias = remaining_bias / (len(values) -1)
                accumulated = 0;
                for i in range(0, len(values)):
                    if (values[i] == chosen_for_bias_value):
                        value_priors_to_unit_interval_map[values[i]] = \
                          [accumulated, accumulated + float(chosen_bias)]
                        accumulated += float(chosen_bias)
                    else:
                        value_priors_to_unit_interval_map[values[i]] = \
                          [accumulated, accumulated + remaining_portion_bias]
                        accumulated += remaining_portion_bias
                class_and_feature_based_value_priors_to_unit_interval_map[class_name][feature] = value_priors_to_unit_interval_map
                if self._debug1:
                    print("\n")
                    print("For class " + class_name + \
                       ": Mapping feature value priors for feature '" + \
                       feature + "' to unit interval:")
                    print("\n")
                    items = value_priors_to_unit_interval_map.items()
                    for item in items:
                        print("    " + item[0] + " ===> " + str(item[1]))
        ele_index = 0
        while (ele_index < how_many_test_samples):
            sample_name = "sample" + "_" + str(ele_index)
            test_sample_records[sample_name] = []
            # Generate class label for this test sample:                
            import random
            ran = random.Random()
            roll_the_dice  = ran.randint(0,1000) / 1000.0
            class_label = ''
            for class_name  in class_priors_to_unit_interval_map.keys():
                v = class_priors_to_unit_interval_map[class_name]
                if ( (roll_the_dice >= v[0]) and (roll_the_dice <= v[1]) ):
                    test_sample_records[sample_name].append( 
                                             "class=" + class_name )
                    class_label = class_name
                    break
            for feature in sorted(list(features_and_values_dict.keys())):
                roll_the_dice  = ran.randint(0,1000) / 1000.0
                value_label = ''
                value_priors_to_unit_interval_map = \
                  class_and_feature_based_value_priors_to_unit_interval_map[class_label][feature]
                for value_name in value_priors_to_unit_interval_map.keys():
                    v = value_priors_to_unit_interval_map[value_name]
                    if ( (roll_the_dice >= v[0]) and (roll_the_dice <= v[1]) ):
                        test_sample_records[sample_name].append( \
                                            feature + "=" + value_name )
                        value_label = value_name;
                        break
            ele_index += 1
        self._test_sample_records = test_sample_records
        if self._debug1:
            print("\n\n")
            print("TERMINAL DISPLAY OF TEST RECORDS:")
            print("\n\n")
            sample_names = test_sample_records.keys()
            sample_names = sorted(sample_names, key=lambda x: int(x.lstrip('sample_')))
            for sample_name in sample_names:
                print(sample_name + " => " + \
                                 str(test_sample_records[sample_name]))

    def find_longest_value(self):
        features_and_values_dict = self._features_and_values_dict
        max_length = 0
        for feature in features_and_values_dict.keys():
            values = features_and_values_dict[feature]
            for value in values:
                if not max_length:
                    max_length = len(str(value))
                if len(str(value)) > max_length:
                    max_length = len(str(value)) 
        return max_length

    def write_test_data_to_file(self):
        features_and_values_dict = self._features_and_values_dict
        class_names = self._class_names
        output_file = self._output_test_datafile
        test_sample_records = self._test_sample_records
        try:
            FILE = open(self._output_test_datafile, 'w') 
        except IOError:
            print("Unable to open file: " + self._output_test_datafile)
            sys.exit(1)
        try:
            FILE2 = open(self._output_class_labels_file, 'w') 
        except IOError:
            print("Unable to open file: " + self._output_class_labels_file)
            sys.exit(1)
        header = '''
# REQUIRED LINE FOLLOWS (the first uncommented line below):
# This line shown below must begin with the string 
#
#             "Feature Order For Data:"  
#
# What comes after this string can be any number of feature labels.  
# The feature values shown in the table in the rest of the file will 
# be considered to be in same order as shown in the next line.
                '''
        FILE.write(header + "\n\n\n")       
        title_string = "Feature Order For Data: "
        features = list(features_and_values_dict.keys())
        features.sort()
        for feature_name in features:
            title_string += str(feature_name) + " "
        title_string.rstrip()
        FILE.write(title_string + "\n\n")
        num_of_columns = len(features) + 1
        field_width = self.find_longest_value() + 2
        sample_names = test_sample_records.keys()
        sample_names = sorted(sample_names, key=lambda x: int(x.lstrip('sample_')))
        record_string = ''
        for sample_name in sample_names:
            sample_name_string = str.ljust(sample_name, 13 )
            record_string += sample_name_string
            record = test_sample_records[sample_name]
            item_parts_dict = {}
            for item in record:
                splits = list( filter(None, re.split(r'=', item)) )
                item_parts_dict[splits[0]] = splits[1]
            label_string = sample_name + " " + item_parts_dict["class"]
            FILE2.write(label_string + "\n")
            del item_parts_dict["class"]
            kees = list(item_parts_dict.keys())
            kees.sort()
            for kee in kees:
                record_string += str.ljust(item_parts_dict[kee], field_width)
            FILE.write(record_string + "\n")
            record_string = ''
        FILE.close()
        FILE2.close

#-----------------------------   Class Node  ----------------------------

# The nodes of the decision tree are instances of this class:
class Node(object):

    nodes_created = -1

    def __init__(self, feature, entropy, class_probabilities, \
                                                branch_features_and_values):
        self._serial_number = self.get_next_serial_num()
        self._feature       = feature
        self._entropy       = entropy
        self._class_probabilities = class_probabilities
        self._branch_features_and_values = branch_features_and_values
        self._linked_to = []

    def get_next_serial_num(self):
        Node.nodes_created += 1
        return Node.nodes_created

    def get_serial_num(self):
        return self._serial_number

    # This is a class method:
    @staticmethod
    def how_many_nodes():
        return Node.nodes_created

    # this returns the feature test at the current node
    def get_feature(self):
        return self._feature

    def set_feature(self, feature):
        self._feature = feature

    def get_entropy(self):
        return self._entropy

    def get_class_probabilities(self):
        return self._class_probabilities

    def get_branch_features_and_values(self):
        return self._branch_features_and_values

    def add_to_branch_features_and_values(self, feature_and_value):
        self._branch_features_and_values.append(feature_and_value)

    def get_children(self):
        return self._linked_to

    def add_child_link(self, new_node):
        self._linked_to.append(new_node)                  

    def delete_all_links(self):
        self._linked_to = None

    def display_node(self):
        feature_at_node = self.get_feature() or " "
        entropy_at_node = self.get_entropy()
        class_probabilities = self.get_class_probabilities()
        serial_num = self.get_serial_num()
        branch_features_and_values = self.get_branch_features_and_values()
        print("\n\nNODE " + str(serial_num) + ":\n   Branch features and values to this \
node: " + str(branch_features_and_values) + "\n   Class probabilities at \
current node: " + str(class_probabilities) + "\n   Entropy at current \
node: " + str(entropy_at_node) + "\n   Best feature test at current \
node: " + feature_at_node + "\n\n")

    def display_decision_tree(self, offset):
        serial_num = self.get_serial_num()
        if len(self.get_children()) > 0:
            feature_at_node = self.get_feature() or " "
            entropy_at_node = self.get_entropy()
            class_probabilities = self.get_class_probabilities()
            print("NODE " + str(serial_num) + ":  " + offset +  "feature: " + feature_at_node \
+ "   entropy: " + str(entropy_at_node) + "   class probs: " + str(class_probabilities) + "\n")
            offset += "   "
            for child in self.get_children():
                child.display_decision_tree(offset)
        else:
            entropy_at_node = self.get_entropy()
            class_probabilities = self.get_class_probabilities()
            print("NODE " + str(serial_num) + ": " + offset + "   entropy: " \
+ str(entropy_at_node) + "    class probs: " + str(class_probabilities) + "\n")


#----------------------------  Test Code Follows  -----------------------

if __name__ == '__main__':

    dt = DecisionTree( training_datafile = "training.dat",  
                        max_depth_desired = 2,
                        entropy_threshold = 0.1,
                        debug1 = 1,
                     )
    dt.get_training_data()

    dt.show_training_data()

    prob = dt.prior_probability_for_class( 'benign' )
    print("prior for benign: ", prob)
    prob = dt.prior_probability_for_class( 'malignant' )
    print("prior for malignant: ", prob)

    prob = dt.probability_for_feature_value( 'smoking', 'heavy')
    print(prob)

    dt.determine_data_condition()

    root_node = dt.construct_decision_tree_classifier()
    root_node.display_decision_tree("   ")

    test_sample = ['exercising=>never', 'smoking=>heavy', 'fatIntake=>heavy', 'videoAddiction=>heavy']
    classification = dt.classify(root_node, test_sample)
    print("Classification: ", classification)

    test_sample = ['videoAddiction=>none', 'exercising=>occasionally', 'smoking=>never', 'fatIntake=>medium']
    classification = dt.classify(root_node, test_sample)
    print("Classification: ", classification)

    print("Number of nodes created: ", root_node.how_many_nodes())