#!/usr/bin/env python

__version__ = '1.6'
__author__  = "Avinash Kak (kak@purdue.edu)"
__date__    = '2012-June-20'
__url__     = 'https://engineering.purdue.edu/kak/distDT/DecisionTree-1.6.html'
__copyright__ = "(C) 2012 Avinash Kak. Python Software Foundation."


import math
import re
import sys
import functools 


#-------------------------  Utility Functions ---------------------------

def sample_index(sample_name):
    '''
    We assume that every record in the training datafile begins with
    the name of the record. The only requirement is that these names
    end in a suffix like '_23', as in 'sample_23' for the 23rd training
    record.  This function returns the integer in the suffix.
    '''
    m = re.search('_(.+)$', sample_name)
    return int(m.group(1))

# Meant only for an array of strings (no nesting):
def deep_copy_array(array_in):
    array_out = []
    for i in range(0, len(array_in)):
        array_out.append( array_in[i] )
    return array_out

# Returns simultaneously the minimum value and its positional index in an
# array. [Could also have used min() and index() defined for Python's
# sequence types.]
def minimum(arr):
    min,index = None,None
    for i in range(0, len(arr)):  
        if min is None or arr[i] < min:
            index = i
            min = arr[i]
    return min,index


#---------------------- DecisionTree Class Definition ---------------------

class DecisionTree(object):

    def __init__(self, *args, **kwargs ):
        if args:
            raise ValueError(  
                   '''DecisionTree constructor can only be called
                      with keyword arguments for the following
                      keywords: training_datafile, entropy_threshold,
                      max_depth_desired, debug1, and debug2''') 

        allowed_keys = 'training_datafile','entropy_threshold', \
                       'max_depth_desired','debug1','debug2'
        keywords_used = kwargs.keys()
        for keyword in keywords_used:
            if keyword not in allowed_keys:
                raise ValueError("Wrong keyword used --- check spelling") 

        training_datafile = entropy_threshold = max_depth_desired = None
        debug1 = debug2 = None

        if 'training_datafile' in kwargs : \
                           training_datafile = kwargs.pop('training_datafile')
        if 'entropy_threshold' in kwargs : \
                           entropy_threshold = kwargs.pop('entropy_threshold')
        if 'max_depth_desired' in kwargs : \
                           max_depth_desired = kwargs.pop('max_depth_desired')
        if 'debug1' in kwargs  :  debug1 = kwargs.pop('debug1')
        if 'debug2' in kwargs  :  debug2 = kwargs.pop('debug2')

        if training_datafile:
            self._training_datafile = training_datafile
        else:
            raise ValueError('''You must specify a training datafile''')
        if entropy_threshold: 
            self._entropy_threshold =  entropy_threshold
        else:
            self._entropy_threshold =  0.001        
        if max_depth_desired:
            self._max_depth_desired = max_depth_desired 
        else:
            self._max_depth_desired = None
        if debug1:
            self._debug1 = debug1
        else:
            self._debug1 = 0
        if debug2:
            self._debug2 = debug2
        else:
            self._debug2 = 0
        self._root_node = None
        self._probability_cache           = {}
        self._entropy_cache               = {}
        self._training_data_dict          = {}
        self._features_and_values_dict    = {}
        self._samples_class_label_dict    = {}
        self._class_names                 = []
        self._class_priors                = []
        self._feature_names               = []

    def get_training_data(self):
        recording_features_flag = 0
        recording_training_data = 0
        table_header = None
        column_labels_dict = {}
        FILE = None
        try:
            FILE = open( self._training_datafile )
        except IOError:
            print("unable to open %s" % self._training_datafile)
            sys.exit(1)
        for line in FILE:
            line = line.rstrip()
            lineskip = r'^[\s=#]*$'
            if re.search(lineskip, line): 
                continue
            elif re.search(r'\s*class', line, re.IGNORECASE) \
                       and not recording_training_data \
                       and not recording_features_flag:
                classpattern = r'^\s*class names:\s*([ \S]+)\s*'
                m = re.search(classpattern, line, re.IGNORECASE)
                if not m: 
                    raise ValueError('''No class names in training file''')
                self._class_names = m.group(1).split()
                continue
            elif re.search(r'\s*feature names and their values', \
                               line, re.IGNORECASE):
                recording_features_flag = 1
                continue
            elif re.search(r'training data', line, re.IGNORECASE):
                recording_training_data = 1
                recording_features_flag = 0
                continue
            elif not recording_training_data and recording_features_flag:
                feature_name_value_pattern = r'^\s*(\S+)\s*=>\s*(.+)'
                m = re.search(feature_name_value_pattern, line, re.IGNORECASE)
                feature_name = m.group(1)
                feature_values = m.group(2).split()
                self._features_and_values_dict[feature_name]  = feature_values
            elif recording_training_data:
                if not table_header:
                    table_header = line.split()
                    for i in range(2, len(table_header)):
                        column_labels_dict[i] = table_header[i]
                    continue
                record = line.split()
                self._samples_class_label_dict[record[0]] = record[1]
                self._training_data_dict[record[0]] = []
                for i in range(2, len(record)):
                    self._training_data_dict[record[0]].append(
                          column_labels_dict[i] + "=>" + record[i] )
        FILE.close()                        
        self._feature_names = self._features_and_values_dict.keys()
        if self._debug2:
            print("Class names: ", self._class_names)
            print( "Feature names: ", self._feature_names)
            print("Features and values: ", self._features_and_values_dict.items())
        for feature in self._feature_names:
            values_for_feature = self._features_and_values_dict[feature]
            for value in values_for_feature:
                feature_and_value = "".join([feature, "=>", value])
                self._probability_cache[feature_and_value] = \
                       self.probability_for_feature_value(feature, value)

    def show_training_data(self):
        print("Class names: ", self._class_names)
        print("\n\nFeatures and Their Possible Values:\n\n")
        features = self._features_and_values_dict.keys()
        for feature in sorted(features):
            print("%s ---> %s" \
                  % (feature, self._features_and_values_dict[feature]))
        print("\n\nSamples vs. Class Labels:\n\n")
        for item in sorted(self._samples_class_label_dict.items(), \
                key = lambda x: sample_index(x[0]) ):
            print(item)
        print("\n\nTraining Samples:\n\n")
        for item in sorted(self._training_data_dict.items(), \
                key = lambda x: sample_index(x[0]) ):
            print(item)


