# -*- coding: utf-8 -*-

__version__   = '2.5.5'
__author__    = "Avinash Kak (kak@purdue.edu)"
__date__      = '2025-May-28'                   
__url__       = 'https://engineering.purdue.edu/kak/distDLS/DLStudio-2.5.5.html'
__copyright__ = "(C) 2025 Avinash Kak. Python Software Foundation."


import sys,os,os.path,glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision                  
import torchvision.transforms as tvt
import torch.optim as optim
import numpy as np
from PIL import ImageFilter
from PIL import Image
import numbers
import re
import math
import random
import copy
import matplotlib.pyplot as plt
import gzip
import pickle
import pymsgbox
import time
import logging

import torchmetrics                                                             ##  for VQGAN
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity       ##  for VQGAN


## Python does not have a decorator for declaring static vars.  But you can use
## the following for achieving the same effect.  I believe I saw it at stackoverflow.com:
def static_var(varname, value):
    def decorate(func):
        setattr(func, varname, value)
        return func
    return decorate

#______________________________  DLStudio Class Definition  ________________________________

class DLStudio(object):

    def __init__(self, *args, **kwargs ):
        if args:
            raise ValueError(  
                   '''DLStudio constructor can only be called with keyword arguments for 
                      the following keywords: epochs, learning_rate, batch_size, momentum,
                      convo_layers_config, image_size, dataroot, path_saved_model, classes, 
                      image_size, convo_layers_config, fc_layers_config, debug_train, use_gpu, and 
                      debug_test''')
        learning_rate = epochs = batch_size = convo_layers_config = momentum = None
        image_size = fc_layers_config = dataroot =  path_saved_model = classes = use_gpu = None
        debug_train  = debug_test = None
        if 'dataroot' in kwargs                      :   dataroot = kwargs.pop('dataroot')
        if 'learning_rate' in kwargs                 :   learning_rate = kwargs.pop('learning_rate')
        if 'momentum' in kwargs                      :   momentum = kwargs.pop('momentum')
        if 'epochs' in kwargs                        :   epochs = kwargs.pop('epochs')
        if 'batch_size' in kwargs                    :   batch_size = kwargs.pop('batch_size')
        if 'convo_layers_config' in kwargs           :   convo_layers_config = kwargs.pop('convo_layers_config')
        if 'image_size' in kwargs                    :   image_size = kwargs.pop('image_size')
        if 'fc_layers_config' in kwargs              :   fc_layers_config = kwargs.pop('fc_layers_config')
        if 'path_saved_model' in kwargs              :   path_saved_model = kwargs.pop('path_saved_model')
        if 'classes' in kwargs                       :   classes = kwargs.pop('classes') 
        if 'use_gpu' in kwargs                       :   use_gpu = kwargs.pop('use_gpu') 
        if 'debug_train' in kwargs                   :   debug_train = kwargs.pop('debug_train') 
        if 'debug_test' in kwargs                    :   debug_test = kwargs.pop('debug_test') 
        if len(kwargs) != 0: raise ValueError('''You have provided unrecognizable keyword args''')
        if dataroot:
            self.dataroot = dataroot
        if convo_layers_config:
            self.convo_layers_config = convo_layers_config
        if image_size:
            self.image_size = image_size
        if fc_layers_config:
            self.fc_layers_config = fc_layers_config
            if fc_layers_config[0] != -1:
                raise Exception("""\n\n\nYour 'fc_layers_config' construction option is not correct. """
                                """The first element of the list of nodes in the fc layer must be -1 """
                                """because the input to fc will be set automatically to the size of """
                                """the final activation volume of the convolutional part of the network""")
        if  path_saved_model:
            self.path_saved_model = path_saved_model
        if classes:
            self.class_labels = classes
        if learning_rate:
            self.learning_rate = learning_rate
        else:
            self.learning_rate = 1e-6
        if momentum:
            self.momentum = momentum
        if epochs:
            self.epochs = epochs
        if batch_size:
            self.batch_size = batch_size
        if use_gpu is not None:
            self.use_gpu = use_gpu
            if use_gpu is True:
                if torch.cuda.is_available():
                    self.device = torch.device("cuda:0")
                else: 
                    self.device = torch.device("cpu")
        else:
            self.device = torch.device("cpu")
        if debug_train:                             
            self.debug_train = debug_train
        else:
            self.debug_train = 0
        if debug_test:                             
            self.debug_test = debug_test
        else:
            self.debug_test = 0
        self.debug_config = 0

    def parse_config_string_for_convo_layers(self):
        '''
        Each collection of 'n' otherwise identical layers in a convolutional network is 
        specified by a string that looks like:

                                 "nx[a,b,c,d]-MaxPool(k)"
        where 
                n      =  num of this type of convo layer
                a      =  number of out_channels                      [in_channels determined by prev layer] 
                b,c    =  kernel for this layer is of size (b,c)      [b along height, c along width]
                d      =  stride for convolutions
                k      =  maxpooling over kxk patches with stride of k

        Example:
                     "n1x[a1,b1,c1,d1]-MaxPool(k1)  n2x[a2,b2,c2,d2]-MaxPool(k2)"
        '''
        configuration = self.convo_layers_config
        configs = configuration.split()
        all_convo_layers = []
        image_size_after_layer = self.image_size
        for k,config in enumerate(configs):
            two_parts = config.split('-')
            how_many_conv_layers_with_this_config = int(two_parts[0][:config.index('x')])
            if self.debug_config:
                print("\n\nhow many convo layers with this config: %d" % how_many_conv_layers_with_this_config)
            maxpooling_size = int(re.findall(r'\d+', two_parts[1])[0])
            if self.debug_config:
                print("\nmax pooling size for all convo layers with this config: %d" % maxpooling_size)
            for conv_layer in range(how_many_conv_layers_with_this_config):            
                convo_layer = {'out_channels':None, 
                               'kernel_size':None, 
                               'convo_stride':None, 
                               'maxpool_size':None,
                               'maxpool_stride': None}
                kernel_params = two_parts[0][config.index('x')+1:][1:-1].split(',')
                if self.debug_config:
                    print("\nkernel_params: %s" % str(kernel_params))
                convo_layer['out_channels'] = int(kernel_params[0])
                convo_layer['kernel_size'] = (int(kernel_params[1]), int(kernel_params[2]))
                convo_layer['convo_stride'] =  int(kernel_params[3])
                image_size_after_layer = [x // convo_layer['convo_stride'] for x in image_size_after_layer]
                convo_layer['maxpool_size'] = maxpooling_size
                convo_layer['maxpool_stride'] = maxpooling_size
                image_size_after_layer = [x // convo_layer['maxpool_size'] for x in image_size_after_layer]
                all_convo_layers.append(convo_layer)
        configs_for_all_convo_layers = {i : all_convo_layers[i] for i in range(len(all_convo_layers))}
        if self.debug_config:
            print("\n\nAll convo layers: %s" % str(configs_for_all_convo_layers))
        last_convo_layer = configs_for_all_convo_layers[len(all_convo_layers)-1]
        out_nodes_final_layer = image_size_after_layer[0] * image_size_after_layer[1] * \
                                                                      last_convo_layer['out_channels']
        self.fc_layers_config[0] = out_nodes_final_layer
        self.configs_for_all_convo_layers = configs_for_all_convo_layers
        return configs_for_all_convo_layers


    def build_convo_layers(self, configs_for_all_convo_layers):
        conv_layers = nn.ModuleList()
        in_channels_for_next_layer = None
        for layer_index in configs_for_all_convo_layers:
            if self.debug_config:
                print("\n\n\nLayer index: %d" % layer_index)
            in_channels = 3 if layer_index == 0 else in_channels_for_next_layer
            out_channels = configs_for_all_convo_layers[layer_index]['out_channels']
            kernel_size = configs_for_all_convo_layers[layer_index]['kernel_size']
            padding = tuple((k-1) // 2 for k in kernel_size)
            stride       = configs_for_all_convo_layers[layer_index]['convo_stride']
            maxpool_size = configs_for_all_convo_layers[layer_index]['maxpool_size']
            if self.debug_config:
                print("\n     in_channels=%d   out_channels=%d    kernel_size=%s     stride=%s    \
                maxpool_size=%s" % (in_channels, out_channels, str(kernel_size), str(stride), 
                str(maxpool_size)))
            conv_layers.append( nn.Conv2d( in_channels,out_channels,kernel_size,stride=stride,padding=padding) )
            conv_layers.append( nn.MaxPool2d( maxpool_size ) )
            conv_layers.append( nn.ReLU() ),
            in_channels_for_next_layer = out_channels
        return conv_layers

    def build_fc_layers(self):
        fc_layers = nn.ModuleList()
        for layer_index in range(len(self.fc_layers_config) - 1):
            fc_layers.append( nn.Linear( self.fc_layers_config[layer_index], 
                                                                self.fc_layers_config[layer_index+1] ) )
        return fc_layers            

    def load_cifar_10_dataset(self):       
        '''
        In the code shown below, the call to "ToTensor()" converts the usual int range 0-255 for pixel 
        values to 0-1.0 float vals and then the call to "Normalize()" changes the range to -1.0-1.0 float 
        vals. For additional explanation of the call to "tvt.ToTensor()", see Slide 31 of my Week 2 
        slides at the DL course website.  And see Slides 32 and 33 for the syntax 
        "tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))".  In this call, the three numbers in the
        first tuple change the means in the three color channels and the three numbers in the second 
        tuple change the standard deviations according to the formula:

                 image_channel_val = (image_channel_val - mean) / std

        The end result is that the values in the image tensor will be normalized to fall between -1.0 
        and +1.0. If needed we can do inverse normalization  by

                 image_channel_val  =   (image_channel_val * std) + mean
        '''

        ##   But then the call to Normalize() changes the range to -1.0-1.0 float vals.
        transform = tvt.Compose([tvt.ToTensor(),
                                 tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])    ## accuracy: 51%
        ##  Define where the training and the test datasets are located:
        train_data_loc = torchvision.datasets.CIFAR10(root=self.dataroot, train=True, download=True, transform=transform)
        test_data_loc = torchvision.datasets.CIFAR10(root=self.dataroot, train=False, download=True, transform=transform)
        ##  Now create the data loaders:
        self.train_data_loader = torch.utils.data.DataLoader(train_data_loc,batch_size=self.batch_size, shuffle=True, num_workers=2)
        self.test_data_loader = torch.utils.data.DataLoader(test_data_loc,batch_size=self.batch_size, shuffle=False, num_workers=2)

    def load_cifar_10_dataset_with_augmentation(self):             
        '''
        In general, we want to do data augmentation for training:
        '''
        transform_train = tvt.Compose([
                                  tvt.RandomCrop(32, padding=4),
                                  tvt.RandomHorizontalFlip(),
                                  tvt.ToTensor(),
                                  tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])        
        ##  Don't need any augmentation for the test data: 
        transform_test = tvt.Compose([
                               tvt.ToTensor(),
                               tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        ##  Define where the training and the test datasets are located
        train_data_loc = torchvision.datasets.CIFAR10( root=self.dataroot, train=True, download=True, transform=transform_train )
        test_data_loc = torchvision.datasets.CIFAR10(  root=self.dataroot, train=False, download=True, transform=transform_test )
        ##  Now create the data loaders:
        self.train_data_loader = torch.utils.data.DataLoader(train_data_loc, batch_size=self.batch_size, shuffle=True, num_workers=2)
        self.test_data_loader = torch.utils.data.DataLoader(test_data_loc, batch_size=self.batch_size, shuffle=False, num_workers=2)

    def imshow(self, img):
        '''
        called by display_tensor_as_image() for displaying the image
        '''
        img = img / 2 + 0.5     # unnormalize
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        plt.show()

    class Net(nn.Module):
        def __init__(self, convo_layers, fc_layers):
            super(DLStudio.Net, self).__init__()
            self.my_modules_convo = convo_layers
            self.my_modules_fc = fc_layers
        def forward(self, x):
            for m in self.my_modules_convo:
                x = m(x)
            x = x.view(x.shape[0], -1)
            for m in self.my_modules_fc:
                x = m(x)
            return x


    def run_code_for_training(self, net, display_images=False):        
        filename_for_out = "performance_numbers_" + str(self.epochs) + ".txt"
        FILE = open(filename_for_out, 'w')
        net = copy.deepcopy(net)
        net = net.to(self.device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=self.learning_rate, momentum=self.momentum)
        print("\n\nStarting training loop...")
        start_time = time.perf_counter()
        loss_tally = []
        elapsed_time = 0.0
        for epoch in range(self.epochs):  
            print("")
            running_loss = 0.0
            for i, data in enumerate(self.train_data_loader):
                inputs, labels = data
                if i % 1000 == 999:
                    current_time = time.perf_counter()
                    elapsed_time = current_time - start_time 
                    print("\n\n[epoch:%d/%d  iter=%4d  elapsed_time=%5d secs]   Ground Truth:     " % 
                          (epoch+1, self.epochs, i+1, elapsed_time) + 
                          ' '.join('%10s' % self.class_labels[labels[j]] for j in range(self.batch_size)))
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                ##  Since PyTorch likes to construct dynamic computational graphs, we need to
                ##  zero out the previously calculated gradients for the learnable parameters:
                optimizer.zero_grad()
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                running_loss += loss.item()
                if i % 1000 == 999:
                    _, predicted = torch.max(outputs.data, 1)
                    print("[epoch:%d/%d  iter=%4d  elapsed_time=%5d secs]   Predicted Labels: " % 
                     (epoch+1, self.epochs, i+1, elapsed_time ) +
                     ' '.join('%10s' % self.class_labels[predicted[j]] for j in range(self.batch_size)))
                    avg_loss = running_loss / float(1000)
                    loss_tally.append(avg_loss)
                    print("[epoch:%d/%d  iter=%4d  elapsed_time=%5d secs]   Loss: %.3f" % 
                                                                   (epoch+1, self.epochs, i+1, elapsed_time, avg_loss))    
                    FILE.write("%.3f\n" % avg_loss)
                    FILE.flush()
                    running_loss = 0.0
                    if display_images:
                        logger = logging.getLogger()
                        old_level = logger.level
                        logger.setLevel(100)
                        plt.figure(figsize=[6,3])
                        plt.imshow(np.transpose(torchvision.utils.make_grid(inputs,  normalize=False, padding=3, pad_value=255).cpu(), (1,2,0)))
                        plt.show()
                        logger.setLevel(old_level)
                loss.backward()
                optimizer.step()
        print("\nFinished Training\n")
        self.save_model(net)
        plt.figure(figsize=(10,5))
        plt.title("Labeling Loss vs. Iterations")
        plt.plot(loss_tally)
        plt.xlabel("iterations")
        plt.ylabel("loss")
        plt.legend(["Plot of loss versus iterations"], fontsize="x-large")
        plt.savefig("loss_vs_iterations.png")
        plt.show()


    def display_tensor_as_image(self, tensor, title=""):
        '''
        This method converts the argument tensor into a photo image that you can display
        in your terminal screen. It can convert tensors of three different shapes
        into images: (3,H,W), (1,H,W), and (H,W), where H, for height, stands for the
        number of pixels in the vertical direction and W, for width, for the same
        along the horizontal direction.  When the first element of the shape is 3,
        that means that the tensor represents a color image in which each pixel in
        the (H,W) plane has three values for the three color channels.  On the other
        hand, when the first element is 1, that stands for a tensor that will be
        shown as a grayscale image.  And when the shape is just (H,W), that is
        automatically taken to be for a grayscale image.
        '''
        tensor_range = (torch.min(tensor).item(), torch.max(tensor).item())
        if tensor_range == (-1.0,1.0):
            ##  The tensors must be between 0.0 and 1.0 for the display:
            print("\n\n\nimage un-normalization called")
            tensor = tensor/2.0 + 0.5     # unnormalize
        plt.figure(title)
        ###  The call to plt.imshow() shown below needs a numpy array. We must also
        ###  transpose the array so that the number of channels (the same thing as the
        ###  number of color planes) is in the last element.  For a tensor, it would be in
        ###  the first element.
        if tensor.shape[0] == 3 and len(tensor.shape) == 3:
            plt.imshow( tensor.numpy().transpose(1,2,0) )
        ###  If the grayscale image was produced by calling torchvision.transform's
        ###  ".ToPILImage()", and the result converted to a tensor, the tensor shape will
        ###  again have three elements in it, however the first element that stands for
        ###  the number of channels will now be 1
        elif tensor.shape[0] == 1 and len(tensor.shape) == 3:
            tensor = tensor[0,:,:]
            plt.imshow( tensor.numpy(), cmap = 'gray' )
        ###  For any one color channel extracted from the tensor representation of a color
        ###  image, the shape of the tensor will be (W,H):
        elif len(tensor.shape) == 2:
            plt.imshow( tensor.numpy(), cmap = 'gray' )
        else:
            sys.exit("\n\n\nfrom 'display_tensor_as_image()': tensor for image is ill formed -- aborting")
        plt.show()

    def check_a_sampling_of_images(self):
        '''
        Displays the first batch_size number of images in your dataset.
        '''
        dataiter = iter(self.train_data_loader)
        images, labels = dataiter.next()
        # Since negative pixel values make no sense for display, setting the 'normalize' 
        # option to True will change the range back from (-1.0,1.0) to (0.0,1.0):
        self.display_tensor_as_image(torchvision.utils.make_grid(images, normalize=True))
        # Print class labels for the images shown:
        print(' '.join('%5s' % self.class_labels[labels[j]] for j in range(self.batch_size)))

    def save_model(self, model):
        '''
        Save the trained model to a disk file
        '''
        torch.save(model.state_dict(), self.path_saved_model)


    def run_code_for_testing(self, net, display_images=False):
        net.load_state_dict(torch.load(self.path_saved_model))
        net = net.eval()
        net = net.to(self.device)
        ##  In what follows, in addition to determining the predicted label for each test
        ##  image, we will also compute some stats to measure the overall performance of
        ##  the trained network.  This we will do in two different ways: For each class,
        ##  we will measure how frequently the network predicts the correct labels.  In
        ##  addition, we will compute the confusion matrix for the predictions.
        filename_for_results = "classification_results_" + str(self.epochs) + ".txt"
        FILE = open(filename_for_results, 'w')
        correct = 0
        total = 0
        confusion_matrix = torch.zeros(len(self.class_labels), len(self.class_labels))
        class_correct = [0] * len(self.class_labels)
        class_total = [0] * len(self.class_labels)
        with torch.no_grad():
            for i,data in enumerate(self.test_data_loader):
                ##  data is set to the images and the labels for one batch at a time:
                images, labels = data
                images = images.to(self.device)
                labels = labels.to(self.device)
                if i % 1000 == 999:
                    print("\n\n[i=%d:] Ground Truth:     " % (i+1) + ' '.join('%5s' % self.class_labels[labels[j]] for j in range(self.batch_size)))
                outputs = net(images)
                ##  max() returns two things: the max value and its index in the 10 element
                ##  output vector.  We are only interested in the index --- since that is 
                ##  essentially the predicted class label:
                _, predicted = torch.max(outputs.data, 1)#
                if i % 1000 == 999:
                    print("[i=%d:] Predicted Labels: " % (i+1) + ' '.join('%5s' % self.class_labels[predicted[j]] for j in range(self.batch_size)))
                    logger = logging.getLogger()
                    old_level = logger.level
                    if display_images:
                        logger.setLevel(100)
                        plt.figure(figsize=[6,3])
                        plt.imshow(np.transpose(torchvision.utils.make_grid(images, normalize=False, padding=3, pad_value=255).cpu(), (1,2,0)))
                        plt.show()
                        logger.setLevel(old_level)
                for label,prediction in zip(labels,predicted):
                        confusion_matrix[label][prediction] += 1
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                ##  comp is a list of size batch_size of "True" and "False" vals
                comp = predicted == labels       
                for j in range(self.batch_size):
                    label = labels[j]
                    ##  The following works because, in a numeric context, the boolean value
                    ##  "False" is the same as number 0 and the boolean value True is the 
                    ##  same as number 1. For that reason "4 + True" will evaluate to 5 and
                    ##  "4 + False" will evaluate to 4.  Also, "1 == True" evaluates to "True"
                    ##  "1 == False" evaluates to "False".  However, note that "1 is True" 
                    ##  evaluates to "False" because the operator "is" does not provide a 
                    ##  numeric context for "True". And so on.  In the statement that follows,
                    ##  while  c[j].item() will either return "False" or "True", for the 
                    ##  addition operator, Python will use the values 0 and 1 instead.
                    class_correct[label] += comp[j].item()
                    class_total[label] += 1
        for j in range(len(self.class_labels)):
            print('Prediction accuracy for %5s : %2d %%' % (self.class_labels[j], 100 * class_correct[j] / class_total[j]))
            FILE.write('\n\nPrediction accuracy for %5s : %2d %%\n' % (self.class_labels[j], 100 * class_correct[j] / class_total[j]))
        print("\n\n\nOverall accuracy of the network on the 10000 test images: %d %%" % (100 * correct / float(total)))
        FILE.write("\n\n\nOverall accuracy of the network on the 10000 test images: %d %%\n" % (100 * correct / float(total)))
        print("\n\nDisplaying the confusion matrix:\n")
        FILE.write("\n\nDisplaying the confusion matrix:\n\n")
        out_str = "         "
        for j in range(len(self.class_labels)):  out_str +=  "%7s" % self.class_labels[j]   
        print(out_str + "\n")
        FILE.write(out_str + "\n\n")
        for i,label in enumerate(self.class_labels):
            out_percents = [100 * confusion_matrix[i,j] / float(class_total[i]) 
                                                      for j in range(len(self.class_labels))]
            out_percents = ["%.2f" % item.item() for item in out_percents]
            out_str = "%6s:  " % self.class_labels[i]
            for j in range(len(self.class_labels)): out_str +=  "%7s" % out_percents[j]
            print(out_str)
            FILE.write(out_str + "\n")
        FILE.close()        


    ###%%%
    #####################################################################################################################
    #############################  Start Definition of Inner Class ExperimentsWithSequential ############################

    class ExperimentsWithSequential(nn.Module):                                
        """
        Demonstrates how to use the torch.nn.Sequential container class

        Class Path:  DLStudio  ->  ExperimentsWithSequential    
        """
        def __init__(self, dl_studio ):
            super(DLStudio.ExperimentsWithSequential, self).__init__()
            self.dl_studio = dl_studio

        def load_cifar_10_dataset(self):       
            self.dl_studio.load_cifar_10_dataset()

        def load_cifar_10_dataset_with_augmentation(self):             
            self.dl_studio.load_cifar_10_dataset_with_augmentation()

        class Net(nn.Module):
            """
            To see if the DLStudio class would work with any network that a user may
            want to experiment with, I copy-and-pasted the network shown below
            from Zhenye's GitHub blog: https://github.com/Zhenye-Na/blog

            Class Path:  DLStudio  ->  ExperimentsWithSequential  ->  Net
            """
            def __init__(self):
                super(DLStudio.ExperimentsWithSequential.Net, self).__init__()
                self.conv_seqn = nn.Sequential(
                    # Conv Layer block 1:
                    nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
                    nn.BatchNorm2d(32),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    # Conv Layer block 2:
                    nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
                    nn.BatchNorm2d(128),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Dropout2d(p=0.05),
                    # Conv Layer block 3:
                    nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
                    nn.BatchNorm2d(256),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                )
                self.fc_seqn = nn.Sequential(
                    nn.Dropout(p=0.1),
                    nn.Linear(4096, 1024),
                    nn.ReLU(inplace=True),
                    nn.Linear(1024, 512),
                    nn.ReLU(inplace=True),
                    nn.Dropout(p=0.1),
                    nn.Linear(512, 10)
                )
    
            def forward(self, x):
                x = self.conv_seqn(x)
                # flatten
                x = x.view(x.shape[0], -1)
                x = self.fc_seqn(x)
                return x

        def run_code_for_training(self, net):        
            self.dl_studio.run_code_for_training(net)

        def save_model(self, model):
            '''
            Save the trained model to a disk file
            '''
            torch.save(model.state_dict(), self.dl_studio.path_saved_model)

        def run_code_for_testing(self, model):
            self.dl_studio.run_code_for_testing(model)


    ###%%%
    #####################################################################################################################
    ###############################  Start Definition of Inner Class ExperimentsWithCIFAR ###############################

    class ExperimentsWithCIFAR(nn.Module):              
        """
        Class Path:  DLStudio  ->  ExperimentsWithCIFAR
        """

        def __init__(self, dl_studio ):
            super(DLStudio.ExperimentsWithCIFAR, self).__init__()
            self.dl_studio = dl_studio

        def load_cifar_10_dataset(self):       
            self.dl_studio.load_cifar_10_dataset()

        def load_cifar_10_dataset_with_augmentation(self):             
            self.dl_studio.load_cifar_10_dataset_with_augmentation()

        ##  You can instantiate two different types of networks when experimenting with 
        ##  the inner class ExperimentsWithCIFAR.  The network shown below is from the 
        ##  PyTorch tutorial
        ##
        ##     https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
        ##
        class Net(nn.Module):
            """
            Class Path:  DLStudio  ->  ExperimentsWithCIFAR  ->  Net
            """
            def __init__(self):
                super(DLStudio.ExperimentsWithCIFAR.Net, self).__init__()
                self.conv1 = nn.Conv2d(3, 6, 5)
                self.conv2 = nn.Conv2d(6, 16, 5)
                self.fc1 = nn.Linear(16 * 5 * 5, 120)
                self.fc2 = nn.Linear(120, 84)
                self.fc3 = nn.Linear(84, 10)
        
            def forward(self, x):
                x = nn.MaxPool2d(2,2)(F.relu(self.conv1(x)))
                x = nn.MaxPool2d(2,2)(F.relu(self.conv2(x)))
                x  =  x.view( x.shape[0], - 1 )
                x = F.relu(self.fc1(x))
                x = F.relu(self.fc2(x))
                x = self.fc3(x)
                return x

        ##  Instead of using the network shown above, you can also use the network shown below.
        ##  if you are playing with the ExperimentsWithCIFAR inner class. If that's what you
        ##  want to do, in the script "playing_with_cifar10.py" in the Examples directory,
        ##  you will need to replace the statement
        ##                          model = exp_cifar.Net()
        ##  by the statement
        ##                          model = exp_cifar.Net2()        
        ##
        class Net2(nn.Module):
            """
            Class Path:  DLStudio  ->  ExperimentsWithCIFAR  ->  Net2
            """
            def __init__(self):
                """
                I created this network class just to see if it was possible to simply calculate
                the size of the first of the fully connected layers from strides in the convo
                layers up to that point and from the out_channels used in the top-most convo 
                layer.   In what you see below, I am keeping track of all the strides by pushing 
                them into the array 'strides'.  Subsequently, in the formula shown in line (A),
                I use the product of all strides and the number of out_channels for the topmost
                layer to compute the size of the first fully-connected layer.
                """
                super(DLStudio.ExperimentsWithCIFAR.Net2, self).__init__()
                self.relu = nn.ReLU()
                strides = []
                patch_size = 2
                ## conv1:
                out_ch, ker_size, conv_stride, pool_stride = 128,5,1,2
                self.conv1 = nn.Conv2d(3, out_ch, (ker_size,ker_size), padding=(ker_size-1)//2)     
                self.pool1 = nn.MaxPool2d(patch_size, pool_stride)     
                strides += (conv_stride, pool_stride)
                ## conv2:
                in_ch = out_ch
                out_ch, ker_size, conv_stride, pool_stride = 128,3,1,2
                self.conv2 = nn.Conv2d(in_ch, out_ch, ker_size, padding=(ker_size-1)//2)
                self.pool2 = nn.MaxPool2d(patch_size, pool_stride)     
                strides += (conv_stride, pool_stride)
                ## conv3:                   
                in_ch = out_ch
                out_ch, ker_size, conv_stride, pool_stride = in_ch,2,1,1
                self.conv3 = nn.Conv2d(in_ch, out_ch, ker_size, padding=1)
                self.pool3 = nn.MaxPool2d(patch_size, pool_stride)         
                ## figure out the number of nodes needed for entry into fc:
                in_size_for_fc = out_ch * (32 // np.prod(strides)) ** 2                    ## (A)
                self.in_size_for_fc = in_size_for_fc
                self.fc1 = nn.Linear(in_size_for_fc, 150)
                self.fc2 = nn.Linear(150, 100)
                self.fc3 = nn.Linear(100, 10)
        
            def forward(self, x):
                ##  We know that forward() begins its with work x shaped as (4,3,32,32) where
                ##  4 is the batch size, 3 in_channels, and where the input image size is 32x32.
                x = self.relu(self.conv1(x))  
                x = self.pool1(x)             
                x = self.relu(self.conv2(x))
                x = self.pool2(x)             
                x = self.pool3(self.relu(self.conv3(x)))
                x  =  x.view( x.shape[0], - 1 )
                x = self.relu(self.fc1( x ))
                x = self.relu(self.fc2( x ))
                x = self.fc3(x)
                return x

        def run_code_for_training(self, net, display_images=False):
            self.dl_studio.run_code_for_training(net, display_images)
            
        def save_model(self, model):
            '''
            Save the trained model to a disk file
            '''
            torch.save(model.state_dict(), self.dl_studio.path_saved_model)

        def run_code_for_testing(self, model, display_images=False):
            self.dl_studio.run_code_for_testing(model, display_images)


    ###%%%
    #####################################################################################################################
    #######################  Start Definition of Inner Class BMEnet for Illustrating Skip Connections  ##################

    class BMEnet(nn.Module):
        """
        This educational class is meant for illustrating the concepts related to the 
        use of skip connections in neural network.  It is now well known that deep
        networks are difficult to train because of the vanishing gradients problem.
        What that means is that as the depth of network increases, the loss gradients
        calculated for the early layers become more and more muted, which suppresses
        the learning of the parameters in those layers.  An important mitigation
        strategy for addressing this problem consists of creating a CNN using blocks
        with skip connections.

        With the code shown in this inner class of the module, you can now experiment with
        skip connections in a CNN to see how a deep network with this feature might improve
        the classification results.  As you will see in the code shown below, the network
        that allows you to construct a CNN with skip connections is named BMEnet.  As shown
        in the script playing_with_skip_connections.py in the Examples directory of the
        distribution, you can easily create a CNN with arbitrary depth just by using the
        "depth" constructor option for the BMEnet class.  The basic block of the network
        constructed by BMEnet is called SkipBlock which, very much like the BasicBlock in
        ResNet-18, has a couple of convolutional layers whose output is combined with the
        input to the block.
    
        Note that the value given to the "depth" constructor option for the BMEnet class
        does NOT translate directly into the actual depth of the CNN. [Again, see the script
        playing_with_skip_connections.py in the Examples directory for how to use this
        option.] The value of "depth" is translated into how many "same input and output
        channels" and the "same input and output sizes" instances of SkipBlock to use
        between successive instances of downsampling and channel-doubling instances of
        SkipBlock.
 
        Class Path: DLStudio -> BMEnet
        """
        def __init__(self, dl_studio, skip_connections=True, depth=8):
            super(DLStudio.BMEnet, self).__init__()
            self.dl_studio = dl_studio
            self.depth = depth
            image_size = dl_studio.image_size
            num_ds = 0                                 ## num_ds stands for number of downsampling steps
            self.conv = nn.Conv2d(3, 64, 3, padding=1)
            self.skip64_arr = nn.ModuleList()
            for i in range(self.depth):
                self.skip64_arr.append(DLStudio.BMEnet.SkipBlock(64, 64, skip_connections=skip_connections))
            self.skip64to128ds = DLStudio.BMEnet.SkipBlock(64, 128, downsample=True, skip_connections=skip_connections )
            num_ds += 1              
            self.skip128_arr = nn.ModuleList()
            for i in range(self.depth):
                self.skip128_arr.append(DLStudio.BMEnet.SkipBlock(128, 128, skip_connections=skip_connections))
            self.skip128to256ds = DLStudio.BMEnet.SkipBlock(128, 256, downsample=True, skip_connections=skip_connections )
            num_ds += 1
            self.skip256_arr = nn.ModuleList()
            for i in range(self.depth):
                self.skip256_arr.append(DLStudio.BMEnet.SkipBlock(256, 256, skip_connections=skip_connections))
            self.fc1 =  nn.Linear( (image_size[0]// (2 ** num_ds))  *  (image_size[1]//(2 ** num_ds))  * 256, 1000)
            self.fc2 =  nn.Linear(1000, 10)

        def forward(self, x):
            x = nn.functional.relu(self.conv(x))          
            for skip64 in self.skip64_arr:
                x = skip64(x)                
            x = self.skip64to128ds(x)
            for skip128 in self.skip128_arr:
                x = skip128(x)                
            x = self.skip128to256ds(x)
            for skip256 in self.skip256_arr:
                x = skip256(x)                
            x  =  x.view( x.shape[0], - 1 )
            x = nn.functional.relu(self.fc1(x))
            x = self.fc2(x)
            return x            


        def load_cifar_10_dataset(self):       
            self.dl_studio.load_cifar_10_dataset()

        def load_cifar_10_dataset_with_augmentation(self):             
            self.dl_studio.load_cifar_10_dataset_with_augmentation()


        class SkipBlock(nn.Module):
            """
            Class Path:   DLStudio  ->  BMEnet  ->  SkipBlock
            """            
            def __init__(self, in_ch, out_ch, downsample=False, skip_connections=True):
                super(DLStudio.BMEnet.SkipBlock, self).__init__()
                self.downsample = downsample
                self.skip_connections = skip_connections
                self.in_ch = in_ch
                self.out_ch = out_ch
                self.convo1 = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
                self.convo2 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
                self.bn1 = nn.BatchNorm2d(in_ch)
                self.bn2 = nn.BatchNorm2d(out_ch)
                self.in2out  =  nn.Conv2d(in_ch, out_ch, 1)       
                if downsample:
                    ##  Setting stride to 2 and kernel_size to 1 amounts to retaining every
                    ##  other pixel in the image --- which halves the size of the image:
                    self.downsampler1 = nn.Conv2d(in_ch, in_ch, 1, stride=2)
                    self.downsampler2 = nn.Conv2d(out_ch, out_ch, 1, stride=2)

            def forward(self, x):
                identity = x                                     
                out = self.convo1(x)                              
                out = self.bn1(out)                              
                out = nn.functional.relu(out)
                out = self.convo2(out)                              
                out = self.bn2(out)                              
                out = nn.functional.relu(out)
                if self.downsample:
                    identity = self.downsampler1(identity)
                    out = self.downsampler2(out)
                if self.skip_connections:
                    if (self.in_ch == self.out_ch) and (self.downsample is False):
                        out = out + identity
                    elif (self.in_ch != self.out_ch) and (self.downsample is False):
                        identity = self.in2out( identity )     
                        out = out + identity
                    elif (self.in_ch != self.out_ch) and (self.downsample is True):
                        out = out + torch.cat((identity, identity), dim=1)
                return out

        def run_code_for_training(self, net, display_images=False):        
            self.dl_studio.run_code_for_training(net, display_images)
            
        def save_model(self, model):
            '''
            Save the trained model to a disk file
            '''
            torch.save(model.state_dict(), self.dl_studio.path_saved_model)

        def run_code_for_testing(self, model, display_images=False):
            self.dl_studio.run_code_for_testing(model, display_images=False)


    ###%%%
    #####################################################################################################################
    ################################  Start Definition of Inner Class CustomDataLoading  ################################

    class CustomDataLoading(nn.Module):             
        """
        This is a testbed for experimenting with a completely grounds-up attempt at
        designing a custom data loader.  Ordinarily, if the basic format of how the dataset
        is stored is similar to one of the datasets that the Torchvision module knows about,
        you can go ahead and use that for your own dataset.  At worst, you may need to carry
        out some light customizations depending on the number of classes involved, etc.

        However, if the underlying dataset is stored in a manner that does not look like
        anything in Torchvision, you have no choice but to supply yourself all of the data
        loading infrastructure.  That is what this inner class of the main DLStudio class 
        is all about.

        The custom data loading exercise here is related to a dataset called PurdueShapes5
        that contains 32x32 images of binary shapes belonging to the following five classes:

                       1.  rectangle
                       2.  triangle
                       3.  disk
                       4.  oval
                       5.  star

        The dataset was generated by randomizing the sizes and the orientations of these
        five patterns.  Since the patterns are rotated with a very simple non-interpolating
        transform, just the act of random rotations can introduce boundary and even interior
        noise in the patterns.

        Each 32x32 image is stored in the dataset as the following list:

                           [R, G, B, Bbox, Label]
        where
                R     :   is a 1024 element list of the values for the red component
                          of the color at all the pixels
           
                B     :   the same as above but for the green component of the color

                G     :   the same as above but for the blue component of the color

                Bbox  :   a list like [x1,y1,x2,y2] that defines the bounding box 
                          for the object in the image
           
                Label :   the shape of the object

        I serialize the dataset with Python's pickle module and then compress it with the
        gzip module.

        You will find the following dataset directories in the "data" subdirectory of
        Examples in the DLStudio distro:

               PurdueShapes5-10000-train.gz
               PurdueShapes5-1000-test.gz
               PurdueShapes5-20-train.gz
               PurdueShapes5-20-test.gz               

        The number that follows the main name string "PurdueShapes5-" is for the number of
        images in the dataset.

        You will find the last two datasets, with 20 images each, useful for debugging your
        logic for object detection and bounding-box regression.

        Class Path:   DLStudio  ->  CustomDataLoading
        """     
        def __init__(self, dl_studio, dataserver_train=None, dataserver_test=None, dataset_file_train=None, dataset_file_test=None):
            super(DLStudio.CustomDataLoading, self).__init__()
            self.dl_studio = dl_studio
            self.dataserver_train = dataserver_train
            self.dataserver_test = dataserver_test

        class PurdueShapes5Dataset(torch.utils.data.Dataset):
            """
            Class Path:   DLStudio  ->  CustomDataLoading  ->  PurdueShapes5Dataset
            """
            def __init__(self, dl_studio, train_or_test, dataset_file):
                super(DLStudio.CustomDataLoading.PurdueShapes5Dataset, self).__init__()
                if train_or_test == 'train' and dataset_file == "PurdueShapes5-10000-train.gz":
                    if os.path.exists("torch_saved_PurdueShapes5-10000_dataset.pt") and \
                              os.path.exists("torch_saved_PurdueShapes5_label_map.pt"):
                        print("\nLoading training data from the torch-saved archive")
                        self.dataset = torch.load("torch_saved_PurdueShapes5-10000_dataset.pt")
                        self.label_map = torch.load("torch_saved_PurdueShapes5_label_map.pt")
                        # reverse the key-value pairs in the label dictionary:
                        self.class_labels = dict(map(reversed, self.label_map.items()))
                    else: 
                        print("""\n\n\nLooks like this is the first time you will be loading in\n"""
                              """the dataset for this script. First time loading could take\n"""
                              """a minute or so.  Any subsequent attempts will only take\n"""
                              """a few seconds.\n\n\n""")
                        root_dir = dl_studio.dataroot
                        f = gzip.open(root_dir + dataset_file, 'rb')
                        dataset = f.read()
                        self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
                        torch.save(self.dataset, "torch_saved_PurdueShapes5-10000_dataset.pt")
                        torch.save(self.label_map, "torch_saved_PurdueShapes5_label_map.pt")
                        # reverse the key-value pairs in the label dictionary:
                        self.class_labels = dict(map(reversed, self.label_map.items()))
                else:
                    root_dir = dl_studio.dataroot
                    f = gzip.open(root_dir + dataset_file, 'rb')
                    dataset = f.read()
                    self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
                    # reverse the key-value pairs in the label dictionary:
                    self.class_labels = dict(map(reversed, self.label_map.items()))
             
            def __len__(self):
                return len(self.dataset)

            def __getitem__(self, idx):
                r = np.array( self.dataset[idx][0] )
                g = np.array( self.dataset[idx][1] )
                b = np.array( self.dataset[idx][2] )
                R,G,B = r.reshape(32,32), g.reshape(32,32), b.reshape(32,32)
                im_tensor = torch.zeros(3,32,32, dtype=torch.float)
                im_tensor[0,:,:] = torch.from_numpy(R)
                im_tensor[1,:,:] = torch.from_numpy(G)
                im_tensor[2,:,:] = torch.from_numpy(B)
                sample = {'image' : im_tensor, 
                          'bbox' : self.dataset[idx][3],                          
                          'label' : self.dataset[idx][4] }
                return sample

        def load_PurdueShapes5_dataset(self, dataserver_train, dataserver_test ):       
            transform = tvt.Compose([tvt.ToTensor(),
                                tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])  
            self.train_dataloader = torch.utils.data.DataLoader(dataserver_train,
                               batch_size=self.dl_studio.batch_size,shuffle=True, num_workers=4)
            self.test_dataloader = torch.utils.data.DataLoader(dataserver_test,
                               batch_size=self.dl_studio.batch_size,shuffle=False, num_workers=4)


        class ECEnet(nn.Module):
            """
     
            Class Path: DLStudio -> CustomDataloading -> ECEnet
    
            """
            def __init__(self, dl_studio, skip_connections=True, depth=8):
                super(DLStudio.CustomDataLoading.ECEnet, self).__init__()
                self.dl_studio = dl_studio
                self.depth = depth
                image_size = dl_studio.image_size
                num_ds = 0                                 ## num_ds stands for number of downsampling steps
                self.conv = nn.Conv2d(3, 64, 3, padding=1)
                self.skip64_arr = nn.ModuleList()
                for i in range(self.depth):
                    self.skip64_arr.append(DLStudio.CustomDataLoading.SkipBlock2(64, 64, skip_connections=skip_connections))
                self.skip64to128ds = DLStudio.CustomDataLoading.SkipBlock2(64, 128, downsample=True, skip_connections=skip_connections )
                num_ds += 1              
                self.skip128_arr = nn.ModuleList()
                for i in range(self.depth):
                    self.skip128_arr.append(DLStudio.CustomDataLoading.SkipBlock2(128, 128, skip_connections=skip_connections))
                self.skip128to256ds = DLStudio.CustomDataLoading.SkipBlock2(128, 256, downsample=True, skip_connections=skip_connections )
                num_ds += 1
                self.skip256_arr = nn.ModuleList()
                for i in range(self.depth):
                    self.skip256_arr.append(DLStudio.CustomDataLoading.SkipBlock2(256, 256, skip_connections=skip_connections))
                self.fc1 =  nn.Linear( (image_size[0]// (2 ** num_ds))  *  (image_size[1]//(2 ** num_ds))  * 256, 1000)
                self.fc2 =  nn.Linear(1000, 10)
    
            def forward(self, x):
                x = nn.functional.relu(self.conv(x))          
                for skip64 in self.skip64_arr:
                    x = skip64(x)                
                x = self.skip64to128ds(x)
                for skip128 in self.skip128_arr:
                    x = skip128(x)                
                x = self.skip128to256ds(x)
                for skip256 in self.skip256_arr:
                    x = skip256(x)                
                x  =  x.view( x.shape[0], - 1 )
                x = nn.functional.relu(self.fc1(x))
                x = self.fc2(x)
                return x            
    
    
            def load_cifar_10_dataset(self):       
                self.dl_studio.load_cifar_10_dataset()
    
            def load_cifar_10_dataset_with_augmentation(self):             
                self.dl_studio.load_cifar_10_dataset_with_augmentation()
    
    
        class SkipBlock2(nn.Module):
            """
            Class Path:   DLStudio  ->  CustomDataloading  ->  SkipBlock
            """            
            def __init__(self, in_ch, out_ch, downsample=False, skip_connections=True):
                super(DLStudio.CustomDataLoading.SkipBlock2, self).__init__()
                self.downsample = downsample
                self.skip_connections = skip_connections
                self.in_ch = in_ch
                self.out_ch = out_ch
                self.convo1 = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
                self.convo2 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
                self.bn1 = nn.BatchNorm2d(in_ch)
                self.bn2 = nn.BatchNorm2d(out_ch)
                self.in2out  =  nn.Conv2d(in_ch, out_ch, 1)       
                if downsample:
                    ##  Setting stride to 2 and kernel_size to 1 amounts to retaining every
                    ##  other pixel in the image --- which halves the size of the image:
                    self.downsampler1 = nn.Conv2d(in_ch, in_ch, 1, stride=2)
                    self.downsampler2 = nn.Conv2d(out_ch, out_ch, 1, stride=2)

            def forward(self, x):
                identity = x                                     
                out = self.convo1(x)                              
                out = self.bn1(out)                              
                out = nn.functional.relu(out)
                out = self.convo2(out)                              
                out = self.bn2(out)                              
                out = nn.functional.relu(out)
                if self.downsample:
                    identity = self.downsampler1(identity)
                    out = self.downsampler2(out)
                if self.skip_connections:
                    if (self.in_ch == self.out_ch) and (self.downsample is False):
                        out = out + identity
                    elif (self.in_ch != self.out_ch) and (self.downsample is False):
                        identity = self.in2out( identity )     
                        out = out + identity
                    elif (self.in_ch != self.out_ch) and (self.downsample is True):
                        out = out + torch.cat((identity, identity), dim=1)
                return out
    

        def run_code_for_training_with_custom_loading(self, net):        
            filename_for_out = "performance_numbers_" + str(self.dl_studio.epochs) + ".txt"
            FILE = open(filename_for_out, 'w')
            net = copy.deepcopy(net)
            net = net.to(self.dl_studio.device)
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.SGD(net.parameters(), 
                         lr=self.dl_studio.learning_rate, momentum=self.dl_studio.momentum)
            for epoch in range(self.dl_studio.epochs):  
                running_loss = 0.0
                for i, data in enumerate(self.train_dataloader):
                    inputs, bounding_box, labels = data['image'], data['bbox'], data['label']
                    if self.dl_studio.debug_train and i % 1000 == 999:
                        print("\n\n\nlabels: %s" % str(labels))
                        print("\n\n\ntype of labels: %s" % type(labels))
                        print("\n\n[iter=%d:] Ground Truth:     " % (i+1) + 
                        ' '.join('%5s' % self.dataserver_train.class_labels[labels[j].item()] for j in range(self.dl_studio.batch_size)))
                    inputs = inputs.to(self.dl_studio.device)
                    labels = labels.to(self.dl_studio.device)
                    optimizer.zero_grad()
                    outputs = net(inputs)
                    loss = criterion(outputs, labels)
                    if self.dl_studio.debug_train and i % 1000 == 999:
                        _, predicted = torch.max(outputs.data, 1)
                        print("[iter=%d:] Predicted Labels: " % (i+1) + 
                         ' '.join('%5s' % self.dataserver.class_labels[predicted[j]] 
                                           for j in range(self.dl_studio.batch_size)))
                        self.dl_studio.display_tensor_as_image(torchvision.utils.make_grid(
             inputs, normalize=True), "see terminal for TRAINING results at iter=%d" % (i+1))
                    loss.backward()
                    optimizer.step()
                    running_loss += loss.item()
                    if i % 1000 == 999:    
                        avg_loss = running_loss / float(1000)
                        print("[epoch:%d, batch:%5d] loss: %.3f" % (epoch + 1, i + 1, avg_loss))
                        FILE.write("%.3f\n" % avg_loss)
                        FILE.flush()
                        running_loss = 0.0
            print("\nFinished Training\n")
            self.save_model(net)
            
        def save_model(self, model):
            '''
            Save the trained model to a disk file
            '''
            torch.save(model.state_dict(), self.dl_studio.path_saved_model)

        def run_code_for_testing_with_custom_loading(self, net):
            net.load_state_dict(torch.load(self.dl_studio.path_saved_model))
            correct = 0
            total = 0
            confusion_matrix = torch.zeros(len(self.dataserver_train.class_labels), 
                                           len(self.dataserver_train.class_labels))
            class_correct = [0] * len(self.dataserver_train.class_labels)
            class_total = [0] * len(self.dataserver_train.class_labels)
            with torch.no_grad():
                for i, data in enumerate(self.test_dataloader):
                    images, bounding_box, labels = data['image'], data['bbox'], data['label']
                    labels = labels.tolist()
                    if self.dl_studio.debug_test and i % 1000 == 0:
                        print("\n\n[i=%d:] Ground Truth:     " %i + ' '.join('%10s' % 
                          self.dataserver_train.class_labels[labels[j]] for j in range(self.dl_studio.batch_size)))
                    outputs = net(images)
                    ##  max() returns two things: the max value and its index in the 10 element
                    ##  output vector.  We are only interested in the index --- since that is 
                    ##  essentially the predicted class label:
                    _, predicted = torch.max(outputs.data, 1)
                    predicted = predicted.tolist()
                    if self.dl_studio.debug_test and i % 1000 == 0:
                        print("[i=%d:] Predicted Labels: " %i + ' '.join('%10s' % 
                          self.dataserver_train.class_labels[predicted[j]] for j in range(self.dl_studio.batch_size)))
                        self.dl_studio.display_tensor_as_image(
                              torchvision.utils.make_grid(images, normalize=True), 
                              "see terminal for test results at i=%d" % i)
                    for label,prediction in zip(labels,predicted):
                        confusion_matrix[label][prediction] += 1
                    total += len(labels)
                    correct +=  [predicted[ele] == labels[ele] for ele in range(len(predicted))].count(True)
                    comp = [predicted[ele] == labels[ele] for ele in range(len(predicted))]
                    for j in range(self.dl_studio.batch_size):
                        label = labels[j]
                        class_correct[label] += comp[j]
                        class_total[label] += 1
            print("\n")
            for j in range(len(self.dataserver_train.class_labels)):
                print('Prediction accuracy for %5s : %2d %%' % (
              self.dataserver_train.class_labels[j], 100 * class_correct[j] / class_total[j]))
            print("\n\n\nOverall accuracy of the network on the 10000 test images: %d %%" % 
                                                                   (100 * correct / float(total)))
            print("\n\nDisplaying the confusion matrix:\n")
            out_str = "                "
            for j in range(len(self.dataserver_train.class_labels)):  
                                 out_str +=  "%15s" % self.dataserver_train.class_labels[j]   
            print(out_str + "\n")
            for i,label in enumerate(self.dataserver_train.class_labels):
                out_percents = [100 * confusion_matrix[i,j] / float(class_total[i]) 
                                 for j in range(len(self.dataserver_train.class_labels))]
                out_percents = ["%.2f" % item.item() for item in out_percents]
                out_str = "%12s:  " % self.dataserver_train.class_labels[i]
                for j in range(len(self.dataserver_train.class_labels)): 
                                                       out_str +=  "%15s" % out_percents[j]
                print(out_str)
    
    ###%%%
    #####################################################################################################################
    ##################################  Start Definition of Inner Class DetectAndLocalize  ##############################

    class DetectAndLocalize(nn.Module):             
        """
        The purpose of this inner class is to focus on object detection in images --- as
        opposed to image classification.  Most people would say that object detection is a
        more challenging problem than image classification because, in general, the former
        also requires localization.  The simplest interpretation of what is meant by
        localization is that the code that carries out object detection must also output a
        bounding-box rectangle for the object that was detected.

        You will find in this inner class some examples of LOADnet classes meant for solving
        the object detection and localization problem.  The acronym "LOAD" in "LOADnet"
        stands for

                    "LOcalization And Detection"

        The different network examples included here are LOADnet1, LOADnet2, and LOADnet3.
        For now, only pay attention to LOADnet2 since that's the class I have worked with
        the most for the 1.0.7 distribution.

        Class Path:   DLStudio  ->  DetectAndLocalize
        """
        def __init__(self, dl_studio, dataserver_train=None, dataserver_test=None, dataset_file_train=None, dataset_file_test=None):
            super(DLStudio.DetectAndLocalize, self).__init__()
            self.dl_studio = dl_studio
            self.dataserver_train = dataserver_train
            self.dataserver_test = dataserver_test
            self.debug = False

        class PurdueShapes5Dataset(torch.utils.data.Dataset):
            """
            Class Path:   DLStudio  ->  DetectAndLocalize  ->  PurdueShapes5Dataset
            """
            def __init__(self, dl_studio, train_or_test, dataset_file):
                super(DLStudio.DetectAndLocalize.PurdueShapes5Dataset, self).__init__()
                if train_or_test == 'train' and dataset_file == "PurdueShapes5-10000-train.gz":
                    if os.path.exists("torch-saved-PurdueShapes5-10000-dataset.pt") and \
                              os.path.exists("torch-saved-PurdueShapes5-label-map.pt"):
                        print("\nLoading training data from the torch-saved archive")
                        self.dataset = torch.load("torch-saved-PurdueShapes5-10000-dataset.pt")
                        self.label_map = torch.load("torch-saved-PurdueShapes5-label-map.pt")
                        # reverse the key-value pairs in the label dictionary:
                        self.class_labels = dict(map(reversed, self.label_map.items()))
                    else: 
                        print("""\n\n\nLooks like this is the first time you will be loading in\n"""
                              """the dataset for this script. First time loading could take\n"""
                              """a minute or so.  Any subsequent attempts will only take\n"""
                              """a few seconds.\n\n\n""")
                        root_dir = dl_studio.dataroot
                        f = gzip.open(root_dir + dataset_file, 'rb')
                        dataset = f.read()
                        if sys.version_info[0] == 3:
                            self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
                        else:
                            self.dataset, self.label_map = pickle.loads(dataset)
                        torch.save(self.dataset, "torch-saved-PurdueShapes5-10000-dataset.pt")
                        torch.save(self.label_map, "torch-saved-PurdueShapes5-label-map.pt")
                        # reverse the key-value pairs in the label dictionary:
                        self.class_labels = dict(map(reversed, self.label_map.items()))
                elif train_or_test == 'train' and dataset_file == "PurdueShapes5-10000-train-noise-20.gz":
                    if os.path.exists("torch-saved-PurdueShapes5-10000-dataset-noise-20.pt") and \
                              os.path.exists("torch-saved-PurdueShapes5-label-map.pt"):
                        print("\nLoading training data from the torch-saved archive")
                        self.dataset = torch.load("torch-saved-PurdueShapes5-10000-dataset-noise-20.pt")
                        self.label_map = torch.load("torch-saved-PurdueShapes5-label-map.pt")
                        # reverse the key-value pairs in the label dictionary:
                        self.class_labels = dict(map(reversed, self.label_map.items()))
                    else: 
                        print("""\n\n\nLooks like this is the first time you will be loading in\n"""
                              """the dataset for this script. First time loading could take\n"""
                              """a minute or so.  Any subsequent attempts will only take\n"""
                              """a few seconds.\n\n\n""")
                        root_dir = dl_studio.dataroot
                        f = gzip.open(root_dir + dataset_file, 'rb')
                        dataset = f.read()
                        if sys.version_info[0] == 3:
                            self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
                        else:
                            self.dataset, self.label_map = pickle.loads(dataset)
                        torch.save(self.dataset, "torch-saved-PurdueShapes5-10000-dataset-noise-20.pt")
                        torch.save(self.label_map, "torch-saved-PurdueShapes5-label-map.pt")
                        # reverse the key-value pairs in the label dictionary:
                        self.class_labels = dict(map(reversed, self.label_map.items()))
                elif train_or_test == 'train' and dataset_file == "PurdueShapes5-10000-train-noise-50.gz":
                    if os.path.exists("torch-saved-PurdueShapes5-10000-dataset-noise-50.pt") and \
                              os.path.exists("torch-saved-PurdueShapes5-label-map.pt"):
                        print("\nLoading training data from the torch-saved archive")
                        self.dataset = torch.load("torch-saved-PurdueShapes5-10000-dataset-noise-50.pt")
                        self.label_map = torch.load("torch-saved-PurdueShapes5-label-map.pt")
                        # reverse the key-value pairs in the label dictionary:
                        self.class_labels = dict(map(reversed, self.label_map.items()))
                    else: 
                        print("""\n\n\nLooks like this is the first time you will be loading in\n"""
                              """the dataset for this script. First time loading could take\n"""
                              """a minute or so.  Any subsequent attempts will only take\n"""
                              """a few seconds.\n\n\n""")
                        root_dir = dl_studio.dataroot
                        f = gzip.open(root_dir + dataset_file, 'rb')
                        dataset = f.read()
                        if sys.version_info[0] == 3:
                            self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
                        else:
                            self.dataset, self.label_map = pickle.loads(dataset)
                        torch.save(self.dataset, "torch-saved-PurdueShapes5-10000-dataset-noise-50.pt")
                        torch.save(self.label_map, "torch-saved-PurdueShapes5-label-map.pt")
                        # reverse the key-value pairs in the label dictionary:
                        self.class_labels = dict(map(reversed, self.label_map.items()))
                elif train_or_test == 'train' and dataset_file == "PurdueShapes5-10000-train-noise-80.gz":
                    if os.path.exists("torch-saved-PurdueShapes5-10000-dataset-noise-80.pt") and \
                              os.path.exists("torch-saved-PurdueShapes5-label-map.pt"):
                        print("\nLoading training data from the torch-saved archive")
                        self.dataset = torch.load("torch-saved-PurdueShapes5-10000-dataset-noise-80.pt")
                        self.label_map = torch.load("torch-saved-PurdueShapes5-label-map.pt")
                        # reverse the key-value pairs in the label dictionary:
                        self.class_labels = dict(map(reversed, self.label_map.items()))
                    else: 
                        print("""\n\n\nLooks like this is the first time you will be loading in\n"""
                              """the dataset for this script. First time loading could take\n"""
                              """a minute or so.  Any subsequent attempts will only take\n"""
                              """a few seconds.\n\n\n""")
                        root_dir = dl_studio.dataroot
                        f = gzip.open(root_dir + dataset_file, 'rb')
                        dataset = f.read()
                        if sys.version_info[0] == 3:
                            self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
                        else:
                            self.dataset, self.label_map = pickle.loads(dataset)
                        torch.save(self.dataset, "torch-saved-PurdueShapes5-10000-dataset-noise-80.pt")
                        torch.save(self.label_map, "torch-saved-PurdueShapes5-label-map.pt")
                        # reverse the key-value pairs in the label dictionary:
                        self.class_labels = dict(map(reversed, self.label_map.items()))
                else:
                    root_dir = dl_studio.dataroot
                    f = gzip.open(root_dir + dataset_file, 'rb')
                    dataset = f.read()
                    if sys.version_info[0] == 3:
                        self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
                    else:
                        self.dataset, self.label_map = pickle.loads(dataset)
                    # reverse the key-value pairs in the label dictionary:
                    self.class_labels = dict(map(reversed, self.label_map.items()))
             
            def __len__(self):
                return len(self.dataset)

            def __getitem__(self, idx):
                r = np.array( self.dataset[idx][0] )
                g = np.array( self.dataset[idx][1] )
                b = np.array( self.dataset[idx][2] )
                R,G,B = r.reshape(32,32), g.reshape(32,32), b.reshape(32,32)
                im_tensor = torch.zeros(3,32,32, dtype=torch.float)
                im_tensor[0,:,:] = torch.from_numpy(R)
                im_tensor[1,:,:] = torch.from_numpy(G)
                im_tensor[2,:,:] = torch.from_numpy(B)
                bb_tensor = torch.tensor(self.dataset[idx][3], dtype=torch.float)
                sample = {'image' : im_tensor, 
                          'bbox' : bb_tensor,
                          'label' : self.dataset[idx][4] }
                return sample

        def load_PurdueShapes5_dataset(self, dataserver_train, dataserver_test ):       
            self.train_dataloader = torch.utils.data.DataLoader(dataserver_train,
                               batch_size=self.dl_studio.batch_size,shuffle=True, num_workers=4)
            self.test_dataloader = torch.utils.data.DataLoader(dataserver_test,
                               batch_size=self.dl_studio.batch_size,shuffle=False, num_workers=4)
    

        class SkipBlock3(nn.Module):
            """
            Class Path:   DLStudio  ->  DetectAndLocalize  ->  SkipBlock
            """            
            def __init__(self, in_ch, out_ch, downsample=False, skip_connections=True):
                super(DLStudio.DetectAndLocalize.SkipBlock3, self).__init__()
                self.downsample = downsample
                self.skip_connections = skip_connections
                self.in_ch = in_ch
                self.out_ch = out_ch
                self.convo1 = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
                self.convo2 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
                self.bn1 = nn.BatchNorm2d(in_ch)
                self.bn2 = nn.BatchNorm2d(out_ch)
                self.in2out  =  nn.Conv2d(in_ch, out_ch, 1)       
                if downsample:
                    ##  Setting stride to 2 and kernel_size to 1 amounts to retaining every
                    ##  other pixel in the image --- which halves the size of the image:
                    self.downsampler1 = nn.Conv2d(in_ch, in_ch, 1, stride=2)
                    self.downsampler2 = nn.Conv2d(out_ch, out_ch, 1, stride=2)

            def forward(self, x):
                identity = x                                     
                out = self.convo1(x)                              
                out = self.bn1(out)                              
                out = nn.functional.relu(out)
                out = self.convo2(out)                              
                out = self.bn2(out)                              
                out = nn.functional.relu(out)
                if self.downsample:
                    identity = self.downsampler1(identity)
                    out = self.downsampler2(out)
                if self.skip_connections:
                    if (self.in_ch == self.out_ch) and (self.downsample is False):
                        out = out + identity
                    elif (self.in_ch != self.out_ch) and (self.downsample is False):
                        identity = self.in2out( identity )     
                        out = out + identity
                    elif (self.in_ch != self.out_ch) and (self.downsample is True):
                        out = out + torch.cat((identity, identity), dim=1)
                return out



        class LOADnet1(nn.Module):
            """
            The acronym 'LOAD' stands for 'LOcalization And Detection'.  LOADnet1 only
            uses fully-connected layers for the regression

            Class Path:   DLStudio  ->  DetectAndLocalize  ->  LOADnet1
            """
            def __init__(self, skip_connections=True, depth=32):
                super(DLStudio.DetectAndLocalize.LOADnet1, self).__init__()
                self.pool_count = 3
                self.depth = depth // 2
                self.conv = nn.Conv2d(3, 64, 3, padding=1)
                self.skip64 = DLStudio.DetectAndLocalize.SkipBlock3(64, 64, skip_connections=skip_connections)
                self.skip64ds = DLStudio.DetectAndLocalize.SkipBlock3(64, 64, downsample=True, skip_connections=skip_connections)
                self.skip64to128 = DLStudio.DetectAndLocalize.SkipBlock3(64, 128, skip_connections=skip_connections )
                self.skip128 = DLStudio.DetectAndLocalize.SkipBlock3(128, 128, skip_connections=skip_connections)
                self.skip128ds = DLStudio.DetectAndLocalize.SkipBlock3(128,128, downsample=True, skip_connections=skip_connections)
                self.fc1 =  nn.Linear(128 * (32 // 2**self.pool_count)**2, 1000)
                self.fc2 =  nn.Linear(1000, 5)
                self.fc3 =  nn.Linear(32768, 1000)
                self.fc4 =  nn.Linear(1000, 4)

            def forward(self, x):
                x = nn.MaxPool2d(2,2)(nn.functional.relu(self.conv(x)))          
                ## The labeling section:
                for _ in range(self.depth // 4):
                    x1 = self.skip64(x)                                               
                x1 = self.skip64ds(x1)
                for _ in range(self.depth // 4):
                    x1 = self.skip64(x1)                                               
                x1 = self.skip64to128(x1)
                for _ in range(self.depth // 4):
                    x1 = self.skip128(x1)                                               
                x1 = self.skip128ds(x1)                                               
                for _ in range(self.depth // 4):
                    x1 = self.skip128(x1)                                               
                x1  =  x.view( x1.shape[0], - 1 )
                x1 = nn.functional.relu(self.fc1(x1))
                x1 = self.fc2(x1)
                ## The Bounding Box regression:
                x2 =  x.view( x.shape[0], - 1 )
                x2 = nn.functional.relu(self.fc3(x2))
                x2 = self.fc4(x2)
                return x1,x2

        class LOADnet2(nn.Module):
            """
            The acronym 'LOAD' stands for 'LOcalization And Detection'.  LOADnet2 uses
            both convo and linear layers for regression

            Class Path:   DLStudio  ->  DetectAndLocalize  ->  LOADnet2
            """ 
            def __init__(self, skip_connections=True, depth=8):
                super(DLStudio.DetectAndLocalize.LOADnet2, self).__init__()
                if depth not in [8,10,12,14,16]:
                    sys.exit("LOADnet2 has only been tested for 'depth' values 8, 10, 12, 14, and 16")
                self.depth = depth // 2
                self.conv = nn.Conv2d(3, 64, 3, padding=1)
                self.bn1  = nn.BatchNorm2d(64)
                self.bn2  = nn.BatchNorm2d(128)
                self.skip64_arr = nn.ModuleList()
                for i in range(self.depth):
                    self.skip64_arr.append(DLStudio.DetectAndLocalize.SkipBlock3(64, 64,
                                                          skip_connections=skip_connections))
                self.skip64ds = DLStudio.DetectAndLocalize.SkipBlock3(64, 64, 
                                            downsample=True, skip_connections=skip_connections)
                self.skip64to128 = DLStudio.DetectAndLocalize.SkipBlock3(64, 128, 
                                                            skip_connections=skip_connections )
                self.skip128_arr = nn.ModuleList()
                for i in range(self.depth):
                    self.skip128_arr.append(DLStudio.DetectAndLocalize.SkipBlock3(128, 128,
                                                         skip_connections=skip_connections))
                self.skip128ds = DLStudio.DetectAndLocalize.SkipBlock3(128,128,
                                            downsample=True, skip_connections=skip_connections)
                self.fc1 =  nn.Linear(2048, 1000)
                self.fc2 =  nn.Linear(1000, 5)

                ##  for regression
                self.conv_seqn = nn.Sequential(
                    nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
                    nn.BatchNorm2d(64),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
                    nn.ReLU(inplace=True)
                )
                self.fc_seqn = nn.Sequential(
                    nn.Linear(16384, 1024),
                    nn.ReLU(inplace=True),
                    nn.Linear(1024, 512),
                    nn.ReLU(inplace=True),
                    nn.Linear(512, 4)        ## output for the 4 coords (x_min,y_min,x_max,y_max) of BBox
                )

            def forward(self, x):
                x = nn.MaxPool2d(2,2)(nn.functional.relu(self.conv(x)))          
                ## The labeling section:
                x1 = x.clone()
                for i,skip64 in enumerate(self.skip64_arr[:self.depth//4]):
                    x1 = skip64(x1)                
                x1 = self.skip64ds(x1)
                for i,skip64 in enumerate(self.skip64_arr[self.depth//4:]):
                    x1 = skip64(x1)                
                x1 = self.bn1(x1)
                x1 = self.skip64to128(x1)
                for i,skip128 in enumerate(self.skip128_arr[:self.depth//4]):
                    x1 = skip128(x1)                
                x1 = self.bn2(x1)
                x1 = self.skip128ds(x1)
                for i,skip128 in enumerate(self.skip128_arr[self.depth//4:]):
                    x1 = skip128(x1)                
                x1 = x1.view( x1.shape[0], - 1 )
                x1 = nn.functional.relu(self.fc1(x1))
                x1 = self.fc2(x1)
                ## The Bounding Box regression:
                x2 = self.conv_seqn(x)
                # flatten
                x2 = x2.view( x.shape[0], - 1 )
                x2 = self.fc_seqn(x2)
                return x1,x2


        class LOADnet3(nn.Module):
            """
            The acronym 'LOAD' stands for 'LOcalization And Detection'.  LOADnet3 uses
            both convo and linear layers for regression

            Class Path:   DLStudio  ->  DetectAndLocalize  ->  LOADnet3

            """ 
            def __init__(self, skip_connections=True, depth=8):
                super(DLStudio.DetectAndLocalize.LOADnet3, self).__init__()
                if depth not in [4, 8, 16]:
                    sys.exit("LOADnet2 has been tested for 'depth' for only 4, 8, and 16")
                self.depth = depth // 4
                self.conv = nn.Conv2d(3, 64, 3, padding=1)
                self.skip64_arr = nn.ModuleList()
                for i in range(self.depth):
                    self.skip64_arr.append(DLStudio.DetectAndLocalize.SkipBlock3(64, 64,
                                                          skip_connections=skip_connections))
                self.skip64ds = DLStudio.DetectAndLocalize.SkipBlock3(64, 64, 
                                            downsample=True, skip_connections=skip_connections)
                self.skip64to128 = DLStudio.DetectAndLocalize.SkipBlock3(64, 128, 
                                                            skip_connections=skip_connections )
                self.skip128_arr = nn.ModuleList()
                for i in range(self.depth):
                    self.skip128_arr.append(DLStudio.DetectAndLocalize.SkipBlock3(128, 128,
                                                         skip_connections=skip_connections))
                self.skip128ds = DLStudio.DetectAndLocalize.SkipBlock3(128,128,
                                            downsample=True, skip_connections=skip_connections)
                self.fc1 =  nn.Linear(2048, 1000)
                self.fc2 =  nn.Linear(1000, 5)

                ##  for regression
                self.conv_seqn = nn.Sequential(
                    nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
                    nn.ReLU(inplace=True)
                )
                self.fc_seqn = nn.Sequential(
                    nn.Linear(16384, 1024),
                    nn.ReLU(inplace=True),
                    nn.Linear(1024, 512),
                    nn.ReLU(inplace=True),
                    nn.Linear(512, 4)
                )
            def forward(self, x):
                x = nn.MaxPool2d(2,2)(nn.functional.relu(self.conv(x)))          
                ## The labeling section:
                x1 = x.clone()
                for i,skip64 in enumerate(self.skip64_arr[:self.depth//4]):
                    x1 = skip64(x1)                
                x1 = self.skip64ds(x1)
                for i,skip64 in enumerate(self.skip64_arr[self.depth//4:]):
                    x1 = skip64(x1)                
                x1 = self.skip64ds(x1)
                x1 = self.skip64to128(x1)
                for i,skip128 in enumerate(self.skip128_arr[:self.depth//4]):
                    x1 = skip128(x1)                
                for i,skip128 in enumerate(self.skip128_arr[self.depth//4:]):
                    x1 = skip128(x1)                
                x1  =  x1.view( x1.shape[0], - 1 )
                x1 = nn.functional.relu(self.fc1(x1))
                x1 = self.fc2(x1)
                ## The Bounding Box regression:
                for _ in range(4):
                    x2 = self.skip64(x)                                               
                x2 = self.skip64to128(x2)
                for _ in range(4):
                    x2 = self.skip128(x2)                                               
                x2 = x.view( x.shape[0], - 1 )
                x2 = nn.functional.relu(self.fc3(x2))
                x2 = self.fc4(x2)
                return x1,x2


        class DIoULoss(nn.Module):
            """
            Class Path:   DLStudio  ->  DetectAndLocalize  ->  DIOULoss

            This is a Custom Loss Function for implementing the variants of the IoU 
            (Intersection over Union) loss as described on Slides 37 through 42 of my 
            Week 7 presentation on Object Detection and Localization.
            """
            def __init__(self, dl_studio, loss_mode):
                super(DLStudio.DetectAndLocalize.DIoULoss, self).__init__()
                self.dl_studio = dl_studio
                self.loss_mode = loss_mode

            def forward(self, predicted, target, loss_mode):
                debug = 0
                ##  We calculate the MSELoss between the predicted and the target BBs just for sanity check.
                ##  It is not used in the loss that is returned by this function [However, note that the 
                ##  d2_loss defined below is the same thing as what is returned by MSELoss]:
                displacement_loss = nn.MSELoss()(predicted, target)                                           
                ##  We call the MSELoss again, but this time with "reduction='none'".  The reason for that
                ##  is that we need to calculate the MSELoss on a per-instance basis in the batch for the
                ##  normalizations we are going to need later in our calculation of the IoU-based loss function.
                ##  The following call returns a tensor of shape (Bx4) where B is the batch size and 4
                ##  is for four numeric values in a BB vector.
                d2_loss_per_instance = nn.MSELoss(reduction='none')(predicted, target)                        
                ##  Averaging the above along Axis 1 gives us the instance based MSE Loss we want:
                d2_mean_loss_per_instance = torch.mean(d2_loss_per_instance, 1)                               
                ##  Averaging of the above along Axis 0 should give us a single scalar that would be
                ##  the same as the "displacement_loss" in the first line:
                d2_loss = torch.mean(d2_mean_loss_per_instance,0)                                             
                if debug:
                    print("\n\nMSE Loss: ", displacement_loss)
                    print("\n\nd2_loss_per_instance_in_batch: ", d2_loss_per_instance)
                    print("\n\nd2_mean_loss_per_instance_in_batch: ", d2_mean_loss_per_instance)
                    print("\n\nd2 loss: ", d2_loss)
  
                ##  Our next job is to figure out the BB for the convex hull of the predicted and target BBs. To 
                ##  thta end, we first find the upper-left corner of the convex hull by finding the infimum of the
                ##  of the min (i,j) coordinates associated with the predicted and the target BBs:
                hull_min_i  = torch.min( torch.cat( ( torch.transpose( torch.unsqueeze(predicted[:,0],0), 1,0 ),
                                                      torch.transpose( torch.unsqueeze(predicted[:,2],0), 1,0 ),
                                                         torch.transpose( torch.unsqueeze(target[:,0],0), 1,0 ),
                                                         torch.transpose( torch.unsqueeze(target[:,2],0), 1,0 ) ), 1 ), 1 )[0].type(torch.uint8)
                hull_min_j  = torch.min( torch.cat( ( torch.transpose( torch.unsqueeze(predicted[:,1],0), 1,0 ),
                                                      torch.transpose( torch.unsqueeze(predicted[:,3],0), 1,0 ),
                                                         torch.transpose( torch.unsqueeze(target[:,1],0), 1,0 ),
                                                         torch.transpose( torch.unsqueeze(target[:,3],0), 1,0 ) ), 1 ), 1 )[0].type(torch.uint8)

                ##  Next we need to find the lower-right corner of the convex hull.  We do so by finding the
                ##  supremum of the max (i,j) coordinates associated with the predicted and the target BBs:
                hull_max_i  = torch.max( torch.cat( ( torch.transpose( torch.unsqueeze(predicted[:,0],0), 1,0 ),
                                                      torch.transpose( torch.unsqueeze(predicted[:,2],0), 1,0 ),
                                                         torch.transpose( torch.unsqueeze(target[:,0],0), 1,0 ),
                                                         torch.transpose( torch.unsqueeze(target[:,2],0), 1,0 ) ), 1 ), 1 )[0].type(torch.uint8)

                hull_max_j  = torch.max( torch.cat( ( torch.transpose( torch.unsqueeze(predicted[:,1],0), 1,0 ),
                                                      torch.transpose( torch.unsqueeze(predicted[:,3],0), 1,0 ),
                                                         torch.transpose( torch.unsqueeze(target[:,1],0), 1,0 ),
                                                         torch.transpose( torch.unsqueeze(target[:,3],0), 1,0 ) ), 1 ), 1 )[0].type(torch.uint8)

                ##  We now call on the torch.cat to organize the instance-based convex_hull min and max coordinates
                ##  into what the convex-hull BB should look like for a batch.  If B is the batch size, the shape of 
                ##  convex_hull_bb should be (B, 4):
                convex_hull_bb = torch.cat( ( torch.transpose( torch.unsqueeze(hull_min_i,0), 1,0), 
                                              torch.transpose( torch.unsqueeze(hull_min_j,0), 1,0), 
                                              torch.transpose( torch.unsqueeze(hull_max_i,0), 1,0), 
                                              torch.transpose( torch.unsqueeze(hull_max_j,0), 1,0) ), 1 ).float().to(self.dl_studio.device)

                ##  Need the square of the diagonal of the convex hull for normalization:
                convex_hull_diagonal_squared  =  torch.square(convex_hull_bb[:,0] - convex_hull_bb[:,2])  +  torch.square(convex_hull_bb[:,1] - convex_hull_bb[:,3])

                ##  Since we will be using the BB corners for indexing, we need to convert them into ints:
                predicted = predicted.type(torch.uint8)
                target = target.type(torch.uint8)
                convex_hull_bb = convex_hull_bb.type(torch.uint8)

                ##  Our next job is to convert all three BBs --- predicted, target, and convex_hull --- into binary
                ##  for set operations of union, intersection, and the set-difference of the union from the 
                ##  convex hull.  We start by initializing the three arras for each instance in the batch:
                img_size = self.dl_studio.image_size
                predicted_arr = torch.zeros(predicted.shape[0], img_size[0], img_size[1]).to(self.dl_studio.device)    
                target_arr = torch.zeros(predicted.shape[0], img_size[0], img_size[1]).to(self.dl_studio.device)       
                convex_hull_arr = torch.zeros(predicted.shape[0], img_size[0], img_size[1]).to(self.dl_studio.device)  
                ##  We fill the three arrays --- predicted, target, and convex_hull --- according to their respective BBs:
                for k in range(predicted_arr.shape[0]):                                                            
                    predicted_arr[ k, predicted[k,0]:predicted[k,2],  predicted[k,1]:predicted[k,3] ] = 1         
                    target_arr[ k, target[k,0]:target[k,2],  target[k,1]:target[k,3] ] = 1         
                    convex_hull_arr[ k, convex_hull_bb[k,0]:convex_hull_bb[k,2],  convex_hull_bb[k,1]:convex_hull_bb[k,3] ] = 1         
                ##  We are ready for the set operations:
                intersection_arr = predicted_arr * target_arr                                                     
                intersecs = torch.sum( intersection_arr, dim=(1,2) )                                              
                union_arr = torch.logical_or( predicted_arr > 0, target_arr > 0 ).type(torch.uint8)               
                unions = torch.sum( union_arr, dim=(1,2) )                                                        
                ## find the set difference of the convex hull and the union for each batch instance:
                diff_arr = (convex_hull_arr !=  union_arr).type(torch.uint8)
                ## what's the total number of pixels in the the set difference:            
                diff_sum_per_instance = torch.sum( diff_arr, dim=(1,2) )
                ## also, what is the total number of pixels in the convex hull for each batch instance:
                convex_hull_sum_per_instance = torch.sum( convex_hull_arr, dim=(1,2) )
                if  (convex_hull_sum_per_instance < 10).any(): return torch.tensor([float('nan')])
                ## find the ratio we need for the DIoU formula [see Eq. (8) on Slide 40 of my Week 7 slides]:
                epsilon = 1e-6
                ratio = diff_sum_per_instance.type(torch.float) / (convex_hull_sum_per_instance.type(torch.float) + epsilon) 
                ## find the IoU            
                iou = intersecs / (unions + epsilon)                          
                iou_loss = torch.mean(1 - iou, 0)                             
                d2_normed = d2_mean_loss_per_instance / (convex_hull_diagonal_squared + epsilon)     
                d2_normed_loss = torch.mean(d2_normed, 0)        
                ratio_loss  =  torch.mean( ratio, 0 )
                if self.loss_mode == 'd2':
                    diou_loss =  d2_loss                         
                elif self.loss_mode == 'diou1':
                    diou_loss = iou_loss + d2_loss               
                elif self.loss_mode == 'diou2':
                    diou_loss = iou_loss + d2_normed_loss        
                elif self.loss_mode == 'diou3':
                    diou_loss = iou_loss + d2_normed_loss + ratio_loss
                return diou_loss



        def run_code_for_training_with_iou_regression(self, net, loss_mode='d2', show_images=True):
            """
            This training routine is called by

                     object_detection_and_localization_iou.py

            in the Examples directory.

            The possible values for loss_mode are:  'd2', 'diou1', 'diou2', 'diou3' with the following meanings:

            d2     :   d2_loss                                         This is just the MSE loss based on the square
                                                                       of the distance between the centers of the 
                                                                       predicted BB and the ground-truth BB.

            diou1  :   iou_loss   +   d2_loss                          We add to the pure IOU loss the value d2_loss
                                                                       defined above
          
            diou2  :   iou_loss   +   d2_normed_loss                   We now normalize the squared distance between the
                                                                       centers of the predicted BB and ground_truth BB by
                                                                       the diagonal of the convex hull of the two BBs.

            diou3  :   iou_loss   +   d2_normed_loss + ratio_loss      We now normalize the 


            IMPORTANT NOTE:  You are likely to get the best results if you set the learning rate to 1e-4 for d2 and 
                             diou1 options.  If the option you use is diou2 or diou3, set the learning rate to 5e-3
            """
            filename_for_out1 = "performance_numbers_" + str(self.dl_studio.epochs) + "label.txt"
            filename_for_out2 = "performance_numbers_" + str(self.dl_studio.epochs) + "regres.txt"
            FILE1 = open(filename_for_out1, 'w')
            FILE2 = open(filename_for_out2, 'w')
            net = copy.deepcopy(net)
            net = net.to(self.dl_studio.device)
            criterion1 = nn.CrossEntropyLoss()
            criterion2 = self.dl_studio.DetectAndLocalize.DIoULoss(self.dl_studio, loss_mode)
            optimizer = optim.SGD(net.parameters(), lr=self.dl_studio.learning_rate, momentum=self.dl_studio.momentum)
            print("\n\nStarting training loop...\n\n")
            start_time = time.perf_counter()
            labeling_loss_tally = []   
            regression_loss_tally = [] 
            elapsed_time = 0.0   
            for epoch in range(self.dl_studio.epochs):  
                print("")
                running_loss_labeling = 0.0
                running_loss_regression = 0.0       
                for i, data in enumerate(self.train_dataloader):
                    gt_too_small = False
                    inputs, bbox_gt, labels = data['image'], data['bbox'], data['label']
                    inputs = inputs.to(self.dl_studio.device)
                    labels = labels.to(self.dl_studio.device)
                    bbox_gt = bbox_gt.to(self.dl_studio.device)
                    optimizer.zero_grad()
                    if self.debug:
                        self.dl_studio.display_tensor_as_image(
                          torchvision.utils.make_grid(inputs.cpu(), nrow=4, normalize=True, padding=2, pad_value=10))
                    outputs = net(inputs)
                    outputs_label = outputs[0]
                    bbox_pred = outputs[1]
                    if i % 500 == 499:
                        current_time = time.perf_counter()
                        elapsed_time = current_time - start_time
                        print("\n\n\n[epoch:%d/%d  iter=%4d  elapsed_time=%5d secs]      Ground Truth:     " % 
                                 (epoch+1, self.dl_studio.epochs, i+1, elapsed_time) 
                               + ' '.join('%10s' % self.dataserver_train.class_labels[labels[j].item()] 
                                                                for j in range(self.dl_studio.batch_size)))
                        inputs_copy = inputs.detach().clone()
                        inputs_copy = inputs_copy.cpu()
                        bbox_pc = bbox_pred.detach().clone()
                        bbox_pc[bbox_pc<0] = 0
                        bbox_pc[bbox_pc>31] = 31
                        bbox_pc[torch.isnan(bbox_pc)] = 0
                        _, predicted = torch.max(outputs_label.data, 1)
                        print("[epoch:%d/%d  iter=%4d  elapsed_time=%5d secs]  Predicted Labels:     " % 
                                (epoch+1, self.dl_studio.epochs, i+1, elapsed_time)  
                              + ' '.join('%10s' % self.dataserver_train.class_labels[predicted[j].item()] 
                                                                 for j in range(self.dl_studio.batch_size)))
                        if show_images == True:
                            for idx in range(self.dl_studio.batch_size):
                                i1 = int(bbox_gt[idx][1])
                                i2 = int(bbox_gt[idx][3])
                                j1 = int(bbox_gt[idx][0])
                                j2 = int(bbox_gt[idx][2])
                                k1 = int(bbox_pc[idx][1])
                                k2 = int(bbox_pc[idx][3])
                                l1 = int(bbox_pc[idx][0])
                                l2 = int(bbox_pc[idx][2])
                                print("                    gt_bb:  [%d,%d,%d,%d]"%(i1,j1,i2,j2))        
                                print("                  pred_bb:  [%d,%d,%d,%d]"%(k1,l1,k2,l2))
                                inputs_copy[idx,1,i1:i2,j1] = 255
                                inputs_copy[idx,1,i1:i2,j2] = 255
                                inputs_copy[idx,1,i1,j1:j2] = 255
                                inputs_copy[idx,1,i2,j1:j2] = 255
                                inputs_copy[idx,0,k1:k2,l1] = 255                      
                                inputs_copy[idx,0,k1:k2,l2] = 255
                                inputs_copy[idx,0,k1,l1:l2] = 255
                                inputs_copy[idx,0,k2,l1:l2] = 255
    
                    loss_labeling = criterion1(outputs_label, labels)
                    loss_labeling.backward(retain_graph=True)        
                    loss_regression = criterion2(bbox_pred, bbox_gt, loss_mode)
                    if torch.isnan(loss_regression): continue
                    loss_regression.backward()
                    optimizer.step()
                    running_loss_labeling += loss_labeling.item()    
                    running_loss_regression += loss_regression.item()                
                    if i % 500 == 499:    
                        avg_loss_labeling = running_loss_labeling / float(500)
                        avg_loss_regression = running_loss_regression / float(500)
                        labeling_loss_tally.append(avg_loss_labeling)  
                        regression_loss_tally.append(avg_loss_regression)    
                        print("[epoch:%d/%d  iter=%4d  elapsed_time=%5d secs]       loss_labeling %.3f        loss_regression: %.3f " %  (epoch+1, self.dl_studio.epochs, i+1, elapsed_time, avg_loss_labeling, avg_loss_regression))
                        FILE1.write("%.3f\n" % avg_loss_labeling)
                        FILE1.flush()
                        FILE2.write("%.3f\n" % avg_loss_regression)
                        FILE2.flush()
                        running_loss_labeling = 0.0
                        running_loss_regression = 0.0

                    if show_images == True:
                        if i%500==499:
                            logger = logging.getLogger()
                            old_level = logger.level
                            logger.setLevel(100)
                            plt.figure(figsize=[8,3])
                            plt.imshow(np.transpose(torchvision.utils.make_grid(inputs_copy, normalize=True,
                                                                             padding=3, pad_value=255).cpu(), (1,2,0)))
                            plt.show()
                            logger.setLevel(old_level)
            print("\nFinished Training\n")
            self.save_model(net)
            plt.figure(figsize=(10,5))
            plt.title("Labeling Loss vs. Iterations")
            plt.plot(labeling_loss_tally)
            plt.xlabel("iterations")
            plt.ylabel("labeling loss")
#            plt.legend()
            plt.legend(["Plot of loss versus iterations"], fontsize="x-large")
            plt.savefig("labeling_loss.png")
            plt.show()
            plt.title("regression Loss vs. Iterations")
            plt.plot(regression_loss_tally)
            plt.xlabel("iterations")
            plt.ylabel("regression loss")
#            plt.legend()
            plt.legend(["Plot of loss versus iterations"], fontsize="x-large")
            plt.savefig("regression_loss.png")
            plt.show()



        def run_code_for_training_with_CrossEntropy_and_MSE_Losses(self, net, show_images=True):        
            """
            This training routine is called by

                     object_detection_and_localization.py

            in the Examples directory.
            """
            filename_for_out1 = "performance_numbers_" + str(self.dl_studio.epochs) + "label.txt"
            filename_for_out2 = "performance_numbers_" + str(self.dl_studio.epochs) + "regres.txt"
            FILE1 = open(filename_for_out1, 'w')
            FILE2 = open(filename_for_out2, 'w')
            net = copy.deepcopy(net)
            net = net.to(self.dl_studio.device)
            criterion1 = nn.CrossEntropyLoss()
            criterion2 = nn.MSELoss()
            optimizer = optim.SGD(net.parameters(), lr=self.dl_studio.learning_rate, momentum=self.dl_studio.momentum)
            print("\n\nStarting training loop...\n\n")
            start_time = time.perf_counter()
            labeling_loss_tally = []   
            regression_loss_tally = [] 
            elapsed_time = 0.0   
            for epoch in range(self.dl_studio.epochs):  
                print("")
                running_loss_labeling = 0.0
                running_loss_regression = 0.0       
                for i, data in enumerate(self.train_dataloader):
                    gt_too_small = False
                    inputs, bbox_gt, labels = data['image'], data['bbox'], data['label']
                    if i % 500 == 499:
                        current_time = time.perf_counter()
                        elapsed_time = current_time - start_time
                        print("\n\n\n[epoch:%d/%d  iter=%4d  elapsed_time=%5d secs]      Ground Truth:     " % 
                                 (epoch+1, self.dl_studio.epochs, i+1, elapsed_time) 
                               + ' '.join('%10s' % self.dataserver_train.class_labels[labels[j].item()] 
                                                                for j in range(self.dl_studio.batch_size)))
                    inputs = inputs.to(self.dl_studio.device)
                    labels = labels.to(self.dl_studio.device)
                    bbox_gt = bbox_gt.to(self.dl_studio.device)
                    optimizer.zero_grad()
                    if self.debug:
                        self.dl_studio.display_tensor_as_image(
                          torchvision.utils.make_grid(inputs.cpu(), nrow=4, normalize=True, padding=2, pad_value=10))
                    outputs = net(inputs)
                    outputs_label = outputs[0]
                    bbox_pred = outputs[1]
                    if i % 500 == 499:
                        inputs_copy = inputs.detach().clone()
                        inputs_copy = inputs_copy.cpu()
                        bbox_pc = bbox_pred.detach().clone()
                        bbox_pc[bbox_pc<0] = 0
                        bbox_pc[bbox_pc>31] = 31
                        bbox_pc[torch.isnan(bbox_pc)] = 0
                        _, predicted = torch.max(outputs_label.data, 1)
                        print("[epoch:%d/%d  iter=%4d  elapsed_time=%5d secs]  Predicted Labels:     " % 
                                (epoch+1, self.dl_studio.epochs, i+1, elapsed_time)  
                              + ' '.join('%10s' % self.dataserver_train.class_labels[predicted[j].item()] 
                                                                 for j in range(self.dl_studio.batch_size)))
                        for idx in range(self.dl_studio.batch_size):
                            i1 = int(bbox_gt[idx][1])
                            i2 = int(bbox_gt[idx][3])
                            j1 = int(bbox_gt[idx][0])
                            j2 = int(bbox_gt[idx][2])
                            k1 = int(bbox_pc[idx][1])
                            k2 = int(bbox_pc[idx][3])
                            l1 = int(bbox_pc[idx][0])
                            l2 = int(bbox_pc[idx][2])
                            print("                    gt_bb:  [%d,%d,%d,%d]"%(i1,j1,i2,j2))
                            print("                  pred_bb:  [%d,%d,%d,%d]"%(k1,l1,k2,l2))
                            inputs_copy[idx,1,i1:i2,j1] = 255
                            inputs_copy[idx,1,i1:i2,j2] = 255
                            inputs_copy[idx,1,i1,j1:j2] = 255
                            inputs_copy[idx,1,i2,j1:j2] = 255
                            inputs_copy[idx,0,k1:k2,l1] = 255                      
                            inputs_copy[idx,0,k1:k2,l2] = 255
                            inputs_copy[idx,0,k1,l1:l2] = 255
                            inputs_copy[idx,0,k2,l1:l2] = 255
                    loss_labeling = criterion1(outputs_label, labels)
                    loss_labeling.backward(retain_graph=True)        
                    loss_regression = criterion2(bbox_pred, bbox_gt)
                    loss_regression.backward()
                    optimizer.step()
                    running_loss_labeling += loss_labeling.item()    
                    running_loss_regression += loss_regression.item()                
                    if i % 500 == 499:    
                        avg_loss_labeling = running_loss_labeling / float(500)
                        avg_loss_regression = running_loss_regression / float(500)
                        labeling_loss_tally.append(avg_loss_labeling)  
                        regression_loss_tally.append(avg_loss_regression)    
                        print("[epoch:%d/%d  iter=%4d  elapsed_time=%5d secs]       loss_labeling %.3f        loss_regression: %.3f " %  (epoch+1, self.dl_studio.epochs, i+1, elapsed_time, avg_loss_labeling, avg_loss_regression))
                        FILE1.write("%.3f\n" % avg_loss_labeling)
                        FILE1.flush()
                        FILE2.write("%.3f\n" % avg_loss_regression)
                        FILE2.flush()
                        running_loss_labeling = 0.0
                        running_loss_regression = 0.0
#                    if i%500==499:
                    if i%500==499 and show_images is True:
                        logger = logging.getLogger()
                        old_level = logger.level
                        logger.setLevel(100)
                        plt.figure(figsize=[8,3])
                        plt.imshow(np.transpose(torchvision.utils.make_grid(inputs_copy, normalize=True,
                                                                         padding=3, pad_value=255).cpu(), (1,2,0)))
                        plt.show()
                        logger.setLevel(old_level)
            print("\nFinished Training\n")
            self.save_model(net)
            plt.figure(figsize=(10,5))
            plt.title("Labeling Loss vs. Iterations")
            plt.plot(labeling_loss_tally)
            plt.xlabel("iterations")
            plt.ylabel("labeling loss")
#            plt.legend()
            plt.legend(["Plot of loss versus iterations"], fontsize="x-large")
            plt.savefig("labeling_loss.png")
            plt.show()
            plt.title("regression Loss vs. Iterations")
            plt.plot(regression_loss_tally)
            plt.xlabel("iterations")
            plt.ylabel("regression loss")
#            plt.legend()
            plt.legend(["Plot of loss versus iterations"], fontsize="x-large")
            plt.savefig("regression_loss.png")
            plt.show()


        def save_model(self, model):
            '''
            Save the trained model to a disk file
            '''
            torch.save(model.state_dict(), self.dl_studio.path_saved_model)


        def run_code_for_testing_detection_and_localization(self, net):
            net.load_state_dict(torch.load(self.dl_studio.path_saved_model))
            correct = 0
            total = 0
            confusion_matrix = torch.zeros(len(self.dataserver_train.class_labels), 
                                           len(self.dataserver_train.class_labels))
            class_correct = [0] * len(self.dataserver_train.class_labels)
            class_total = [0] * len(self.dataserver_train.class_labels)
            with torch.no_grad():
                for i, data in enumerate(self.test_dataloader):
                    images, bounding_box, labels = data['image'], data['bbox'], data['label']
                    labels = labels.tolist()
                    if self.dl_studio.debug_test and i % 50 == 0:
                        print("\n\n[i=%d:] Ground Truth:     " %i + ' '.join('%10s' % 
                         self.dataserver_train.class_labels[labels[j]] for j in range(self.dl_studio.batch_size)))
                    outputs = net(images)
                    outputs_label = outputs[0]
                    outputs_regression = outputs[1]
                    outputs_regression[outputs_regression < 0] = 0
                    outputs_regression[outputs_regression > 31] = 31
                    outputs_regression[torch.isnan(outputs_regression)] = 0
                    output_bb = outputs_regression.tolist()
                    _, predicted = torch.max(outputs_label.data, 1)
                    predicted = predicted.tolist()
                    if self.dl_studio.debug_test and i % 50 == 0:
                        print("[i=%d:] Predicted Labels: " %i + ' '.join('%10s' % 
                              self.dataserver_train.class_labels[predicted[j]] for j in range(self.dl_studio.batch_size)))
                        for idx in range(self.dl_studio.batch_size):
                            i1 = int(bounding_box[idx][1])
                            i2 = int(bounding_box[idx][3])
                            j1 = int(bounding_box[idx][0])
                            j2 = int(bounding_box[idx][2])
                            k1 = int(output_bb[idx][1])
                            k2 = int(output_bb[idx][3])
                            l1 = int(output_bb[idx][0])
                            l2 = int(output_bb[idx][2])
                            print("                    gt_bb:  [%d,%d,%d,%d]"%(j1,i1,j2,i2))
                            print("                  pred_bb:  [%d,%d,%d,%d]"%(l1,k1,l2,k2))
                            images[idx,0,i1:i2,j1] = 255
                            images[idx,0,i1:i2,j2] = 255
                            images[idx,0,i1,j1:j2] = 255
                            images[idx,0,i2,j1:j2] = 255
                            images[idx,2,k1:k2,l1] = 255                      
                            images[idx,2,k1:k2,l2] = 255
                            images[idx,2,k1,l1:l2] = 255
                            images[idx,2,k2,l1:l2] = 255
                        logger = logging.getLogger()
                        old_level = logger.level
                        logger.setLevel(100)
                        plt.figure(figsize=[8,3])
                        plt.imshow(np.transpose(torchvision.utils.make_grid(images, normalize=True,
                                                                         padding=3, pad_value=255).cpu(), (1,2,0)))
                        plt.show()
                        logger.setLevel(old_level)
                    for label,prediction in zip(labels,predicted):
                        confusion_matrix[label][prediction] += 1
                    total += len(labels)
                    correct +=  [predicted[ele] == labels[ele] for ele in range(len(predicted))].count(True)
                    comp = [predicted[ele] == labels[ele] for ele in range(len(predicted))]
                    for j in range(self.dl_studio.batch_size):
                        label = labels[j]
                        class_correct[label] += comp[j]
                        class_total[label] += 1
            print("\n")
            for j in range(len(self.dataserver_train.class_labels)):
                print('Prediction accuracy for %5s : %2d %%' % (
              self.dataserver_train.class_labels[j], 100 * class_correct[j] / class_total[j]))
            print("\n\n\nOverall accuracy of the network on the 1000 test images: %d %%" % 
                                                                   (100 * correct / float(total)))
            print("\n\nDisplaying the confusion matrix:\n")
            out_str = "                "
            for j in range(len(self.dataserver_train.class_labels)):  
                                 out_str +=  "%15s" % self.dataserver_train.class_labels[j]   
            print(out_str + "\n")
            for i,label in enumerate(self.dataserver_train.class_labels):
                out_percents = [100 * confusion_matrix[i,j] / float(class_total[i]) 
                                 for j in range(len(self.dataserver_train.class_labels))]
                out_percents = ["%.2f" % item.item() for item in out_percents]
                out_str = "%12s:  " % self.dataserver_train.class_labels[i]
                for j in range(len(self.dataserver_train.class_labels)): 
                                                       out_str +=  "%15s" % out_percents[j]
                print(out_str)



    ###%%%
    #####################################################################################################################
    #################################  Start Definition of Inner Class SemanticSegmentation  ############################

    class SemanticSegmentation(nn.Module):             
        """
        The purpose of this inner class is to be able to use the DLStudio platform for
        experiments with semantic segmentation.  At its simplest level, the purpose of
        semantic segmentation is to assign correct labels to the different objects in a
        scene, while localizing them at the same time.  At a more sophisticated level, a
        system that carries out semantic segmentation should also output a symbolic
        expression based on the objects found in the image and their spatial relationships
        with one another.

        The workhorse of this inner class is the mUNet network that is based on the UNET
        network that was first proposed by Ronneberger, Fischer and Brox in the paper
        "U-Net: Convolutional Networks for Biomedical Image Segmentation".  Their Unet
        extracts binary masks for the cell pixel blobs of interest in biomedical images.
        The output of their Unet can therefore be treated as a pixel-wise binary classifier
        at each pixel position.  The mUnet class, on the other hand, is intended for
        segmenting out multiple objects simultaneously form an image. [A weaker reason for
        "Multi" in the name of the class is that it uses skip connections not only across
        the two arms of the "U", but also also along the arms.  The skip connections in the
        original Unet are only between the two arms of the U.  In mUnet, each object type is
        assigned a separate channel in the output of the network.

        This version of DLStudio also comes with a new dataset, PurdueShapes5MultiObject,
        for experimenting with mUnet.  Each image in this dataset contains a random number
        of selections from five different shapes, with the shapes being randomly scaled,
        oriented, and located in each image.  The five different shapes are: rectangle,
        triangle, disk, oval, and star.

           Class Path:   DLStudio  ->  SemanticSegmentation
        """
        def __init__(self, dl_studio, max_num_objects, dataserver_train=None, dataserver_test=None, dataset_file_train=None, dataset_file_test=None):
            super(DLStudio.SemanticSegmentation, self).__init__()
            self.dl_studio = dl_studio
            self.max_num_objects = max_num_objects
            self.dataserver_train = dataserver_train
            self.dataserver_test = dataserver_test


        class PurdueShapes5MultiObjectDataset(torch.utils.data.Dataset):
            """
            The very first thing to note is that the images in the dataset
            PurdueShapes5MultiObjectDataset are of size 64x64.  Each image has a random
            number (up to five) of the objects drawn from the following five shapes:
            rectangle, triangle, disk, oval, and star.  Each shape is randomized with
            respect to all its parameters, including those for its scale and location in the
            image.

            Each image in the dataset is represented by two data objects, one a list and the
            other a dictionary. The list data objects consists of the following items:

                [R, G, B, mask_array, mask_val_to_bbox_map]                                   ## (A)
            
            and the other data object is a dictionary that is set to:
            
                label_map = {'rectangle':50, 
                             'triangle' :100, 
                             'disk'     :150, 
                             'oval'     :200, 
                             'star'     :250}                                                 ## (B)
            
            Note that that second data object for each image is the same, as shown above.

            In the rest of this comment block, I'll explain in greater detail the elements
            of the list in line (A) above.

            
            R,G,B:
            ------

            Each of these is a 4096-element array whose elements store the corresponding
            color values at each of the 4096 pixels in a 64x64 image.  That is, R is a list
            of 4096 integers, each between 0 and 255, for the value of the red component of
            the color at each pixel. Similarly, for G and B.
            

            mask_array:
            ----------

            The fourth item in the list shown in line (A) above is for the mask which is a
            numpy array of shape:
            
                           (5, 64, 64)
            
            It is initialized by the command:
            
                 mask_array = np.zeros((5,64,64), dtype=np.uint8)
            
            In essence, the mask_array consists of five planes, each of size 64x64.  Each
            plane of the mask array represents an object type according to the following
            shape_index
            
                    shape_index = (label_map[shape] - 50) // 50
            
            where the label_map is as shown in line (B) above.  In other words, the
            shape_index values for the different shapes are:
            
                     rectangle:  0
                      triangle:  1
                          disk:  2
                          oval:  3
                          star:  4
            
            Therefore, the first layer (of index 0) of the mask is where the pixel values of
            50 are stored at all those pixels that belong to the rectangle shapes.
            Similarly, the second mask layer (of index 1) is where the pixel values of 100
            are stored at all those pixel coordinates that belong to the triangle shapes in
            an image; and so on.
            
            It is in the manner described above that we define five different masks for an
            image in the dataset.  Each mask is for a different shape and the pixel values
            at the nonzero pixels in each mask layer are keyed to the shapes also.
            
            A reader is likely to wonder as to the need for this redundancy in the dataset
            representation of the shapes in each image.  Such a reader is likely to ask: Why
            can't we just use the binary values 1s and 0s in each mask layer where the
            corresponding pixels are in the image?  Setting these mask values to 50, 100,
            etc., was done merely for convenience.  I went with the intuition that the
            learning needed for multi-object segmentation would become easier if each shape
            was represented by a different pixels value in the corresponding mask. So I went
            ahead incorporated that in the dataset generation program itself.

            The mask values for the shapes are not to be confused with the actual RGB values
            of the pixels that belong to the shapes. The RGB values at the pixels in a shape
            are randomly generated.  Yes, all the pixels in a shape instance in an image
            have the same RGB values (but that value has nothing to do with the values given
            to the mask pixels for that shape).
            
            
            mask_val_to_bbox_map:
            --------------------
                   
            The fifth item in the list in line (A) above is a dictionary that tells us what
            bounding-box rectangle to associate with each shape in the image.  To illustrate
            what this dictionary looks like, assume that an image contains only one
            rectangle and only one disk, the dictionary in this case will look like:
            
                mask values to bbox mappings:  {200: [], 
                                                250: [], 
                                                100: [], 
                                                 50: [[56, 20, 63, 25]], 
                                                150: [[37, 41, 55, 59]]}
            
            Should there happen to be two rectangles in the same image, the dictionary would
            then be like:
            
                mask values to bbox mappings:  {200: [], 
                                                250: [], 
                                                100: [], 
                                                 50: [[56, 20, 63, 25], [18, 16, 32, 36]], 
                                                150: [[37, 41, 55, 59]]}
            
            Therefore, it is not a problem even if all the objects in an image are of the
            same type.  Remember, the object that are selected for an image are shown
            randomly from the different shapes.  By the way, an entry like '[56, 20, 63,
            25]' for the bounding box means that the upper-left corner of the BBox for the
            'rectangle' shape is at (56,20) and the lower-right corner of the same is at the
            pixel coordinates (63,25).
            
            As far as the BBox quadruples are concerned, in the definition
            
                    [min_x,min_y,max_x,max_y]
            
            note that x is the horizontal coordinate, increasing to the right on your
            screen, and y is the vertical coordinate increasing downwards.

            Class Path:   DLStudio  ->  SemanticSegmentation  ->  PurdueShapes5MultiObjectDataset

            """
            def __init__(self, dl_studio, segmenter, train_or_test, dataset_file):
                super(DLStudio.SemanticSegmentation.PurdueShapes5MultiObjectDataset, self).__init__()
                max_num_objects = segmenter.max_num_objects
                if train_or_test == 'train' and dataset_file == "PurdueShapes5MultiObject-10000-train.gz":
                    if os.path.exists("torch_saved_PurdueShapes5MultiObject-10000_dataset.pt") and \
                              os.path.exists("torch_saved_PurdueShapes5MultiObject_label_map.pt"):
                        print("\nLoading training data from torch saved file")
                        self.dataset = torch.load("torch_saved_PurdueShapes5MultiObject-10000_dataset.pt")
                        self.label_map = torch.load("torch_saved_PurdueShapes5MultiObject_label_map.pt")
                        self.num_shapes = len(self.label_map)
                        self.image_size = dl_studio.image_size
                    else: 
                        print("""\n\n\nLooks like this is the first time you will be loading in\n"""
                              """the dataset for this script. First time loading could take\n"""
                              """a few minutes.  Any subsequent attempts will only take\n"""
                              """a few seconds.\n\n\n""")
                        root_dir = dl_studio.dataroot
                        f = gzip.open(root_dir + dataset_file, 'rb')
                        dataset = f.read()
                        self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
                        torch.save(self.dataset, "torch_saved_PurdueShapes5MultiObject-10000_dataset.pt")
                        torch.save(self.label_map, "torch_saved_PurdueShapes5MultiObject_label_map.pt")
                        # reverse the key-value pairs in the label dictionary:
                        self.class_labels = dict(map(reversed, self.label_map.items()))
                        self.num_shapes = len(self.class_labels)
                        self.image_size = dl_studio.image_size
                else:
                    root_dir = dl_studio.dataroot
                    f = gzip.open(root_dir + dataset_file, 'rb')
                    dataset = f.read()
                    if sys.version_info[0] == 3:
                        self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
                    else:
                        self.dataset, self.label_map = pickle.loads(dataset)
                    # reverse the key-value pairs in the label dictionary:
                    self.class_labels = dict(map(reversed, self.label_map.items()))
                    self.num_shapes = len(self.class_labels)
                    self.image_size = dl_studio.image_size

            def __len__(self):
                return len(self.dataset)

            def __getitem__(self, idx):
                image_size = self.image_size
                r = np.array( self.dataset[idx][0] )
                g = np.array( self.dataset[idx][1] )
                b = np.array( self.dataset[idx][2] )
                R,G,B = r.reshape(image_size[0],image_size[1]), g.reshape(image_size[0],image_size[1]), b.reshape(image_size[0],image_size[1])
                im_tensor = torch.zeros(3,image_size[0],image_size[1], dtype=torch.float)
                im_tensor[0,:,:] = torch.from_numpy(R)
                im_tensor[1,:,:] = torch.from_numpy(G)
                im_tensor[2,:,:] = torch.from_numpy(B)
                mask_array = np.array(self.dataset[idx][3])
                max_num_objects = len( mask_array[0] ) 
                mask_tensor = torch.from_numpy(mask_array)
                mask_val_to_bbox_map =  self.dataset[idx][4]
                max_bboxes_per_entry_in_map = max([ len(mask_val_to_bbox_map[key]) for key in mask_val_to_bbox_map ])
                ##  The first arg 5 is for the number of bboxes we are going to need. If all the
                ##  shapes are exactly the same, you are going to need five different bbox'es.
                ##  The second arg is the index reserved for each shape in a single bbox
                bbox_tensor = torch.zeros(max_num_objects,self.num_shapes,4, dtype=torch.float)
                for bbox_idx in range(max_bboxes_per_entry_in_map):
                    for key in mask_val_to_bbox_map:
                        if len(mask_val_to_bbox_map[key]) == 1:
                            if bbox_idx == 0:
                                bbox_tensor[bbox_idx,key,:] = torch.from_numpy(np.array(mask_val_to_bbox_map[key][bbox_idx]))
                        elif len(mask_val_to_bbox_map[key]) > 1 and bbox_idx < len(mask_val_to_bbox_map[key]):
                            bbox_tensor[bbox_idx,key,:] = torch.from_numpy(np.array(mask_val_to_bbox_map[key][bbox_idx]))
                sample = {'image'        : im_tensor, 
                          'mask_tensor'  : mask_tensor,
                          'bbox_tensor'  : bbox_tensor }
                return sample

        def load_PurdueShapes5MultiObject_dataset(self, dataserver_train, dataserver_test ):   
            self.train_dataloader = torch.utils.data.DataLoader(dataserver_train,
                        batch_size=self.dl_studio.batch_size,shuffle=True, num_workers=4)
            self.test_dataloader = torch.utils.data.DataLoader(dataserver_test,
                               batch_size=self.dl_studio.batch_size,shuffle=False, num_workers=4)


        class SkipBlockDN(nn.Module):
            """
            This class for the skip connections in the downward leg of the "U"

            Class Path:   DLStudio  ->  SemanticSegmentation  ->  SkipBlockDN
            """
            def __init__(self, in_ch, out_ch, downsample=False, skip_connections=True):
                super(DLStudio.SemanticSegmentation.SkipBlockDN, self).__init__()
                self.downsample = downsample
                self.skip_connections = skip_connections
                self.in_ch = in_ch
                self.out_ch = out_ch
                self.convo1 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
                self.convo2 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
                self.bn1 = nn.BatchNorm2d(out_ch)
                self.bn2 = nn.BatchNorm2d(out_ch)
                if downsample:
                    self.downsampler = nn.Conv2d(in_ch, out_ch, 1, stride=2)
            def forward(self, x):
                identity = x                                     
                out = self.convo1(x)                              
                out = self.bn1(out)                              
                out = nn.functional.relu(out)
                if self.in_ch == self.out_ch:
                    out = self.convo2(out)                              
                    out = self.bn2(out)                              
                    out = nn.functional.relu(out)
                if self.downsample:
                    out = self.downsampler(out)
                    identity = self.downsampler(identity)
                if self.skip_connections:
                    if self.in_ch == self.out_ch:
                        out = out + identity
                    else:
                        out = out + torch.cat((identity, identity), dim=1) 
                return out


        class SkipBlockUP(nn.Module):
            """
            This class is for the skip connections in the upward leg of the "U"

            Class Path:   DLStudio  ->  SemanticSegmentation  ->  SkipBlockUP
            """
            def __init__(self, in_ch, out_ch, upsample=False, skip_connections=True):
                super(DLStudio.SemanticSegmentation.SkipBlockUP, self).__init__()
                self.upsample = upsample
                self.skip_connections = skip_connections
                self.in_ch = in_ch
                self.out_ch = out_ch
                self.convoT1 = nn.ConvTranspose2d(in_ch, out_ch, 3, padding=1)
                self.convoT2 = nn.ConvTranspose2d(in_ch, out_ch, 3, padding=1)
                self.bn1 = nn.BatchNorm2d(out_ch)
                self.bn2 = nn.BatchNorm2d(out_ch)
                if upsample:
                    self.upsampler = nn.ConvTranspose2d(in_ch, out_ch, 1, stride=2, dilation=2, output_padding=1, padding=0)
            def forward(self, x):
                identity = x                                     
                out = self.convoT1(x)                              
                out = self.bn1(out)                              
                out = nn.functional.relu(out)
                out  =  nn.ReLU(inplace=False)(out)            
                if self.in_ch == self.out_ch:
                    out = self.convoT2(out)                              
                    out = self.bn2(out)                              
                    out = nn.functional.relu(out)
                if self.upsample:
                    out = self.upsampler(out)
                    identity = self.upsampler(identity)
                if self.skip_connections:
                    if self.in_ch == self.out_ch:
                        out = out + identity                              
                    else:
                        out = out + identity[:,self.out_ch:,:,:]
                return out
        

        class mUNet(nn.Module):
            """
            This network is called mUNet because it is intended for segmenting out
            multiple objects simultaneously form an image. [A weaker reason for "Multi" in
            the name of the class is that it uses skip connections not only across the two
            arms of the "U", but also also along the arms.]  The classic UNET was first
            proposed by Ronneberger, Fischer and Brox in the paper "U-Net: Convolutional
            Networks for Biomedical Image Segmentation".  Their UNET extracts binary masks
            for the cell pixel blobs of interest in biomedical images.  The output of their
            UNET therefore can therefore be treated as a pixel-wise binary classifier at
            each pixel position.

            The mUNet presented here, on the other hand, is meant specifically for
            simultaneously identifying and localizing multiple objects in a given image.
            Each object type is assigned a separate channel in the output of the network.

            I have created a dataset, PurdueShapes5MultiObject, for experimenting with
            mUNet.  Each image in this dataset contains a random number of selections from
            five different shapes, with the shapes being randomly scaled, oriented, and
            located in each image.  The five different shapes are: rectangle, triangle,
            disk, oval, and star.

            Class Path:   DLStudio  ->  SemanticSegmentation  ->  mUNet

            """ 
            def __init__(self, skip_connections=True, depth=16):
                super(DLStudio.SemanticSegmentation.mUNet, self).__init__()
                self.depth = depth // 2
                self.conv_in = nn.Conv2d(3, 64, 3, padding=1)
                ##  For the DN arm of the U:
                self.bn1DN  = nn.BatchNorm2d(64)
                self.bn2DN  = nn.BatchNorm2d(128)
                self.skip64DN_arr = nn.ModuleList()
                for i in range(self.depth):
                    self.skip64DN_arr.append(DLStudio.SemanticSegmentation.SkipBlockDN(64, 64, skip_connections=skip_connections))
                self.skip64dsDN = DLStudio.SemanticSegmentation.SkipBlockDN(64, 64,   downsample=True, skip_connections=skip_connections)
                self.skip64to128DN = DLStudio.SemanticSegmentation.SkipBlockDN(64, 128, skip_connections=skip_connections )
                self.skip128DN_arr = nn.ModuleList()
                for i in range(self.depth):
                    self.skip128DN_arr.append(DLStudio.SemanticSegmentation.SkipBlockDN(128, 128, skip_connections=skip_connections))
                self.skip128dsDN = DLStudio.SemanticSegmentation.SkipBlockDN(128,128, downsample=True, skip_connections=skip_connections)
                ##  For the UP arm of the U:
                self.bn1UP  = nn.BatchNorm2d(128)
                self.bn2UP  = nn.BatchNorm2d(64)
                self.skip64UP_arr = nn.ModuleList()
                for i in range(self.depth):
                    self.skip64UP_arr.append(DLStudio.SemanticSegmentation.SkipBlockUP(64, 64, skip_connections=skip_connections))
                self.skip64usUP = DLStudio.SemanticSegmentation.SkipBlockUP(64, 64, upsample=True, skip_connections=skip_connections)
                self.skip128to64UP = DLStudio.SemanticSegmentation.SkipBlockUP(128, 64, skip_connections=skip_connections )
                self.skip128UP_arr = nn.ModuleList()
                for i in range(self.depth):
                    self.skip128UP_arr.append(DLStudio.SemanticSegmentation.SkipBlockUP(128, 128, skip_connections=skip_connections))
                self.skip128usUP = DLStudio.SemanticSegmentation.SkipBlockUP(128,128, upsample=True, skip_connections=skip_connections)
                self.conv_out = nn.ConvTranspose2d(64, 5, 3, stride=2,dilation=2,output_padding=1,padding=2)

            def forward(self, x):
                ##  Going down to the bottom of the U:
                x = nn.MaxPool2d(2,2)(nn.functional.relu(self.conv_in(x)))          
                for i,skip64 in enumerate(self.skip64DN_arr[:self.depth//4]):
                    x = skip64(x)                
        
                num_channels_to_save1 = x.shape[1] // 2
                save_for_upside_1 = x[:,:num_channels_to_save1,:,:].clone()
                x = self.skip64dsDN(x)
                for i,skip64 in enumerate(self.skip64DN_arr[self.depth//4:]):
                    x = skip64(x)                
                x = self.bn1DN(x)
                num_channels_to_save2 = x.shape[1] // 2
                save_for_upside_2 = x[:,:num_channels_to_save2,:,:].clone()
                x = self.skip64to128DN(x)
                for i,skip128 in enumerate(self.skip128DN_arr[:self.depth//4]):
                    x = skip128(x)                
        
                x = self.bn2DN(x)
                num_channels_to_save3 = x.shape[1] // 2
                save_for_upside_3 = x[:,:num_channels_to_save3,:,:].clone()
                for i,skip128 in enumerate(self.skip128DN_arr[self.depth//4:]):
                    x = skip128(x)                
                x = self.skip128dsDN(x)
                ## Coming up from the bottom of U on the other side:
                x = self.skip128usUP(x)          
                for i,skip128 in enumerate(self.skip128UP_arr[:self.depth//4]):
                    x = skip128(x)                
                x[:,:num_channels_to_save3,:,:] =  save_for_upside_3
                x = self.bn1UP(x)
                for i,skip128 in enumerate(self.skip128UP_arr[:self.depth//4]):
                    x = skip128(x)                
                x = self.skip128to64UP(x)
                for i,skip64 in enumerate(self.skip64UP_arr[self.depth//4:]):
                    x = skip64(x)                
                x[:,:num_channels_to_save2,:,:] =  save_for_upside_2
                x = self.bn2UP(x)
                x = self.skip64usUP(x)
                for i,skip64 in enumerate(self.skip64UP_arr[:self.depth//4]):
                    x = skip64(x)                
                x[:,:num_channels_to_save1,:,:] =  save_for_upside_1
                x = self.conv_out(x)
                return x
        

        class SegmentationLoss(nn.Module):
            """
            I wrote this class before I switched to MSE loss.  I am leaving it here
            in case I need to get back to it in the future.  

            Class Path:   DLStudio  ->  SemanticSegmentation  ->  SegmentationLoss
            """
            def __init__(self, batch_size):
                super(DLStudio.SemanticSegmentation.SegmentationLoss, self).__init__()
                self.batch_size = batch_size
            def forward(self, output, mask_tensor):
                composite_loss = torch.zeros(1,self.batch_size)
                mask_based_loss = torch.zeros(1,5)
                for idx in range(self.batch_size):
                    outputh = output[idx,0,:,:]
                    for mask_layer_idx in range(mask_tensor.shape[0]):
                        mask = mask_tensor[idx,mask_layer_idx,:,:]
                        element_wise = (outputh - mask)**2                   
                        mask_based_loss[0,mask_layer_idx] = torch.mean(element_wise)
                    composite_loss[0,idx] = torch.sum(mask_based_loss)
                return torch.sum(composite_loss) / self.batch_size


        def run_code_for_training_for_semantic_segmentation(self, net):        
            filename_for_out1 = "performance_numbers_" + str(self.dl_studio.epochs) + ".txt"
            FILE1 = open(filename_for_out1, 'w')
            net = copy.deepcopy(net)
            net = net.to(self.dl_studio.device)
            criterion1 = nn.MSELoss()
            optimizer = optim.SGD(net.parameters(), 
                         lr=self.dl_studio.learning_rate, momentum=self.dl_studio.momentum)
            start_time = time.perf_counter()
            for epoch in range(self.dl_studio.epochs):  
                print("")
                running_loss_segmentation = 0.0
                for i, data in enumerate(self.train_dataloader):    
                    im_tensor,mask_tensor,bbox_tensor =data['image'],data['mask_tensor'],data['bbox_tensor']
                    im_tensor   = im_tensor.to(self.dl_studio.device)
                    mask_tensor = mask_tensor.type(torch.FloatTensor)
                    mask_tensor = mask_tensor.to(self.dl_studio.device)                 
                    bbox_tensor = bbox_tensor.to(self.dl_studio.device)
                    optimizer.zero_grad()
                    output = net(im_tensor) 
                    segmentation_loss = criterion1(output, mask_tensor)  
                    segmentation_loss.backward()
                    optimizer.step()
                    running_loss_segmentation += segmentation_loss.item()    
                    if i%500==499:    
                        current_time = time.perf_counter()
                        elapsed_time = current_time - start_time
                        avg_loss_segmentation = running_loss_segmentation / float(500)
                        print("[epoch=%d/%d, iter=%4d  elapsed_time=%3d secs]   MSE loss: %.3f" % (epoch+1, self.dl_studio.epochs, i+1, elapsed_time, avg_loss_segmentation))
                        FILE1.write("%.3f\n" % avg_loss_segmentation)
                        FILE1.flush()
                        running_loss_segmentation = 0.0
            print("\nFinished Training\n")
            self.save_model(net)


        def save_model(self, model):
            '''
            Save the trained model to a disk file
            '''
            torch.save(model.state_dict(), self.dl_studio.path_saved_model)


        def run_code_for_testing_semantic_segmentation(self, net):
            net.load_state_dict(torch.load(self.dl_studio.path_saved_model))
            batch_size = self.dl_studio.batch_size
            image_size = self.dl_studio.image_size
            max_num_objects = self.max_num_objects
            with torch.no_grad():
                for i, data in enumerate(self.test_dataloader):
                    im_tensor,mask_tensor,bbox_tensor =data['image'],data['mask_tensor'],data['bbox_tensor']
                    if i % 50 == 0:
                        print("\n\n\n\nShowing output for test batch %d: " % (i+1))
                        outputs = net(im_tensor)                        
                        ## In the statement below: 1st arg for batch items, 2nd for channels, 3rd and 4th for image size
                        output_bw_tensor = torch.zeros(batch_size,1,image_size[0],image_size[1], dtype=float)
                        for image_idx in range(batch_size):
                            for layer_idx in range(max_num_objects): 
                                for m in range(image_size[0]):
                                    for n in range(image_size[1]):
                                        output_bw_tensor[image_idx,0,m,n]  =  torch.max( outputs[image_idx,:,m,n] )
                        display_tensor = torch.zeros(7 * batch_size,3,image_size[0],image_size[1], dtype=float)
                        for idx in range(batch_size):
                            for bbox_idx in range(max_num_objects):   
                                bb_tensor = bbox_tensor[idx,bbox_idx]
                                for k in range(max_num_objects):
                                    i1 = int(bb_tensor[k][1])
                                    i2 = int(bb_tensor[k][3])
                                    j1 = int(bb_tensor[k][0])
                                    j2 = int(bb_tensor[k][2])
                                    output_bw_tensor[idx,0,i1:i2,j1] = 255
                                    output_bw_tensor[idx,0,i1:i2,j2] = 255
                                    output_bw_tensor[idx,0,i1,j1:j2] = 255
                                    output_bw_tensor[idx,0,i2,j1:j2] = 255
                                    im_tensor[idx,0,i1:i2,j1] = 255
                                    im_tensor[idx,0,i1:i2,j2] = 255
                                    im_tensor[idx,0,i1,j1:j2] = 255
                                    im_tensor[idx,0,i2,j1:j2] = 255
                        display_tensor[:batch_size,:,:,:] = output_bw_tensor
                        display_tensor[batch_size:2*batch_size,:,:,:] = im_tensor

                        for batch_im_idx in range(batch_size):
                            for mask_layer_idx in range(max_num_objects):
                                for i in range(image_size[0]):
                                    for j in range(image_size[1]):
                                        if mask_layer_idx == 0:
                                            if 25 < outputs[batch_im_idx,mask_layer_idx,i,j] < 85:
                                                outputs[batch_im_idx,mask_layer_idx,i,j] = 255
                                            else:
                                                outputs[batch_im_idx,mask_layer_idx,i,j] = 50
                                        elif mask_layer_idx == 1:
                                            if 65 < outputs[batch_im_idx,mask_layer_idx,i,j] < 135:
                                                outputs[batch_im_idx,mask_layer_idx,i,j] = 255
                                            else:
                                                outputs[batch_im_idx,mask_layer_idx,i,j] = 50
                                        elif mask_layer_idx == 2:
                                            if 115 < outputs[batch_im_idx,mask_layer_idx,i,j] < 185:
                                                outputs[batch_im_idx,mask_layer_idx,i,j] = 255
                                            else:
                                                outputs[batch_im_idx,mask_layer_idx,i,j] = 50
                                        elif mask_layer_idx == 3:
                                            if 165 < outputs[batch_im_idx,mask_layer_idx,i,j] < 230:
                                                outputs[batch_im_idx,mask_layer_idx,i,j] = 255
                                            else:
                                                outputs[batch_im_idx,mask_layer_idx,i,j] = 50
                                        elif mask_layer_idx == 4:
                                            if outputs[batch_im_idx,mask_layer_idx,i,j] > 210:
                                                outputs[batch_im_idx,mask_layer_idx,i,j] = 255
                                            else:
                                                outputs[batch_im_idx,mask_layer_idx,i,j] = 50

                                display_tensor[2*batch_size+batch_size*mask_layer_idx+batch_im_idx,:,:,:]= outputs[batch_im_idx,mask_layer_idx,:,:]
                        self.dl_studio.display_tensor_as_image(
                           torchvision.utils.make_grid(display_tensor, nrow=batch_size, normalize=True, padding=2, pad_value=10))




    ###%%%
    #####################################################################################################################
    ####################################  Start Definition of Inner Class Autoencoder  ##################################

    class Autoencoder(nn.Module):             
        """
         The man reason for the existence of this inner class in DLStudio is for it to serve as the base class for VAE 
         (Variational Auto-Encoder).  That way, the VAE class can focus exclusively on the random-sampling logic 
         specific to variational encoding while the base class Autoencoder does the convolutional and 
         transpose-convolutional heavy lifting associated with the usual encoding-decoding of image data.

        Class Path:   DLStudio  ->  Autoencoder
        """
        def __init__(self, dl_studio, encoder_in_im_size, encoder_out_im_size, decoder_out_im_size, encoder_out_ch, num_repeats, path_saved_model):
            super(DLStudio.Autoencoder, self).__init__()
            self.dl_studio = dl_studio
            ## The parameter num_repeats is how many times you want to repeat in the Encoder the SkipBlock that has the same number of
            ## channels at the input and the output (See the code for EncoderForAutoenc):
            self.encoder  =  DLStudio.Autoencoder.EncoderForAutoenc( dl_studio, encoder_in_im_size, encoder_out_im_size, encoder_out_ch, num_repeats)
            decoder_in_im_size = encoder_out_im_size
            self.decoder  =  DLStudio.Autoencoder.DecoderForAutoenc( dl_studio, decoder_in_im_size, decoder_out_im_size )
            self.path_saved_model = path_saved_model

        def forward(self, x):   
            x =  self.encoder(x)                                                                                             
            x =  self.decoder(x)                                                                                             
            return x             


        class EncoderForAutoenc(nn.Module):
            """
            The two main components of an Autoencoder are the encoder and the decoder. This is the encoder part of the 
             Autoencoder.

            The parameter num_repeats is how many times you want to repeat in the Encoder the SkipBlock that has the same 
            number of channels at the input and the output.

            Class Path:   DLStudio  ->  Autoencoder  ->  EncoderForAutoenc
            """ 
            def __init__(self, dl_studio, encoder_in_im_size, encoder_out_im_size, encoder_out_ch, num_repeats, skip_connections=True):
                super(DLStudio.Autoencoder.EncoderForAutoenc, self).__init__()
                downsampling_ratio =  encoder_in_im_size[0] // encoder_out_im_size[0]
                num_downsamples =  downsampling_ratio // 2
                assert( num_downsamples == 1 or num_downsamples == 2 or num_downsamples == 4 )
                self.depth = num_downsamples
                self.encoder_out_im_size = encoder_out_im_size
                self.encoder_out_ch = encoder_out_ch
                self.conv_in = nn.Conv2d(3, 64, 3, padding=1)
                self.bn1DN  = nn.BatchNorm2d(64)
                self.bn2DN  = nn.BatchNorm2d(128)
                self.bn3DN  = nn.BatchNorm2d(256)
                self.skip_arr = nn.ModuleList()
                if self.depth == 1:
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(64, 64, downsample=True, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(64, 128, downsample=False, skip_connections=skip_connections))
                    for _ in range(num_repeats):
                        self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(128, 128, downsample=False, skip_connections=skip_connections))
                elif self.depth == 2:
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(64, 64, downsample=True, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(64, 128, downsample=False, skip_connections=skip_connections))
                    for _ in range(num_repeats):
                        self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(128, 128, downsample=False, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(128, 128, downsample=True, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(128, 256, downsample=False, skip_connections=skip_connections))
                    for _ in range(num_repeats):
                        self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(256, 256, downsample=False, skip_connections=skip_connections))
                elif self.depth == 4:
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(64, 64, downsample=True, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(64, 128, downsample=False, skip_connections=skip_connections))
                    for _ in range(num_repeats):
                        self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(128, 128, downsample=False, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(128, 128, downsample=True, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(128, 256, downsample=False, skip_connections=skip_connections))
                    for _ in range(num_repeats):
                        self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(256, 256, downsample=False, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(256, 256, downsample=True, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(256, 512, downsample=False, skip_connections=skip_connections))
                    for _ in range(num_repeats):
                        self.skip_arr.append(DLStudio.Autoencoder.SkipBlockEncoder(512, 512, downsample=False, skip_connections=skip_connections))
                self.skip128DN = DLStudio.Autoencoder.SkipBlockEncoder(128,128, skip_connections=skip_connections)

            def forward(self, x):
                x = nn.functional.relu(self.conv_in(x))          
                for layer in self.skip_arr:
                    x = layer(x)
                if (x.shape[2:] != self.encoder_out_im_size) or (x.shape[1] != self.encoder_out_ch):
                    print("\n\nShape of x at output of Encoder: ", x.shape) 
                    sys.exit("\n\nThe Encoder part of the Autoencoder is misconfigured. Encoder output not according to specs\n\n")
                return x


        class DecoderForAutoenc(nn.Module):
            """
            The two main components of an Autoencoder are the encoder and the decoder.
            This is the decoder part of the Autoencoder.

            This Decoder uses bilinear interpolation for final upsampling.           XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX

            The next Decoder that follows is based on using nn.ConvTranspose2d for upsampling.

            Class Path:   DLStudio  ->  Autoencoder  ->  DecoderForAutoenc
            """ 
            def __init__(self, dl_studio, decoder_in_im_size, decoder_out_im_size, skip_connections=True):
                super(DLStudio.Autoencoder.DecoderForAutoenc, self).__init__()
                upsampling_ratio =  decoder_out_im_size[0] // decoder_in_im_size[0]
                num_upsamples =  upsampling_ratio // 2
                assert( num_upsamples == 1 or num_upsamples == 2 or num_upsamples == 4)
                self.depth = num_upsamples
                self.decoder_out_im_size = decoder_out_im_size
                self.conv_in = nn.Conv2d(3, 64, 3, padding=1)
                self.bn1DN  = nn.BatchNorm2d(64)
                self.bn2DN  = nn.BatchNorm2d(128)
                self.bn3DN  = nn.BatchNorm2d(256)
                self.skip_arr = nn.ModuleList()
                if self.depth == 1:
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(128, 64, upsample=False, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(64, 64, upsample=True, skip_connections=skip_connections))
                elif self.depth == 2:
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(256, 128, upsample=False, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(128, 128, upsample=True, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(128, 64, upsample=False, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(64, 64, upsample=True, skip_connections=skip_connections))
                elif self.depth == 4:
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(512, 256, upsample=False, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(256, 256, upsample=True, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(256, 128, upsample=False, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(128, 128, upsample=True, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(128, 64, upsample=False, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(64, 64, upsample=True, skip_connections=skip_connections))
                self.bn1UP  = nn.BatchNorm2d(256)
                self.bn2UP  = nn.BatchNorm2d(128)
                self.bn3UP  = nn.BatchNorm2d(64)
#                self.conv_out3 = nn.ConvTranspose2d(64,3, 3, stride=1,dilation=1,output_padding=0,padding=1)
                self.conv_out3 = nn.Conv2d(64,3,3,padding=1)

            def forward(self, x):
                for layer in self.skip_arr:
                    x = layer(x)
                x = self.conv_out3(x)
                if x.shape[2:] != self.decoder_out_im_size:
                    print("\n\nShape of x at output of Decoder: ", x.shape) 
                    sys.exit("\n\nThe Decoder part of the Autoencoder is misconfigured. Output image not according to specs\n\n")
                return x


        class DecoderForAutoenc_CT(nn.Module):
            """
            The "_CT" in the name of this class signifies that this Decoder uses ConvTranspose layers for upsampling.

            Note that using nn.ConvTranspose for upsampling may introducing gridding artifacts in the output images.

            Class Path:   DLStudio  ->  Autoencoder  ->  DecoderForAutoenc_CT
            """ 
            def __init__(self, dl_studio, decoder_in_im_size, decoder_out_im_size, skip_connections=True):
                super(DLStudio.Autoencoder.DecoderForAutoenc_CT, self).__init__()
                upsampling_ratio =  decoder_out_im_size[0] // decoder_in_im_size[0]
                num_upsamples =  upsampling_ratio // 2
                assert( num_upsamples == 1 or num_upsamples == 2 or num_upsamples == 4)
                self.depth = num_upsamples
                self.decoder_out_im_size = decoder_out_im_size
                self.conv_in = nn.Conv2d(3, 64, 3, padding=1)
                self.bn1DN  = nn.BatchNorm2d(64)
                self.bn2DN  = nn.BatchNorm2d(128)
                self.bn3DN  = nn.BatchNorm2d(256)
                self.skip_arr = nn.ModuleList()
                if self.depth == 1:
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(128, 64, upsample=False, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(64, 64, upsample=True, skip_connections=skip_connections))
                elif self.depth == 2:
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(256, 128, upsample=False, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(128, 128, upsample=True, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(128, 64, upsample=False, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(64, 64, upsample=True, skip_connections=skip_connections))
                elif self.depth == 4:
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(512, 256, upsample=False, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(256, 256, upsample=True, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(256, 128, upsample=False, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(128, 128, upsample=True, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(128, 64, upsample=False, skip_connections=skip_connections))
                    self.skip_arr.append(DLStudio.Autoencoder.SkipBlockDecoder(64, 64, upsample=True, skip_connections=skip_connections))
                self.bn1UP  = nn.BatchNorm2d(256)
                self.bn2UP  = nn.BatchNorm2d(128)
                self.bn3UP  = nn.BatchNorm2d(64)
                self.conv_out3 = nn.ConvTranspose2d(64,3, 3, stride=1,dilation=1,output_padding=0,padding=1)

            def forward(self, x):
                for layer in self.skip_arr:
                    x = layer(x)
                x = self.conv_out3(x)
                if x.shape[2:] != self.decoder_out_im_size:
                    print("\n\nShape of x at output of Decoder: ", x.shape) 
                    sys.exit("\n\nThe Decoder part of the Autoencoder is misconfigured. Output image not according to specs\n\n")
                return x


        def save_autoencoder_model(self, model):
            '''
            Save the trained model to a disk file
            '''
            torch.save(model.state_dict(), self.path_saved_model)


        def run_code_for_training_autoencoder( self, display_train_loss=False ):
      
            autoencoder = self.to(self.dl_studio.device)
            criterion = nn.MSELoss()
            optimizer = optim.Adam(autoencoder.parameters(), lr=self.dl_studio.learning_rate)
            accum_times = []
            start_time = time.perf_counter()
            training_loss_tally = []
            print("")
            batch_size = self.dl_studio.batch_size
            print("\n\n batch_size: ", batch_size)
            print("\n\n number of batches in the dataset: ", len(self.train_dataloader))
    
            for epoch in range(self.dl_studio.epochs):                                                              
                print("")
                running_loss = 0.0
                for i, data in enumerate(self.train_dataloader):                                    
                    input_images, _ = data       
                    input_images = input_images.to(self.dl_studio.device)
                    optimizer.zero_grad()
                    autoencoder_output = autoencoder( input_images )
                    loss  =  criterion( autoencoder_output, input_images )
                    loss.backward()                                                                                        
                    optimizer.step()
                    running_loss += loss
                    if i % 200 == 199:    
                        avg_loss = running_loss / float(200)
                        training_loss_tally.append(avg_loss.item())
                        running_loss = 0.0
                        current_time = time.perf_counter()
                        time_elapsed = current_time-start_time
                        print("[epoch:%2d/%2d  i:%4d  elapsed_time: %4d secs]     loss: %.4f" % (epoch+1, self.dl_studio.epochs, i+1,time_elapsed,avg_loss)) 
                        accum_times.append(current_time-start_time)
            print("\nFinished Training\n")
            self.save_autoencoder_model( autoencoder )
            if display_train_loss:
                plt.figure(figsize=(10,5))
                plt.title("Training Loss vs. Iterations")
                plt.plot(training_loss_tally)
                plt.xlabel("iterations")
                plt.ylabel("training loss")
                plt.legend(["Plot of loss versus iterations"], fontsize="x-large")
                plt.savefig("training_loss.png")
                plt.show()
    


        def run_code_for_evaluating_autoencoder(self, visualization_dir = "autoencoder_visualization_dir" ):
    
            if os.path.exists(visualization_dir):  
                """
                Clear out the previous outputs in the visualization directory
                """
                files = glob.glob(visualization_dir + "/*")
                for file in files: 
                    if os.path.isfile(file): 
                        os.remove(file) 
                    else: 
                        files = glob.glob(file + "/*") 
                        list(map(lambda x: os.remove(x), files)) 
            else: 
                os.mkdir(visualization_dir)   
            autoencoder = self
            autoencoder.load_state_dict(torch.load(self.path_saved_model))
            autoencoder.to(self.dl_studio.device)
            with torch.no_grad():
                for i, data in enumerate(self.test_dataloader):                                    
                    print("\n\n\n=========Showing results for test batch %d===============" % i)
                    test_images, _ = data     

                    test_images = test_images.to(self.dl_studio.device)
                    autoencoder_output = autoencoder( test_images )
                    autoencoder_output  =  ( autoencoder_output - autoencoder_output.min() ) / ( autoencoder_output.max() -  autoencoder_output.min() )
                    together = torch.zeros( test_images.shape[0], test_images.shape[1], test_images.shape[2], 2 * test_images.shape[3], dtype=torch.float )
                    together[:,:,:,0:test_images.shape[3]]  =  test_images
                    together[:,:,:,test_images.shape[3]:]  =   autoencoder_output 

                    plt.figure(figsize=(40,20))
                    plt.imshow(np.transpose(torchvision.utils.make_grid(together.cpu(), normalize=False, padding=3, pad_value=255).cpu(), (1,2,0)))
                    plt.title("Autoencoder Output Images for iteration %d" % i)
                    plt.savefig(visualization_dir + "/autoenc_output_%s" % str(i) + ".png")
                    plt.show()

    

        class SkipBlockEncoder(nn.Module):
            """
            This is a building-block class for the skip connections in EncoderForAutoenc

            Class Path:   DLStudio  ->  Autoencoder  ->  SkipBlockEncoder
            """
            def __init__(self, in_ch, out_ch, downsample=False, skip_connections=True):
                super(DLStudio.Autoencoder.SkipBlockEncoder, self).__init__()
                self.downsample = downsample
                self.skip_connections = skip_connections
                self.in_ch = in_ch
                self.out_ch = out_ch
                self.convo1 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
                self.convo2 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
#                self.bn1 = nn.BatchNorm2d(out_ch)
                self.gn1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6, affine=True)   
#                self.bn2 = nn.BatchNorm2d(out_ch)
                self.gn2 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6, affine=True)   
                if downsample:
                    self.downsampler = nn.Conv2d(in_ch, out_ch, 1, stride=2)

            def forward(self, x):
                identity = x                                     
                out = self.convo1(x)                              
#                out = self.bn1(out)                              
                out = self.gn1(out)                              
                out = nn.functional.relu(out)
                if self.in_ch == self.out_ch:
                    out = self.convo2(out)                              
#                    out = self.bn2(out)                              
                    out = self.gn2(out)                              
                    out = nn.functional.relu(out)
                if self.downsample:
                    out = self.downsampler(out)
                    identity = self.downsampler(identity)
                if self.skip_connections:
                    if self.in_ch == self.out_ch:
                        out = out + identity
                    else:
                        out = out + torch.cat((identity, identity), dim=1) 
                return out

        class SkipBlockDecoder(nn.Module):
            """
            This is a building-block class for the skip connections in DecoderForAutoenc

            This SkipBlock is based on using interpolation for upsampling.

            Class Path:   DLStudio  ->  Autoencoder  ->  SkipBlockDecoder
            """
            def __init__(self, in_ch, out_ch, upsample=False, skip_connections=True):
                super(DLStudio.Autoencoder.SkipBlockDecoder, self).__init__()
                self.upsample = upsample
                self.skip_connections = skip_connections
                self.in_ch = in_ch
                self.out_ch = out_ch
                self.conv1 = nn.Conv2d(in_ch, out_ch, 3,padding=1)
                self.conv2 = nn.Conv2d(in_ch, out_ch, 3,padding=1)
                self.conv3 = nn.Conv2d(out_ch, out_ch, 3,padding=1)
#                self.bn1 = nn.BatchNorm2d(out_ch)
                self.gn1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6, affine=True)   
#                self.bn2 = nn.BatchNorm2d(out_ch)
                self.gn2 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6, affine=True)   
                self.gn3 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6, affine=True)   

            def forward(self, x):
                identity = x                                     
                out = self.conv1(x)
#                out = self.bn1(out) 
                out = self.gn1(out) 
                out = nn.functional.relu(out)
                if self.in_ch == self.out_ch:
                    out = self.conv2(out)                              
#                    out = self.bn2(out)                              
                    out = self.gn2(out)                              
                    out = nn.functional.relu(out)
                if self.upsample:
                    ## This is the ONLY part where upsampling takes place in this skip block
                    out = F.interpolate(out, scale_factor=2.0)
                    identity = F.interpolate(identity, scale_factor=2.0)
                if self.skip_connections:
                    if self.in_ch == self.out_ch:
                        out = out + identity                              
                    else:
                        out = out + identity[:,self.out_ch:,:,:]
                out = self.conv3(out)
                out = self.gn3(out)                              
                out = nn.functional.relu(out)
                return out


        class SkipBlockDecoder_CT(nn.Module):
            """
            This is a building-block class for the skip connections in DecoderForAutoenc

            This class uses convTranspose layers for upsampling

            Class Path:   DLStudio  ->  Autoencoder  ->  SkipBlockDecoder
            """
            def __init__(self, in_ch, out_ch, upsample=False, skip_connections=True):
                super(DLStudio.Autoencoder.SkipBlockDecoder_CT, self).__init__()
                self.upsample = upsample
                self.skip_connections = skip_connections
                self.in_ch = in_ch
                self.out_ch = out_ch
                self.convoT1 = nn.ConvTranspose2d(in_ch, out_ch, 3, padding=1)
                self.convoT2 = nn.ConvTranspose2d(in_ch, out_ch, 3, padding=1)
                self.bn1 = nn.BatchNorm2d(out_ch)
                self.bn2 = nn.BatchNorm2d(out_ch)
                if upsample:
                    self.upsampler = nn.ConvTranspose2d(in_ch, out_ch, 1, stride=2, dilation=2, output_padding=1, padding=0)
            def forward(self, x):
                identity = x                                     
                out = self.convoT1(x)                              
                out = self.bn1(out)                              
                out  =  nn.ReLU(inplace=False)(out)            
                if self.in_ch == self.out_ch:
                    out = self.convoT2(out)                              
                    out = self.bn2(out)                              
                    out = nn.functional.relu(out)
                if self.upsample:
                    out = self.upsampler(out)
                    identity = self.upsampler(identity)
                if self.skip_connections:
                    if self.in_ch == self.out_ch:
                        out = out + identity                              
                    else:
                        out = out + identity[:,self.out_ch:,:,:]
                return out


        def set_dataloader(self):
            """
            Note the call to random_split() in the second statement for dividing the overall dataset of images into 
            two DISJOINT parts, one for training and the other for testing.  Since my evaluation of the VAE at this
            time is purely on the basis of the visual quality of the output of the Decoder, I have set aside only
            200 randomly chosen images for testing.  Ordinarily, through, you would want to split the dataset in 
            the 70:30 or 80:20 ratio for training and testing.
            """
            dataset = torchvision.datasets.ImageFolder(root=self.dl_studio.dataroot,       
                           transform = tvt.Compose([                 
                                                tvt.Resize(self.dl_studio.image_size),             
                                                tvt.CenterCrop(self.dl_studio.image_size),         
                                                tvt.ToTensor(),                     
                                                tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),         
                           ]))

            dataset_train, dataset_test  =  torch.utils.data.random_split( dataset, lengths = [len(dataset) - 200, 200])

            self.train_dataloader = torch.utils.data.DataLoader(dataset_train, batch_size=self.dl_studio.batch_size, shuffle=True, num_workers=4)
            self.test_dataloader = torch.utils.data.DataLoader(dataset_test, batch_size=self.dl_studio.batch_size, shuffle=True, num_workers=4)



    ###%%%
    #####################################################################################################################
    #######################################  Start Definition of Inner Class VAE  #######################################

    class VAE (Autoencoder):             
        """
        VAE stands for "Variational Auto Encoder".  These days, you are more likely to see it
        written as "variational autoencoder".  I consider VAE as one of the foundational neural
        architectures in Deep Learning.  VAE is based on the new celebrated 2014 paper 
        "Auto-Encoding Variational Bayes" by Kingma and Welling.  The idea is for the Encoder 
        part of an Encoder-Decoder pair to learn the probability distribution for the Latent 
        Space Representation of a training dataset.  Described loosely, the latent vector z for 
        an input image x would be the "essence" of what x is depicting.  Presumably, after the
        latent distribution has been learned, the Decoder should be able to transform any "noise" 
        vector sampled from the latent distribution and convert it into the sort of output you 
        would see during the training process.

        In case you are wondering about the dimensionality of the Latent Space, consider the case
        that the input images are eventually converted into 8x8 pixel arrays, with each pixel
        represented by a 128-dimensional embedding.  In a vectorized representation, this implies
        an 8192-dimensional space for the Latent Distribution.  The mean (mu) and the log-variance
        values (logvar) values learned by the Encoder would represent vectors in an 8,192 
        dimensional space.  The Decoder's job would be sample this distribution and attempt a
        reconstruction of what the user wants to see at the output of the Decoder.

        As you can see, the VAE class is derived from the parent class Autoencoder.  Bulk of the
        computing in VAE is done through the functionality packed into the Autoencoder class.
        Therefore, in order to fully understand the VAE implementation here, your starting point
        should be the code for the Autoencoder class.  

        Class Path:   DLStudio  ->  VAE
        """  

        def __init__(self, dl_studio, encoder_in_im_size, encoder_out_im_size, decoder_out_im_size, encoder_out_ch, num_repeats, path_saved_encoder, path_saved_decoder ):
            super(DLStudio.VAE, self).__init__( dl_studio, encoder_in_im_size, encoder_out_im_size, decoder_out_im_size, encoder_out_ch, num_repeats, path_saved_model=None )
            self.parent_encoder =  DLStudio.Autoencoder.EncoderForAutoenc(dl_studio, encoder_in_im_size, encoder_out_im_size, encoder_out_ch, num_repeats, skip_connections=True )
            self.parent_decoder =  DLStudio.Autoencoder.DecoderForAutoenc(dl_studio, encoder_out_im_size, decoder_out_im_size)
            self.vae_encoder =  DLStudio.VAE.VaeEncoder(self.parent_encoder, encoder_out_im_size, encoder_out_ch)      
            self.vae_decoder =  DLStudio.VAE.VaeDecoder(self.parent_decoder, encoder_out_im_size, encoder_out_ch)
            self.encoder_out_im_size = self.encoder.encoder_out_im_size
            self.encoder_out_ch  =  self.encoder.encoder_out_ch
            self.path_saved_encoder = path_saved_encoder
            self.path_saved_decoder = path_saved_decoder


        class VaeEncoder(nn.Module):
            """
            The most important thing to note here is that this Encoder outputs ONLY the mean and the log-variance
            of the Gaussian distribution that models the latent vectors. VAEs are based on the assumption that 
            Latent Distributions are far simpler than the probability distributions that would model the image
            dataset used for training.

            Class Path:   DLStudio  ->  VAE  ->  VaeEncoder
            """
            def __init__(self, parent_encoder, encoder_out_im_size, encoder_out_ch):
                super(DLStudio.VAE.VaeEncoder, self).__init__()
                self.parent_encoder = parent_encoder
                self.num_nodes = encoder_out_im_size[0] * encoder_out_im_size[1]  * encoder_out_ch
                self.latent_dim = encoder_out_ch
                self.mu_layer =  nn.Linear(self.num_nodes, self.latent_dim)
                self.log_var_layer =  nn.Linear(self.num_nodes, self.latent_dim)

            def forward(self, x):
               encoded = self.parent_encoder(x)
               mu  = self.mu_layer(encoded.view(-1, self.num_nodes))        
               log_var = self.log_var_layer(encoded.view(-1, self.num_nodes))
               return mu, log_var


        class VaeDecoder(nn.Module):
            """
            The VAE Decoder's job is to take the mu and logvar values produced by the Encoder and
            generate an output image that contains the information that the user wants to see there.
            For obvious reasons, as to what exactly is seen at the output of the Decoder would 
            depend on the loss function used and the shape of the output tensor.  If all you wanted
            to see was a reduced dimensionality image at the output, you would need to change the
            final layers of the Decoder so that the final output corresponds to the shape that goes
            with that representation.

            Class Path:   DLStudio  ->  VAE  ->  VaeDecoder
            """
            def __init__(self, parent_decoder, encoder_out_im_size, encoder_out_ch):
                super(DLStudio.VAE.VaeDecoder, self).__init__()
                self.parent_decoder = parent_decoder
                self.encoder_out_im_size = encoder_out_im_size
                self.num_nodes = encoder_out_im_size[0] * encoder_out_im_size[1]  * encoder_out_ch
                self.latent_dim = encoder_out_ch
                self.reparametrized_to_decoder_input = nn.Linear(self.latent_dim, self.num_nodes)

            def reparameterize(self, mu, logvar):
               std  =  torch.exp(0.5 * logvar)

               ##  In the next statement, 'torch.randn' is sampling from an isotropic zero-mean 
               ##  unit-covariance Gaussian.  The call 'torch.randn_like' ensures that the returned 
               ##  tensor will have the same shape as the 'std' tensor.  
               ##
               ##  In order to understand the shape of 'std', consider the case when the size of the
               ##  pixel array at the Encoder output is 8x8, the embedding size 128, and the 
               ##  batch_size 48.  In this case, you have 64 pixels at the output of the Encoder 
               ##  (before you go into the Linear layers for mu and logvar estimation). So the 
               ##  shape of both 'logvar' and 'std' is going to be [48, 8192] where 8192 is the product 
               ##  of the 64 pixels and the 128 channels at each pixel.  Note that the shapes for all 
               ##  three of 'mu', 'logvar', and 'std' are identical and, for our example, that shape is
               ##  [48, 8192].
               eps =  torch.randn_like( std )                    ## standard normal N(0;1)
               return mu + eps * std

            def forward(self, mu, logvar):
               z = self.reparameterize( mu, logvar )
               z = self.reparametrized_to_decoder_input(z)
               decoded = self.parent_decoder( z.view(-1, self.latent_dim, self.encoder_out_im_size[0], self.encoder_out_im_size[1]) )
               return decoded, mu, logvar


        def save_encoder_model(self, model):
            '''
            Save the trained model to a disk file
            '''
            torch.save(model.state_dict(), self.path_saved_encoder)


        def save_decoder_model(self, model):
            '''
            Save the trained model to a disk file
            '''
            torch.save(model.state_dict(), self.path_saved_decoder)



        def run_code_for_training_VAE( self, vae_net, loss_weighting, display_train_loss=False ):
            """
            The code for set_dataloaders() for the VAE class shows how the overall dataset of images is divided into
            training and testing subsets.  

            The important thing to keep in mind about this function is the relative weighting of the reconstruction
            loss vis-a-vis the KL-divergence.  For an "optimized" VAE implementation, finding the best value to use
            for this relative weighting of the two loss components would be a part of hyperparameter tuning of the
            network.
            """            
            def loss_criterion(input_images, decoder_output_images, log_var, weighting):
                recon_loss = nn.MSELoss(reduction='sum')( input_images, decoder_output_images )   ## reconstruction loss
                KLD = -0.5 * torch.sum( 1 + log_var - mu.pow(2)  -  log_var.exp() )               ## KL Divergence
                KLD =  KLD * weighting
                return  recon_loss + KLD,  recon_loss, KLD

            vae_encoder = vae_net.vae_encoder.to(self.dl_studio.device)
            vae_decoder = vae_net.vae_decoder.to(self.dl_studio.device)
            accum_times = []
            start_time = time.perf_counter()
            print("")
            batch_size = self.dl_studio.batch_size
            print("\n\n batch_size: ", batch_size)
            num_batches_in_data_source = len(self.train_dataloader)
            total_num_updates = self.dl_studio.epochs * num_batches_in_data_source
            print("\n\n number of batches in the dataset: ", num_batches_in_data_source)
            optimizer1 = optim.Adam(vae_encoder.parameters(), lr=self.dl_studio.learning_rate)     
            optimizer2 = optim.Adam(vae_decoder.parameters(), lr=self.dl_studio.learning_rate)     
            mu = logvar = 0.0

            total_training_loss_tally = []
            recons_loss_tally = []
            KL_divergence_tally = []

            for epoch in range(self.dl_studio.epochs):                                                              
                print("")
                ##  The following are needed for calculating the avg values between displays:
                running_loss = running_recon_loss = running_kld_loss = 0.0
                for i, data in enumerate(self.train_dataloader):                                    
                    input_images, _ = data                              
                    input_images = input_images.to(self.dl_studio.device)
                    optimizer1.zero_grad()
                    optimizer2.zero_grad()
                    mu, logvar =  vae_encoder( input_images )
                    ##  As required by VAE, the Decoder is only being supplied with the mean 'mu' and the log-variance 'logvar':
                    decoder_out, _, _ =  vae_decoder( mu, logvar )    
                    loss, recon_loss, kld_loss  =  loss_criterion( input_images,  decoder_out, logvar, loss_weighting )
                    loss.backward()                                                                                        
                    optimizer1.step()
                    optimizer2.step()
                    running_loss += loss
                    running_recon_loss += recon_loss
                    running_kld_loss += kld_loss
                    if i % 200 == 199:    
                        avg_loss = running_loss / float(200)
                        avg_recon_loss = running_recon_loss / float(200)
                        avg_kld_loss = running_kld_loss / float(200)
                        total_training_loss_tally.append(avg_loss.item())
                        recons_loss_tally.append(avg_recon_loss.item())
                        KL_divergence_tally.append(avg_kld_loss.item())
                        running_loss = running_recon_loss = running_kld_loss = 0.0
                        current_time = time.perf_counter()
                        time_elapsed = current_time-start_time
                        print("[epoch:%2d/%2d  i:%4d  elapsed_time: %4d secs]     loss: %10.4f      recon_loss: %10.4f      kld_loss:  %10.4f " % 
                                               (epoch+1, self.dl_studio.epochs, i+1,time_elapsed,avg_loss,avg_recon_loss,avg_kld_loss)) 
                        accum_times.append(current_time-start_time)

            print("\nFinished Training\n")
            self.save_encoder_model( vae_encoder )
            self.save_decoder_model( vae_decoder )

            params_saved = { 'mean': mu, 'log_variance': logvar}
            pickle.dump(params_saved, open('params_saved.p', 'wb'))

            if display_train_loss:

                fig, (ax1,ax2,ax3) = plt.subplots(nrows=1, ncols=3, figsize=(20,5))

                ax1.plot(total_training_loss_tally)
                ax2.plot(recons_loss_tally)                                  
                ax3.plot(KL_divergence_tally)                                  

                ax1.set_xticks(np.arange(total_num_updates // 200))    ## since each val for plotting is generated every 200 iterations
                ax2.set_xticks(np.arange(total_num_updates // 200))
                ax3.set_xticks(np.arange(total_num_updates // 200))

                ax1.set_xlabel("iterations") 
                ax2.set_xlabel("iterations") 
                ax3.set_xlabel("iterations") 

                ax1.set_ylabel("total training loss") 
                ax2.set_ylabel("reconstruction loss") 
                ax3.set_ylabel("KL divergence") 

                plt.savefig("all_training_losses.png")
                plt.show()
   


        def run_code_for_evaluating_VAE(self, vae_net, visualization_dir = "vae_visualization_dir" ):
            """
            The main point here is to use the co-called "unseen images" for evaluating the performance
            of the VAE Encoder-Decoder network.  If you look at the set_dataloader() function for the
            VAE class, you will see me setting aside a certain number of the available images for testing.
            These randomly chosen images play NO role in training.
            """
    
            if os.path.exists(visualization_dir):  
                """
                Clear out the previous outputs in the visualization directory
                """
                files = glob.glob(visualization_dir + "/*")
                for file in files: 
                    if os.path.isfile(file): 
                        os.remove(file) 
                    else: 
                        files = glob.glob(file + "/*") 
                        list(map(lambda x: os.remove(x), files)) 
            else: 
                os.mkdir(visualization_dir)   
     
            vae_encoder = vae_net.vae_encoder.eval()
            vae_decoder = vae_net.vae_decoder.eval()
            vae_encoder.load_state_dict(torch.load(self.path_saved_encoder))
            vae_decoder.load_state_dict(torch.load(self.path_saved_decoder))
            vae_encoder.to(self.dl_studio.device) 
            vae_decoder.to(self.dl_studio.device)
            with torch.no_grad():
                for i, data in enumerate(self.test_dataloader):                                    
                    print("\n\n\n=========Showing results for test batch %d===============" % i)
                    test_images, _ = data     
                    test_images = test_images.to(self.dl_studio.device)
                    mu, logvar =  vae_encoder( test_images )
                    ##  In the next statement, using mu and logvar, the Decoder first uses the "reparameterization trick" 
                    ##  to sample the latent distribution and to then feed it into the rest of the Decoder for image generation:
                    decoder_out, _, _ =  vae_decoder( mu, logvar )   
                    decoder_out  =  ( decoder_out - decoder_out.min() ) / ( decoder_out.max() -  decoder_out.min() )
                    together = torch.zeros( test_images.shape[0], test_images.shape[1], test_images.shape[2], 2 * test_images.shape[3], dtype=torch.float )
                    together[:,:,:,0:test_images.shape[3]]  =  test_images
                    together[:,:,:,test_images.shape[3]:]  =   decoder_out 
                    plt.figure(figsize=(40,20))
                    plt.imshow(np.transpose(torchvision.utils.make_grid(together.cpu(), normalize=False, padding=3, pad_value=255).cpu(), (1,2,0)))
                    plt.title("VAE Output Images for iteration %d" % i)
                    plt.savefig(visualization_dir + "/vae_decoder_out_%s" % str(i) + ".png")
                    plt.show()



        def run_code_for_generating_images_from_noise_VAE(self, vae_net, visualization_dir = "vae_gen_visualization_dir" ):
            """
            This function is for testing the functioning of just the Generator (which is the Decoder) in
            the VAE network.  That is, after we have trained the VAE network, we disconnect the Encoder 
            and ask the Decoder to sample the latent distribution for generating the images.

            Remember, the latent distribution is represented entirely by the final values learned for 
            the mean (mu) and the log of the variance (logvar) that represent how close the training process
            was able to come to the ideal of zero-mean and unit-covariance isotropic distribution.
            Since the job of this function is to sample the latent distribution actually learned, we must
            also supply with the (mu,logvar) values learned during training.
            """
            if os.path.exists(visualization_dir):  
                """
                Clear out the previous outputs in the visualization directory
                """
                files = glob.glob(visualization_dir + "/*")
                for file in files: 
                    if os.path.isfile(file): 
                        os.remove(file) 
                    else: 
                        files = glob.glob(file + "/*") 
                        list(map(lambda x: os.remove(x), files)) 
            else: 
                os.mkdir(visualization_dir)   
     
            vae_decoder = vae_net.vae_decoder.eval()
            vae_decoder.load_state_dict(torch.load(self.path_saved_decoder))
            params_saved = pickle.load( open('params_saved.p', 'rb') )
            mu, logvar = params_saved['mean'], params_saved['log_variance']

            ##  The size of the batch axis for the mu and logvar tensors will corresponds to the number 
            ##  of images in the last batch used for training.  If you want the purely generative process
            ##  in this script (which uses the VAE Decoder in a standalone mode) to produce a batchful of 
            ##  images, you need to expand the previously learned mu and logvar tensors as shown below:
            if mu.shape[0] < self.dl_studio.batch_size:
                new_mu = torch.zeros( (self.dl_studio.batch_size, mu.shape[1]) ).float()  
                new_mu[:mu.shape[0]]  =  mu
                new_mu[mu.shape[0]:] = mu[:(self.dl_studio.batch_size - mu.shape[0])]

                new_logvar = torch.zeros( (self.dl_studio.batch_size, logvar.shape[1]) ).float()  
                new_logvar[:logvar.shape[0]]  = logvar
                new_logvar[logvar.shape[0]:] = logvar[:(self.dl_studio.batch_size - logvar.shape[0])]
            mu = new_mu.to(self.dl_studio.device)
            logvar = new_logvar.to(self.dl_studio.device)
            vae_decoder.to(self.dl_studio.device)
            sample_standard_normal_distribution = True
            sample_learned_normal_distribution =  False
            with torch.no_grad():
                for i in range(5):
                    print("\n\n\n=========Showing results for test batch %d===============" % i)
                    if sample_standard_normal_distribution:
                        mu  =  torch.zeros_like(mu).float().to(self.dl_studio.device)
                        logvar = torch.ones_like(logvar).float().to(self.dl_studio.device)
                    elif sample_learned_normal_distribution:
                        std = torch.exp(0.5 * logvar)
                    ##  In the next statement, using mu and logvar, the Decoder first uses the "reparameterization trick" 
                    ##  to sample the latent distribution and to then feed it into the rest of the Decoder for image generation:
                    decoder_out, _, _ =  vae_decoder( mu, logvar )   
                    decoder_out  =  ( decoder_out - decoder_out.min() ) / ( decoder_out.max() -  decoder_out.min() )
                    fake_input = torch.zeros_like(decoder_out).float().to(self.dl_studio.device)
                    together = torch.zeros( fake_input.shape[0], fake_input.shape[1], fake_input.shape[2], 2 * fake_input.shape[3], dtype=torch.float )
                    together[:,:,:,0:fake_input.shape[3]]  =  fake_input
                    together[:,:,:,fake_input.shape[3]:]  =   decoder_out 
                    plt.figure(figsize=(40,20))
                    plt.imshow(np.transpose(torchvision.utils.make_grid(together.cpu(), normalize=False, padding=3, pad_value=255).cpu(), (1,2,0)))
                    plt.title("VAE Output Images for iteration %d" % i)
                    plt.savefig(visualization_dir + "/vae_decoder_out_%s" % str(i) + ".png")
                    plt.show()


        def set_dataloader(self):
            """
            Note the call to random_split() in the second statement for dividing the overall dataset of images into 
            two DISJOINT parts, one for training and the other for testing.  Since my evaluation of the VAE at this
            time is purely on the basis of the visual quality of the output of the Decoder, I have set aside only
            200 randomly chosen images for testing.  Ordinarily, through, you would want to split the dataset in 
            the 70:30 or 80:20 ratio for training and testing.
            """
            dataset = torchvision.datasets.ImageFolder(root=self.dl_studio.dataroot,       
                           transform = tvt.Compose([                 
                                                tvt.Resize(self.dl_studio.image_size),             
                                                tvt.CenterCrop(self.dl_studio.image_size),         
                                                tvt.ToTensor(),                     
                                                tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),         
                           ]))

            dataset_train, dataset_test  =  torch.utils.data.random_split( dataset, lengths = [len(dataset) - 200, 200])
            self.train_dataloader = torch.utils.data.DataLoader(dataset_train, batch_size=self.dl_studio.batch_size, shuffle=True, num_workers=4)
            self.test_dataloader = torch.utils.data.DataLoader(dataset_test, batch_size=self.dl_studio.batch_size, shuffle=True, num_workers=4)



    ###%%%
    #####################################################################################################################
    ######################################  Start Definition of Inner Class VQVAE  ######################################

    class VQVAE (Autoencoder):             
        """
        VQVAE is an important architecture in deep learning because it teaches us about what
        has come to be known as "Codebook Learning" for more efficient discrete representation
        of images with a finite vocabulary of embedding vectors.

        VQVAE stands for "Vector Quantized Variational Auto Encoder", which is also frequently
        represented by the acronym VQ-VAE.  The concept of VQ-VAE was formulated in the 2018
        publication "Neural Discrete Representation Learning" by van den Oord, Vinyals, and
        Kavukcuoglu.

        For the case of images, VQ-VAE means that we want to represent an input image using a
        user-specified number of embedding vectors.  You could think of the set of embedding
        vectors as constituting a fixed-size vocabulary for representing the input data.

        To make the definition of Codebook Learning more specific, say we are using an
        Encoder-Decoder to create such a fixed-vocabulary based representation for the images.
        Let's assume that the Encoder converts each input batch of images into a (B,C,H,W)
        shaped tensor where the height H and the width W are likely to be small numbers, say 8
        each, and C is likely to be, say, 128.  Let's also say that the batch size is 256.

        The total number of pixels in all the batch instances at the output of the Encoder will
        be B*H*W.  I'll represent this number of pixels with the notation BHW.  For the example
        numbers used above, BHW will be equal to 256*8*8 = 16384.

        Taking cognizance of the channel axis, we can say that each of the 16,384 pixels at the
        output of the Encoder is represented by a 128 element vector along the channel axis.

        As things stand, each C-dimensional pixel based vector at the output of the Encoder will
        be a continuous valued vector.

        The goal of VQ-VAE is define a Codebook of K vectors, each of dimension D, with the
        idea that each of the C-dimensional BHW vectors at the output of the Encode will be
        replaced by the closest of the K D-dimensional vectors in the Codebook.  For practical
        reasons, we require D=C.

        The Decoder's job then is to try its best to recreate the input using the Codebook
        approximations at the output of the Encoder.

        The goal of VQ-VAE is to demonstrate that it is possible to learn a Codebook with K
        elements that can subsequently be used to represent any input.

        You can think of the learned Codebook vectors as the quantized versions of what the
        Encoder presents at its output.
        
        As you can see, the VQVAE class is derived from the parent class Autoencoder.  Bulk of the
        computing in VQVAE is done through the functionality packed into the Autoencoder class.
        Therefore, in order to fully understand the VQVAE implementation here, your starting point
        should be the code for the Autoencoder class.  

        Note that the VQVAE code presented here is still tentative.  Most of the heavy lifting
        at the moment is done by the two Vector Representation classes I have borrowed from
        "zalandoresearch" at GitHub:

                  https://github.com/zalandoresearch/pytorch-vq-vae

        Class Path:   DLStudio  ->  VQVAE
        """  
        def __init__(self, dl_studio, encoder_in_im_size,  encoder_out_im_size, decoder_out_im_size, encoder_out_ch, num_repeats, num_codebook_vecs,
            codebook_vec_dim, commitment_cost, decay, path_saved_encoder, path_saved_decoder, path_saved_vector_quantizer, path_saved_prevq_and_postvq_convos ):
            super(DLStudio.VQVAE, self).__init__( dl_studio, encoder_in_im_size, encoder_out_im_size, decoder_out_im_size, encoder_out_ch, num_repeats, path_saved_model=None )
            self.parent_encoder =  DLStudio.Autoencoder.EncoderForAutoenc(dl_studio, encoder_in_im_size, encoder_out_im_size, encoder_out_ch, num_repeats, skip_connections=True) 
            self.parent_decoder =  DLStudio.Autoencoder.DecoderForAutoenc(dl_studio, encoder_out_im_size, decoder_out_im_size)
            self.num_codebook_vecs  = num_codebook_vecs
            self.codebook_vec_dim = codebook_vec_dim
            self.commitment_cost = commitment_cost
            self.decay = decay
            self.vqvae_encoder =  DLStudio.VQVAE.VQVaeEncoder(self.parent_encoder)      
            self.vqvae_decoder =  DLStudio.VQVAE.VQVaeDecoder(self.parent_decoder, encoder_out_im_size, encoder_out_ch)
            self.vector_quantizer = DLStudio.VQVAE.VectorQuantizerEMA(num_codebook_vecs, codebook_vec_dim, commitment_cost, decay)
            self.pre_vq_convo  =  nn.Conv2d(in_channels=encoder_out_ch, out_channels=codebook_vec_dim, kernel_size=1, stride=1)
            self.post_vq_convo =  nn.Conv2d(in_channels=encoder_out_ch, out_channels=codebook_vec_dim, kernel_size=1, stride=1)
            self.encoder_out_im_size = self.encoder.encoder_out_im_size
            self.encoder_out_ch  =  self.encoder.encoder_out_ch
            self.path_saved_encoder = path_saved_encoder
            self.path_saved_decoder = path_saved_decoder
            self.path_saved_vector_quantizer = path_saved_vector_quantizer
            self.path_saved_prevq_and_postvq_convos = path_saved_prevq_and_postvq_convos

        ##  These getter methods are for the subclass VQGAN that is derived from VQVAE:
        def get_vqvae_encoder(self):
            return self.vqvae_encoder 
        def get_vqvae_decoder(self):
            return self.vqvae_decoder 
        def get_vector_quantizer(self):      
            return self.vector_quantizer
        def get_pre_vq_convo(self):
            return self.pre_vq_convo
        def get_post_vq_convo(self):
            return self.post_vq_convo


        class VQVaeEncoder(nn.Module):
            """
            I'll use the same Encoder that is in VQVAE's parent class Autoencoder. 

            Class Path:   DLStudio  ->  VQVAE  ->  VQVaeEncoder
            """
            def __init__(self, parent_encoder):
                super(DLStudio.VQVAE.VQVaeEncoder, self).__init__()
                self.parent_encoder = parent_encoder

            def forward(self, x):
               encoded = self.parent_encoder(x)
               return encoded


        class VQVaeDecoder(nn.Module):
            """
            I'll use the same Decoder that is in VQVAE's parent class Autoencoder.

            Class Path:   DLStudio  ->  VQVAE  ->  VQVaeDecoder
            """
            def __init__(self, parent_decoder, encoder_out_im_size, encoder_out_ch):
                super(DLStudio.VQVAE.VQVaeDecoder, self).__init__()
                self.parent_decoder = parent_decoder
                self.encoder_out_im_size = encoder_out_im_size
                self.encoder_out_ch = encoder_out_ch

            def forward(self, quantized):
                decoded = self.parent_decoder( quantized.view(-1, self.encoder_out_ch, self.encoder_out_im_size[0], self.encoder_out_im_size[1]) )
                return decoded


        class VectorQuantizer(nn.Module):
            """
            This class is from:

                      https://github.com/zalandoresearch/pytorch-vq-vae

            This is an implementation of VQ-VAE by Aäron van den Oord et al. 

            Class Path:   DLStudio  ->  VQVAE  ->  VectorQuantizer
            """

            @static_var("_codebook", None) 
            def __init__(self, num_codebook_vecs, codebook_vec_dim, commitment_cost):
                super(DLStudio.VQVAE.VectorQuantizer, self).__init__()                
                self._codebook_vec_dim = codebook_vec_dim
                self._num_codebook_vecs = num_codebook_vecs
                
                self._codebook = nn.Embedding(self._num_codebook_vecs, self._codebook_vec_dim)
                self._codebook.weight.data.uniform_(-1/self._num_codebook_vecs, 1/self._num_embeddings)
                self._commitment_cost = commitment_cost
        
            def forward(self, inputs):
                # convert inputs from BCHW -> BHWC
                inputs = inputs.permute(0, 2, 3, 1).contiguous()               ## Reshaping the output of the Encoder since the
                                                                               ##   channel axis is going to be treated as the
                                                                               ##   embedding axis.
                input_shape = inputs.shape                                     ## Needed later for shape restoration with unflattening 
                # Flatten input
                flat_input = inputs.view(-1, self._codebook_vec_dim)
                # Calculate distances between the input embedding vector and each of the codebook vectors
                distances = (torch.sum(flat_input**2, dim=1, keepdim=True) + torch.sum(self._codebook.weight**2, dim=1)
                                                                              - 2 * torch.matmul(flat_input, self._codebook.weight.t()))
                # Encoding
                encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
                encodings = torch.zeros(encoding_indices.shape[0], self._num_codebook_vecs, device=inputs.device)
                encodings.scatter_(1, encoding_indices, 1)
                # Quantize and unflatten
                quantized = torch.matmul(encodings, self._codebook.weight).view(input_shape)
                # Loss
                e_latent_loss = F.mse_loss(quantized.detach(), inputs)     
                q_latent_loss = F.mse_loss(quantized, inputs.detach())     
                loss = q_latent_loss + self._commitment_cost * e_latent_loss
                
                quantized = inputs + (quantized - inputs).detach()
                avg_probs = torch.mean(encodings, dim=0)
                perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
                
                # convert quantized from BHWC -> BCHW
                return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
        
        
        class VectorQuantizerEMA(nn.Module):
            """
            This class is from:

                      https://github.com/zalandoresearch/pytorch-vq-vae

            This is an implementation by Dominic Rampas of the VQ-VAE by Aäron van den Oord et al. 

            Class Path:   DLStudio  ->  VQVAE  ->  VectorQuantizerEMA
            """  
            static_codebook = {}
            def __init__(self, num_codebook_vecs, codebook_vec_dim, commitment_cost, decay, epsilon=1e-5):
                super(DLStudio.VQVAE.VectorQuantizerEMA, self).__init__()                
                self._codebook_vec_dim = codebook_vec_dim
                self._num_codebook_vecs = num_codebook_vecs
                DLStudio.VQVAE.VectorQuantizerEMA.static_codebook = nn.Embedding(self._num_codebook_vecs, self._codebook_vec_dim)
                self._codebook = DLStudio.VQVAE.VectorQuantizerEMA.static_codebook
                self._codebook.weight.data.normal_()
                self._commitment_cost = commitment_cost
                self.register_buffer('_ema_cluster_size', torch.zeros(num_codebook_vecs))
                self._ema_w = nn.Parameter(torch.Tensor(num_codebook_vecs, self._codebook_vec_dim))
                self._ema_w.data.normal_()
                self._decay = decay
                self._epsilon = epsilon
        
            def forward(self, inputs):
                # convert inputs from BCHW -> BHWC
                inputs = inputs.permute(0, 2, 3, 1).contiguous()                 ## Reshaping the output of the Encoder since the
                                                                                 ##   channel axis is going to be treated as the
                                                                                 ##   embedding axis.
                input_shape = inputs.shape                                       ## Needed later for shape restoration with unflattening 
                # Flatten input
                flat_input = inputs.view(-1, self._codebook_vec_dim)
                # Calculate distances
                distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                            + torch.sum(self._codebook.weight**2, dim=1)
                            - 2 * torch.matmul(flat_input, self._codebook.weight.t()))
                # Encoding
                encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
                ##  dimensionality num_emeddings
                encodings = torch.zeros(encoding_indices.shape[0], self._num_codebook_vecs, device=inputs.device)

                encodings.scatter_(1, encoding_indices, 1)    
                # Quantize and unflatten
                quantized = torch.matmul(encodings, self._codebook.weight).view(input_shape)
                if self.training:
                    self._ema_cluster_size = self._ema_cluster_size * self._decay +  (1 - self._decay) * torch.sum(encodings, 0)
                    # Laplace smoothing of the cluster size
                    n = torch.sum(self._ema_cluster_size.data)
                    self._ema_cluster_size = ( (self._ema_cluster_size + self._epsilon) / (n + self._num_codebook_vecs * self._epsilon) * n)
                    dw = torch.matmul(encodings.t(), flat_input)
                    self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
                    self._codebook.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
                # Loss
                e_latent_loss = F.mse_loss(quantized.detach(), inputs)    
                loss = self._commitment_cost * e_latent_loss
                # Straight Through Estimator
                quantized = inputs + (quantized - inputs).detach()
                ##  this histogram will be flat.
                avg_probs = torch.mean(encodings, dim=0)
                perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))  
                # convert quantized from BHWC -> BCHW
                return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

        
        def save_encoder_model(self, model):
            '''
            Save the trained Encoder model to a disk file
            '''
            torch.save(model.state_dict(), self.path_saved_encoder)

        def save_decoder_model(self, model):
            '''
            Save the trained Decoder model to a disk file
            '''
            torch.save(model.state_dict(), self.path_saved_decoder)

        def save_vector_quantizer_model(self, model):
            '''
            Save the trained Vector Quantizer model to a disk file
            '''
            torch.save(model.state_dict(), self.path_saved_vector_quantizer)

        def save_prevq_and_postvq_convos(self, prevq_convo, postvq_convo):
            convo_dict =  {"prevq" : prevq_convo,  "postvq" : postvq_convo}
            torch.save( convo_dict, self.path_saved_prevq_and_postvq_convos )


        def run_code_for_training_VQVAE( self, vqvae, display_train_loss=False ):
            """
            The code for set_dataloaders() for the VAE class shows how the overall dataset of images is divided into
            training and testing subsets.  
            """            
            vqvae_encoder = vqvae.vqvae_encoder.to(self.dl_studio.device)
            vqvae_vector_quantizer =  vqvae.vector_quantizer.to(self.dl_studio.device)
            vqvae_decoder = vqvae.vqvae_decoder.to(self.dl_studio.device)
            pre_vq_convo = vqvae.pre_vq_convo.to(self.dl_studio.device)
            post_vq_convo = vqvae.post_vq_convo.to(self.dl_studio.device)

            accum_times = []
            start_time = time.perf_counter()
            print("")
            batch_size = self.dl_studio.batch_size
            print("\n\n batch_size: ", batch_size)
            num_batches_in_data_source = len(self.train_dataloader)
            total_num_updates = self.dl_studio.epochs * num_batches_in_data_source
            print("\n\n number of batches in the dataset: ", num_batches_in_data_source)
            optimizer1 = optim.Adam(vqvae_encoder.parameters(), lr=self.dl_studio.learning_rate)     
            optimizer2 = optim.Adam(vqvae_decoder.parameters(), lr=self.dl_studio.learning_rate)     
            optimizer3 = optim.Adam(vqvae_vector_quantizer.parameters(), lr=self.dl_studio.learning_rate)     
            optimizer4 = optim.Adam(pre_vq_convo.parameters(), lr=self.dl_studio.learning_rate)     
            optimizer5 = optim.Adam(post_vq_convo.parameters(), lr=self.dl_studio.learning_rate)     

            training_loss_tally = []
            perplexity_tally = []        
            data_variance = 0.0
            for epoch in range(self.dl_studio.epochs):                                                              
                print("")
                running_loss = 0.0
                running_perplexity = 0.0
                for i, data in enumerate(self.train_dataloader):                                    
                    input_images, _ = data                              
                    input_images = input_images.to(self.dl_studio.device)
                    optimizer1.zero_grad()
                    optimizer2.zero_grad()
                    optimizer3.zero_grad()
                    optimizer4.zero_grad()
                    optimizer5.zero_grad()

                    z = vqvae_encoder(input_images)
                    z = pre_vq_convo(z)
                    vq_loss, quantized, perplexity, _ =  vqvae_vector_quantizer(z)
                    z = post_vq_convo(quantized)
                    decoder_out = vqvae_decoder(z)

                    recon_loss = nn.MSELoss(reduction='sum')( input_images, decoder_out ) 
                    loss = recon_loss + vq_loss
                    loss.backward()                                                                                        
                    optimizer1.step()
                    optimizer2.step()
                    optimizer3.step()
                    optimizer4.step()
                    optimizer5.step()

                    running_loss += loss
                    running_perplexity += perplexity
                    if i % 200 == 199:    
                        avg_loss = running_loss / float(200)
                        avg_perplexity = running_perplexity / float(200)                        
                        training_loss_tally.append(avg_loss.item())
                        perplexity_tally.append(avg_perplexity.item())
                        running_loss = 0.0
                        running_perplexity = 0.0
                        current_time = time.perf_counter()
                        time_elapsed = current_time-start_time
                        print("[epoch:%2d/%2d  i:%4d  elapsed_time: %4d secs]   loss: %10.6f         perplexity: %10.6f " % 
                                                             (epoch+1, self.dl_studio.epochs, i+1, time_elapsed, avg_loss, avg_perplexity)) 
                        accum_times.append(current_time-start_time)
            print("\nFinished Training VQVAE\n")
            self.save_encoder_model( vqvae_encoder )
            self.save_decoder_model( vqvae_decoder )
            self.save_vector_quantizer_model( vqvae_vector_quantizer )
            self.save_prevq_and_postvq_convos(pre_vq_convo, post_vq_convo)
            fig, (ax1,ax2) = plt.subplots(nrows=1, ncols=2, figsize=(12,5))
            ax1.plot(training_loss_tally)
            ax2.plot(perplexity_tally)                                  
            ax1.set_xticks(np.arange(total_num_updates // 200))    ## since each val for plotting is generated every 200 iterations
            ax2.set_xticks(np.arange(total_num_updates // 200))
            ax1.set_xlabel("iterations") 
            ax2.set_xlabel("iterations") 
            ax1.set_ylabel("vqvae training loss") 
            ax2.set_ylabel("vqvae Perplexity") 
            plt.savefig("vqvae_training_losses_and_perplexity.png")
            plt.show()


        def run_code_for_evaluating_VQVAE(self, vqvae, visualization_dir = "vqvae_visualization_dir" ):
            """
            The main point here is to use the co-called "unseen images" for evaluating the
            performance of VQVAE.  If you look at the set_dataloader() function for the VAE
            class, you will see me setting aside a certain number of the available images for
            testing.  These randomly chosen images play NO role in training.
            """
            if os.path.exists(visualization_dir):  
                """
                Clear out the previous outputs in the visualization directory
                """
                files = glob.glob(visualization_dir + "/*")
                for file in files: 
                    if os.path.isfile(file): 
                        os.remove(file) 
                    else: 
                        files = glob.glob(file + "/*") 
                        list(map(lambda x: os.remove(x), files)) 
            else: 
                os.mkdir(visualization_dir)   
     
            vqvae_encoder = vqvae.vqvae_encoder.eval()
            vqvae_decoder = vqvae.vqvae_decoder.eval()
            vqvae_vector_quantizer = vqvae.vector_quantizer.eval()

            convo_dict = torch.load(self.path_saved_prevq_and_postvq_convos)
            pre_vq_convo = convo_dict["prevq"]
            post_vq_convo = convo_dict["postvq"]
            pre_vq_convo   =  vqvae.pre_vq_convo.eval()
            post_vq_convo   =  vqvae.post_vq_convo.eval()

            vqvae_encoder.load_state_dict(torch.load(self.path_saved_encoder))
            vqvae_decoder.load_state_dict(torch.load(self.path_saved_decoder))
            vqvae_vector_quantizer.load_state_dict(torch.load(self.path_saved_vector_quantizer))
            vqvae_encoder.to(self.dl_studio.device) 
            vqvae_decoder.to(self.dl_studio.device)
            vqvae_vector_quantizer.to(self.dl_studio.device)
            pre_vq_convo.to(self.dl_studio.device)
            post_vq_convo.to(self.dl_studio.device)

            with torch.no_grad():
                for i, data in enumerate(self.test_dataloader):                                    
                    print("\n\n\n=========Showing VQVAE results for test batch %d===============" % i)
                    test_images, _ = data     
                    test_images = test_images.to(self.dl_studio.device)

                    z = vqvae_encoder(test_images)
                    z = pre_vq_convo(z)
                    _, quantized, perplexity, _ =  vqvae_vector_quantizer(z)
                    z = post_vq_convo(quantized)
                    decoder_out = vqvae_decoder(z)

                    decoder_out  =  ( decoder_out - decoder_out.min() ) / ( decoder_out.max() -  decoder_out.min() )
                    together = torch.zeros( test_images.shape[0], test_images.shape[1], test_images.shape[2], 2 * test_images.shape[3], dtype=torch.float )
                    together[:,:,:,0:test_images.shape[3]]  =  test_images
                    together[:,:,:,test_images.shape[3]:]  =   decoder_out 
                    plt.figure(figsize=(40,20))
                    plt.imshow(np.transpose(torchvision.utils.make_grid(together.cpu(), normalize=False, padding=3, pad_value=255).cpu(), (1,2,0)))
                    plt.title("VQVAE Output Images for iteration %d" % i)
                    plt.savefig(visualization_dir + "/vqvae_decoder_out_%s" % str(i) + ".png")
                    plt.show()


        def set_dataloader(self):
            """
            Note the call to random_split() in the second statement for dividing the overall dataset of images into 
            two DISJOINT parts, one for training and the other for testing.  Since my evaluation of the VAE at this
            time is purely on the basis of the visual quality of the output of the Decoder, I have set aside only
            200 randomly chosen images for testing.  Ordinarily, through, you would want to split the dataset in 
            the 70:30 or 80:20 ratio for training and testing.
            """
            dataset = torchvision.datasets.ImageFolder(root=self.dl_studio.dataroot,       
                           transform = tvt.Compose([                 
                                                tvt.Resize(self.dl_studio.image_size),             
                                                tvt.CenterCrop(self.dl_studio.image_size),         
                                                tvt.ToTensor(),                     
                                                tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),         
                           ]))
            dataset_train, dataset_test  =  torch.utils.data.random_split( dataset, lengths = [len(dataset) - 200, 200])
            self.train_dataloader = torch.utils.data.DataLoader(dataset_train, batch_size=self.dl_studio.batch_size, shuffle=True, num_workers=4)
            self.test_dataloader = torch.utils.data.DataLoader(dataset_test, batch_size=self.dl_studio.batch_size, shuffle=True, num_workers=4)



    ###%%%
    #####################################################################################################################
    ######################################  Start Definition of Inner Class VQGAN  ######################################

    class VQGAN (VQVAE):             

        def __init__(self, dl_studio, encoder_in_im_size,  encoder_out_im_size, decoder_out_im_size, encoder_out_ch, num_repeats, num_codebook_vecs, 
              codebook_vec_dim, commitment_cost, decay, perceptual_loss_factor, use_patch_gan_logic, path_saved_generator):
            super(DLStudio.VQGAN, self).__init__( dl_studio, encoder_in_im_size,  encoder_out_im_size, decoder_out_im_size, encoder_out_ch,
                                                                  num_repeats, num_codebook_vecs, codebook_vec_dim, commitment_cost, decay,
                              path_saved_encoder=None, path_saved_decoder=None, path_saved_vector_quantizer=None, path_saved_prevq_and_postvq_convos=None)
            self.num_codebook_vecs  = num_codebook_vecs
            self.codebook_vec_dim = codebook_vec_dim
            self.commitment_cost = commitment_cost
            self.decay = decay
            self.perceptual_loss_factor = perceptual_loss_factor
            self.vqgan_encoder =  super(DLStudio.VQGAN, self).get_vqvae_encoder()
            self.vqgan_decoder =  super(DLStudio.VQGAN, self).get_vqvae_decoder()
            self.vqgan_vector_quantizer =  super(DLStudio.VQGAN, self).get_vector_quantizer()
            self.vqgan_pre_vq_convo = super(DLStudio.VQGAN, self).get_pre_vq_convo()
            self.vqgan_post_vq_convo = super(DLStudio.VQGAN, self).get_post_vq_convo()
            self.discriminator = DLStudio.VQGAN.Discriminator_PatchGAN() if use_patch_gan_logic else DLStudio.VQGAN.Discriminator()
            self.path_saved_generator = path_saved_generator
            self.codebook = super(DLStudio.VQGAN, self).VectorQuantizerEMA.static_codebook

        class Discriminator(nn.Module):
            """
            This is for the non-patchGAN case.

            This is an implementation of the DCGAN Discriminator. I refer to the DCGAN network topology as
            the 4-2-1 network.  Each layer of the Discriminator network carries out a strided
            convolution with a 4x4 kernel, a 2x2 stride and a 1x1 padding for all but the final
            layer. The output of the final convolutional layer is pushed through a sigmoid to yield
            a scalar value as the final output for each image in a batch.
    
            Class Path:  DLStudio  ->   VQGAN  ->  Discriminator
            """
            def __init__(self):
                super(DLStudio.VQGAN.Discriminator, self).__init__()
                self.conv_in = nn.Conv2d(  3,    64,      kernel_size=4,      stride=2,    padding=1, bias=False)
                self.conv_in2 = nn.Conv2d( 64,   128,     kernel_size=4,      stride=2,    padding=1, bias=False)
                self.conv_in3 = nn.Conv2d( 128,  256,     kernel_size=4,      stride=2,    padding=1, bias=False)
                self.conv_in4 = nn.Conv2d( 256,  512,     kernel_size=4,      stride=2,    padding=1, bias=False)
                self.conv_in5 = nn.Conv2d( 512,  1,       kernel_size=4,      stride=1,    padding=0, bias=False)
                self.bn1  = nn.BatchNorm2d(128)
                self.bn2  = nn.BatchNorm2d(256)
                self.bn3  = nn.BatchNorm2d(512)
                self.sig = nn.Sigmoid()
    
            def forward(self, x):                 
                x = torch.nn.functional.leaky_relu(self.conv_in(x), negative_slope=0.2, inplace=True)
                x = self.bn1(self.conv_in2(x))
                x = torch.nn.functional.leaky_relu(x, negative_slope=0.2, inplace=True)
                x = self.bn2(self.conv_in3(x))
                x = torch.nn.functional.leaky_relu(x, negative_slope=0.2, inplace=True)
                x = self.bn3(self.conv_in4(x))
                x = torch.nn.functional.leaky_relu(x, negative_slope=0.2, inplace=True)
                x = self.conv_in5(x)
                x = self.sig(x)
                return x


        class Discriminator_PatchGAN(nn.Module):
            """
            This is a slight variation of the Discriminator by Dominic Rampas:

                      https://github.com/zalandoresearch/pytorch-vq-vae

            Class Path:  DLStudio  ->   VQGAN  ->  Discriminator_PatchGAN
            """
            def __init__(self):
                super(DLStudio.VQGAN.Discriminator_PatchGAN, self).__init__()
                num_filters_last = 128
                n_layers = 3
                layers = [nn.Conv2d(3, num_filters_last, 4, 2, 1), nn.LeakyReLU(0.2)]
                num_filters_mult = 1
        
                for i in range(1, n_layers + 1):
                    num_filters_mult_last = num_filters_mult
                    num_filters_mult = min(2 ** i, 8)
                    layers += [
                        nn.Conv2d(num_filters_last * num_filters_mult_last, num_filters_last * num_filters_mult, 4,
                                  2 if i < n_layers else 3, 3, bias=False),
                        nn.BatchNorm2d(num_filters_last * num_filters_mult),
                        nn.LeakyReLU(0.2, True)
                    ]
        
                layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, 4, 1, 1))
                self.model = nn.Sequential(*layers)
            def forward(self, x):
                return nn.Sigmoid()( self.model(x) )
        


        class Discriminator_PatchGAN_2(nn.Module):
            """
            This Discriminator is from DLStudio's AdversarialLearning module.  
 
            Class Path:  DLStudio  ->   VQGAN  ->  Discriminator_PatchGAN
            """
            def __init__(self):
                super(DLStudio.VQGAN.Discriminator_PatchGAN, self).__init__()
                self.conv_in = nn.Conv2d(  3,    64,      kernel_size=4,      stride=2,    padding=1, bias=False)
                self.conv_in2 = nn.Conv2d( 64,   128,     kernel_size=4,      stride=2,    padding=1, bias=False)
                self.conv_in3 = nn.Conv2d( 128,  256,     kernel_size=4,      stride=2,    padding=1, bias=False)
                self.conv_in5 = nn.Conv2d( 256,  1,       kernel_size=5,      stride=1,    padding=0, bias=False)
                self.bn1  = nn.BatchNorm2d(128)
                self.bn2  = nn.BatchNorm2d(256)
                self.sig = nn.Sigmoid()
    
            def forward(self, x):                 
                x = torch.nn.functional.leaky_relu(self.conv_in(x), negative_slope=0.2, inplace=True)
                x = self.bn1(self.conv_in2(x))
                x = torch.nn.functional.leaky_relu(x, negative_slope=0.2, inplace=True)
                x = self.bn2(self.conv_in3(x))
                x = torch.nn.functional.leaky_relu(x, negative_slope=0.2, inplace=True)
                x = self.conv_in5(x)
                x = self.sig(x)
                return x


        def save_generator_model(self, model):
            '''
            Save the trained Generator (meanin, VQGAN) model to a disk file
            '''
            torch.save(model.state_dict(), self.path_saved_generator)

        def weights_init(self, m):        
            """
            Uses the DCGAN initializations for the weights
            """
            classname = m.__class__.__name__     
            if classname.find('Conv') != -1:         
                nn.init.normal_(m.weight.data, 0.0, 0.02)      
            elif classname.find('BatchNorm') != -1:         
                nn.init.normal_(m.weight.data, 1.0, 0.02)       
                nn.init.constant_(m.bias.data, 0)      

        def run_code_for_training_VQGAN( self, vqgan ):
            """
            This uses the regular Discriminator.  That is, the Discriminator puts out A SINGLE SCALAR VALUE that
            expresses the probability that the image at its input came from the probability distribution that
            describes the training dataset.

            IMPORTANT:  You will get significantly better results if you train with the next function named

                                       run_code_for_PATCH_BASED_training_VQGAN

            Also note that the code for set_dataloaders() for the VAE class shows how the overall dataset of 
            images is divided into training and testing subsets.  
            """            
            def normed_tensor(x):
                norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
                return x / (norm_factor + 1e-10)

            class Generator(nn.Module):
                """
                In keeping with the ethos of Adversarial Learning, it's good to bundle the code so that there
                is readily identifiable separation between the Generator part and the Discriminator part.  In 
                our case, the Generator is the VQGAN network itself.
                """
                def __init__(self):
                    super(Generator, self).__init__()
                    self.vqgan_encoder = vqgan.vqgan_encoder
                    self.vqgan_vector_quantizer =  vqgan.vector_quantizer
                    self.vqgan_decoder = vqgan.vqgan_decoder
                    self.pre_vq_convo = vqgan.vqgan_pre_vq_convo
                    self.post_vq_convo = vqgan.vqgan_post_vq_convo

                def forward(self, input_images):      
                    z = self.vqgan_encoder(input_images)
                    z = self.pre_vq_convo(z)
                    vq_loss, quantized, perplexity, _ =  self.vqgan_vector_quantizer(z)
                    z = self.post_vq_convo(quantized)
                    decoder_out = self.vqgan_decoder(z)
                    decoder_out = normed_tensor(decoder_out)
                    return decoder_out, perplexity, vq_loss

            generator = Generator().to(self.dl_studio.device)
            print("\n\nType of generator constructed: ", type(generator))
            print("number of learnable params in generator: ", sum(p.numel() for p in generator.parameters() if p.requires_grad))
            self.generator = generator
            discriminator = vqgan.discriminator.to(self.dl_studio.device)
            discriminator.apply(self.weights_init)
            generator.apply(self.weights_init)
            print("number of learnable params in discriminator: ", sum(p.numel() for p in discriminator.parameters() if p.requires_grad))
            lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg').to(self.dl_studio.device)
            accum_times = []
            start_time = time.perf_counter()
            print("")
            batch_size = self.dl_studio.batch_size
            print("\n\n batch_size: ", batch_size)
            num_batches_in_data_source = len(self.train_dataloader)
            total_num_updates = self.dl_studio.epochs * num_batches_in_data_source
            print("\n\n number of batches in the dataset: ", num_batches_in_data_source)
            optimizer1 = optim.Adam(generator.parameters(), lr=self.dl_studio.learning_rate)
            optimizer2 = optim.Adam(discriminator.parameters(), lr=self.dl_studio.learning_rate)     
            disc_loss_reals_tally = []
            disc_loss_fakes_tally = []
            generator_loss_tally = []
            perplexity_tally = []        
            real_label = 1      ##  Will be used as target when the Discriminator is trained on the dataset images
            fake_label = 0      ##  Will be used as target when the Discriminator is fed the output of the "Generator" --- meaning the VQGAN
            for epoch in range(self.dl_studio.epochs):                                                              
                print("")
                running_disc_loss_reals =  running_disc_loss_fakes = running_generator_loss = running_perplexity =  0.0

                for i, data in enumerate(self.train_dataloader):                                    
                    input_images, _ = data                              
                    input_images = input_images.to(self.dl_studio.device)
                    input_images = normed_tensor(input_images)
                    optimizer1.zero_grad()             ## generator
                    optimizer2.zero_grad()             ## discriminator
                    ##  Maximization for Discriminator training --- Part 1:
                    ##
                    ##  Maximization-Part 1 means that we want the output of the Discrminator to be as large as possible, 
                    ##  meaning to be as close to 1.0 as possible when it sees the training images.  The Discriminator outputs 
                    ##  the prob that an image came from the same distribution as the training dataset. The larger this prob,
                    ##  the smaller the BCELoss:
                    targets = torch.full((input_images.shape[0],), real_label, dtype=torch.float, device=self.dl_studio.device)  
                    output_disc_reals = discriminator(input_images).view(-1)     ## Discriminaotor should produce a scalar for each im in batch
                    lossD_for_reals = nn.BCELoss()(output_disc_reals, targets)                                                   

                    ##  Maximization for Discriminator training --- Part 2:
                    ##
                    targets = torch.full((input_images.shape[0],), fake_label, dtype=torch.float, device=self.dl_studio.device)  
                    targets = targets.float().to(self.dl_studio.device)
                    decoder_out, perplexity, vq_loss = generator(input_images)
                    output_disc_fakes = discriminator(decoder_out).view(-1)        ## Discriminaotor should produce a scalar for each im in batch
                    lossD_for_fakes = nn.BCELoss()(output_disc_fakes.detach().view(-1), targets)    ##  NOTE: invocation of detach() on generator
                    discriminator_loss  =   lossD_for_reals + lossD_for_fakes
                    discriminator_loss.backward()
                    ## Only the Discriminator params will be update
                    optimizer2.step()

                    ##  Minimization for Generator training
                    ##
                    targets = torch.full((input_images.shape[0],), real_label, dtype=torch.float, device=self.dl_studio.device)  
                    lossG_for_fakes = nn.BCELoss()(output_disc_fakes, targets)                                                   
                    recon_loss = nn.MSELoss()( input_images, decoder_out ) 
                    perceptual_loss = lpips( normed_tensor(input_images), normed_tensor(decoder_out) )
                    generator_loss  =  recon_loss +  lossG_for_fakes + vq_loss  + self.perceptual_loss_factor * perceptual_loss     
                    generator_loss.backward()
                    ## Only the VQGAN params (the Generator) will be updated:

                    optimizer1.step()
                    running_disc_loss_reals += lossD_for_reals
                    running_disc_loss_fakes += lossD_for_fakes
                    running_perplexity += perplexity
                    running_generator_loss += generator_loss
                    if i % 200 == 199:    
                        avg_gen_loss = running_generator_loss / float(200)
                        avg_disc_loss_reals = running_disc_loss_reals / float(200)
                        avg_disc_loss_fakes = running_disc_loss_fakes / float(200)
                        avg_perplexity = running_perplexity / float(200)                        
                        generator_loss_tally.append(avg_gen_loss.item())
                        disc_loss_reals_tally.append(avg_disc_loss_reals.item())
                        disc_loss_fakes_tally.append(avg_disc_loss_fakes.item())
                        perplexity_tally.append(avg_perplexity.item())
                        running_disc_loss_reals = running_disc_loss_fakes = running_generator_loss = running_perplexity = 0.0
                        current_time = time.perf_counter()
                        time_elapsed = current_time-start_time
                        print("[epoch:%2d/%2d  i:%4d  elapsed_time: %4d secs]   disc_loss_reals: %8.6f   disc_loss_fakes: %8.6f   gen_loss:  %8.6f   perplexity: %8.6f " %  (epoch+1, self.dl_studio.epochs, i+1, time_elapsed,                    avg_disc_loss_reals, avg_disc_loss_fakes, avg_gen_loss, avg_perplexity)) 
                        accum_times.append(current_time-start_time)
                torch.save(generator.state_dict(), "checkpoint_dir/checkpoint_" +  str(epoch))
                torch.save(self.codebook.state_dict(), "codebooks_saved/codebook_" +  str(epoch))
            print("\nFinished Training VQGAN\n")
            self.save_generator_model( generator )
            fig, (ax1,ax2,ax3,ax4) = plt.subplots(nrows=1, ncols=4, figsize=(20,5))
            ax1.plot(disc_loss_reals_tally)
            ax2.plot(disc_loss_fakes_tally)
            ax3.plot(generator_loss_tally)
            ax4.plot(perplexity_tally)                                  
            ax1.set_xticks(np.arange(total_num_updates // 200))    ## since each val for plotting is generated every 200 iterations
            ax2.set_xticks(np.arange(total_num_updates // 200))
            ax3.set_xticks(np.arange(total_num_updates // 200))
            ax4.set_xticks(np.arange(total_num_updates // 200))
            ax1.set_xlabel("iterations") 
            ax2.set_xlabel("iterations") 
            ax3.set_xlabel("iterations") 
            ax4.set_xlabel("iterations") 
            ax1.set_ylabel("discriminator loss - reals") 
            ax2.set_ylabel("discriminator loss - fakes") 
            ax3.set_ylabel("generator loss") 
            ax4.set_ylabel("vqgan Perplexity") 
            plt.savefig("vqgan_training_losses_and_perplexity.png")
            plt.show()


        def run_code_for_PATCH_BASED_training_VQGAN( self, vqgan ):
            """
            This is based on the patchGAN based Discriminator.  That is, the Discriminator assumes that the input
            image can be thought of as being composed of an NxN array of patches.  Subsequently, it puts out an
            NxN array of probability numbers, with each number expressing the belief that it came from the same
            probability distribution that defines the training dataset of images.

            The code for set_dataloaders() for the VAE class shows how the overall dataset of images is divided into
            training and testing subsets.  
            """            
            def normed_tensor(x):
                norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
                return x / (norm_factor + 1e-10)

            class Generator(nn.Module):
                def __init__(self):
                    super(Generator, self).__init__()
                    self.vqgan_encoder = vqgan.vqgan_encoder
                    self.vqgan_vector_quantizer =  vqgan.vector_quantizer
                    self.vqgan_decoder = vqgan.vqgan_decoder
                    self.pre_vq_convo = vqgan.vqgan_pre_vq_convo
                    self.post_vq_convo = vqgan.vqgan_post_vq_convo
                def forward(self, input_images):      
                    z = self.vqgan_encoder(input_images)
                    z = self.pre_vq_convo(z)
                    vq_loss, quantized, perplexity, encoding_indices =  self.vqgan_vector_quantizer(z)
                    z = self.post_vq_convo(quantized)
                    decoder_out = self.vqgan_decoder(z)
                    decoder_out = normed_tensor(decoder_out)
                    return decoder_out, perplexity, vq_loss, encoding_indices

            self.codebook = super(DLStudio.VQGAN, self).VectorQuantizerEMA.static_codebook
            generator = Generator().to(self.dl_studio.device)
            print("\n\nType of generator constructed: ", type(generator))
            print("number of learnable params in generator: ", sum(p.numel() for p in generator.parameters() if p.requires_grad))
            self.generator = generator
            discriminator = vqgan.discriminator.to(self.dl_studio.device)
            discriminator.apply(self.weights_init)
            generator.apply(self.weights_init)
            print("number of learnable params in discriminator: ", sum(p.numel() for p in discriminator.parameters() if p.requires_grad))
            lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg').to(self.dl_studio.device)
            accum_times = []
            start_time = time.perf_counter()
            print("")
            batch_size = self.dl_studio.batch_size
            print("\n\n batch_size: ", batch_size)
            num_batches_in_data_source = len(self.train_dataloader)
            total_num_updates = self.dl_studio.epochs * num_batches_in_data_source
            print("\n\n number of batches in the dataset: ", num_batches_in_data_source)
            optimizer1 = optim.Adam(generator.parameters(), lr=self.dl_studio.learning_rate)
            optimizer2 = optim.Adam(discriminator.parameters(), lr=self.dl_studio.learning_rate)     
            disc_loss_reals_tally = []
            disc_loss_fakes_tally = []
            generator_loss_tally = []
            perplexity_tally = []        
            data_variance = 0.0
            for epoch in range(self.dl_studio.epochs):                                                              
                print("")
                running_disc_loss_reals =  running_disc_loss_fakes = running_generator_loss = running_perplexity =  0.0
                for i, data in enumerate(self.train_dataloader):                                    
                    input_images, _ = data                              
                    input_images = input_images.to(self.dl_studio.device)
                    input_images = normed_tensor(input_images)
                    optimizer1.zero_grad()             ## generator
                    optimizer2.zero_grad()             ## discriminator

                    ##  Maximization for Discriminator training --- Part 1:
                    ##
                    ##  Maximization-Part 1 means that we want the output of the Discrminator to be as large as possible, 
                    ##  meaning to be as close to 1.0 as possible when it sees the training images.  The Discriminator outputs 
                    ##  the prob that an image came from the same distribution as the training dataset. The larger this prob,
                    ##  the smaller the BCELoss:
                    targets = torch.ones( input_images.shape[0], 1, 4, 4 ).float().to(self.dl_studio.device)
                    output_disc_reals = discriminator(input_images)     ## Discriminaotor should produce a scalar for each im in batch
                    lossD_for_reals = nn.BCELoss()(output_disc_reals, targets)                                                   

                    ##  Maximization for Discriminator training --- Part 2:
                    ##
                    targets = torch.zeros( input_images.shape[0], 1, 4, 4 ).float().to(self.dl_studio.device)
                    if 'singlefile' in str(type(generator)):
                        decoder_out, vq_loss, encoding_indices = generator(input_images)   
                        perplexity = len(encoding_indices)
                    else:
                        decoder_out, perplexity, vq_loss, encoding_indices = generator(input_images)   
                    output_disc_fakes = discriminator(decoder_out)        ## Discriminaotor should produce a scalar for each im in batch
                    lossD_for_fakes = nn.BCELoss()(output_disc_fakes.detach(), targets)    ##  NOTE: invocation of detach() on generator
                    discriminator_loss  =   (lossD_for_reals + lossD_for_fakes).mean()
                    discriminator_loss.backward()
                    ## Only the Discriminator params will be update
                    optimizer2.step()

                    ##  Minimization for Generator training
                    ##
                    targets = torch.ones( input_images.shape[0], 1, 4, 4 ).float().to(self.dl_studio.device)
                    lossG_for_fakes = nn.BCELoss()(output_disc_fakes, targets)                                                   
                    recon_loss = nn.MSELoss()( input_images, decoder_out ) 
                    perceptual_loss = lpips( normed_tensor(input_images), normed_tensor(decoder_out) )
                    generator_loss  =  recon_loss +  lossG_for_fakes + vq_loss  + self.perceptual_loss_factor * perceptual_loss     
                    generator_loss.backward()
                    ## Only the VQGAN params (the Generator) will be updated:
                    optimizer1.step()

                    running_disc_loss_reals += lossD_for_reals
                    running_disc_loss_fakes += lossD_for_fakes
                    running_perplexity += perplexity
                    running_generator_loss += generator_loss
                    if i % 200 == 199:    
                        avg_gen_loss = running_generator_loss / float(200)
                        avg_disc_loss_reals = running_disc_loss_reals / float(200)
                        avg_disc_loss_fakes = running_disc_loss_fakes / float(200)
                        avg_perplexity = running_perplexity / float(200)                        

                        generator_loss_tally.append(avg_gen_loss.item())
                        disc_loss_reals_tally.append(avg_disc_loss_reals.item())
                        disc_loss_fakes_tally.append(avg_disc_loss_fakes.item())
                        if 'singlefile' in str(type(generator)):
                            perplexity_tally.append(avg_perplexity)
                        else:
                            perplexity_tally.append(avg_perplexity.item())
                        running_disc_loss_reals = running_disc_loss_fakes = running_generator_loss = running_perplexity = 0.0
                        current_time = time.perf_counter()
                        time_elapsed = current_time-start_time
                        print("[epoch:%2d/%2d  i:%4d  elapsed_time: %4d secs]   disc_loss_reals: %8.6f   disc_loss_fakes: %8.6f   gen_loss:  %8.6f   perplexity: %8.6f " %  (epoch+1, self.dl_studio.epochs, i+1, time_elapsed, avg_disc_loss_reals, avg_disc_loss_fakes, avg_gen_loss, avg_perplexity))
                        accum_times.append(current_time-start_time)
                if epoch % 10 == 9:
                    torch.save(generator.state_dict(), "checkpoint_dir/generator_" +  str(epoch))
                    torch.save(self.codebook.state_dict(), "checkpoint_dir/codebook_" +  str(epoch))
                    torch.save(self.vqgan_encoder.state_dict(), "checkpoint_dir/vqgan_encoder_" +  str(epoch))
                    torch.save(self.vqgan_decoder.state_dict(), "checkpoint_dir/vqgan_decoder_" +  str(epoch))
                    torch.save(self.vqgan_vector_quantizer.state_dict(), "checkpoint_dir/vqgan_vector_quantizer_" +  str(epoch))
                    torch.save(self.vqgan_pre_vq_convo.state_dict(), "checkpoint_dir/vqgan_pre_vq_convo_" +  str(epoch))
                    torch.save(self.vqgan_post_vq_convo.state_dict(), "checkpoint_dir/vqgan_post_vq_convo_" +  str(epoch))

            print("\nFinished Training VQGAN\n")
            self.save_generator_model( generator )
            fig, (ax1,ax2,ax3,ax4) = plt.subplots(nrows=1, ncols=4, figsize=(20,5))
            ax1.plot(disc_loss_reals_tally)
            ax2.plot(disc_loss_fakes_tally)
            ax3.plot(generator_loss_tally)
            ax4.plot(perplexity_tally)                                  
            ax1.set_xticks(np.arange(total_num_updates // 200))    ## since each val for plotting is generated every 200 iterations
            ax2.set_xticks(np.arange(total_num_updates // 200))
            ax3.set_xticks(np.arange(total_num_updates // 200))
            ax4.set_xticks(np.arange(total_num_updates // 200))
            ax1.set_xlabel("iterations") 
            ax2.set_xlabel("iterations") 
            ax3.set_xlabel("iterations") 
            ax4.set_xlabel("iterations") 
            ax1.set_ylabel("discriminator loss - reals") 
            ax2.set_ylabel("discriminator loss - fakes") 
            ax3.set_ylabel("generator loss") 
            ax4.set_ylabel("vqgan Perplexity") 
            plt.savefig("vqgan_training_losses_and_perplexity.png")
            plt.show()


        def run_code_for_evaluating_VQGAN(self, vqgan, visualization_dir = "vqgan_visualization_dir" ):
            """
            The main point here is to use the co-called "unseen images" for evaluating the
            performance of VQGAN.  If you look at the set_dataloader() function for the VAE
            class, you will see me setting aside a certain number of the available images for
            testing.  These randomly chosen images play NO role in training.
            """
            if os.path.exists(visualization_dir):  
                """
                Clear out the previous outputs in the visualization directory
                """
                files = glob.glob(visualization_dir + "/*")
                for file in files: 
                    if os.path.isfile(file): 
                        os.remove(file) 
                    else: 
                        files = glob.glob(file + "/*") 
                        list(map(lambda x: os.remove(x), files)) 
            else: 
                os.mkdir(visualization_dir)   

            def normed_tensor(x):
                norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
                return x / (norm_factor + 1e-10)

            class Generator(nn.Module):
                def __init__(self):
                    super(Generator, self).__init__()
                    self.vqgan_encoder = vqgan.vqgan_encoder
                    self.vqgan_vector_quantizer =  vqgan.vector_quantizer
                    self.vqgan_decoder = vqgan.vqgan_decoder
                    self.pre_vq_convo = vqgan.vqgan_pre_vq_convo
                    self.post_vq_convo = vqgan.vqgan_post_vq_convo
                def forward(self, input_images):      
                    z = self.vqgan_encoder(input_images)
                    z = self.pre_vq_convo(z)
                    vq_loss, quantized, perplexity, encoding_indices =  self.vqgan_vector_quantizer(z)
                    z = self.post_vq_convo(quantized)
                    decoder_out = self.vqgan_decoder(z)
                    decoder_out = normed_tensor(decoder_out)
                    return decoder_out, perplexity, vq_loss, encoding_indices

            generator = Generator().to(self.dl_studio.device)
            generator.load_state_dict(torch.load(self.path_saved_generator))
            with torch.no_grad():
                for i, data in enumerate(self.test_dataloader):                                    
                    print("\n\n\n=========Showing VQGAN results for test batch %d===============" % i)
                    test_images, _ = data     
                    test_images = test_images.to(self.dl_studio.device)
                    if 'singlefile' in str(type(generator)):
                        decoder_out, vq_loss, _ = generator(test_images)   
                    else:
                        decoder_out, _, vq_loss, _ = generator(test_images)   
                    decoder_out  =  ( decoder_out - decoder_out.min() ) / ( decoder_out.max() -  decoder_out.min() )
                    together = torch.zeros( test_images.shape[0], test_images.shape[1], test_images.shape[2], 2 * test_images.shape[3], dtype=torch.float )
                    together[:,:,:,0:test_images.shape[3]]  =  test_images
                    together[:,:,:,test_images.shape[3]:]  =   decoder_out 
                    plt.figure(figsize=(40,20))
                    plt.imshow(np.transpose(torchvision.utils.make_grid(together.cpu(), normalize=False, padding=3, pad_value=255).cpu(), (1,2,0)))
                    plt.title("VQGAN Output Images for iteration %d" % i)
                    plt.savefig(visualization_dir + "/vqgan_decoder_out_%s" % str(i) + ".png")
                    plt.show()
  

        def display_2_images( self, in_image, out_image ):
            """
            Will also work for a batch of images for the two arguments
            """
            out  =  ( out_image - out_image.min() ) / ( out_image.max() -  out_image.min() )
            together = torch.zeros( in_image.shape[0], in_image.shape[1], in_image.shape[2], 2 * in_image.shape[3], dtype=torch.float )
            together[:,:,:,0:in_image.shape[3]]  =  in_image
            together[:,:,:,in_image.shape[3]:]  =   out 
            plt.figure(figsize=(40,20))
            plt.imshow(np.transpose(torchvision.utils.make_grid(together.detach(), normalize=False, padding=3, pad_value=255).cpu(), (1,2,0)))
            plt.title("Intermediate Input and Ouput Images")
            plt.savefig("intermediate_input_and_output_images.png")
            plt.show()


        @torch.no_grad()
        def encode_image_into_sequence_of_indices_to_codebook_vectors(self, im_name=None ):
            """
            im_name is the name of a file with a suffix like ".jpg", ".png", etc.

            When im_name is not supplied, the function produces the output for a batch of images.

            WHY THIS FUNCTION IS USEFUL:  Codebook learning erases the distinction between processing
                                          languages (the same thing as text) and processing images.  Codebook 
                                          learning is a natural fit for language processing because languages
                                          are serial structures and the most fundamental unit in such a 
            structure is a word (or a token as a subword). Once you have set the vocabulary for the 
            fundamental units, it automatically follows that any sentence would be expressible as a sequence of
            the tokens and, consequently, as a sequence of the embedding vectors for the tokens.
            
            Codebook learning as made possible by VQGAN allows an automaton to understand images in exactly 
            the same manner as described above.  What a token vocabulary is for the case of languages
            is the codebook learned by VQGAN for the case of images. The size of the codebook for VQGAN is set 
            by the user, as is the size of the token vocabulary for the case of languages. Subsequently, 
            each embedding vector at the output of the VQGAN Encode is replaced by the closest codebook vector.

            To be more specific, let's say that the user-specified size for the VQGAN codebook is 512 and the
            output of the VQGAN Encoder is of shape NxNXC where both the height and the width are equal to N 
            C is the number of channels. Such an encoder can be construed as representing an input image with 
            N^2 embedding vectors, each of size C.  Subsequently, the Vector Quantizer will replace each of 
            these N^2 embedding vectors with the closest codebook vector and, thus, you will have a codebook 
            based representation of the input image.
             
            As you are playing with these notions with the help of this function, you become curious as to 
            what exactly in the images is represented by the codebook vectors, Could the different codebook 
            vectors represent, say, the different types of textures in an image. As you will discover by 
            playing with this function, at this moment in time, there are no good answers to this question. 
            To illustrate, suppose the codebook is learned through just a small number of epochs and that 
            the final value for the perplexity is, say, just around 2.0, that means your codebook will 
            contain only a couple of significant vectors (despite the fact that the codebook size you 
            trained with was, say, 512). In such a case, when you map the N^2 embedding vectors at the 
            output of the VQGAN Encoder to the integer indices associated with the closest codebook vectors, 
            you are likely to see just a couple of different integer indices in the N^2-element long list.  
            What's interesting is that even with just two different integer indices, the outputs produced by 
            the VQGAN Decoder would look very different depending on the positions occupied by the two 
            different codebook vectors.  For example, for the case when the VQGAN encode produces an 8x8 
            array at its output (when means that an input image would be represented by 64 embeddings),
            the following sequence of integer indices

                219,219,219,219,15,15,15,15,219,15,15,219,15,15,15,15,15,15,219,219,15,15,15,15,15,15,15, 
                219,219,15,15,15,15,15,15,219,219,15,15,15,15,15,15,15,219,15,15,15,15,15,15,15,15,15,15,  
                15,15,15,15,15,15,15,15,15
 
            may lead the VQGAN Decoder to output a sunflower image and the following sequence, on the other
            hand,

                15,15,15,15,15,15,15,15,15,15,15,15,15,219,15,15,15,15,15,15,219,219,15,15,15,15,15,219,15, 
                219,15,15,15,15,219,219,15,219,15,15,15,15,15,219,15,219,15,15,15,15,15,219,15,15,15,15,15,  
                15,15,15,15,15,15,15

            may lead to the image of a rose. I am mentioning the names of the flowers in my explanation 
            because my observations are based on my experiments with the flower dataset from the Univ of
            Oxford.
            """
            def normed_tensor(x):
                norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
                return x / (norm_factor + 1e-10)

            self.vqgan_encoder.load_state_dict(torch.load("checkpoint_dir/vqgan_encoder_99"))
            self.vqgan_decoder.load_state_dict(torch.load("checkpoint_dir/vqgan_decoder_99"))
            self.vqgan_vector_quantizer.load_state_dict(torch.load("checkpoint_dir/vqgan_vector_quantizer_99"))
            self.vqgan_pre_vq_convo.load_state_dict(torch.load("checkpoint_dir/vqgan_pre_vq_convo_99"))
            self.vqgan_post_vq_convo.load_state_dict(torch.load("checkpoint_dir/vqgan_post_vq_convo_99"))
            self.codebook = super(DLStudio.VQGAN, self).VectorQuantizerEMA.static_codebook
            self.codebook.load_state_dict(torch.load("checkpoint_dir/codebook_99"))
            if im_name:
                im_as_array =  Image.open( im_name )
                transform = tvt.Compose( [tvt.Resize((64,64)), 
                                          tvt.CenterCrop((64,64)),         
                                          tvt.ToTensor(), 
                                          tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] )
                im_as_tensor = transform( im_as_array )
                ##  Add batch axis:
                im_as_tensor  =  torch.unsqueeze( im_as_tensor, 0 )
            else:
                ## Multithreaded dataloader does not allow for a batch to be drawn randomly. So here's
                ## using an admittedly primitive ploy to get around that:
                for _ in range(random.randint(1,20)):
                    next(iter(self.train_dataloader))
                im_as_tensor = next(iter(self.train_dataloader))
                im_as_tensor = im_as_tensor[0]
            torch.set_printoptions(edgeitems=10_000, linewidth=120)
            z = self.vqgan_encoder(normed_tensor(im_as_tensor))
            z = self.pre_vq_convo(z)
            vq_loss, quantized, perplexity, encoding_indices_as_onehot_vecs =  self.vqgan_vector_quantizer(z)
            encoding_indices_as_ints = encoding_indices_as_onehot_vecs.argmax(1) 
            if im_name:
                print("\n\nFor image %s the encoding indices are:" % im_name)
                print(encoding_indices_as_ints)
            else:
                print("\n\nFor the batch, the encoding indices are are:")           
                batch_size = im_as_tensor.shape[0]
                how_many_per_image = len( encoding_indices_as_ints ) // batch_size
                for i in range(batch_size):
                   print("\nimage %d: %s" % (i, encoding_indices_as_ints[i*how_many_per_image : (i+1)*how_many_per_image]))
            torch.set_printoptions(profile="default")
            z = self.post_vq_convo(quantized)
            decoder_out = self.vqgan_decoder(z)
            decoder_out = normed_tensor(decoder_out)
            self.display_2_images(im_as_tensor, decoder_out)


        def run_code_for_transformer_based_modeling_VQGAN(self, vqgan, epochs_for_xformer_training, max_seq_length, embedding_size, codebook_size,
                     num_warmup_steps, optimizer_params, num_basic_decoders, num_atten_heads, masking, checkpoint_dir, visualization_dir ):
            """ 
            After codebook learning, for what I am going to focus on now, note that VQGAN Generator returns a sequence 
            of integers that are the indices of the codebook vectors for each of the embedding vectors at the output
            of the Encoder. To elaborate, assume that the Encoder outputs an NxNxC array where C is the number of 
            channels.  We can think of each of the N^2 array elements at the Encoder output as the embedding vectors, 
            with the embedding dimension being equal to C.

            Let's say S is the size of the codebook. The Vector Quantizer's job in VQGAN is to return the closest 
            codebook vector for each of the N^2 embedding vectors mentioned above.  If you look at what is returned
            by the class
                                       VectorQuantizerEMA            

            in VQGAN's parent class VQVAE, you will see the following as the last statement for the above:

                    return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

            where "quantized" is a sequence of codebook vectors that are the closest approximations to the N^2 
            embedding vectors at the output of the Encoder and where "encodings" is a sequence of integer index
            values associated with the codebook vectors in "quantized".

            The focus now is on the integer sequence in "encodings".

            The integer sequence that you see in "encodings" is no different from the integer sequence you would
            use for the tokens in a sentence in natural language processing (NLP).  For NLP, the tokenizer that
            you train for a given corpus gives you a token vocabulary that has a prescribed size of, say, 30,000.
            Based on the relative frequencies of the words from which the tokens are derived, the tokenizer 
            gives you a set of tokens, the total number of which will be bounded by the prescribed size, and, for
            each token its integer mapping.  Subsequently, for all neural-network based downstream processing, 
            you will represent text through a token sequence that, in effect, will be a sequence of integer index
            values.

            You can therefore say that what a VQGAN gives you through "encodings" obliterates the difference between
            image processing and language processing and you can think of the Vector Quantizer (VQ) as the tokenizer
            in NLP. After the VQ has learned the codebook, those are your tokens --- in their embedding vector 
            representations.  

            The above implies that just as you can do autoregressive modeling of text, you should be able to carry
            our autoregressive modeling of images through the tokens, meaning through the codebook vectors.
            The goal of the implementation shown below is to illustrate exactly that.

            """
            if os.path.exists(checkpoint_dir):                    ### this checkpoint_dir is just for transformer training
                files = glob.glob(checkpoint_dir + "/*")
                for file in files: 
                    if os.path.isfile(file): 
                        os.remove(file) 
                    else: 
                        files = glob.glob(file + "/*") 
                        list(map(lambda x: os.remove(x), files)) 
            else: 
                os.mkdir(checkpoint_dir)   

            if os.path.exists(visualization_dir):                 ### this visualization_dir is just for transformer training
                """
                Clear out the previous outputs in the visualization directory
                """
                files = glob.glob(visualization_dir + "/*")
                for file in files: 
                    if os.path.isfile(file): 
                        os.remove(file) 
                    else: 
                        files = glob.glob(file + "/*") 
                        list(map(lambda x: os.remove(x), files)) 
            else: 
                os.mkdir(visualization_dir)   

            def normed_tensor(x):
                norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
                return x / (norm_factor + 1e-10)

            self.vqgan_encoder.load_state_dict(torch.load("checkpoint_dir/vqgan_encoder_99"))
            self.vqgan_decoder.load_state_dict(torch.load("checkpoint_dir/vqgan_decoder_99"))
            self.vqgan_vector_quantizer.load_state_dict(torch.load("checkpoint_dir/vqgan_vector_quantizer_99"))
            self.vqgan_pre_vq_convo.load_state_dict(torch.load("checkpoint_dir/vqgan_pre_vq_convo_99"))
            self.vqgan_post_vq_convo.load_state_dict(torch.load("checkpoint_dir/vqgan_post_vq_convo_99"))
            self.codebook = super(DLStudio.VQGAN, self).VectorQuantizerEMA.static_codebook
            self.codebook.load_state_dict(torch.load("checkpoint_dir/codebook_99"))

            self.vqgan_encoder.to(self.dl_studio.device)
            self.vqgan_decoder.to(self.dl_studio.device)
            self.vqgan_vector_quantizer.to(self.dl_studio.device)
            self.vqgan_pre_vq_convo.to(self.dl_studio.device)
            self.vqgan_post_vq_convo.to(self.dl_studio.device)
            self.codebook.to(self.dl_studio.device)

            vocab_size = codebook_size
            xformer = DLStudio.TransformerFG( max_seq_length, embedding_size, vocab_size, num_warmup_steps, optimizer_params).to(self.dl_studio.device)
            master_decoder = DLStudio.MasterDecoderWithMasking(xformer, num_basic_decoders, num_atten_heads, masking).to(self.dl_studio.device)
            print("\nNumber of learnable params in Master Decoder: ", sum(p.numel() for p in master_decoder.parameters() if p.requires_grad))
            beta1,beta2,epsilon = optimizer_params['beta1'], optimizer_params['beta2'], optimizer_params['epsilon']     
            master_decoder_optimizer = DLStudio.ScheduledOptim(optim.Adam(master_decoder.parameters(), betas=(beta1,beta2), eps=epsilon),
                                                lr_mul=2, d_model=embedding_size, n_warmup_steps=num_warmup_steps)    
            max_seq_length = max_seq_length   ##  was set to  "encoder_out_size[0] * encoder_out_size[1]  + 2"  with 2 for SoS and EoS tokens
            criterion = nn.NLLLoss()                                                                                            
            accum_times = []
            start_time = time.perf_counter()
            batch_size = self.dl_studio.batch_size
            print("\nbatch_size: ", batch_size)
            num_batches_in_data_source = len(self.train_dataloader)
            total_num_updates = self.dl_studio.epochs * num_batches_in_data_source
            print("\nnumber of batches in the dataset: ", num_batches_in_data_source)
            training_loss_tally = []
            running_loss = 0.0
            ##  Initialize the SoS and EoS tokens, make them batch wide. Subsequently, during
            ##   training, you will attach SoS at the beginning of the integer indices sequence
            ##   for the codebook vector. And you will attach the EOS at the end.
            SoS_token = nn.Parameter( torch.randn(1,embedding_size, 1)).cuda()
            SoS_token_label = max_seq_length
            EoS_token = nn.Parameter( torch.randn(1,embedding_size, 1)).cuda()
            EoS_token_label = max_seq_length + 1
            debug = False
            print("\n\n\n           BE PATIENT! Transformers are slow to train --- especially on typical university lab hardware\n")
            for epoch in range(epochs_for_xformer_training):                                                              
                print("")
                print("\nepoch index: ", epoch)
                running_xformer_loss =  0.0
                ##  I am using a dataloader that expects a list of images in a batch followed by 
                ##  a list of the corresponding target label integers.  In our case, though, we only
                ##  have the images.  I synthesize the target label integers separately in what follows.
                for training_iter, data in enumerate(self.train_dataloader):                                    
                    master_decoder_optimizer.zero_grad()
                    input_images, _ = data                              
                    input_images = input_images.cuda()
                    input_images = normed_tensor(input_images)
                    batch_size = input_images.shape[0]                   ##  This may change at the end of an epoch
                    ##  Preparing the end tokens, SoS and EoS, for feeding into the transformer:
                    SoS_token_batch = SoS_token.repeat(batch_size,1,1)   ##  Repeat for each batch instance
                    SoS_token_batch = torch.transpose(SoS_token_batch, 1,2)
                    EoS_token_batch = EoS_token.repeat(batch_size,1,1)   ##  Repeat for each batch instance
                    EoS_token_batch = torch.transpose(EoS_token_batch, 1,2)
                    ##  Feeding the input image batch into the VQGAN Encoder:
                    z = self.vqgan_encoder(normed_tensor(input_images)).cuda()
                    z = self.pre_vq_convo(z)
                    ##  The Vector Quantizer will give us both the integer indices (as onehot vecs) and
                    ##  the codebook vectors associated with the input batch:
                    vq_loss, quantized, perplexity, encoding_indices_as_onehot_vecs =  self.vqgan_vector_quantizer(z)
                    ##  We need to turn the onehot vecs for the integer indices into actual integer values:
                    encoding_indices_as_ints = encoding_indices_as_onehot_vecs.argmax(1) 
                    ##  Now we must prepare the ground-truth target labels for the transformer:
                    indices_tensor = torch.tensor(encoding_indices_as_ints)
                    indices_tensor = indices_tensor.view(z.shape[0], -1).cuda()
                    target_labels = torch.zeros(size=(batch_size, max_seq_length), dtype=torch.int64).cuda()
                    target_labels[:,1:-1] = indices_tensor
                    target_labels[:,0] = SoS_token_label
                    target_labels[:,-1] = EoS_token_label
                    ##  Let's now process the codebook vectors that correspond to the above integer indices since
                    ##  they are going to be used as the embedding vectors associated with the above integer indices:
                    quantized =  quantized.reshape(z.shape[0], z.shape[1], -1).cuda()
                    ##  We need to now synthesize the input tensor for the transformer: 
                    input_tensor = torch.transpose(quantized, 1,2).cuda()   
                    input_tensor = torch.cat( (SoS_token_batch, input_tensor, EoS_token_batch), dim=1 )
                    predicted_indices = torch.zeros(batch_size, max_seq_length, dtype=torch.int64).cuda()
                    mask = torch.ones(1, dtype=int)                         ## initialize the mask                      
                    LOSS = 0.0
                    for word_index in range(1,input_tensor.shape[1]):
                        masked_input_seq = master_decoder.apply_mask(input_tensor, mask)                                
                        predicted_word_logprobs, predicted_word_index_values = master_decoder(input_tensor, mask)
                        predicted_indices[:,word_index] = predicted_word_index_values
                        loss = criterion(predicted_word_logprobs, target_labels[:, word_index])           
                        LOSS += loss
                        mask = torch.cat( ( mask, torch.ones(1, dtype=int) ) )                                          
                    predicted_indices = np.array(predicted_indices.cpu())
                    LOSS.backward()
                    master_decoder_optimizer.step_and_update_lr()                                                       
                    loss_normed = LOSS.item() / input_tensor.shape[0]
                    running_loss += loss_normed
                    prev_seq_logprobs  =  predicted_word_logprobs
                    if training_iter % 100 == 99:    
                        avg_loss = running_loss / float(100)
                        training_loss_tally.append(avg_loss)
                        running_loss = 0.0
                        current_time = time.perf_counter()
                        time_elapsed = current_time-start_time
                        print("[epoch:%2d/%d  iter:%4d  elapsed_time: %4d secs]     loss: %.4f" % (epoch+1,self.dl_studio.epochs,training_iter+1,time_elapsed,avg_loss)) 
                        accum_times.append(current_time-start_time)
                ##  At the beginning of the training session, the designated checkpoint_dir has already been flushed
                torch.save(master_decoder.state_dict(), checkpoint_dir + "/master_decoder_" +  str(epoch))
            print("\nFinished Training\n")
            plt.figure(figsize=(10,5))
            plt.title("FG Training Loss vs. Iterations")
            plt.plot(training_loss_tally)
            plt.xlabel("iterations")
            plt.ylabel("training loss")
            plt.legend(["Plot of loss versus iterations"], fontsize="x-large")
            plt.savefig("training_loss_FG_" +  str(self.dl_studio.epochs) + ".png")
            plt.show()



        def run_code_for_evaluating_transformer_based_modeling_using_VQGAN(self, vqgan, max_seq_length, embedding_size, codebook_size, 
                num_basic_decoders, num_atten_heads, masking, xformer_checkpoint, visualization_dir = "vqgan_xformer_visualization_dir" ):
            """
            After the VQGAN transformer has been trained as described in the prevous "run_code_" method, we need
            to test the transformer model on previously unseen images.
            """
            if os.path.exists(visualization_dir):  
                """
                Clear out the previous outputs in the visualization directory
                """
                files = glob.glob(visualization_dir + "/*")
                for file in files: 
                    if os.path.isfile(file): 
                        os.remove(file) 
                    else: 
                        files = glob.glob(file + "/*") 
                        list(map(lambda x: os.remove(x), files)) 
            else: 
                os.mkdir(visualization_dir)   

            def normed_tensor(x):
                norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
                return x / (norm_factor + 1e-10)

            self.vqgan_encoder.load_state_dict(torch.load("checkpoint_dir/vqgan_encoder_99"))
            self.vqgan_decoder.load_state_dict(torch.load("checkpoint_dir/vqgan_decoder_99"))
            self.vqgan_vector_quantizer.load_state_dict(torch.load("checkpoint_dir/vqgan_vector_quantizer_99"))
            self.vqgan_pre_vq_convo.load_state_dict(torch.load("checkpoint_dir/vqgan_pre_vq_convo_99"))
            self.vqgan_post_vq_convo.load_state_dict(torch.load("checkpoint_dir/vqgan_post_vq_convo_99"))
            self.codebook = super(DLStudio.VQGAN, self).VectorQuantizerEMA.static_codebook
            self.codebook.load_state_dict(torch.load("checkpoint_dir/codebook_99"))

            self.vqgan_encoder.cuda()
            self.vqgan_decoder.cuda()
            self.vqgan_vector_quantizer.cuda()
            self.vqgan_pre_vq_convo.cuda()
            self.vqgan_post_vq_convo.cuda()
            self.codebook.cuda()

            vocab_size = codebook_size
            xformer = DLStudio.TransformerFG( max_seq_length, embedding_size, vocab_size)
            master_decoder = DLStudio.MasterDecoderWithMasking(xformer, num_basic_decoders, num_atten_heads, masking).cuda()
            master_decoder.load_state_dict(torch.load(xformer_checkpoint))
            master_decoder.cuda()

            SoS_token = nn.Parameter( torch.randn(1,embedding_size, 1)).cuda()
            SoS_token_label = max_seq_length
            EoS_token = nn.Parameter( torch.randn(1,embedding_size, 1)).cuda()
            EoS_token_label = max_seq_length + 1

            with torch.no_grad():
                for i, data in enumerate(self.test_dataloader):                                    
                    print("\n\n\n=========Showing VQGAN Tranformer results for test batch %d===============" % i)
                    torch.set_printoptions(edgeitems=10_000, linewidth=120)
                    test_images, _ = data     
                    test_images = test_images.cuda()
                    batch_size = test_images.shape[0]             
                    z = self.vqgan_encoder(normed_tensor(test_images)).cuda()
                    input_shape = z.shape                               ## to be used later on the output of the transformer
                    z = self.pre_vq_convo(z)
                    _, quantized, _, encoding_indices_as_onehot_vecs =  self.vqgan_vector_quantizer(z)
                    encoding_indices_as_ints = encoding_indices_as_onehot_vecs.argmax(1) 
                    indices_tensor = torch.tensor(encoding_indices_as_ints)
                    indices_tensor = indices_tensor.view(z.shape[0], -1).cuda()
                    quantized =  quantized.reshape(z.shape[0], z.shape[1], -1).cuda()
                    test_images_tensor = torch.transpose(quantized, 1,2).cuda()
                    SoS_token_batch = SoS_token.repeat(batch_size,1,1)   ##  Repeat for each batch instance
                    SoS_token_batch = torch.transpose(SoS_token_batch, 1,2)
                    EoS_token_batch = EoS_token.repeat(batch_size,1,1)   ##  Repeat for each batch instance
                    EoS_token_batch = torch.transpose(EoS_token_batch, 1,2)
                    test_images_tensor = torch.cat( (SoS_token_batch, test_images_tensor, EoS_token_batch), dim=1 )
                    mask = torch.ones(1, dtype=int)                         ## initialize the mask                      
                    predicted_indices = torch.zeros(batch_size, max_seq_length, dtype=torch.int64).cuda()
                    for word_index in range(1,test_images_tensor.shape[1]):
                        masked_input_seq = master_decoder.apply_mask(test_images_tensor, mask)                                
                        predicted_word_logprobs, predicted_word_index_values = master_decoder(test_images_tensor, mask)
                        predicted_indices[:,word_index] = predicted_word_index_values
#                    print("\n\nPredicted indices for the images: ", predicted_indices)
                    torch.set_printoptions(profile="default")
                    ##  We now convert the integer-valued index values for the codebook vectors into onehot vector
                    ##  representations of the same.  Subsequently, we matix-multiply the tensor of one-hot vectors
                    ##  with the "matrix" that represents the codebook vector weights to get the sequence of the codebook
                    ##  vectors for the output of the transformer. Thiese can then be fed into the VQGAN Decoder to 
                    ##  to construct the output image:
                    onehot_encodings = torch.zeros(batch_size * (max_seq_length-2), codebook_size).cuda()  ## max_seq_length inludes 
                                                                                                           ##  the two end tokens
                    onehot_encodings.scatter_(1, predicted_indices[1:-1], 1)    
                    quantized = torch.matmul(onehot_encodings, self.codebook.weight)
                    quantized = quantized.view(input_shape)
                    z = self.post_vq_convo(quantized)
                    decoder_out = self.vqgan_decoder(z)
                    decoder_out = normed_tensor(decoder_out)
#                    self.display_2_images(test_images, decoder_out)
                    together = torch.zeros( test_images.shape[0], test_images.shape[1], test_images.shape[2], 2 * test_images.shape[3], dtype=torch.float )
                    together[:,:,:,0:test_images.shape[3]]  =  test_images
                    together[:,:,:,test_images.shape[3]:]  =   decoder_out 
                    plt.figure(figsize=(40,20))
                    plt.imshow(np.transpose(torchvision.utils.make_grid(together.cpu(), normalize=False, padding=3, pad_value=255).cpu(), (1,2,0)))
                    plt.title("VQGAN Output Images for iteration %d" % i)
                    plt.savefig(visualization_dir + "/vqgan_transformer_decoder_out_%s" % str(i) + ".png")
                    plt.show()


    ########################################  Start Definition of Inner Class TransformerFG  ####################################

    class TransformerFG(nn.Module):             
        """
        I have borrowed from the DLStudio's Transformers module.  "FG" stands for "First Generation" --- which is the Transformer
        as originally proposed by Vaswani et al.
        """
        def __init__(self, max_seq_length, embedding_size, vocab_size, num_warmup_steps=None, optimizer_params=None):
            super(DLStudio.TransformerFG, self).__init__()
            self.max_seq_length = max_seq_length
            self.embedding_size = embedding_size
            self.num_warmup_steps = num_warmup_steps
            self.optimizer_params = optimizer_params
            self.vocab_size = vocab_size
    
    class EmbeddingGenerator(nn.Module):
        def __init__(self, xformer, embedding_size):
            super(DLStudio.EmbeddingGenerator, self).__init__()
            self.vocab_size =  xformer.vocab_size
            self.embedding_size = embedding_size                                             
            self.max_seq_length = xformer.max_seq_length                                                     
            self.embed = nn.Embedding(self.vocab_size, embedding_size)

        def forward(self, sentence_tensor):                                                                 
            sentence_tensor = sentence_tensor
            ## Let's say your batch_size is 4 and that each sentence has a max_seq_length of 10.
            ## The sentence_tensor argument will now be of shape [4,10].  If the embedding size is
            ## is 512, the following call will return a tensor of shape [4,10,512)
            word_embeddings = self.embed(sentence_tensor)
            position_coded_word_embeddings = self.apply_positional_encoding( word_embeddings )
            return position_coded_word_embeddings

        def apply_positional_encoding(self, sentence_tensor):
            position_encodings = torch.zeros_like( sentence_tensor ).float()
            ## Calling unsqueeze() with arg 1 causes the "row tensor" to turn into a "column tensor"
            ##    which is needed in the products shown below. We create a 2D pattern by 
            ##    taking advantage of how PyTorch has overloaded the definition of the infix '*' 
            ##    tensor-tensor multiplication operator.  It in effect creates an output-product of
            ##    of what is essentially a column vector with what is essentially a row vector.
            word_positions = torch.arange(0, self.max_seq_length).unsqueeze(1)            
            div_term =  1.0 / (100.0 ** ( 2.0 * torch.arange(0, self.embedding_size, 2) / float(self.embedding_size) ))
            position_encodings[:, :, 0::2] =  torch.sin(word_positions * div_term)                             
            position_encodings[:, :, 1::2] =  torch.cos(word_positions * div_term)                             
            return sentence_tensor + position_encodings


    ###################################  Self Attention Code for TransformerFG  ###########################################

    class SelfAttention(nn.Module):
        """
        Borrowed from the Transformers module of DLStudio
        """  
        def __init__(self, xformer, num_atten_heads):
            super(DLStudio.SelfAttention, self).__init__()
            self.max_seq_length = xformer.max_seq_length                                                     
            self.embedding_size = xformer.embedding_size
            self.num_atten_heads = num_atten_heads
            self.qkv_size = self.embedding_size // num_atten_heads
            self.attention_heads_arr = nn.ModuleList( [DLStudio.AttentionHead(self.max_seq_length, 
                                    self.qkv_size, num_atten_heads)  for _ in range(num_atten_heads)] )           

        def forward(self, sentence_tensor):                                                                       
            concat_out_from_atten_heads = torch.zeros( sentence_tensor.shape[0], self.max_seq_length, 
                                                                  self.num_atten_heads * self.qkv_size).float()
            for i in range(self.num_atten_heads):                                                                 
                sentence_embed_slice = sentence_tensor[:, :, i * self.qkv_size : (i+1) * self.qkv_size]
                concat_out_from_atten_heads[:, :, i * self.qkv_size : (i+1) * self.qkv_size] =          \
                                                               self.attention_heads_arr[i](sentence_embed_slice)   
            return concat_out_from_atten_heads


    class AttentionHead(nn.Module):
        """
        Borrowed from the Transformers module of DLStudio
        """  
        def __init__(self,  max_seq_length, qkv_size, num_atten_heads):
            super(DLStudio.AttentionHead, self).__init__()
            self.qkv_size = qkv_size
            self.max_seq_length = max_seq_length
            self.WQ =  nn.Linear( self.qkv_size, self.qkv_size )                                                      
            self.WK =  nn.Linear( self.qkv_size, self.qkv_size )                                                      
            self.WV =  nn.Linear( self.qkv_size, self.qkv_size )                                                      
            self.softmax = nn.Softmax(dim=-1)                                                                          

        def forward(self, sent_embed_slice):           ## sent_embed_slice == sentence_embedding_slice                
            Q = self.WQ( sent_embed_slice )                                                                           
            K = self.WK( sent_embed_slice )                                                                           
            V = self.WV( sent_embed_slice )                                                                           
            A = K.transpose(2,1)                                                                                      
            QK_dot_prod = Q @ A                                                                                       
            rowwise_softmax_normalizations = self.softmax( QK_dot_prod )                                              
            Z = rowwise_softmax_normalizations @ V                                                                    
            coeff = 1.0/torch.sqrt(torch.tensor([self.qkv_size]).float()).cuda()
            Z = coeff * Z                                                                          
            return Z


    #########################################  Basic Decoder Class for TransformerFG  #####################################

    class BasicDecoderWithMasking(nn.Module):
        """
        Borrowed from the Transformers module of DLStudio
        """  
        def __init__(self, xformer, num_atten_heads, masking=True):
            super(DLStudio.BasicDecoderWithMasking, self).__init__()
            self.masking = masking
            self.embedding_size = xformer.embedding_size                                             
            self.max_seq_length = xformer.max_seq_length                                                     
            self.num_atten_heads = num_atten_heads
            self.qkv_size = self.embedding_size // num_atten_heads
            self.self_attention_layer = DLStudio.SelfAttention(xformer, num_atten_heads)
            self.norm1 = nn.LayerNorm(self.embedding_size)
            self.norm2 = nn.LayerNorm(self.embedding_size)
            ## What follows are the linear layers for the FFN (Feed Forward Network) part of a BasicDecoder
            self.W1 =  nn.Linear( self.embedding_size, 4 * self.embedding_size )
            self.W2 =  nn.Linear( 4 * self.embedding_size, self.embedding_size ) 
            self.norm3 = nn.LayerNorm(self.embedding_size)

        def forward(self, sentence_tensor, mask):   
            masked_sentence_tensor = self.apply_mask(sentence_tensor, mask)
            Z_concatenated = self.self_attention_layer(masked_sentence_tensor).cuda()
            Z_out = self.norm1(Z_concatenated + masked_sentence_tensor)                     
            ## for FFN:
            basic_decoder_out =  nn.ReLU()(self.W1( Z_out.view( sentence_tensor.shape[0], self.max_seq_length, -1) ))                  
            basic_decoder_out =  self.W2( basic_decoder_out )                                                    
            basic_decoder_out = basic_decoder_out.view(sentence_tensor.shape[0], self.max_seq_length, self.embedding_size )
            basic_decoder_out =  basic_decoder_out  + Z_out 
            basic_decoder_out = self.norm3( basic_decoder_out )
            return basic_decoder_out

        def apply_mask(self, sentence_tensor, mask):  
            out = torch.zeros_like(sentence_tensor).float().cuda()
            out[:,:len(mask),:] = sentence_tensor[:,:len(mask),:] 
            return out    


    ######################################  MasterDecoder Class for TransformerFG #########################################

    class MasterDecoderWithMasking(nn.Module):
        """
        Borrowed from the Transformers module of DLStudio
        """  
        def __init__(self, xformer, num_basic_decoders, num_atten_heads, masking=True):
            super(DLStudio.MasterDecoderWithMasking, self).__init__()
            self.masking = masking
            self.max_seq_length = xformer.max_seq_length
            self.embedding_size = xformer.embedding_size
            self.vocab_size = xformer.vocab_size                                             
            self.basic_decoder_arr = nn.ModuleList([DLStudio.BasicDecoderWithMasking( xformer,
                                                    num_atten_heads, masking) for _ in range(num_basic_decoders)])  
            ##  Need the following layer because we want the prediction of each target word to be a probability 
            ##  distribution over the target vocabulary. The conversion to probs would be done by the criterion 
            ##  nn.CrossEntropyLoss in the training loop:
            self.out = nn.Linear(self.embedding_size, self.vocab_size)                                          

        def forward(self, sentence_tensor, mask):                                                   
            out_tensor = sentence_tensor
            for i in range(len(self.basic_decoder_arr)):                                                 
                out_tensor = self.basic_decoder_arr[i](out_tensor, mask)                              
            word_index = mask.shape[0]
            last_word_tensor = out_tensor[:,word_index]                                      
            last_word_onehot = self.out(last_word_tensor)        
            output_word_logprobs = nn.LogSoftmax(dim=1)(last_word_onehot)                                     
            _, idx_max = torch.max(output_word_logprobs, 1)                
            ## the logprobs are over the entire vocabulary of the tokenizer
            return output_word_logprobs, idx_max


        def apply_mask(self, sentence_tensor, mask):  
            out = torch.zeros_like(sentence_tensor).float().cuda()
            out[:,:len(mask),:] = sentence_tensor[:,:len(mask),:] 
            return out    


    ###########################  ScheduledOptim Code for TransformerFG #############################

    class ScheduledOptim():
        """
        As in the Transformers module of DLStudio, for the scheduling of the learning rate
        during the warm-up phase of training TransformerFG, I have borrowed the class shown below
        from the GitHub code made available by Yu-Hsiang Huang at:

                https://github.com/jadore801120/attention-is-all-you-need-pytorch
        """
        def __init__(self, optimizer, lr_mul, d_model, n_warmup_steps):
            self._optimizer = optimizer
            self.lr_mul = lr_mul
            self.d_model = d_model
            self.n_warmup_steps = n_warmup_steps
            self.n_steps = 0

        def step_and_update_lr(self):
            "Step with the inner optimizer"
            self._update_learning_rate()
            self._optimizer.step()
    
        def zero_grad(self):
            "Zero out the gradients with the inner optimizer"
            self._optimizer.zero_grad()
    
        def _get_lr_scale(self):
            d_model = self.d_model
            n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps
            return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5))
    
        def _update_learning_rate(self):
            ''' Learning rate scheduling per step '''
            self.n_steps += 1
            lr = self.lr_mul * self._get_lr_scale()
            for param_group in self._optimizer.param_groups:
                param_group['lr'] = lr



    ###%%%
    #####################################################################################################################
    ####################################  Start Definition of Inner Class TextClassification  ###########################

    class TextClassification(nn.Module):             
        """
        The purpose of this inner class is to be able to use the DLStudio platform for simple 
        experiments in text classification.  Consider, for example, the problem of automatic 
        classification of variable-length user feedback: you want to create a neural network
        that can label an uploaded product review of arbitrary length as positive or negative.  
        One way to solve this problem is with a recurrent neural network in which you use a 
        hidden state for characterizing a variable-length product review with a fixed-length 
        state vector.  This inner class allows you to carry out such experiments.

        Class Path:  DLStudio -> TextClassification 
        """
        def __init__(self, dl_studio, dataserver_train=None, dataserver_test=None, dataset_file_train=None, 
                                                               dataset_file_test=None, display_train_loss=False):
            super(DLStudio.TextClassification, self).__init__()
            self.dl_studio = dl_studio
            self.dataserver_train = dataserver_train
            self.dataserver_test = dataserver_test
            self.display_train_loss = display_train_loss

        class SentimentAnalysisDataset(torch.utils.data.Dataset):
            """
            The sentiment analysis datasets that I have made available were extracted from
            an archive of user feedback comments as made available by Amazon for the year
            2007.  The original archive contains user feedback on 25 product categories. 
            For each product category, there are two files named 'positive.reviews' and
            'negative.reviews', with each file containing 1000 reviews. I believe that
            characterizing the reviews as 'positive' or 'negative' was carried out by 
            human annotators. Regardless, the reviews in these two files can be used to 
            train a neural network whose purpose would be to automatically characterize
            a product as being positive or negative. 

            I have extracted the following datasets extracted from the Amazon archive:

                 sentiment_dataset_train_200.tar.gz        vocab_size = 43,285
                 sentiment_dataset_test_200.tar.gz  

                 sentiment_dataset_train_40.tar.gz         vocab_size = 17,001
                 sentiment_dataset_test_40.tar.gz    

                 sentiment_dataset_train_3.tar.gz          vocab_size = 3,402
                 sentiment_dataset_test_3.tar.gz    

            The integer in the name of each dataset is the number of reviews collected 
            from the 'positive.reviews' and the 'negative.reviews' files for each product
            category.  Therefore, the dataset with 200 in its name has a total of 400 
            reviews for each product category.

            As to why I am presenting these three different datasets, note that, as shown
            above, the size of the vocabulary depends on the number of reviews selected
            and the size of the vocabulary has a strong bearing on how long it takes to 
            train an algorithm for text classification. For one simple reason for that: 
            the size of the one-hot representation for the words equals the size of the 
            vocabulary.  Therefore, the one-hot representation for the words for the 
            dataset with 200 in its name will be a one-axis tensor of size 43,285.

            For a purely feedforward network, it is not a big deal for the input tensors
            to be size Nx43285 where N is the number of words in a review.  And even for
            RNNs with simple feedback, that does not slow things down.  However, when 
            using GRUs, it's an entirely different matter if you are tying to run your
            experiments on, say, a laptop with a Quadro GPU.  Hence the reason for providing
            the datasets with 200 and 40 reviews.  The dataset with just 3 reviews is for
            debugging your code.

            Class Path:  DLStudio -> TextClassification -> SentimentAnalysisDataset
            """
            def __init__(self, dl_studio, train_or_test, dataset_file):
                super(DLStudio.TextClassification.SentimentAnalysisDataset, self).__init__()
                self.train_or_test = train_or_test
                root_dir = dl_studio.dataroot
                f = gzip.open(root_dir + dataset_file, 'rb')
                dataset = f.read()
                if train_or_test == 'train':
                    if sys.version_info[0] == 3:
                        self.positive_reviews_train, self.negative_reviews_train, self.vocab = pickle.loads(dataset, encoding='latin1')
                    else:
                        self.positive_reviews_train, self.negative_reviews_train, self.vocab = pickle.loads(dataset)
                    self.categories = sorted(list(self.positive_reviews_train.keys()))
                    self.category_sizes_train_pos = {category : len(self.positive_reviews_train[category]) for category in self.categories}
                    self.category_sizes_train_neg = {category : len(self.negative_reviews_train[category]) for category in self.categories}
                    self.indexed_dataset_train = []
                    for category in self.positive_reviews_train:
                        for review in self.positive_reviews_train[category]:
                            self.indexed_dataset_train.append([review, category, 1])
                    for category in self.negative_reviews_train:
                        for review in self.negative_reviews_train[category]:
                            self.indexed_dataset_train.append([review, category, 0])
                    random.shuffle(self.indexed_dataset_train)
                elif train_or_test == 'test':
                    if sys.version_info[0] == 3:
                        self.positive_reviews_test, self.negative_reviews_test, self.vocab = pickle.loads(dataset, encoding='latin1')
                    else:
                        self.positive_reviews_test, self.negative_reviews_test, self.vocab = pickle.loads(dataset)
                    self.vocab = sorted(self.vocab)
                    self.categories = sorted(list(self.positive_reviews_test.keys()))
                    self.category_sizes_test_pos = {category : len(self.positive_reviews_test[category]) for category in self.categories}
                    self.category_sizes_test_neg = {category : len(self.negative_reviews_test[category]) for category in self.categories}
                    self.indexed_dataset_test = []
                    for category in self.positive_reviews_test:
                        for review in self.positive_reviews_test[category]:
                            self.indexed_dataset_test.append([review, category, 1])
                    for category in self.negative_reviews_test:
                        for review in self.negative_reviews_test[category]:
                            self.indexed_dataset_test.append([review, category, 0])
                    random.shuffle(self.indexed_dataset_test)

            def get_vocab_size(self):
                return len(self.vocab)

            def one_hotvec_for_word(self, word):
                word_index =  self.vocab.index(word)
                hotvec = torch.zeros(1, len(self.vocab))
                hotvec[0, word_index] = 1
                return hotvec

            def review_to_tensor(self, review):
                review_tensor = torch.zeros(len(review), len(self.vocab))
                for i,word in enumerate(review):
                    review_tensor[i,:] = self.one_hotvec_for_word(word)
                return review_tensor

            def sentiment_to_tensor(self, sentiment):
                """
                Sentiment is ordinarily just a binary valued thing.  It is 0 for negative
                sentiment and 1 for positive sentiment.  We need to pack this value in a
                two-element tensor.
                """        
                sentiment_tensor = torch.zeros(2)
                if sentiment == 1:
                    sentiment_tensor[1] = 1
                elif sentiment == 0: 
                    sentiment_tensor[0] = 1
                sentiment_tensor = sentiment_tensor.type(torch.long)
                return sentiment_tensor

            def __len__(self):
                if self.train_or_test == 'train':
                    return len(self.indexed_dataset_train)
                elif self.train_or_test == 'test':
                    return len(self.indexed_dataset_test)

            def __getitem__(self, idx):
                sample = self.indexed_dataset_train[idx] if self.train_or_test == 'train' else self.indexed_dataset_test[idx]
                review = sample[0]
                review_category = sample[1]
                review_sentiment = sample[2]
                review_sentiment = self.sentiment_to_tensor(review_sentiment)
                review_tensor = self.review_to_tensor(review)
                category_index = self.categories.index(review_category)
                sample = {'review'       : review_tensor, 
                          'category'     : category_index, # should be converted to tensor, but not yet used
                          'sentiment'    : review_sentiment }
                return sample

        def load_SentimentAnalysisDataset(self, dataserver_train, dataserver_test ):   
            self.train_dataloader = torch.utils.data.DataLoader(dataserver_train,
                        batch_size=self.dl_studio.batch_size,shuffle=True, num_workers=1)
            self.test_dataloader = torch.utils.data.DataLoader(dataserver_test,
                               batch_size=self.dl_studio.batch_size,shuffle=False, num_workers=1)

        class TEXTnet(nn.Module):
            """
            This network is meant for semantic classification of variable-length sentiment 
            data.  Based on my limited testing, the performance of this network is very
            poor because it has no protection against vanishing gradients when used in an
            RNN.

            Class Path:  DLStudio -> TextClassification -> TEXTnet
            """
            def __init__(self, input_size, hidden_size, output_size):
                super(DLStudio.TextClassification.TEXTnet, self).__init__()
                self.input_size = input_size
                self.hidden_size = hidden_size
                self.output_size = output_size
                self.combined_to_hidden = nn.Linear(input_size + hidden_size, hidden_size)
                self.combined_to_middle = nn.Linear(input_size + hidden_size, 100)
                self.middle_to_out = nn.Linear(100, output_size)     
                self.logsoftmax = nn.LogSoftmax(dim=1)
                self.dropout = nn.Dropout(p=0.1)

            def forward(self, input, hidden):
                combined = torch.cat((input, hidden), 1)
                hidden = self.combined_to_hidden(combined)
                hidden = torch.tanh(hidden)                   
                out = self.combined_to_middle(combined)
                out = nn.functional.relu(out)
                out = self.dropout(out)
                out = self.middle_to_out(out)
                out = self.logsoftmax(out)
                return out,hidden         

            def init_hidden(self):
                hidden = torch.zeros(1, self.hidden_size)
                return hidden


        class TEXTnetOrder2(nn.Module):
            """
            In this variant of the TEXTnet network, the value of hidden as used at each
            time step also includes its value at the previous time step.  This fact, not
            directly apparent by the definition of the class shown below, is made possible
            by the last parameter, cell, in the header of forward().  As you can see below,
            at the end of forward(), the value of the cell goes through a linear layer
            and through a sigmoid nonlinearity. By the way, since the sigmoid saturates at 0
            and 1, it can act like a switch. Later when I use this class in the training
            function, you will see the cell values being used in such a manner that the
            hidden state at each time step is mixed with the hidden state at the previous
            time step, but only to the extent allowed by the switching action of the Sigmoid.

            Class Path:  DLStudio -> TextClassification -> TEXTnetOrder2
            """
            def __init__(self, input_size, hidden_size, output_size):
                super(DLStudio.TextClassification.TEXTnetOrder2, self).__init__()
                self.input_size = input_size
                self.hidden_size = hidden_size
                self.output_size = output_size
                self.combined_to_hidden = nn.Linear(input_size + 2*hidden_size, hidden_size)
                self.combined_to_middle = nn.Linear(input_size + 2*hidden_size, 100)
                self.middle_to_out = nn.Linear(100, output_size)     
                self.logsoftmax = nn.LogSoftmax(dim=1)
                self.dropout = nn.Dropout(p=0.1)
                # for the cell
                self.linear_for_cell = nn.Linear(hidden_size, hidden_size)

            def forward(self, input, hidden, cell):
                combined = torch.cat((input, hidden, cell), 1)
                hidden = self.combined_to_hidden(combined)
                hidden = torch.tanh(hidden)                     
                out = self.combined_to_middle(combined)
                out = nn.functional.relu(out)
                out = self.dropout(out)
                out = self.middle_to_out(out)
                out = self.logsoftmax(out)
                hidden_clone = hidden.clone()
                cell = torch.sigmoid(self.linear_for_cell(hidden_clone))
                return out,hidden,cell         

            def initialize_cell(self):
                weight = next(self.linear_for_cell.parameters()).data
                cell = weight.new(1, self.hidden_size).zero_()
                return cell

            def init_hidden(self):
                hidden = torch.zeros(1, self.hidden_size)
                return hidden


        class GRUnet(nn.Module):
            """
            Source: https://blog.floydhub.com/gru-with-pytorch/
            with the only modification that the final output of forward() is now
            routed through LogSoftmax activation. 

            In the definition shown below, input_size is the size of the vocabulary, the 
            hidden_size is typically 512, and the output_size is set to 2 for the two
            sentiments, positive and negative. 

            Class Path: DLStudio  ->  TextClassification  ->  GRUnet
            """
            def __init__(self, input_size, hidden_size, output_size, num_layers, drop_prob=0.2):
                super(DLStudio.TextClassification.GRUnet, self).__init__()
                self.hidden_size = hidden_size
                self.num_layers = num_layers
                self.gru = nn.GRU(input_size, hidden_size, num_layers)
                self.fc = nn.Linear(hidden_size, output_size)
                self.relu = nn.ReLU()
                self.logsoftmax = nn.LogSoftmax(dim=1)
                
            def forward(self, x, h):
                out, h = self.gru(x, h)
                out = self.fc(self.relu(out[:,-1]))
                out = self.logsoftmax(out)
                return out, h

            def init_hidden(self):
                weight = next(self.parameters()).data
                #                                     batch_size   
                hidden = weight.new(  self.num_layers,     1,         self.hidden_size   ).zero_()
                return hidden

        def save_model(self, model):
            "Save the trained model to a disk file"
            torch.save(model.state_dict(), self.dl_studio.path_saved_model)


        def run_code_for_training_with_TEXTnet(self, net, display_train_loss=False):        
            filename_for_out = "performance_numbers_" + str(self.dl_studio.epochs) + ".txt"
            FILE = open(filename_for_out, 'w')
            net.to(self.dl_studio.device)
            ## Note that the TEXTnet and TEXTnetOrder2 both produce LogSoftmax output:
            criterion = nn.NLLLoss()
            accum_times = []
            optimizer = optim.SGD(net.parameters(), 
                         lr=self.dl_studio.learning_rate, momentum=self.dl_studio.momentum)
            start_time = time.perf_counter()
            training_loss_tally = []
            for epoch in range(self.dl_studio.epochs):  
                print("")
                running_loss = 0.0
                for i, data in enumerate(self.train_dataloader):    
                    hidden = net.init_hidden().to(self.dl_studio.device)              
                    review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
                    review_tensor = review_tensor.to(self.dl_studio.device)
                    sentiment = sentiment.to(self.dl_studio.device)
                    optimizer.zero_grad()
                    input = torch.zeros(1,review_tensor.shape[2])
                    input = input.to(self.dl_studio.device)
                    for k in range(review_tensor.shape[1]):
                        input[0,:] = review_tensor[0,k]
                        output, hidden = net(input, hidden)
                    loss = criterion(output, torch.argmax(sentiment,1))
                    running_loss += loss.item()
#                    loss.backward(retain_graph=True)        
                    loss.backward()        
                    optimizer.step()
                    if i % 200 == 199:    
                        avg_loss = running_loss / float(200)
                        training_loss_tally.append(avg_loss)
                        current_time = time.perf_counter()
                        time_elapsed = current_time-start_time
                        print("[epoch:%d  iter:%4d  elapsed_time: %4d secs]     loss: %.5f" % (epoch+1,i+1, time_elapsed,avg_loss))
                        accum_times.append(current_time-start_time)
                        FILE.write("%.3f\n" % avg_loss)
                        FILE.flush()
                        running_loss = 0.0
            print("\nFinished Training\n")
            self.save_model(net)
            if display_train_loss:
                plt.figure(figsize=(10,5))
                plt.title("Training Loss vs. Iterations")
                plt.plot(training_loss_tally)
                plt.xlabel("iterations")
                plt.ylabel("training loss")
#                plt.legend()
                plt.legend(["Plot of loss versus iterations"], fontsize="x-large")
                plt.savefig("training_loss.png")
                plt.show()


        def run_code_for_training_with_TEXTnetOrder2(self, net, display_train_loss=False):        
            filename_for_out = "performance_numbers_" + str(self.dl_studio.epochs) + ".txt"
            FILE = open(filename_for_out, 'w')
            net.to(self.dl_studio.device)
            ## Note that the TEXTnet and TEXTnetOrder2 both produce LogSoftmax output:
            criterion = nn.NLLLoss()
            accum_times = []
            optimizer = optim.SGD(net.parameters(), 
                         lr=self.dl_studio.learning_rate, momentum=self.dl_studio.momentum)
            start_time = time.perf_counter()
            training_loss_tally = []
            for epoch in range(self.dl_studio.epochs):  
                print("")
                running_loss = 0.0
                for i, data in enumerate(self.train_dataloader):    
                    hidden = net.init_hidden().to(self.dl_studio.device)              
                    cell_prev = net.initialize_cell().to(self.dl_studio.device)
                    cell_prev_2_prev = net.initialize_cell().to(self.dl_studio.device)
                    review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
                    review_tensor = review_tensor.to(self.dl_studio.device)
                    sentiment = sentiment.to(self.dl_studio.device)
                    optimizer.zero_grad()
                    input = torch.zeros(1,review_tensor.shape[2])
                    input = input.to(self.dl_studio.device)
                    for k in range(review_tensor.shape[1]):
                        input[0,:] = review_tensor[0,k]
                        output, hidden, cell = net(input, hidden, cell_prev_2_prev)
                        if k == 0:
                            cell_prev = cell
                        else:
                            cell_prev_2_prev = cell_prev
                            cell_prev = cell
                    loss = criterion(output, torch.argmax(sentiment,1))
                    running_loss += loss.item()
                    loss.backward()        
                    optimizer.step()
                    if i % 200 == 199:    
                        avg_loss = running_loss / float(200)
                        training_loss_tally.append(avg_loss)
                        current_time = time.perf_counter()
                        time_elapsed = current_time-start_time
                        print("[epoch:%d  iter:%4d  elapsed_time: %4d secs]     loss: %.5f" % (epoch+1,i+1, time_elapsed,avg_loss))
                        accum_times.append(current_time-start_time)
                        FILE.write("%.3f\n" % avg_loss)
                        FILE.flush()
                        running_loss = 0.0
            print("\nFinished Training\n")
            self.save_model(net)
            if display_train_loss:
                plt.figure(figsize=(10,5))
                plt.title("Training Loss vs. Iterations")
                plt.plot(training_loss_tally)
                plt.xlabel("iterations")
                plt.ylabel("training loss")
#                plt.legend()
                plt.legend(["Plot of loss versus iterations"], fontsize="x-large")
                plt.savefig("training_loss.png")
                plt.show()


        def run_code_for_training_for_text_classification_with_GRU(self, net, display_train_loss=False): 
            filename_for_out = "performance_numbers_" + str(self.dl_studio.epochs) + ".txt"
            FILE = open(filename_for_out, 'w')
            net.to(self.dl_studio.device)
            ##  Note that the GRUnet now produces the LogSoftmax output:
            criterion = nn.NLLLoss()
            accum_times = []
            optimizer = optim.SGD(net.parameters(), 
                         lr=self.dl_studio.learning_rate, momentum=self.dl_studio.momentum)
            start_time = time.perf_counter()
            training_loss_tally = []
            for epoch in range(self.dl_studio.epochs):  
                print("")
                running_loss = 0.0
                for i, data in enumerate(self.train_dataloader):    
                    review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
                    review_tensor = review_tensor.to(self.dl_studio.device)
                    sentiment = sentiment.to(self.dl_studio.device)
                    ## The following type conversion needed for MSELoss:
                    ##sentiment = sentiment.float()
                    optimizer.zero_grad()
                    hidden = net.init_hidden().to(self.dl_studio.device)
                    for k in range(review_tensor.shape[1]):
                        output, hidden = net(torch.unsqueeze(torch.unsqueeze(review_tensor[0,k],0),0), hidden)
                    ## If using NLLLoss, CrossEntropyLoss
                    loss = criterion(output, torch.argmax(sentiment, 1))
                    ## If using MSELoss:
                    ## loss = criterion(output, sentiment)     
                    running_loss += loss.item()
                    loss.backward()
                    optimizer.step()
                    if i % 200 == 199:    
                        avg_loss = running_loss / float(200)
                        training_loss_tally.append(avg_loss)
                        current_time = time.perf_counter()
                        time_elapsed = current_time-start_time
                        print("[epoch:%d  iter:%4d  elapsed_time:%4d secs]     loss: %.5f" % (epoch+1,i+1, time_elapsed,avg_loss))
                        accum_times.append(current_time-start_time)
                        FILE.write("%.3f\n" % avg_loss)
                        FILE.flush()
                        running_loss = 0.0
            print("Total Training Time: {}".format(str(sum(accum_times))))
            print("\nFinished Training\n")
            self.save_model(net)
            if display_train_loss:
                plt.figure(figsize=(10,5))
                plt.title("Training Loss vs. Iterations")
                plt.plot(training_loss_tally)
                plt.xlabel("iterations")
                plt.ylabel("training loss")
#                plt.legend()
                plt.legend(["Plot of loss versus iterations"], fontsize="x-large")
                plt.savefig("training_loss.png")
                plt.show()


        def run_code_for_testing_with_TEXTnet(self, net):
            net.load_state_dict(torch.load(self.dl_studio.path_saved_model))
            net.to(self.dl_studio.device)
            classification_accuracy = 0.0
            negative_total = 0
            positive_total = 0
            confusion_matrix = torch.zeros(2,2)
            with torch.no_grad():
                for i, data in enumerate(self.test_dataloader):
                    review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
                    input = torch.zeros(1,review_tensor.shape[2]).to(self.dl_studio.device)
                    hidden = net.init_hidden().to(self.dl_studio.device)
                    for k in range(review_tensor.shape[1]):
                        input[0,:] = review_tensor[0,k]
                        output, hidden = net(input, hidden)
                    predicted_idx = torch.argmax(output).item()
                    gt_idx = torch.argmax(sentiment).item()
                    if i % 100 == 99:
                        print("   [i=%4d]    predicted_label=%d       gt_label=%d" % (i+1, predicted_idx,gt_idx))
                    if predicted_idx == gt_idx:
                        classification_accuracy += 1
                    if gt_idx == 0: 
                        negative_total += 1
                    elif gt_idx == 1:
                        positive_total += 1
                    confusion_matrix[gt_idx,predicted_idx] += 1
            print("\nOverall classification accuracy: %0.2f%%" %  (float(classification_accuracy) * 100 /float(i)))
            out_percent = np.zeros((2,2), dtype='float')
            out_percent[0,0] = "%.3f" % (100 * confusion_matrix[0,0] / float(negative_total))
            out_percent[0,1] = "%.3f" % (100 * confusion_matrix[0,1] / float(negative_total))
            out_percent[1,0] = "%.3f" % (100 * confusion_matrix[1,0] / float(positive_total))
            out_percent[1,1] = "%.3f" % (100 * confusion_matrix[1,1] / float(positive_total))
            print("\n\nNumber of positive reviews tested: %d" % positive_total)
            print("\n\nNumber of negative reviews tested: %d" % negative_total)
            print("\n\nDisplaying the confusion matrix:\n")
            out_str = "                      "
            out_str +=  "%18s    %18s" % ('predicted negative', 'predicted positive')
            print(out_str + "\n")
            for i,label in enumerate(['true negative', 'true positive']):
                out_str = "%12s:  " % label
                for j in range(2):
                    out_str +=  "%18s" % out_percent[i,j]
                print(out_str)

        def run_code_for_testing_with_TEXTnetOrder2(self, net):
            net.load_state_dict(torch.load(self.dl_studio.path_saved_model))
            net.to(self.dl_studio.device)
            classification_accuracy = 0.0
            negative_total = 0
            positive_total = 0
            confusion_matrix = torch.zeros(2,2)
            with torch.no_grad():
                for i, data in enumerate(self.test_dataloader):
                    cell_prev = net.initialize_cell()
                    cell_prev_2_prev = net.initialize_cell()
                    review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
                    input = torch.zeros(1,review_tensor.shape[2]).to(self.dl_studio.device)
                    hidden = net.init_hidden().to(self.dl_studio.device)
                    for k in range(review_tensor.shape[1]):
                        input[0,:] = review_tensor[0,k]
                        output, hidden, cell = net(input, hidden, cell_prev_2_prev)
                        if k == 0:
                            cell_prev = cell
                        else:
                            cell_prev_2_prev = cell_prev
                            cell_prev = cell
                    predicted_idx = torch.argmax(output).item()
                    gt_idx = torch.argmax(sentiment).item()
                    if i % 100 == 99:
                        print("   [i=%4d]    predicted_label=%d       gt_label=%d" % (i+1, predicted_idx,gt_idx))
                    if predicted_idx == gt_idx:
                        classification_accuracy += 1
                    if gt_idx == 0: 
                        negative_total += 1
                    elif gt_idx == 1:
                        positive_total += 1
                    confusion_matrix[gt_idx,predicted_idx] += 1
            print("\nOverall classification accuracy: %0.2f%%" %  (float(classification_accuracy) * 100 /float(i)))
            out_percent = np.zeros((2,2), dtype='float')
            out_percent[0,0] = "%.3f" % (100 * confusion_matrix[0,0] / float(negative_total))
            out_percent[0,1] = "%.3f" % (100 * confusion_matrix[0,1] / float(negative_total))
            out_percent[1,0] = "%.3f" % (100 * confusion_matrix[1,0] / float(positive_total))
            out_percent[1,1] = "%.3f" % (100 * confusion_matrix[1,1] / float(positive_total))
            print("\n\nNumber of positive reviews tested: %d" % positive_total)
            print("\n\nNumber of negative reviews tested: %d" % negative_total)
            print("\n\nDisplaying the confusion matrix:\n")
            out_str = "                      "
            out_str +=  "%18s    %18s" % ('predicted negative', 'predicted positive')
            print(out_str + "\n")
            for i,label in enumerate(['true negative', 'true positive']):
                out_str = "%12s:  " % label
                for j in range(2):
                    out_str +=  "%18s" % out_percent[i,j]
                print(out_str)


        def run_code_for_testing_text_classification_with_GRU(self, net):
            net.load_state_dict(torch.load(self.dl_studio.path_saved_model))
            net.to(self.dl_studio.device)
            classification_accuracy = 0.0
            negative_total = 0
            positive_total = 0
            confusion_matrix = torch.zeros(2,2)
            with torch.no_grad():
                for i, data in enumerate(self.test_dataloader):
                    review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
                    hidden = net.init_hidden().to(self.dl_studio.device)
                    for k in range(review_tensor.shape[1]):
                        output, hidden = net(torch.unsqueeze(torch.unsqueeze(review_tensor[0,k],0),0), hidden)
                    predicted_idx = torch.argmax(output).item()
                    gt_idx = torch.argmax(sentiment).item()
                    if i % 100 == 99:
                        print("   [i=%d]    predicted_label=%d       gt_label=%d\n\n" % (i+1, predicted_idx,gt_idx))
                    if predicted_idx == gt_idx:
                        classification_accuracy += 1
                    if gt_idx == 0: 
                        negative_total += 1
                    elif gt_idx == 1:
                        positive_total += 1
                    confusion_matrix[gt_idx,predicted_idx] += 1
            print("\nOverall classification accuracy: %0.2f%%" %  (float(classification_accuracy) * 100 /float(i)))
            out_percent = np.zeros((2,2), dtype='float')
            out_percent[0,0] = "%.3f" % (100 * confusion_matrix[0,0] / float(negative_total))
            out_percent[0,1] = "%.3f" % (100 * confusion_matrix[0,1] / float(negative_total))
            out_percent[1,0] = "%.3f" % (100 * confusion_matrix[1,0] / float(positive_total))
            out_percent[1,1] = "%.3f" % (100 * confusion_matrix[1,1] / float(positive_total))
            print("\n\nNumber of positive reviews tested: %d" % positive_total)
            print("\n\nNumber of negative reviews tested: %d" % negative_total)
            print("\n\nDisplaying the confusion matrix:\n")
            out_str = "                      "
            out_str +=  "%18s    %18s" % ('predicted negative', 'predicted positive')
            print(out_str + "\n")
            for i,label in enumerate(['true negative', 'true positive']):
                out_str = "%12s:  " % label
                for j in range(2):
                    out_str +=  "%18s" % out_percent[i,j]
                print(out_str)


    ###%%%
    #####################################################################################################################
    ##########################  Start Definition of Inner Class TextClassificationWithEmbeddings  #######################

    class TextClassificationWithEmbeddings(nn.Module):             
        """
        The text processing class described previously, TextClassification, was based on
        using one-hot vectors for representing the words.  The main challenge we faced
        with one-hot vectors was that the larger the size of the training dataset, the
        larger the size of the vocabulary, and, therefore, the larger the size of the
        one-hot vectors.  The increase in the size of the one-hot vectors led to a
        model with a significantly larger number of learnable parameters --- and, that,
        in turn, created a need for a still larger training dataset.  Sounds like a classic
        example of a vicious circle.  In this section, I use the idea of word embeddings
        to break out of this vicious circle.

        Word embeddings are fixed-sized numerical representations for words that are
        learned on the basis of the similarity of word contexts.  The original and still
        the most famous of these representations are known as the word2vec
        embeddings. The embeddings that I use in this section consist of pre-trained
        300-element word vectors for 3 million words and phrases as learned from Google
        News reports.  I access these embeddings through the popular Gensim library.
 
        Class Path:  DLStudio -> TextClassificationWithEmbeddings
        """
        def __init__(self, dl_studio,dataserver_train=None,dataserver_test=None,dataset_file_train=None,dataset_file_test=None):
            super(DLStudio.TextClassificationWithEmbeddings, self).__init__()
            self.dl_studio = dl_studio
            self.dataserver_train = dataserver_train
            self.dataserver_test = dataserver_test

        class SentimentAnalysisDataset(torch.utils.data.Dataset):
            """
            In relation to the SentimentAnalysisDataset defined for the TextClassification section of 
            DLStudio, the __getitem__() method of the dataloader must now fetch the embeddings from
            the word2vec word vectors.

            Class Path:  DLStudio -> TextClassificationWithEmbeddings -> SentimentAnalysisDataset
            """
            def __init__(self, dl_studio, train_or_test, dataset_file, path_to_saved_embeddings=None):
                super(DLStudio.TextClassificationWithEmbeddings.SentimentAnalysisDataset, self).__init__()
                import gensim.downloader as gen_api
#                self.word_vectors = gen_api.load("word2vec-google-news-300")
                self.path_to_saved_embeddings = path_to_saved_embeddings
                self.train_or_test = train_or_test
                root_dir = dl_studio.dataroot
                f = gzip.open(root_dir + dataset_file, 'rb')
                dataset = f.read()
                if path_to_saved_embeddings is not None:
                    import gensim.downloader as genapi
                    from gensim.models import KeyedVectors 
                    if os.path.exists(path_to_saved_embeddings + 'vectors.kv'):
                        self.word_vectors = KeyedVectors.load(path_to_saved_embeddings + 'vectors.kv')
                    else:
                        print("""\n\nSince this is your first time to install the word2vec embeddings, it may take"""
                              """\na couple of minutes. The embeddings occupy around 3.6GB of your disk space.\n\n""")
                        self.word_vectors = genapi.load("word2vec-google-news-300")               
                        ##  'kv' stands for  "KeyedVectors", a special datatype used by gensim because it 
                        ##  has a smaller footprint than dict
                        self.word_vectors.save(path_to_saved_embeddings + 'vectors.kv')    
                if train_or_test == 'train':
                    if sys.version_info[0] == 3:
                        self.positive_reviews_train, self.negative_reviews_train, self.vocab = pickle.loads(dataset, encoding='latin1')
                    else:
                        self.positive_reviews_train, self.negative_reviews_train, self.vocab = pickle.loads(dataset)
                    self.categories = sorted(list(self.positive_reviews_train.keys()))
                    self.category_sizes_train_pos = {category : len(self.positive_reviews_train[category]) for category in self.categories}
                    self.category_sizes_train_neg = {category : len(self.negative_reviews_train[category]) for category in self.categories}
                    self.indexed_dataset_train = []
                    for category in self.positive_reviews_train:
                        for review in self.positive_reviews_train[category]:
                            self.indexed_dataset_train.append([review, category, 1])
                    for category in self.negative_reviews_train:
                        for review in self.negative_reviews_train[category]:
                            self.indexed_dataset_train.append([review, category, 0])
                    random.shuffle(self.indexed_dataset_train)
                elif train_or_test == 'test':
                    if sys.version_info[0] == 3:
                        self.positive_reviews_test, self.negative_reviews_test, self.vocab = pickle.loads(dataset, encoding='latin1')
                    else:
                        self.positive_reviews_test, self.negative_reviews_test, self.vocab = pickle.loads(dataset)
                    self.vocab = sorted(self.vocab)
                    self.categories = sorted(list(self.positive_reviews_test.keys()))
                    self.category_sizes_test_pos = {category : len(self.positive_reviews_test[category]) for category in self.categories}
                    self.category_sizes_test_neg = {category : len(self.negative_reviews_test[category]) for category in self.categories}
                    self.indexed_dataset_test = []
                    for category in self.positive_reviews_test:
                        for review in self.positive_reviews_test[category]:
                            self.indexed_dataset_test.append([review, category, 1])
                    for category in self.negative_reviews_test:
                        for review in self.negative_reviews_test[category]:
                            self.indexed_dataset_test.append([review, category, 0])
                    random.shuffle(self.indexed_dataset_test)

            def review_to_tensor(self, review):
                list_of_embeddings = []
                for i,word in enumerate(review):
                    if word in self.word_vectors.key_to_index:
                        embedding = self.word_vectors[word]
                        list_of_embeddings.append(np.array(embedding))
                    else:
                        next
#                review_tensor = torch.FloatTensor( list_of_embeddings )
                review_tensor = torch.FloatTensor( np.array(list_of_embeddings) )
                return review_tensor

            def sentiment_to_tensor(self, sentiment):
                """
                Sentiment is ordinarily just a binary valued thing.  It is 0 for negative
                sentiment and 1 for positive sentiment.  We need to pack this value in a
                two-element tensor.
                """        
                sentiment_tensor = torch.zeros(2)
                if sentiment == 1:
                    sentiment_tensor[1] = 1
                elif sentiment == 0: 
                    sentiment_tensor[0] = 1
                sentiment_tensor = sentiment_tensor.type(torch.long)
                return sentiment_tensor

            def __len__(self):
                if self.train_or_test == 'train':
                    return len(self.indexed_dataset_train)
                elif self.train_or_test == 'test':
                    return len(self.indexed_dataset_test)

            def __getitem__(self, idx):
                sample = self.indexed_dataset_train[idx] if self.train_or_test == 'train' else self.indexed_dataset_test[idx]
                review = sample[0]
                review_category = sample[1]
                review_sentiment = sample[2]
                review_sentiment = self.sentiment_to_tensor(review_sentiment)
                review_tensor = self.review_to_tensor(review)
                category_index = self.categories.index(review_category)
                sample = {'review'       : review_tensor, 
                          'category'     : category_index, # should be converted to tensor, but not yet used
                          'sentiment'    : review_sentiment }
                return sample

        def load_SentimentAnalysisDataset(self, dataserver_train, dataserver_test ):   
            self.train_dataloader = torch.utils.data.DataLoader(dataserver_train,
                        batch_size=self.dl_studio.batch_size,shuffle=True, num_workers=2)
            self.test_dataloader = torch.utils.data.DataLoader(dataserver_test,
                               batch_size=self.dl_studio.batch_size,shuffle=False, num_workers=2)

        class TEXTnetWithEmbeddings(nn.Module):
            """
            This is embeddings version of the class TEXTnet class shown previously.  Since we
            are using the word2vec embeddings, we know that the input size for each word vector 
            will be a constant value of 300.  Overall, though, this network is meant for semantic 
            classification of variable-length sentiment data.  Based on my limited testing, the 
            performance of this network is very poor because it has no protection against 
            vanishing gradients when used in an RNN.  

            Class Path:  DLStudio -> TextClassificationWithEmbeddings -> TEXTnetWithEmbeddings
            """
            def __init__(self, input_size, hidden_size, output_size):
                super(DLStudio.TextClassificationWithEmbeddings.TEXTnetWithEmbeddings, self).__init__()
                self.input_size = input_size
                self.hidden_size = hidden_size
                self.output_size = output_size
                self.combined_to_hidden = nn.Linear(input_size + hidden_size, hidden_size)
                self.combined_to_middle = nn.Linear(input_size + hidden_size, 100)
                self.middle_to_out = nn.Linear(100, output_size)     
                self.logsoftmax = nn.LogSoftmax(dim=1)

            def forward(self, input, hidden):
                combined = torch.cat((input, hidden), 1)
                hidden = self.combined_to_hidden(combined)
                hidden = torch.tanh(hidden)                     
                out = self.combined_to_middle(combined)
                out = nn.functional.relu(out)
                out = self.middle_to_out(out)
                out = self.logsoftmax(out)
                return out,hidden         

            def init_hidden(self):
                hidden = torch.zeros(1, self.hidden_size)
                return hidden


        class TEXTnetOrder2WithEmbeddings(nn.Module):
            """
            This is an embeddings version of the TEXTnetOrder2 class shown previously.
            With the embeddings, we know that the size the tensor for word will be 300.
            As to how TEXTnetOrder2 differs from TEXTnet, the value of hidden as used at
            each time step also includes its value at the previous time step.  This 
            fact, not directly apparent by the definition of the class shown below, 
            is made possible by the last parameter, cell, in the header of forward().  
            All you can see here, at the end of forward(), is that the value of cell 
            goes through a linear layer and through a sigmoid nonlinearity. By the way, 
            since the sigmoid saturates at 0 and 1, it can act like a switch. Later 
            when I use this class in the training function, you will see the cell
            values being used in such a manner that the hidden state at each time
            step is mixed with the hidden state at the previous time step.

            Class Path:  DLStudio -> TextClassificationWithEmbeddings -> TEXTnetOrder2WithEmbeddings
            """
            def __init__(self, hidden_size, output_size, input_size=300):
                super(DLStudio.TextClassificationWithEmbeddings.TEXTnetOrder2WithEmbeddings, self).__init__()
                self.input_size = input_size
                self.hidden_size = hidden_size
                self.output_size = output_size
                self.combined_to_hidden = nn.Linear(input_size + 2*hidden_size, hidden_size)
                self.combined_to_middle = nn.Linear(input_size + 2*hidden_size, 100)
                self.middle_to_out = nn.Linear(100, output_size)     
                self.logsoftmax = nn.LogSoftmax(dim=1)
                self.dropout = nn.Dropout(p=0.1)
                # for the cell
                self.linear_for_cell = nn.Linear(hidden_size, hidden_size)

            def forward(self, input, hidden, cell):
                combined = torch.cat((input, hidden, cell), 1)
                hidden = self.combined_to_hidden(combined)
                hidden = torch.tanh(hidden)                     
                out = self.combined_to_middle(combined)
                out = nn.functional.relu(out)
                out = self.dropout(out)
                out = self.middle_to_out(out)
                out = self.logsoftmax(out)
                hidden_clone = hidden.clone()
#                cell = torch.tanh(self.linear_for_cell(hidden_clone))
                cell = torch.sigmoid(self.linear_for_cell(hidden_clone))
                return out,hidden,cell         

            def initialize_cell(self):
                weight = next(self.linear_for_cell.parameters()).data
                cell = weight.new(1, self.hidden_size).zero_()
                return cell

            def init_hidden(self):
                hidden = torch.zeros(1, self.hidden_size)
                return hidden


        class GRUnetWithEmbeddings(nn.Module):
            """
            For this embeddings adapted version of the GRUnet shown earlier, we can assume that
            the 'input_size' for a tensor representing a word is always 300.
            Source: https://blog.floydhub.com/gru-with-pytorch/
            with the only modification that the final output of forward() is now
            routed through LogSoftmax activation. 

            Class Path:  DLStudio -> TextClassificationWithEmbeddings -> GRUnetWithEmbeddings 
            """
            def __init__(self, input_size, hidden_size, output_size, num_layers=1): 
                """
                -- input_size is the size of the tensor for each word in a sequence of words.  If you word2vec
                       embedding, the value of this variable will always be equal to 300.
                -- hidden_size is the size of the hidden state in the RNN
                -- output_size is the size of output of the RNN.  For binary classification of 
                       input text, output_size is 2.
                -- num_layers creates a stack of GRUs
                """
                super(DLStudio.TextClassificationWithEmbeddings.GRUnetWithEmbeddings, self).__init__()
                self.input_size = input_size
                self.hidden_size = hidden_size
                self.num_layers = num_layers
                self.gru = nn.GRU(input_size, hidden_size, num_layers)
                self.fc = nn.Linear(hidden_size, output_size)
                self.relu = nn.ReLU()
                self.logsoftmax = nn.LogSoftmax(dim=1)
                
            def forward(self, x, h):
                out, h = self.gru(x, h)
                out = self.fc(self.relu(out[:,-1]))
                out = self.logsoftmax(out)
                return out, h

            def init_hidden(self):
                weight = next(self.parameters()).data
                #                  num_layers  batch_size    hidden_size
                hidden = weight.new(  2,          1,         self.hidden_size    ).zero_()
                return hidden

        def save_model(self, model):
            "Save the trained model to a disk file"
            torch.save(model.state_dict(), self.dl_studio.path_saved_model)


        def run_code_for_training_with_TEXTnet_word2vec(self, net, display_train_loss=False):        
            filename_for_out = "performance_numbers_" + str(self.dl_studio.epochs) + ".txt"
            FILE = open(filename_for_out, 'w')
            net = copy.deepcopy(net)
            net = net.to(self.dl_studio.device)
            ## Note that the TEXTnet and TEXTnetOrder2 both produce LogSoftmax output. So we
            ## use nn.NLLLoss. The combined effect of LogSoftMax and NLLLoss is the same as 
            ## for the CrossEntropyLoss
            criterion = nn.NLLLoss()
            accum_times = []
            optimizer = optim.SGD(net.parameters(), 
                         lr=self.dl_studio.learning_rate, momentum=self.dl_studio.momentum)
            start_time = time.perf_counter()
            training_loss_tally = []
            for epoch in range(self.dl_studio.epochs):  
                print("")
                running_loss = 0.0
                for i, data in enumerate(self.train_dataloader):    
                    hidden = net.init_hidden().to(self.dl_studio.device)              
                    review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
                    review_tensor = review_tensor.to(self.dl_studio.device)
                    sentiment = sentiment.to(self.dl_studio.device)
                    optimizer.zero_grad()
                    input = torch.zeros(1,review_tensor.shape[2]).to(self.dl_studio.device)
                    for k in range(review_tensor.shape[1]):
                        input[0,:] = review_tensor[0,k]
                        output, hidden = net(input, hidden)
                    loss = criterion(output, torch.argmax(sentiment,1))
                    running_loss += loss.item()
                    loss.backward(retain_graph=True)        
                    optimizer.step()
                    if i % 200 == 199:    
                        avg_loss = running_loss / float(200)
                        training_loss_tally.append(avg_loss)
                        running_loss = 0.0
                        current_time = time.perf_counter()
                        time_elapsed = current_time-start_time
                        print("[epoch:%d  iter:%4d  elapsed_time: %4d secs]     loss: %.5f" % (epoch+1,i+1, time_elapsed,avg_loss))
                        accum_times.append(current_time-start_time)
                        FILE.write("%.3f\n" % avg_loss)
                        FILE.flush()
            print("\nFinished Training\n\n")
            self.save_model(net)
            if display_train_loss:
                plt.figure(figsize=(10,5))
                plt.title("Training Loss vs. Iterations")
                plt.plot(training_loss_tally)
                plt.xlabel("iterations")
                plt.ylabel("training loss")
#                plt.legend()
                plt.legend(["Plot of loss versus iterations"], fontsize="x-large")
                plt.savefig("training_loss.png")
                plt.show()


        def run_code_for_training_with_TEXTnetOrder2_word2vec(self, net, display_train_loss=False):        
            filename_for_out = "performance_numbers_" + str(self.dl_studio.epochs) + ".txt"
            FILE = open(filename_for_out, 'w')
            net = copy.deepcopy(net)
            net.to(self.dl_studio.device)
            ## Note that the TEXTnet and TEXTnetOrder2 both produce LogSoftmax output:
            criterion = nn.NLLLoss()
            accum_times = []
            optimizer = optim.SGD(net.parameters(), 
                                       lr=self.dl_studio.learning_rate, momentum=self.dl_studio.momentum)
            start_time = time.perf_counter()
            training_loss_tally = []
            for epoch in range(self.dl_studio.epochs):  
                print("")
                running_loss = 0.0
                for i, data in enumerate(self.train_dataloader):    
                    cell_prev = net.initialize_cell().to(self.dl_studio.device)
                    cell_prev_2_prev = net.initialize_cell().to(self.dl_studio.device)
                    hidden = net.init_hidden().to(self.dl_studio.device)              
                    review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
                    review_tensor = review_tensor.to(self.dl_studio.device)
                    sentiment = sentiment.to(self.dl_studio.device)
                    optimizer.zero_grad()
                    input = torch.zeros(1,review_tensor.shape[2])
                    input = input.to(self.dl_studio.device)
                    for k in range(review_tensor.shape[1]):
                        input[0,:] = review_tensor[0,k]
                        output, hidden, cell = net(input, hidden, cell_prev_2_prev)
                        if k == 0:
                            cell_prev = cell
                        else:
                            cell_prev_2_prev = cell_prev
                            cell_prev = cell
                    loss = criterion(output, torch.argmax(sentiment,1))
                    running_loss += loss.item()
                    loss.backward()        
                    optimizer.step()
                    if i % 200 == 199:    
                        avg_loss = running_loss / float(200)
                        training_loss_tally.append(avg_loss)
                        current_time = time.perf_counter()
                        time_elapsed = current_time-start_time
                        print("[epoch:%d  iter:%4d  elapsed_time: %4d secs]     loss: %.5f" % (epoch+1,i+1, time_elapsed,avg_loss))
                        accum_times.append(current_time-start_time)
                        FILE.write("%.3f\n" % avg_loss)
                        FILE.flush()
                        running_loss = 0.0
            print("\nFinished Training\n")
            self.save_model(net)
            if display_train_loss:
                plt.figure(figsize=(10,5))
                plt.title("Training Loss vs. Iterations")
                plt.plot(training_loss_tally)
                plt.xlabel("iterations")
                plt.ylabel("training loss")
#                plt.legend()
                plt.legend(["Plot of loss versus iterations"], fontsize="x-large")
                plt.savefig("training_loss.png")
                plt.show()


        def run_code_for_training_for_text_classification_with_GRU_word2vec(self, net, display_train_loss=False): 
            filename_for_out = "performance_numbers_" + str(self.dl_studio.epochs) + ".txt"
            FILE = open(filename_for_out, 'w')
            net = copy.deepcopy(net)
            net = net.to(self.dl_studio.device)
            ##  Note that the GRUnet now produces the LogSoftmax output:
            criterion = nn.NLLLoss()
            accum_times = []
            optimizer = optim.SGD(net.parameters(), 
                         lr=self.dl_studio.learning_rate, momentum=self.dl_studio.momentum)
            training_loss_tally = []
            start_time = time.perf_counter()
            for epoch in range(self.dl_studio.epochs):  
                print("")
                running_loss = 0.0
                for i, data in enumerate(self.train_dataloader):    
                    review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
                    review_tensor = review_tensor.to(self.dl_studio.device)
                    sentiment = sentiment.to(self.dl_studio.device)
                    ## The following type conversion needed for MSELoss:
                    ##sentiment = sentiment.float()
                    optimizer.zero_grad()
                    hidden = net.init_hidden().to(self.dl_studio.device)
                    for k in range(review_tensor.shape[1]):
                        output, hidden = net(torch.unsqueeze(torch.unsqueeze(review_tensor[0,k],0),0), hidden)
                    loss = criterion(output, torch.argmax(sentiment, 1))
                    running_loss += loss.item()
                    loss.backward()
                    optimizer.step()
                    if i % 200 == 199:    
                        avg_loss = running_loss / float(200)
                        training_loss_tally.append(avg_loss)
                        current_time = time.perf_counter()
                        time_elapsed = current_time-start_time
                        print("[epoch:%d  iter:%4d  elapsed_time:%4d secs]     loss: %.5f" % (epoch+1,i+1, time_elapsed,avg_loss))
                        accum_times.append(current_time-start_time)
                        FILE.write("%.5f\n" % avg_loss)
                        FILE.flush()
                        running_loss = 0.0
            self.save_model(net)
            print("Total Training Time: {}".format(str(sum(accum_times))))
            print("\nFinished Training\n\n")
            if display_train_loss:
                plt.figure(figsize=(10,5))
                plt.title("Training Loss vs. Iterations")
                plt.plot(training_loss_tally)
                plt.xlabel("iterations")
                plt.ylabel("training loss")
#                plt.legend()
                plt.legend(["Plot of loss versus iterations"], fontsize="x-large")
                plt.savefig("training_loss.png")
                plt.show()


        def run_code_for_testing_with_TEXTnet_word2vec(self, net):
            net.load_state_dict(torch.load(self.dl_studio.path_saved_model))
            net.to(self.dl_studio.device)
            classification_accuracy = 0.0
            negative_total = 0
            positive_total = 0
            confusion_matrix = torch.zeros(2,2)
            with torch.no_grad():
                for i, data in enumerate(self.test_dataloader):
                    review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
                    review_tensor = review_tensor.to(self.dl_studio.device)
                    category      = category.to(self.dl_studio.device)
                    sentiment     = sentiment.to(self.dl_studio.device)
                    input = torch.zeros(1,review_tensor.shape[2]).to(self.dl_studio.device)
                    hidden = net.init_hidden().to(self.dl_studio.device)
                    for k in range(review_tensor.shape[1]):
                        input[0,:] = review_tensor[0,k]
                        output, hidden = net(input, hidden)
                    predicted_idx = torch.argmax(output).item()
                    gt_idx = torch.argmax(sentiment).item()
                    if i % 100 == 99:
                        print("   [i=%4d]    predicted_label=%d       gt_label=%d" % (i+1, predicted_idx,gt_idx))
                    if predicted_idx == gt_idx:
                        classification_accuracy += 1
                    if gt_idx == 0: 
                        negative_total += 1
                    elif gt_idx == 1:
                        positive_total += 1
                    confusion_matrix[gt_idx,predicted_idx] += 1
            print("\nOverall classification accuracy: %0.2f%%" %  (float(classification_accuracy) * 100 /float(i)))
            out_percent = np.zeros((2,2), dtype='float')
            out_percent[0,0] = "%.3f" % (100 * confusion_matrix[0,0] / float(negative_total))
            out_percent[0,1] = "%.3f" % (100 * confusion_matrix[0,1] / float(negative_total))
            out_percent[1,0] = "%.3f" % (100 * confusion_matrix[1,0] / float(positive_total))
            out_percent[1,1] = "%.3f" % (100 * confusion_matrix[1,1] / float(positive_total))
            print("\n\nNumber of positive reviews tested: %d" % positive_total)
            print("\n\nNumber of negative reviews tested: %d" % negative_total)
            print("\n\nDisplaying the confusion matrix:\n")
            out_str = "                      "
            out_str +=  "%18s    %18s" % ('predicted negative', 'predicted positive')
            print(out_str + "\n")
            for i,label in enumerate(['true negative', 'true positive']):
                out_str = "%12s%%:  " % label
                for j in range(2):
                    out_str +=  "%18s%%" % out_percent[i,j]
                print(out_str)


        def run_code_for_testing_with_TEXTnetOrder2_word2vec(self, net):
            net.load_state_dict(torch.load(self.dl_studio.path_saved_model))
            net.to(self.dl_studio.device)
            classification_accuracy = 0.0
            negative_total = 0
            positive_total = 0
            confusion_matrix = torch.zeros(2,2)
            with torch.no_grad():
                for i, data in enumerate(self.test_dataloader):
                    cell_prev = net.initialize_cell()
                    cell_prev_2_prev = net.initialize_cell()
                    review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
                    input = torch.zeros(1,review_tensor.shape[2]).to(self.dl_studio.device)
                    hidden = net.init_hidden().to(self.dl_studio.device)
                    for k in range(review_tensor.shape[1]):
                        input[0,:] = review_tensor[0,k]
                        output, hidden, cell = net(input, hidden, cell_prev_2_prev)
                        if k == 0:
                            cell_prev = cell
                        else:
                            cell_prev_2_prev = cell_prev
                            cell_prev = cell
                    predicted_idx = torch.argmax(output).item()
                    gt_idx = torch.argmax(sentiment).item()
                    if i % 100 == 99:
                        print("   [i=%4d]    predicted_label=%d       gt_label=%d" % (i+1, predicted_idx,gt_idx))
                    if predicted_idx == gt_idx:
                        classification_accuracy += 1
                    if gt_idx == 0: 
                        negative_total += 1
                    elif gt_idx == 1:
                        positive_total += 1
                    confusion_matrix[gt_idx,predicted_idx] += 1
            print("\nOverall classification accuracy: %0.2f%%" %  (float(classification_accuracy) * 100 /float(i)))
            out_percent = np.zeros((2,2), dtype='float')
            out_percent[0,0] = "%.3f" % (100 * confusion_matrix[0,0] / float(negative_total))
            out_percent[0,1] = "%.3f" % (100 * confusion_matrix[0,1] / float(negative_total))
            out_percent[1,0] = "%.3f" % (100 * confusion_matrix[1,0] / float(positive_total))
            out_percent[1,1] = "%.3f" % (100 * confusion_matrix[1,1] / float(positive_total))
            print("\n\nNumber of positive reviews tested: %d" % positive_total)
            print("\n\nNumber of negative reviews tested: %d" % negative_total)
            print("\n\nDisplaying the confusion matrix:\n")
            out_str = "                      "
            out_str +=  "%18s    %18s" % ('predicted negative', 'predicted positive')
            print(out_str + "\n")
            for i,label in enumerate(['true negative', 'true positive']):
                out_str = "%12s:  " % label
                for j in range(2):
                    out_str +=  "%18s" % out_percent[i,j]
                print(out_str)


        def run_code_for_testing_text_classification_with_GRU_word2vec(self, net):
            net.load_state_dict(torch.load(self.dl_studio.path_saved_model))
            classification_accuracy = 0.0
            negative_total = 0
            positive_total = 0
            confusion_matrix = torch.zeros(2,2)
            with torch.no_grad():
                for i, data in enumerate(self.test_dataloader):
                    review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
                    hidden = net.init_hidden()
                    for k in range(review_tensor.shape[1]):
                        output, hidden = net(torch.unsqueeze(torch.unsqueeze(review_tensor[0,k],0),0), hidden)
                    predicted_idx = torch.argmax(output).item()
                    gt_idx = torch.argmax(sentiment).item()
                    if i % 100 == 99:
                        print("   [i=%d]    predicted_label=%d       gt_label=%d" % (i+1, predicted_idx,gt_idx))
                    if predicted_idx == gt_idx:
                        classification_accuracy += 1
                    if gt_idx == 0: 
                        negative_total += 1
                    elif gt_idx == 1:
                        positive_total += 1
                    confusion_matrix[gt_idx,predicted_idx] += 1
            print("\nOverall classification accuracy: %0.2f%%" %  (float(classification_accuracy) * 100 /float(i)))
            out_percent = np.zeros((2,2), dtype='float')
            out_percent[0,0] = "%.3f" % (100 * confusion_matrix[0,0] / float(negative_total))
            out_percent[0,1] = "%.3f" % (100 * confusion_matrix[0,1] / float(negative_total))
            out_percent[1,0] = "%.3f" % (100 * confusion_matrix[1,0] / float(positive_total))
            out_percent[1,1] = "%.3f" % (100 * confusion_matrix[1,1] / float(positive_total))
            print("\n\nNumber of positive reviews tested: %d" % positive_total)
            print("\n\nNumber of negative reviews tested: %d" % negative_total)
            print("\n\nDisplaying the confusion matrix:\n")
            out_str = "                      "
            out_str +=  "%18s    %18s" % ('predicted negative', 'predicted positive')
            print(out_str + "\n")
            for i,label in enumerate(['true negative', 'true positive']):
                out_str = "%12s:  " % label
                for j in range(2):
                    out_str +=  "%18s%%" % out_percent[i,j]
                print(out_str)


#_________________________  End of DLStudio Class Definition ___________________________

#______________________________    Test code follows    _________________________________

if __name__ == '__main__': 
    pass