#------------------    Classify with Decision Tree  -----------------------

    def classify(self, root_node, features_and_values):
        if not self.check_names_used(features_and_values):
            raise ValueError("Error in the names you have used for features and/or values") 
        classification = self.recursive_descent_for_classification( \
                                    root_node, features_and_values )
        if self._debug1:
            print("\nThe classification:")
            for class_name in self._class_names:
                print("    " + class_name + " with probability " + \
                                        str(classification[class_name]))
        return classification

    def recursive_descent_for_classification(self, node, feature_and_values):
        feature_test_at_node = node.get_feature()
        value_for_feature = None
        remaining_features_and_values = []
        for feature_and_value in feature_and_values:
            pattern = r'(.+)=>(.+)'
            m = re.search(pattern, feature_and_value)
            feature,value = m.group(1),m.group(2)
            if feature == feature_test_at_node:
                value_for_feature = value
            else:
                remaining_features_and_values.append(feature_and_value)
        if feature_test_at_node:
            feature_value_combo = \
                    "".join([feature_test_at_node,"=>",value_for_feature])
        children = node.get_children()
        answer = {}
        if len(children) == 0:
            leaf_node_class_probabilities = node.get_class_probabilities()
            for i in range(0, len(self._class_names)):
                answer[self._class_names[i]] = leaf_node_class_probabilities[i]
            return answer
        for child in children:
            branch_features_and_values = child.get_branch_features_and_values()
            last_feature_and_value_on_branch = branch_features_and_values[-1] 
            if last_feature_and_value_on_branch == feature_value_combo:
                answer = self.recursive_descent_for_classification(child, \
                                        remaining_features_and_values)
                break
        return answer

    def classify_by_asking_questions(self, root_node):
        classification = self.interactive_recursive_descent_for_classification(root_node)
        return classification

    def interactive_recursive_descent_for_classification(self, node):
        feature_test_at_node = node.get_feature()
        possible_values_for_feature = \
                   self._features_and_values_dict[feature_test_at_node]
        value_for_feature = None
        while 1:
            value_for_feature = None
            if sys.version_info[0] == 3:
                value_for_feature = \
                   input( "\nWhat is the value for the feature '" + \
               feature_test_at_node + "'?" + "\n" +    \
               "Enter one of: " + str(possible_values_for_feature) + " => " )
            else:
                value_for_feature = \
                   raw_input( "\nWhat is the value for the feature '" + \
               feature_test_at_node + "'?" + "\n" +    \
               "Enter one of: " + str(possible_values_for_feature) + " => " )
            value_for_feature = value_for_feature.strip()
            answer_found = 0
            for value in possible_values_for_feature:
                if value == value_for_feature: 
                    answer_found = 1
                    break
            if answer_found == 1: break
            print("\n")
            print("You entered illegal value. Let's try again")
            print("\n")
        feature_value_combo = \
                "".join([feature_test_at_node,"=>",value_for_feature])
        children = node.get_children()
        answer = {}
        if len(children) == 0:
            leaf_node_class_probabilities = node.get_class_probabilities()
            for i in range(0, len(self._class_names)):
                answer[self._class_names[i]] = leaf_node_class_probabilities[i]
            return answer
        for child in children:
            branch_features_and_values = child.get_branch_features_and_values()
            last_feature_and_value_on_branch = branch_features_and_values[-1] 
            if last_feature_and_value_on_branch == feature_value_combo:
                answer = self.interactive_recursive_descent_for_classification(child)
                break
        return answer

#----------------------  Construct Decision Tree  -------------------------- 

    def construct_decision_tree_classifier(self):
        if self._debug1:        
            self.determine_data_condition() 
            print("\nStarting construction of the decision tree:\n") 
        class_probabilities = \
          list(map(lambda x: self.prior_probability_for_class(x), \
                                                   self._class_names))
        entropy = self.class_entropy_on_priors()
        root_node = Node(None, entropy, class_probabilities, [])
        self._root_node = root_node
        self.recursive_descent(root_node)
        return root_node        

    def recursive_descent(self, node):
        features_and_values_on_branch = node.get_branch_features_and_values()
        best_feature, best_feature_entropy =  \
         self.best_feature_calculator(features_and_values_on_branch)
        node.set_feature(best_feature)
        if self._debug1: node.display_node() 
        if self._max_depth_desired is not None and \
         len(features_and_values_on_branch) >= self._max_depth_desired:
            return
        if best_feature is None: return
        if best_feature_entropy \
                   < node.get_entropy() - self._entropy_threshold:
            values_for_feature = \
                  self._features_and_values_dict[best_feature]
            feature_value_combos = \
              map(lambda x: "".join([best_feature,"=>",x]), values_for_feature)
            for feature_and_value in feature_value_combos:
                extended_branch_features_and_values = None
                if features_and_values_on_branch is None:
                    extended_branch_features_and_values = feature_and_value
                else:
                    extended_branch_features_and_values = \
                        deep_copy_array( features_and_values_on_branch )
                    extended_branch_features_and_values.append(\
                                                      feature_and_value)
                class_probabilities = list(map(lambda x: \
         self.probability_for_a_class_given_sequence_of_features_and_values(\
                x, extended_branch_features_and_values), self._class_names))
                child_node = Node(None, best_feature_entropy, \
                     class_probabilities, extended_branch_features_and_values)
                node.add_child_link( child_node )
                self.recursive_descent(child_node)

    def best_feature_calculator(self, features_and_values_on_branch):
        features_already_used = []
        for feature_and_value in features_and_values_on_branch:
            pattern = r'(.+)=>(.+)'
            m = re.search(pattern, feature_and_value)
            feature = m.group(1)
            features_already_used.append(feature)
        feature_tests_not_yet_used = []
        for feature in self._feature_names:
            if (feature not in features_already_used):
                feature_tests_not_yet_used.append(feature)
        if len(feature_tests_not_yet_used) == 0: return None, None
        array_of_entropy_values_for_different_features = []
        for i in range(0, len(feature_tests_not_yet_used)):
            values = \
             self._features_and_values_dict[feature_tests_not_yet_used[i]]
            entropy_for_new_feature = None
            for value in values:
                feature_and_value_string = \
                   "".join([feature_tests_not_yet_used[i], "=>", value]) 
                extended_features_and_values_on_branch = None
                if len(features_and_values_on_branch) > 0:
                    extended_features_and_values_on_branch =  \
                          deep_copy_array(features_and_values_on_branch)
                    extended_features_and_values_on_branch.append(  \
                                              feature_and_value_string) 
                else:
                    extended_features_and_values_on_branch  =    \
                        [feature_and_value_string]
                if entropy_for_new_feature is None:
                    entropy_for_new_feature =  \
                   self.class_entropy_for_a_given_sequence_of_features_values(\
                             extended_features_and_values_on_branch) \
                     * \
                     self.probability_of_a_sequence_of_features_and_values( \
                         extended_features_and_values_on_branch)
                    continue
                else:
                    entropy_for_new_feature += \
                  self.class_entropy_for_a_given_sequence_of_features_values(\
                         extended_features_and_values_on_branch) \
                     *  \
                     self.probability_of_a_sequence_of_features_and_values( \
                         extended_features_and_values_on_branch)
            array_of_entropy_values_for_different_features.append(\
                                         entropy_for_new_feature)
        min,index = minimum(array_of_entropy_values_for_different_features)
        return feature_tests_not_yet_used[index], min


#--------------------------  Entropy Calculators  --------------------------

    def class_entropy_on_priors(self):
        if 'priors' in self._entropy_cache:
            return self._entropy_cache['priors']
        entropy = None
        for class_name in self._class_names:
            prob = self.prior_probability_for_class(class_name)
            if (prob >= 0.0001) and (prob <= 0.999):
                log_prob = math.log(prob,2)
            if prob < 0.0001:
                log_prob = 0 
            if prob > 0.999:
                log_prob = 0 
            if entropy is None:
                entropy = -1.0 * prob * log_prob
                continue
            entropy += -1.0 * prob * log_prob
        self._entropy_cache['priors'] = entropy
        return entropy

    def class_entropy_for_a_given_feature_and_given_value(self, \
                                feature_name, feature_value):
        feature_and_value = "".join([feature_name, "=>", feature_value])
        if feature_and_value in self._entropy_cache:
            return self._entropy_cache[feature_and_value]
        entropy = None
        for class_name in self._class_names:
            prob = self.probability_for_a_class_given_feature_value( \
                                class_name, feature_name, feature_value)
            if (prob >= 0.0001) and (prob <= 0.999):
                log_prob = math.log(prob,2)
            if prob < 0.0001:  
                log_prob = 0 
            if prob > 0.999:   
                log_prob = 0 
            if entropy is None:
                entropy = -1.0 * prob * log_prob
                continue
            entropy += - (prob * log_prob)
        self._entropy_cache[feature_and_value] = entropy 
        return entropy

    def class_entropy_for_a_given_feature(self, feature_name):
        if feature_name in self._entropy_cache:
            return self._entropy_cache[feature_name]
        entropy = None
        for feature_value in self._features_and_values_dict[feature_name]:
            if entropy is None:
                entropy = \
                    self.class_entropy_for_a_given_feature_and_given_value(\
                                              feature_name, feature_value) \
                    * \
                    self.probability_for_feature_value( \
                                       feature_name,feature_value)
                continue
            entropy += \
                self.class_entropy_for_a_given_feature_and_given_value( \
                                              feature_name, feature_value) \
                *  \
                self.probability_for_feature_value(feature_name,feature_value)
        self._entropy_cache[feature_name] = entropy
        return entropy

    def class_entropy_for_a_given_sequence_of_features_values(self, \
                                       array_of_features_and_values):
        sequence = ":".join(array_of_features_and_values)
        if sequence in self._entropy_cache:
            return self._entropy_cache[sequence]
        entropy = None    
        for class_name in self._class_names:
            prob = \
           self.probability_for_a_class_given_sequence_of_features_and_values(\
                 class_name, array_of_features_and_values)
            if prob == 0:
                prob = 1.0/len(self._class_names)
            if (prob >= 0.0001) and (prob <= 0.999):
                log_prob = math.log(prob,2)
            if prob < 0.0001:
                log_prob = 0 
            if prob > 0.999:
                log_prob = 0 
            if entropy is None:
                entropy = -1.0 * prob * log_prob
                continue
            entropy += -1.0 * prob * log_prob
        self._entropy_cache[sequence] = entropy
        return entropy


#-------------------------  Probability Calculators ------------------------

    def prior_probability_for_class(self, class_name):
        class_name_in_cache = "".join(["prior::", class_name])
        if class_name_in_cache in self._probability_cache:
            return self._probability_cache[class_name_in_cache]
        total_num_of_samples = len( self._samples_class_label_dict )
        all_values = self._samples_class_label_dict.values()
        for this_class_name in self._class_names:
            trues = list(filter( lambda x: x == this_class_name, all_values ))
            prior_for_this_class = (1.0 * len(trues)) / total_num_of_samples
            this_class_name_in_cache = "".join(["prior::", this_class_name])
            self._probability_cache[this_class_name_in_cache] = \
                                                    prior_for_this_class
        return self._probability_cache[class_name_in_cache]

    def probability_for_feature_value(self, feature, value):
        feature_and_value = "".join([feature, "=>", value])
        if feature_and_value in self._probability_cache:
            return self._probability_cache[feature_and_value]
        values_for_feature = self._features_and_values_dict[feature]
        values_for_feature = list(map(lambda x: feature + "=>" + x, \
                                                   values_for_feature))
        value_counts = [0] * len(values_for_feature)
        for sample in sorted(self._training_data_dict.keys(), \
                key = lambda x: sample_index(x) ):
            features_and_values = self._training_data_dict[sample]
            for i in range(0, len(values_for_feature)):
                for current_value in features_and_values:
                    if values_for_feature[i] == current_value:
                        value_counts[i] += 1 
        for i in range(0, len(values_for_feature)):
            self._probability_cache[values_for_feature[i]] = \
                      value_counts[i] / (1.0 * len(self._training_data_dict))
        if feature_and_value in self._probability_cache:
            return self._probability_cache[feature_and_value]
        else:
            return 0

    def probability_for_feature_value_given_class(self, feature_name, \
                                        feature_value, class_name):
        feature_value_class = \
             "".join([feature_name,"=>",feature_value,"::",class_name])
        if feature_value_class in self._probability_cache:
            return self._probability_cache[feature_value_class]
        samples_for_class = []
        for sample_name in self._samples_class_label_dict.keys():
            if self._samples_class_label_dict[sample_name] == class_name:
                samples_for_class.append(sample_name) 
        values_for_feature = self._features_and_values_dict[feature_name]
        values_for_feature = \
        list(map(lambda x: "".join([feature_name,"=>",x]), values_for_feature))
        value_counts = [0] * len(values_for_feature)
        for sample in samples_for_class:
            features_and_values = self._training_data_dict[sample]
            for i in range(0, len(values_for_feature)):
                for current_value in (features_and_values):
                    if values_for_feature[i] == current_value:
                        value_counts[i] += 1 
        total_count = functools.reduce(lambda x,y:x+y, value_counts)
        for i in range(0, len(values_for_feature)):
            feature_and_value_for_class = \
                     "".join([values_for_feature[i],"::",class_name])
            self._probability_cache[feature_and_value_for_class] = \
                                       value_counts[i] / (1.0 * total_count)
        feature_and_value_and_class = \
              "".join([feature_name, "=>", feature_value,"::",class_name])
        if feature_and_value_and_class in self._probability_cache:
            return self._probability_cache[feature_and_value_and_class]
        else:
            return 0

    def probability_of_a_sequence_of_features_and_values(self, \
                                        array_of_features_and_values):
        sequence = ":".join(array_of_features_and_values)
        if sequence in self._probability_cache:
            return self._probability_cache[sequence]
        probability = None
        for feature_and_value in array_of_features_and_values:
            pattern = r'(.+)=>(.+)'
            m = re.search(pattern, feature_and_value)
            feature,value = m.group(1),m.group(2)
            if probability is None:
                probability = \
                   self.probability_for_feature_value(feature, value)
                continue
            else:
                probability *= \
                  self.probability_for_feature_value(feature, value)
        self._probability_cache[sequence] = probability
        return probability

    def probability_for_sequence_of_features_and_values_given_class(self, \
                            array_of_features_and_values, class_name):
        sequence = ":".join(array_of_features_and_values)
        sequence_with_class = "".join([sequence, "::", class_name])
        if sequence_with_class in self._probability_cache:
            return self._probability_cache[sequence_with_class]
        probability = None
        for feature_and_value in array_of_features_and_values:
            pattern = r'(.+)=>(.+)'
            m = re.search(pattern, feature_and_value)
            feature,value = m.group(1),m.group(2)
            if probability is None:
                probability = self.probability_for_feature_value_given_class(\
                                                 feature, value, class_name)
                continue
            else:
                probability *= self.probability_for_feature_value_given_class(\
                                           feature, value, class_name)
        self._probability_cache[sequence_with_class] = probability
        return probability 

    def probability_for_a_class_given_feature_value(self, class_name, \
                                              feature_name, feature_value):
        prob = self.probability_for_feature_value_given_class( \
                                 feature_name, feature_value, class_name)
        answer = (prob * self.prior_probability_for_class(class_name)) \
                 /                                                     \
                 self.probability_for_feature_value(feature_name,feature_value)
        return answer

    def probability_for_a_class_given_sequence_of_features_and_values(self, \
                    class_name, array_of_features_and_values):
        sequence = ":".join(array_of_features_and_values)
        class_and_sequence = "".join([class_name, "::", sequence])
        if class_and_sequence in self._probability_cache:
            return self._probability_cache[class_and_sequence]
        array_of_class_probabilities = [0] * len(self._class_names)
        for i in range(0, len(self._class_names)):
            prob = \
            self.probability_for_sequence_of_features_and_values_given_class(\
                   array_of_features_and_values, self._class_names[i]) 
            if prob == 0:
                array_of_class_probabilities[i] = 0 
                continue
            prob_of_feature_sequence = \
                self.probability_of_a_sequence_of_features_and_values(  \
                                              array_of_features_and_values)
            prior = self.prior_probability_for_class(self._class_names[i])
            array_of_class_probabilities[i] =   \
               prob * self.prior_probability_for_class(self._class_names[i]) \
                 / prob_of_feature_sequence
        sum_probability = \
          functools.reduce(lambda x,y:x+y, array_of_class_probabilities)
        if sum_probability == 0:
            array_of_class_probabilities = [1.0/len(self._class_names)] \
                                              * len(self._class_names)
        else:
            array_of_class_probabilities = \
                     list(map(lambda x: x / sum_probability,\
                               array_of_class_probabilities))
        for i in range(0, len(self._class_names)):
            this_class_and_sequence = \
                     "".join([self._class_names[i], "::", sequence])
            self._probability_cache[this_class_and_sequence] = \
                                     array_of_class_probabilities[i]
        return self._probability_cache[class_and_sequence]


#---------------------  Class Based Utilities  ---------------------

    def determine_data_condition(self):
        num_of_features = len(self._feature_names)
        values = list(self._features_and_values_dict.values())
        print("Number of features: ", num_of_features)
        max_num_values = 0
        for i in range(0, len(values)):
            if ((not max_num_values) or (len(values[i]) > max_num_values)):
                max_num_values = len(values[i])
        print("Largest number of feature values is: ", max_num_values)
        estimated_number_of_nodes = max_num_values ** num_of_features
        print("\nWORST CASE SCENARIO: The decision tree COULD have as many \
as   \n   %d nodes. The exact number of nodes created depends\n   critically \
on the entropy_threshold used for node expansion.\n   (The default for this \
threshold is 0.001.)\n" % (estimated_number_of_nodes))
        if estimated_number_of_nodes > 10000:
            print("\nTHIS IS WAY TOO MANY NODES. Consider using a relatively \
large\n   value for entropy_threshold to reduce the number of nodes created.\n")
            ans = raw_input("\nDo you wish to continue? Enter 'y' if yes:  ")
            if ans != 'y':
                sys.exit(0)
        print("\nHere are the probabilities of feature-value pairs in your data:\n\n")
        for feature in self._feature_names:
            values_for_feature = self._features_and_values_dict[feature]
            for value in values_for_feature:
                prob = self.probability_for_feature_value(feature,value) 
                print("Probability of feature-value pair (%s,%s): %.3f" % \
                                                (feature,value,prob)) 

    def check_names_used(self, features_and_values_test_data):
        for feature_and_value in features_and_values_test_data:
            pattern = r'(.+)=>(.+)'
            m = re.search(pattern, feature_and_value)
            feature,value = m.group(1),m.group(2)
            if feature is None or value is None:
                raise ValueError("Your test data has formatting error")
            if feature not in self._feature_names:
                return 0
            if value not in self._features_and_values_dict[feature]:
                return 0
        return 1

    def get_class_names(self):
        return self._class_names


#----------------  Generate Your Own Training Data  ----------------

class TrainingDataGenerator(object):
    def __init__(self, *args, **kwargs ):
        if args:
            raise ValueError(  
                   '''TrainingDataGenerator can only be called
                      with keyword arguments for the following
                      keywords: output_datafile, parameter_file,
                      number_of_training_samples, write_to_file,
                      debug1, and debug2''') 
        allowed_keys = 'output_datafile','parameter_file', \
                       'number_of_training_samples', 'write_to_file', \
                       'debug1','debug2'
        keywords_used = kwargs.keys()
        for keyword in keywords_used:
            if keyword not in allowed_keys:
                raise ValueError("Wrong keyword used --- check spelling") 

        output_datafile = parameter_file = number_of_training_samples = None
        write_to_file = debug1 = debug2 = None

        if 'output_datafile' in kwargs : \
                           output_datafile = kwargs.pop('output_datafile')
        if 'parameter_file' in kwargs : \
                           parameter_file = kwargs.pop('parameter_file')
        if 'number_of_training_samples' in kwargs : \
          number_of_training_samples = kwargs.pop('number_of_training_samples')
        if 'write_to_file' in kwargs : \
                                   write_to_file = kwargs.pop('write_to_file')
        if 'debug1' in kwargs  :  debug1 = kwargs.pop('debug1')
        if 'debug2' in kwargs  :  debug2 = kwargs.pop('debug2')

        if output_datafile:
            self._output_datafile = output_datafile
        else:
            raise ValueError('''You must specify an output datafile''')
        if parameter_file: 
            self._parameter_file =  parameter_file
        else:
            raise ValueError('''You must specify a parameter file''')
        if number_of_training_samples:
            self._number_of_training_samples = number_of_training_samples
        else:
            raise ValueError('''You forgot to specify the number of training samples needed''')
        if write_to_file:
            self._write_to_file = write_to_file
        else:
            self._write_to_file = 0          
        if debug1:
            self._debug1 = debug1
        else:
            self._debug1 = 0
        if debug2:
            self._debug2 = debug2
        else:
            self._debug2 = 0
        self._training_sample_records     = {}
        self._features_and_values_dict    = {}
        self._bias_dict                   = {}
        self._class_names                 = []
        self._class_priors                = []

    # Read the parameter for generating the TRAINING data
    def read_parameter_file( self ):
        debug1 = self._debug1
        debug2 = self._debug2
        write_to_file = self._write_to_file
        number_of_training_samples = self._number_of_training_samples
        input_parameter_file = self._parameter_file
        all_params = []
        param_string = ''
        try:
            FILE = open(input_parameter_file, 'r')
        except IOError:
            print("unable to open %s" % input_parameter_file)
            sys.exit(1)
        all_params = FILE.read()
        all_params = re.split(r'\n', all_params)
        FILE.close()
        pattern = r'^(?![ ]*#)'
        try:
            regex = re.compile( pattern )
        except:
            print("error in your pattern")
            sys.exit(1)
        all_params = list( filter( regex.search, all_params ) )
        all_params = list( filter( None, all_params ) )
        all_params = [x.rstrip('\n') for x in all_params]
        param_string = ' '.join( all_params )
        pattern = '^\s*class names:(.*?)\s*class priors:(.*?)(feature: .*)'
        m = re.search( pattern, param_string )
        rest_params = m.group(3)
        self._class_names = list( filter(None, re.split(r'\s+', m.group(1))) )
        self._class_priors = list( filter(None, re.split(r'\s+', m.group(2))) )
        pattern = r'(feature:.*?) (bias:.*)'
        m = re.search( pattern, rest_params  )
        feature_string = m.group(1)
        bias_string = m.group(2)
        features_and_values_dict = {}
        features = list( filter( None, re.split( r'(feature[:])', feature_string ) ) )
        for item in features:
            if re.match(r'feature', item): continue
            splits = list( filter(None, re.split(r' ', item)) )
            for i in range(0, len(splits)):
                if i == 0: features_and_values_dict[splits[0]] = []
                else:
                    if re.match( r'values', splits[i] ): continue
                    features_and_values_dict[splits[0]].append( splits[i] )
        self._features_and_values_dict = features_and_values_dict
        bias_dict = {}
        biases = list( filter(None, re.split(r'(bias[:]\s*class[:])', bias_string )) )
        for item in biases:
            if re.match(r'bias', item): continue
            splits = list( filter(None, re.split(r' ', item)) )
            feature_name = ''
            for i in range(0, len(splits)):
                if i == 0:
                    bias_dict[splits[0]] = {}
                elif ( re.search( r'(^.+)[:]$', splits[i] ) ):
                    m = re.search(  r'(^.+)[:]$', splits[i] )
                    feature_name = m.group(1)
                    bias_dict[splits[0]][feature_name] = []
                else:
                    if not feature_name: continue
                    bias_dict[splits[0]][feature_name].append( splits[i] )
        self._bias_dict = bias_dict
        if self._debug1:
            print("\n\n") 
            print("Class names: " + str(self._class_names))
            print("\n") 
            num_of_classes = len(self._class_names)
            print("Number of classes: " + str(num_of_classes))
            print("\n")
            print("Class priors: " + str(self._class_priors))
            print("\n\n")
            print("Here are the features and their possible valuesn")
            print("\n")
            items = self._features_and_values_dict.items()
            for item in items:
                print(item[0] + " ===> " + str(item[1]))
            print("\n")
            print("Here is the biasing for each class:")
            print("\n")          
            items = self._bias_dict.items()
            for item in items:
                print("\n")
                print(item[0])
                items2 = list( item[1].items() )
                for i in range(0, len(items2)):
                    print( items2[i])

    def gen_training_data( self ):
        class_names = self._class_names
        class_priors = self._class_priors
        training_sample_records = {}
        features_and_values_dict = self._features_and_values_dict
        bias_dict  = self._bias_dict
        how_many_training_samples = self._number_of_training_samples
        class_priors_to_unit_interval_map = {}
        accumulated_interval = 0
        for i in range(0, len(class_names)):
            class_priors_to_unit_interval_map[class_names[i]] = \
            (accumulated_interval, accumulated_interval+float(class_priors[i]))
            accumulated_interval += float(class_priors[i])
        if self._debug1:
            print("Mapping of class priors to unit interval:")
            print("\n")
            items = class_priors_to_unit_interval_map.items()
            for item in items:
                print(item[0] + " ===> " + str(item[1]))
        class_and_feature_based_value_priors_to_unit_interval_map = {}
        for class_name  in class_names:
            class_and_feature_based_value_priors_to_unit_interval_map[class_name] = {}
            for feature in features_and_values_dict.keys():
                class_and_feature_based_value_priors_to_unit_interval_map[class_name][feature] = {}
        for class_name  in class_names:
            for feature in features_and_values_dict.keys():
                values = features_and_values_dict[feature]
                if len(bias_dict[class_name][feature]) > 0:
                    bias_string = bias_dict[class_name][feature][0]
                else:
                    no_bias = 1.0 / len(values)
                    bias_string = values[0] +  "=" + str(no_bias)
                value_priors_to_unit_interval_map = {}
                splits = list( filter( None, re.split(r'\s*=\s*', bias_string) ) )
                chosen_for_bias_value = splits[0]
                chosen_bias = splits[1]
                remaining_bias = 1 - float(chosen_bias)
                remaining_portion_bias = remaining_bias / (len(values) -1)
                accumulated = 0;
                for i in range(0, len(values)):
                    if (values[i] == chosen_for_bias_value):
                        value_priors_to_unit_interval_map[values[i]] = \
                          [accumulated, accumulated + float(chosen_bias)]
                        accumulated += float(chosen_bias)
                    else:
                        value_priors_to_unit_interval_map[values[i]] = \
                          [accumulated, accumulated + remaining_portion_bias]
                        accumulated += remaining_portion_bias
                class_and_feature_based_value_priors_to_unit_interval_map[class_name][feature] = value_priors_to_unit_interval_map
                if self._debug2:
                    print("\n")
                    print( "For class " + class_name + \
                       ": Mapping feature value priors for feature '" + \
                       feature + "' to unit interval: ")
                    print("\n")
                    items = value_priors_to_unit_interval_map.items()
                    for item in items:
                        print("    " + item[0] + " ===> " + str(item[1]))
        ele_index = 0
        while (ele_index < how_many_training_samples):
            sample_name = "sample" + "_" + str(ele_index)
            training_sample_records[sample_name] = []
            # Generate class label for this training sample:                
            import random
            ran = random.Random()
            roll_the_dice  = ran.randint(0,1000) / 1000.0
            class_label = ''
            for class_name  in class_priors_to_unit_interval_map.keys():
                v = class_priors_to_unit_interval_map[class_name]
                if ( (roll_the_dice >= v[0]) and (roll_the_dice <= v[1]) ):
                    training_sample_records[sample_name].append( 
                                             "class=" + class_name )
                    class_label = class_name
                    break
            for feature in sorted(list(features_and_values_dict.keys())):
                roll_the_dice  = ran.randint(0,1000) / 1000.0
                value_label = ''
                value_priors_to_unit_interval_map = \
                  class_and_feature_based_value_priors_to_unit_interval_map[class_label][feature]
                for value_name in value_priors_to_unit_interval_map.keys():
                    v = value_priors_to_unit_interval_map[value_name]
                    if ( (roll_the_dice >= v[0]) and (roll_the_dice <= v[1]) ):
                        training_sample_records[sample_name].append( \
                                            feature + "=" + value_name )
                        value_label = value_name;
                        break
            ele_index += 1
        self._training_sample_records = training_sample_records
        if self._debug2:
            print("\n\n")
            print("TERMINAL DISPLAY OF TRAINING RECORDS:")
            print("\n\n")
            sample_names = training_sample_records.keys()
            sample_names = sorted( sample_names, key=lambda x: int(x.lstrip('sample_')) )
            for sample_name in sample_names:
                print(sample_name + " => " + \
                             str(training_sample_records[sample_name]))

    def find_longest_feature_or_value(self):
        features_and_values_dict = self._features_and_values_dict
        max_length = 0
        for feature in features_and_values_dict.keys():
            if not max_length:
                max_length = len(str(feature))
            if len(str(feature)) > max_length:
                max_length = len(str(feature)) 
            values = features_and_values_dict[feature]
            for value in values:
                if len(str(value)) > max_length:
                    max_length = len(str(value)) 
        return max_length

    def write_training_data_to_file( self ):
        features_and_values_dict = self._features_and_values_dict
        class_names = self._class_names
        output_file = self._output_datafile
        training_sample_records = self._training_sample_records
        try:
            FILE = open(self._output_datafile, 'w') 
        except IOError:
            print("Unable to open file: " + self._output_datafile)
            sys.exit(1)
        class_names_string = ''
        for aname in class_names:
            class_names_string += aname + " "
        class_names_string.rstrip()
        FILE.write("Class names: %s\n\n" % class_names_string ) 
        FILE.write("Feature names and their values:\n")
        features = list( features_and_values_dict.keys() )
        if len(features) == 0:
            print("You probably forgot to call gen_training_data() before " + \
                          "calling write_training_data_to_file()") 
            sys.exit(1)
        for i in range(0, len(features)):
            values = features_and_values_dict[features[i]]
            values_string = ''
            for aname in values:
                values_string += aname + " "
            values_string.rstrip()
            FILE.write("     %(s1)s => %(s2)s\n" % {'s1':features[i], 's2':values_string} )
        FILE.write("\n\nTraining Data:\n\n")
        num_of_columns = len(features) + 2
        field_width = self.find_longest_feature_or_value() + 2
        if field_width < 12: field_width = 12
        title_string = str.ljust( "sample", field_width ) + \
                       str.ljust( "class", field_width )
        features.sort()
        for feature_name in features:
            title_string += str.ljust( str(feature_name), field_width )
        FILE.write(title_string + "\n")
        separator = '=' * len(title_string)
        FILE.write(separator + "\n")
        sample_names = list( training_sample_records.keys() )
        sample_names = sorted( sample_names, key=lambda x: int(x.lstrip('sample_')) )
        record_string = ''
        for sample_name in sample_names:
            sample_name_string = str.ljust(sample_name, field_width)
            record_string += sample_name_string
            record = training_sample_records[sample_name]
            item_parts_dict = {}
            for item in record:
                splits = list( filter(None, re.split(r'=', item)) )
                item_parts_dict[splits[0]] = splits[1]
            record_string += str.ljust(item_parts_dict["class"], field_width)
            del item_parts_dict["class"]
            kees = list(item_parts_dict.keys())
            kees.sort()
            for kee in kees:
                record_string += str.ljust(item_parts_dict[kee], field_width)
            FILE.write(record_string + "\n")
            record_string = ''
        FILE.close()


#------------------------  Generate Your Own Test Data ---------------------

class TestDataGenerator(object):
    def __init__(self, *args, **kwargs ):
        if args:
            raise ValueError(  
                   '''TestDataGenerator can only be called
                      with keyword arguments for the following
                      keywords: parameter_file, output_test_datafile,
                      output_class_labels_file, number_of_test_samples, 
                      write_to_file, debug1, and debug2''') 
        allowed_keys = 'output_test_datafile','parameter_file', \
                       'number_of_test_samples', 'write_to_file', \
                       'output_class_labels_file', 'debug1', 'debug2'
        keywords_used = kwargs.keys()
        for keyword in keywords_used:
            if keyword not in allowed_keys:
                raise ValueError("Wrong keyword used --- check spelling") 
        output_test_datafile = parameter_file = number_of_test_samples = None
        write_to_file = debug1 = debug2 = None
        if 'output_test_datafile' in kwargs : \
                    output_test_datafile = kwargs.pop('output_test_datafile')
        if 'output_class_labels_file' in kwargs : \
             output_class_labels_file =  kwargs.pop('output_class_labels_file')
        if 'parameter_file' in kwargs : \
                           parameter_file = kwargs.pop('parameter_file')
        if 'number_of_test_samples' in kwargs : \
          number_of_test_samples = kwargs.pop('number_of_test_samples')
        if 'write_to_file' in kwargs : \
                                   write_to_file = kwargs.pop('write_to_file')
        if 'debug1' in kwargs  :  debug1 = kwargs.pop('debug1')
        if 'debug2' in kwargs  :  debug2 = kwargs.pop('debug2')
        if output_test_datafile:
            self._output_test_datafile = output_test_datafile
        else:
            raise ValueError('''You must specify an output test datafile''')
        if output_class_labels_file:
            self._output_class_labels_file = output_class_labels_file
        else:
            raise ValueError('''You must specify an output file for class labels''')
        if parameter_file: 
            self._parameter_file =  parameter_file
        else:
            raise ValueError('''You must specify a parameter file''')
        if number_of_test_samples:
            self._number_of_test_samples = number_of_test_samples
        else:
            raise ValueError('''You forgot to specify the number of test samples needed''')
        if write_to_file:
            self._write_to_file = write_to_file
        else:
            self._write_to_file = 0          
        if debug1: self._debug1 = debug1
        else: self._debug1 = 0
        if debug2: self._debug2 = debug2
        else: self._debug2 = 0
        self._test_sample_records         = {}
        self._features_and_values_dict    = {}
        self._bias_dict                   = {}
        self._class_names                 = []
        self._class_priors                = []

    # Read the parameter file for generating the TEST data
    def read_parameter_file( self ):
        debug1 = self._debug1
        debug2 = self._debug2
        write_to_file = self._write_to_file
        number_of_test_samples = self._number_of_test_samples
        input_parameter_file = self._parameter_file
        all_params = []
        param_string = ''
        try:
            FILE = open(input_parameter_file, 'r') 
        except IOError:
            print("Unable to open file: " + input_parameter_file)
            sys.exit(1)
        all_params = FILE.read()

        all_params = re.split(r'\n', all_params)
        FILE.close()
        pattern = r'^(?![ ]*#)'
        try:
            regex = re.compile( pattern )
        except:
            print("error in your pattern")
            sys.exit(1)
        all_params = list( filter( regex.search, all_params ) )
        all_params = list( filter( None, all_params ) )
        all_params = [x.rstrip('\n') for x in all_params]
        param_string = ' '.join( all_params )
        pattern = '^\s*class names:(.*?)\s*class priors:(.*?)(feature: .*)'
        m = re.search( pattern, param_string )
        rest_params = m.group(3)
        self._class_names = list( filter(None, re.split(r'\s+', m.group(1))) )
        self._class_priors = list( filter(None, re.split(r'\s+', m.group(2))) )
        pattern = r'(feature:.*?) (bias:.*)'
        m = re.search( pattern, rest_params  )
        feature_string = m.group(1)
        bias_string = m.group(2)
        features_and_values_dict = {}
        features = list( filter( None, re.split( r'(feature[:])', feature_string ) ) )
        for item in features:
            if re.match(r'feature', item): continue
            splits = list( filter(None, re.split(r' ', item)) )
            for i in range(0, len(splits)):
                if i == 0: features_and_values_dict[splits[0]] = []
                else:
                    if re.match( r'values', splits[i] ): continue
                    features_and_values_dict[splits[0]].append( splits[i] )
        self._features_and_values_dict = features_and_values_dict
        bias_dict = {}
        biases = list( filter(None, re.split(r'(bias[:]\s*class[:])', bias_string )) )
        for item in biases:
            if re.match(r'bias', item): continue
            splits = list( filter(None, re.split(r' ', item)) )
            feature_name = ''
            for i in range(0, len(splits)):
                if i == 0:
                    bias_dict[splits[0]] = {}
                elif ( re.search( r'(^.+)[:]$', splits[i] ) ):
                    m = re.search(  r'(^.+)[:]$', splits[i] )
                    feature_name = m.group(1)
                    bias_dict[splits[0]][feature_name] = []
                else:
                    if not feature_name: continue
                    bias_dict[splits[0]][feature_name].append( splits[i] )
        self._bias_dict = bias_dict
        if self._debug1:
            print("\n\n")
            print("Class names: " + str(self._class_names))
            print("\n")
            num_of_classes = len(self._class_names)
            print("Number of classes: " + str(num_of_classes))
            print("\n")
            print("Class priors: " + str(self._class_priors))
            print("\n\n")
            print("Here are the features and their possible values:")
            print("\n")
            items = self._features_and_values_dict.items()
            for item in items:
                print(item[0] + " ===> " + str(item[1]))
            print("\n")
            print("Here is the biasing for each class:")
            print("\n")            
            items = self._bias_dict.items()
            for item in items:
                print("\n")
                print(item[0])
                items2 = list( item[1].items() )
                for i in range(0, len(items2)):
                    print( items2[i])

    def gen_test_data( self ):
        class_names = self._class_names
        class_priors = self._class_priors
        test_sample_records = {}
        features_and_values_dict = self._features_and_values_dict
        bias_dict  = self._bias_dict
        how_many_test_samples = self._number_of_test_samples
        file_for_class_labels = self._output_class_labels_file
        class_priors_to_unit_interval_map = {}
        accumulated_interval = 0
        for i in range(0, len(class_names)):
            class_priors_to_unit_interval_map[class_names[i]] = \
            (accumulated_interval, accumulated_interval+float(class_priors[i]))
            accumulated_interval += float(class_priors[i])
        if self._debug1:
            print("Mapping of class priors to unit interval:")
            print("\n")
            items = class_priors_to_unit_interval_map.items()
            for item in items:
                print(item[0] + " ===> " + str(item[1]))
        class_and_feature_based_value_priors_to_unit_interval_map = {}
        for class_name  in class_names:
            class_and_feature_based_value_priors_to_unit_interval_map[class_name] = {}
            for feature in features_and_values_dict.keys():
                class_and_feature_based_value_priors_to_unit_interval_map[class_name][feature] = {}
        for class_name  in class_names:
            for feature in features_and_values_dict.keys():
                values = features_and_values_dict[feature]
                if len(bias_dict[class_name][feature]) > 0:
                    bias_string = bias_dict[class_name][feature][0]
                else:
                    no_bias = 1.0 / len(values)
                    bias_string = values[0] +  "=" + str(no_bias)
                value_priors_to_unit_interval_map = {}
                splits = list( filter( None, re.split(r'\s*=\s*', bias_string) ) )
                chosen_for_bias_value = splits[0]
                chosen_bias = splits[1]
                remaining_bias = 1 - float(chosen_bias)
                remaining_portion_bias = remaining_bias / (len(values) -1)
                accumulated = 0;
                for i in range(0, len(values)):
                    if (values[i] == chosen_for_bias_value):
                        value_priors_to_unit_interval_map[values[i]] = \
                          [accumulated, accumulated + float(chosen_bias)]
                        accumulated += float(chosen_bias)
                    else:
                        value_priors_to_unit_interval_map[values[i]] = \
                          [accumulated, accumulated + remaining_portion_bias]
                        accumulated += remaining_portion_bias
                class_and_feature_based_value_priors_to_unit_interval_map[class_name][feature] = value_priors_to_unit_interval_map
                if self._debug1:
                    print("\n")
                    print("For class " + class_name + \
                       ": Mapping feature value priors for feature '" + \
                       feature + "' to unit interval:")
                    print("\n")
                    items = value_priors_to_unit_interval_map.items()
                    for item in items:
                        print("    " + item[0] + " ===> " + str(item[1]))
        ele_index = 0
        while (ele_index < how_many_test_samples):
            sample_name = "sample" + "_" + str(ele_index)
            test_sample_records[sample_name] = []
            # Generate class label for this test sample:                
            import random
            ran = random.Random()
            roll_the_dice  = ran.randint(0,1000) / 1000.0
            class_label = ''
            for class_name  in class_priors_to_unit_interval_map.keys():
                v = class_priors_to_unit_interval_map[class_name]
                if ( (roll_the_dice >= v[0]) and (roll_the_dice <= v[1]) ):
                    test_sample_records[sample_name].append( 
                                             "class=" + class_name )
                    class_label = class_name
                    break
            for feature in sorted(list(features_and_values_dict.keys())):
                roll_the_dice  = ran.randint(0,1000) / 1000.0
                value_label = ''
                value_priors_to_unit_interval_map = \
                  class_and_feature_based_value_priors_to_unit_interval_map[class_label][feature]
                for value_name in value_priors_to_unit_interval_map.keys():
                    v = value_priors_to_unit_interval_map[value_name]
                    if ( (roll_the_dice >= v[0]) and (roll_the_dice <= v[1]) ):
                        test_sample_records[sample_name].append( \
                                            feature + "=" + value_name )
                        value_label = value_name;
                        break
            ele_index += 1
        self._test_sample_records = test_sample_records
        if self._debug1:
            print("\n\n")
            print("TERMINAL DISPLAY OF TEST RECORDS:")
            print("\n\n")
            sample_names = test_sample_records.keys()
            sample_names = sorted(sample_names, key=lambda x: int(x.lstrip('sample_')))
            for sample_name in sample_names:
                print(sample_name + " => " + \
                                 str(test_sample_records[sample_name]))

    def find_longest_value(self):
        features_and_values_dict = self._features_and_values_dict
        max_length = 0
        for feature in features_and_values_dict.keys():
            values = features_and_values_dict[feature]
            for value in values:
                if not max_length:
                    max_length = len(str(value))
                if len(str(value)) > max_length:
                    max_length = len(str(value)) 
        return max_length

    def write_test_data_to_file(self):
        features_and_values_dict = self._features_and_values_dict
        class_names = self._class_names
        output_file = self._output_test_datafile
        test_sample_records = self._test_sample_records
        try:
            FILE = open(self._output_test_datafile, 'w') 
        except IOError:
            print("Unable to open file: " + self._output_test_datafile)
            sys.exit(1)
        try:
            FILE2 = open(self._output_class_labels_file, 'w') 
        except IOError:
            print("Unable to open file: " + self._output_class_labels_file)
            sys.exit(1)
        header = '''
# REQUIRED LINE FOLLOWS (the first uncommented line below):
# This line shown below must begin with the string 
#
#             "Feature Order For Data:"  
#
# What comes after this string can be any number of feature labels.  
# The feature values shown in the table in the rest of the file will 
# be considered to be in same order as shown in the next line.
                '''
        FILE.write(header + "\n\n\n")       
        title_string = "Feature Order For Data: "
        features = list(features_and_values_dict.keys())
        features.sort()
        for feature_name in features:
            title_string += str(feature_name) + " "
        title_string.rstrip()
        FILE.write(title_string + "\n\n")
        num_of_columns = len(features) + 1
        field_width = self.find_longest_value() + 2
        sample_names = test_sample_records.keys()
        sample_names = sorted(sample_names, key=lambda x: int(x.lstrip('sample_')))
        record_string = ''
        for sample_name in sample_names:
            sample_name_string = str.ljust(sample_name, field_width)
            record_string += sample_name_string
            record = test_sample_records[sample_name]
            item_parts_dict = {}
            for item in record:
                splits = list( filter(None, re.split(r'=', item)) )
                item_parts_dict[splits[0]] = splits[1]
            label_string = sample_name + " " + item_parts_dict["class"]
            FILE2.write(label_string + "\n")
            del item_parts_dict["class"]
            kees = list(item_parts_dict.keys())
            kees.sort()
            for kee in kees:
                record_string += str.ljust(item_parts_dict[kee], field_width)
            FILE.write(record_string + "\n")
            record_string = ''
        FILE.close()
        FILE2.close

#-----------------------------   Class Node  ----------------------------

# The nodes of the decision tree are instances of this class:
class Node(object):

    nodes_created = -1

    def __init__(self, feature, entropy, class_probabilities, \
                                                branch_features_and_values):
        self._serial_number = self.get_next_serial_num()
        self._feature       = feature
        self._entropy       = entropy
        self._class_probabilities = class_probabilities
        self._branch_features_and_values = branch_features_and_values
        self._linked_to = []

    def get_next_serial_num(self):
        Node.nodes_created += 1
        return Node.nodes_created

    def get_serial_num(self):
        return self._serial_number

    # This is a class method:
    @staticmethod
    def how_many_nodes():
        return Node.nodes_created

    # this returns the feature test at the current node
    def get_feature(self):
        return self._feature

    def set_feature(self, feature):
        self._feature = feature

    def get_entropy(self):
        return self._entropy

    def get_class_probabilities(self):
        return self._class_probabilities

    def get_branch_features_and_values(self):
        return self._branch_features_and_values

    def add_to_branch_features_and_values(self, feature_and_value):
        self._branch_features_and_values.append(feature_and_value)

    def get_children(self):
        return self._linked_to

    def add_child_link(self, new_node):
        self._linked_to.append(new_node)                  

    def delete_all_links(self):
        self._linked_to = None

    def display_node(self):
        feature_at_node = self.get_feature() or " "
        entropy_at_node = self.get_entropy()
        class_probabilities = self.get_class_probabilities()
        serial_num = self.get_serial_num()
        branch_features_and_values = self.get_branch_features_and_values()
        print("\n\nNODE " + str(serial_num) + ":\n   Branch features and values to this \
node: " + str(branch_features_and_values) + "\n   Class probabilities at \
current node: " + str(class_probabilities) + "\n   Entropy at current \
node: " + str(entropy_at_node) + "\n   Best feature test at current \
node: " + feature_at_node + "\n\n")

    def display_decision_tree(self, offset):
        serial_num = self.get_serial_num()
        if len(self.get_children()) > 0:
            feature_at_node = self.get_feature() or " "
            entropy_at_node = self.get_entropy()
            class_probabilities = self.get_class_probabilities()
            print("NODE " + str(serial_num) + ":  " + offset +  "feature: " + feature_at_node \
+ "   entropy: " + str(entropy_at_node) + "   class probs: " + str(class_probabilities) + "\n")
            offset += "   "
            for child in self.get_children():
                child.display_decision_tree(offset)
        else:
            entropy_at_node = self.get_entropy()
            class_probabilities = self.get_class_probabilities()
            print("NODE " + str(serial_num) + ": " + offset + "   entropy: " \
+ str(entropy_at_node) + "    class probs: " + str(class_probabilities) + "\n")


#----------------------------  Test Code Follows  -----------------------

if __name__ == '__main__':

    dt = DecisionTree( training_datafile = "training.dat",  
                        max_depth_desired = 2,
                        entropy_threshold = 0.1,
                        debug1 = 1,
                     )
    dt.get_training_data()

    dt.show_training_data()

    prob = dt.prior_probability_for_class( 'benign' )
    print("prior for benign: ", prob)
    prob = dt.prior_probability_for_class( 'malignant' )
    print("prior for malignant: ", prob)

    prob = dt.probability_for_feature_value( 'smoking', 'heavy')
    print(prob)

    dt.determine_data_condition()

    root_node = dt.construct_decision_tree_classifier()
    root_node.display_decision_tree("   ")

    test_sample = ['exercising=>never', 'smoking=>heavy', 'fatIntake=>heavy', 'videoAddiction=>heavy']
    classification = dt.classify(root_node, test_sample)
    print("Classification: ", classification)

    test_sample = ['videoAddiction=>none', 'exercising=>occasionally', 'smoking=>never', 'fatIntake=>medium']
    classification = dt.classify(root_node, test_sample)
    print("Classification: ", classification)

    print("Number of nodes created: ", root_node.how_many_nodes())