# -*- coding: utf-8 -*-
__version__ = '2.0.6'
__author__ = "Avinash Kak (kak@purdue.edu)"
__date__ = '2021-March-17'
__url__ = 'https://engineering.purdue.edu/kak/distDLS/DLStudio-2.0.6.html'
__copyright__ = "(C) 2021 Avinash Kak. Python Software Foundation."
import sys,os,os.path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as tvt
import torch.optim as optim
import numpy as np
from PIL import ImageFilter
import numbers
import re
import math
import random
import copy
import matplotlib.pyplot as plt
import gzip
import pickle
import pymsgbox
import time
import logging
#______________________________ DLStudio Class Definition ________________________________
class DLStudio(object):
def __init__(self, *args, **kwargs ):
if args:
raise ValueError(
'''DLStudio constructor can only be called with keyword arguments for
the following keywords: epochs, learning_rate, batch_size, momentum,
convo_layers_config, image_size, dataroot, path_saved_model, classes,
image_size, convo_layers_config, fc_layers_config, debug_train, use_gpu, and
debug_test''')
learning_rate = epochs = batch_size = convo_layers_config = momentum = None
image_size = fc_layers_config = dataroot = path_saved_model = classes = use_gpu = None
debug_train = debug_test = None
if 'dataroot' in kwargs : dataroot = kwargs.pop('dataroot')
if 'learning_rate' in kwargs : learning_rate = kwargs.pop('learning_rate')
if 'momentum' in kwargs : momentum = kwargs.pop('momentum')
if 'epochs' in kwargs : epochs = kwargs.pop('epochs')
if 'batch_size' in kwargs : batch_size = kwargs.pop('batch_size')
if 'convo_layers_config' in kwargs : convo_layers_config = kwargs.pop('convo_layers_config')
if 'image_size' in kwargs : image_size = kwargs.pop('image_size')
if 'fc_layers_config' in kwargs : fc_layers_config = kwargs.pop('fc_layers_config')
if 'path_saved_model' in kwargs : path_saved_model = kwargs.pop('path_saved_model')
if 'classes' in kwargs : classes = kwargs.pop('classes')
if 'use_gpu' in kwargs : use_gpu = kwargs.pop('use_gpu')
if 'debug_train' in kwargs : debug_train = kwargs.pop('debug_train')
if 'debug_test' in kwargs : debug_test = kwargs.pop('debug_test')
if len(kwargs) != 0: raise ValueError('''You have provided unrecognizable keyword args''')
if dataroot:
self.dataroot = dataroot
if convo_layers_config:
self.convo_layers_config = convo_layers_config
if image_size:
self.image_size = image_size
if fc_layers_config:
self.fc_layers_config = fc_layers_config
# if fc_layers_config[0] is not -1:
if fc_layers_config[0] != -1:
raise Exception("""\n\n\nYour 'fc_layers_config' construction option is not correct. """
"""The first element of the list of nodes in the fc layer must be -1 """
"""because the input to fc will be set automatically to the size of """
"""the final activation volume of the convolutional part of the network""")
if path_saved_model:
self.path_saved_model = path_saved_model
if classes:
self.class_labels = classes
if learning_rate:
self.learning_rate = learning_rate
else:
self.learning_rate = 1e-6
if momentum:
self.momentum = momentum
if epochs:
self.epochs = epochs
if batch_size:
self.batch_size = batch_size
if use_gpu is not None:
self.use_gpu = use_gpu
if use_gpu is True:
if torch.cuda.is_available():
self.device = torch.device("cuda:0")
else:
raise Exception("You requested GPU support, but there's no GPU on this machine")
else:
self.device = torch.device("cpu")
if debug_train:
self.debug_train = debug_train
else:
self.debug_train = 0
if debug_test:
self.debug_test = debug_test
else:
self.debug_test = 0
self.debug_config = 0
# self.device = torch.device("cuda:0" if torch.cuda.is_available() and self.use_gpu is False else "cpu")
def parse_config_string_for_convo_layers(self):
'''
Each collection of 'n' otherwise identical layers in a convolutional network is
specified by a string that looks like:
"nx[a,b,c,d]-MaxPool(k)"
where
n = num of this type of convo layer
a = number of out_channels [in_channels determined by prev layer]
b,c = kernel for this layer is of size (b,c) [b along height, c along width]
d = stride for convolutions
k = maxpooling over kxk patches with stride of k
Example:
"n1x[a1,b1,c1,d1]-MaxPool(k1) n2x[a2,b2,c2,d2]-MaxPool(k2)"
'''
configuration = self.convo_layers_config
configs = configuration.split()
all_convo_layers = []
image_size_after_layer = self.image_size
for k,config in enumerate(configs):
two_parts = config.split('-')
how_many_conv_layers_with_this_config = int(two_parts[0][:config.index('x')])
if self.debug_config:
print("\n\nhow many convo layers with this config: %d" % how_many_conv_layers_with_this_config)
maxpooling_size = int(re.findall(r'\d+', two_parts[1])[0])
if self.debug_config:
print("\nmax pooling size for all convo layers with this config: %d" % maxpooling_size)
for conv_layer in range(how_many_conv_layers_with_this_config):
convo_layer = {'out_channels':None,
'kernel_size':None,
'convo_stride':None,
'maxpool_size':None,
'maxpool_stride': None}
kernel_params = two_parts[0][config.index('x')+1:][1:-1].split(',')
if self.debug_config:
print("\nkernel_params: %s" % str(kernel_params))
convo_layer['out_channels'] = int(kernel_params[0])
convo_layer['kernel_size'] = (int(kernel_params[1]), int(kernel_params[2]))
convo_layer['convo_stride'] = int(kernel_params[3])
image_size_after_layer = [x // convo_layer['convo_stride'] for x in image_size_after_layer]
convo_layer['maxpool_size'] = maxpooling_size
convo_layer['maxpool_stride'] = maxpooling_size
image_size_after_layer = [x // convo_layer['maxpool_size'] for x in image_size_after_layer]
all_convo_layers.append(convo_layer)
configs_for_all_convo_layers = {i : all_convo_layers[i] for i in range(len(all_convo_layers))}
if self.debug_config:
print("\n\nAll convo layers: %s" % str(configs_for_all_convo_layers))
last_convo_layer = configs_for_all_convo_layers[len(all_convo_layers)-1]
out_nodes_final_layer = image_size_after_layer[0] * image_size_after_layer[1] * \
last_convo_layer['out_channels']
self.fc_layers_config[0] = out_nodes_final_layer
self.configs_for_all_convo_layers = configs_for_all_convo_layers
return configs_for_all_convo_layers
def build_convo_layers(self, configs_for_all_convo_layers):
conv_layers = nn.ModuleList()
in_channels_for_next_layer = None
for layer_index in configs_for_all_convo_layers:
if self.debug_config:
print("\n\n\nLayer index: %d" % layer_index)
in_channels = 3 if layer_index == 0 else in_channels_for_next_layer
out_channels = configs_for_all_convo_layers[layer_index]['out_channels']
kernel_size = configs_for_all_convo_layers[layer_index]['kernel_size']
padding = tuple((k-1) // 2 for k in kernel_size)
stride = configs_for_all_convo_layers[layer_index]['convo_stride']
maxpool_size = configs_for_all_convo_layers[layer_index]['maxpool_size']
if self.debug_config:
print("\n in_channels=%d out_channels=%d kernel_size=%s stride=%s \
maxpool_size=%s" % (in_channels, out_channels, str(kernel_size), str(stride),
str(maxpool_size)))
conv_layers.append( nn.Conv2d( in_channels,out_channels,kernel_size,stride=stride,padding=padding) )
conv_layers.append( nn.MaxPool2d( maxpool_size ) )
conv_layers.append( nn.ReLU() ),
in_channels_for_next_layer = out_channels
return conv_layers
def build_fc_layers(self):
fc_layers = nn.ModuleList()
for layer_index in range(len(self.fc_layers_config) - 1):
fc_layers.append( nn.Linear( self.fc_layers_config[layer_index],
self.fc_layers_config[layer_index+1] ) )
return fc_layers
def load_cifar_10_dataset(self):
'''
We make sure that the transformation applied to the image end the images being normalized.
Consider this call to normalize: "Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))". The three
numbers in the first tuple affect the means in the three color channels and the three
numbers in the second tuple affect the standard deviations. In this case, we want the
image value in each channel to be changed to:
image_channel_val = (image_channel_val - mean) / std
So with mean and std both set 0.5 for all three channels, if the image tensor originally
was between 0 and 1.0, after this normalization, the tensor will be between -1.0 and +1.0.
If needed we can do inverse normalization by
image_channel_val = (image_channel_val * std) + mean
'''
## The call to ToTensor() converts the usual int range 0-255 for pixel values to 0-1.0 float vals
## But then the call to Normalize() changes the range to -1.0-1.0 float vals.
transform = tvt.Compose([tvt.ToTensor(),
tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) ## accuracy: 51%
## Define where the training and the test datasets are located:
train_data_loc = torchvision.datasets.CIFAR10(root=self.dataroot, train=True, download=True, transform=transform)
test_data_loc = torchvision.datasets.CIFAR10(root=self.dataroot, train=False, download=True, transform=transform)
## Now create the data loaders:
self.train_data_loader = torch.utils.data.DataLoader(train_data_loc,batch_size=self.batch_size, shuffle=True, num_workers=2)
self.test_data_loader = torch.utils.data.DataLoader(test_data_loc,batch_size=self.batch_size, shuffle=False, num_workers=2)
def load_cifar_10_dataset_with_augmentation(self):
'''
In general, we want to do data augmentation for training:
'''
transform_train = tvt.Compose([
tvt.RandomCrop(32, padding=4),
tvt.RandomHorizontalFlip(),
tvt.ToTensor(),
# tvt.Normalize((0.20, 0.20, 0.20), (0.20, 0.20, 0.20))])
tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
## Don't need any augmentation for the test data:
transform_test = tvt.Compose([
tvt.ToTensor(),
# tvt.Normalize((0.20, 0.20, 0.20), (0.20, 0.20, 0.20))])
tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
## Define where the training and the test datasets are located
train_data_loc = torchvision.datasets.CIFAR10(
root=self.dataroot, train=True, download=True, transform=transform_train)
test_data_loc = torchvision.datasets.CIFAR10(
root=self.dataroot, train=False, download=True, transform=transform_test)
## Now create the data loaders:
self.train_data_loader = torch.utils.data.DataLoader(train_data_loc, batch_size=self.batch_size,
shuffle=True, num_workers=2)
self.test_data_loader = torch.utils.data.DataLoader(test_data_loc, batch_size=self.batch_size,
shuffle=False, num_workers=2)
def imshow(self, img):
'''
called by display_tensor_as_image() for displaying the image
'''
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
class Net(nn.Module):
def __init__(self, convo_layers, fc_layers):
super(DLStudio.Net, self).__init__()
self.my_modules_convo = convo_layers
self.my_modules_fc = fc_layers
def forward(self, x):
for m in self.my_modules_convo:
x = m(x)
x = x.view(x.size(0), -1)
for m in self.my_modules_fc:
x = m(x)
return x
def run_code_for_training(self, net, display_images=False):
filename_for_out = "performance_numbers_" + str(self.epochs) + ".txt"
FILE = open(filename_for_out, 'w')
net = copy.deepcopy(net)
net = net.to(self.device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=self.learning_rate, momentum=self.momentum)
print("\n\nStarting training loop...")
start_time = time.perf_counter()
loss_tally = []
elapsed_time = 0.0
for epoch in range(self.epochs):
print("")
running_loss = 0.0
for i, data in enumerate(self.train_data_loader):
inputs, labels = data
if i % 1000 == 999:
current_time = time.perf_counter()
elapsed_time = current_time - start_time
print("\n\n[epoch:%d/%d iter=%4d elapsed_time=%5d secs] Ground Truth: " %
(epoch+1, self.epochs, i+1, elapsed_time) +
' '.join('%10s' % self.class_labels[labels[j]] for j in range(self.batch_size)))
inputs = inputs.to(self.device)
labels = labels.to(self.device)
## Since PyTorch likes to construct dynamic computational graphs, we need to
## zero out the previously calculated gradients for the learnable parameters:
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item()
if i % 1000 == 999:
_, predicted = torch.max(outputs.data, 1)
print("[epoch:%d/%d iter=%4d elapsed_time=%5d secs] Predicted Labels: " %
(epoch+1, self.epochs, i+1, elapsed_time ) +
' '.join('%10s' % self.class_labels[predicted[j]] for j in range(self.batch_size)))
avg_loss = running_loss / float(2000)
loss_tally.append(avg_loss)
print("[epoch:%d/%d iter=%4d elapsed_time=%5d secs] Loss: %.3f" %
(epoch+1, self.epochs, i+1, elapsed_time, avg_loss))
FILE.write("%.3f\n" % avg_loss)
FILE.flush()
running_loss = 0.0
if display_images:
logger = logging.getLogger()
old_level = logger.level
logger.setLevel(100)
plt.figure(figsize=[6,3])
plt.imshow(np.transpose(torchvision.utils.make_grid(inputs,
normalize=False, padding=3, pad_value=255).cpu(), (1,2,0)))
plt.show()
logger.setLevel(old_level)
loss.backward()
optimizer.step()
print("\nFinished Training\n")
self.save_model(net)
plt.figure(figsize=(10,5))
plt.title("Labeling Loss vs. Iterations")
plt.plot(loss_tally)
plt.xlabel("iterations")
plt.ylabel("loss")
plt.legend()
plt.savefig("playing_with_skips_loss.png")
plt.show()
def display_tensor_as_image(self, tensor, title=""):
'''
This method converts the argument tensor into a photo image that you can display
in your terminal screen. It can convert tensors of three different shapes
into images: (3,H,W), (1,H,W), and (H,W), where H, for height, stands for the
number of pixels in the vertical direction and W, for width, for the same
along the horizontal direction. When the first element of the shape is 3,
that means that the tensor represents a color image in which each pixel in
the (H,W) plane has three values for the three color channels. On the other
hand, when the first element is 1, that stands for a tensor that will be
shown as a grayscale image. And when the shape is just (H,W), that is
automatically taken to be for a grayscale image.
'''
tensor_range = (torch.min(tensor).item(), torch.max(tensor).item())
if tensor_range == (-1.0,1.0):
## The tensors must be between 0.0 and 1.0 for the display:
print("\n\n\nimage un-normalization called")
tensor = tensor/2.0 + 0.5 # unnormalize
plt.figure(title)
### The call to plt.imshow() shown below needs a numpy array. We must also
### transpose the array so that the number of channels (the same thing as the
### number of color planes) is in the last element. For a tensor, it would be in
### the first element.
if tensor.shape[0] == 3 and len(tensor.shape) == 3:
# plt.imshow( tensor.numpy().transpose(1,2,0) )
plt.imshow( tensor.numpy().transpose(1,2,0) )
### If the grayscale image was produced by calling torchvision.transform's
### ".ToPILImage()", and the result converted to a tensor, the tensor shape will
### again have three elements in it, however the first element that stands for
### the number of channels will now be 1
elif tensor.shape[0] == 1 and len(tensor.shape) == 3:
tensor = tensor[0,:,:]
plt.imshow( tensor.numpy(), cmap = 'gray' )
### For any one color channel extracted from the tensor representation of a color
### image, the shape of the tensor will be (W,H):
elif len(tensor.shape) == 2:
plt.imshow( tensor.numpy(), cmap = 'gray' )
else:
sys.exit("\n\n\nfrom 'display_tensor_as_image()': tensor for image is ill formed -- aborting")
plt.show()
def check_a_sampling_of_images(self):
'''
Displays the first batch_size number of images in your dataset.
'''
dataiter = iter(self.train_data_loader)
images, labels = dataiter.next()
# Since negative pixel values make no sense for display, setting the 'normalize'
# option to True will change the range back from (-1.0,1.0) to (0.0,1.0):
self.display_tensor_as_image(torchvision.utils.make_grid(images, normalize=True))
# Print class labels for the images shown:
print(' '.join('%5s' % self.class_labels[labels[j]] for j in range(self.batch_size)))
def save_model(self, model):
'''
Save the trained model to a disk file
'''
torch.save(model.state_dict(), self.path_saved_model)
def run_code_for_testing(self, net, display_images=False):
net.load_state_dict(torch.load(self.path_saved_model))
net = net.eval()
net = net.to(self.device)
## In what follows, in addition to determining the predicted label for each test
## image, we will also compute some stats to measure the overall performance of
## the trained network. This we will do in two different ways: For each class,
## we will measure how frequently the network predicts the correct labels. In
## addition, we will compute the confusion matrix for the predictions.
filename_for_results = "classification_results_" + str(self.epochs) + ".txt"
FILE = open(filename_for_results, 'w')
correct = 0
total = 0
confusion_matrix = torch.zeros(len(self.class_labels), len(self.class_labels))
class_correct = [0] * len(self.class_labels)
class_total = [0] * len(self.class_labels)
with torch.no_grad():
for i,data in enumerate(self.test_data_loader):
## data is set to the images and the labels for one batch at a time:
images, labels = data
images = images.to(self.device)
labels = labels.to(self.device)
if i % 1000 == 999:
print("\n\n[i=%d:] Ground Truth: " % (i+1) + ' '.join('%5s' % self.class_labels[labels[j]]
for j in range(self.batch_size)))
outputs = net(images)
## max() returns two things: the max value and its index in the 10 element
## output vector. We are only interested in the index --- since that is
## essentially the predicted class label:
_, predicted = torch.max(outputs.data, 1)#
# if display_images and i % 1000 == 999:
if i % 1000 == 999:
print("[i=%d:] Predicted Labels: " % (i+1) + ' '.join('%5s' % self.class_labels[predicted[j]]
for j in range(self.batch_size)))
logger = logging.getLogger()
old_level = logger.level
if display_images:
logger.setLevel(100)
plt.figure(figsize=[6,3])
plt.imshow(np.transpose(torchvision.utils.make_grid(images,
normalize=False, padding=3, pad_value=255).cpu(), (1,2,0)))
plt.show()
logger.setLevel(old_level)
for label,prediction in zip(labels,predicted):
confusion_matrix[label][prediction] += 1
total += labels.size(0)
correct += (predicted == labels).sum().item()
## comp is a list of size batch_size of "True" and "False" vals
comp = predicted == labels
for j in range(self.batch_size):
label = labels[j]
## The following works because, in a numeric context, the boolean value
## "False" is the same as number 0 and the boolean value True is the
## same as number 1. For that reason "4 + True" will evaluate to 5 and
## "4 + False" will evaluate to 4. Also, "1 == True" evaluates to "True"
## "1 == False" evaluates to "False". However, note that "1 is True"
## evaluates to "False" because the operator "is" does not provide a
## numeric context for "True". And so on. In the statement that follows,
## while c[j].item() will either return "False" or "True", for the
## addition operator, Python will use the values 0 and 1 instead.
class_correct[label] += comp[j].item()
class_total[label] += 1
for j in range(len(self.class_labels)):
print('Prediction accuracy for %5s : %2d %%' % (self.class_labels[j], 100 * class_correct[j] / class_total[j]))
FILE.write('\n\nPrediction accuracy for %5s : %2d %%\n' % (self.class_labels[j], 100 * class_correct[j] / class_total[j]))
print("\n\n\nOverall accuracy of the network on the 10000 test images: %d %%" % (100 * correct / float(total)))
FILE.write("\n\n\nOverall accuracy of the network on the 10000 test images: %d %%\n" % (100 * correct / float(total)))
print("\n\nDisplaying the confusion matrix:\n")
FILE.write("\n\nDisplaying the confusion matrix:\n\n")
out_str = " "
for j in range(len(self.class_labels)): out_str += "%7s" % self.class_labels[j]
print(out_str + "\n")
FILE.write(out_str + "\n\n")
for i,label in enumerate(self.class_labels):
out_percents = [100 * confusion_matrix[i,j] / float(class_total[i])
for j in range(len(self.class_labels))]
out_percents = ["%.2f" % item.item() for item in out_percents]
out_str = "%6s: " % self.class_labels[i]
for j in range(len(self.class_labels)): out_str += "%7s" % out_percents[j]
print(out_str)
FILE.write(out_str + "\n")
FILE.close()
########################################################################################
############### Start Definition of Inner Class ExperimentsWithSequential #############
class ExperimentsWithSequential(nn.Module):
"""
Demonstrates how to use the torch.nn.Sequential container class
"""
def __init__(self, dl_studio ):
super(DLStudio.ExperimentsWithSequential, self).__init__()
self.dl_studio = dl_studio
def load_cifar_10_dataset(self):
self.dl_studio.load_cifar_10_dataset()
def load_cifar_10_dataset_with_augmentation(self):
self.dl_studio.load_cifar_10_dataset_with_augmentation()
class Net(nn.Module):
"""
To see if the DLStudio class would work with any network that a user may want
to experiment with, I copy-and-pasted the network shown below from the following
page by Zhenye at GitHub:
https://zhenye-na.github.io/2018/09/28/pytorch-cnn-cifar10.html
"""
def __init__(self):
super(DLStudio.ExperimentsWithSequential.Net, self).__init__()
self.conv_seqn = nn.Sequential(
# Conv Layer block 1:
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
# Conv Layer block 2:
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Dropout2d(p=0.05),
# Conv Layer block 3:
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fc_seqn = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(4096, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Dropout(p=0.1),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.conv_seqn(x)
# flatten
x = x.view(x.size(0), -1)
x = self.fc_seqn(x)
return x
def run_code_for_training(self, net):
self.dl_studio.run_code_for_training(net)
def save_model(self, model):
'''
Save the trained model to a disk file
'''
torch.save(model.state_dict(), self.dl_studio.path_saved_model)
def run_code_for_testing(self, model):
self.dl_studio.run_code_for_testing(model)
########################################################################################
################## Start Definition of Inner Class ExperimentsWithCIFAR ###############
class ExperimentsWithCIFAR(nn.Module):
def __init__(self, dl_studio ):
super(DLStudio.ExperimentsWithCIFAR, self).__init__()
self.dl_studio = dl_studio
def load_cifar_10_dataset(self):
self.dl_studio.load_cifar_10_dataset()
def load_cifar_10_dataset_with_augmentation(self):
self.dl_studio.load_cifar_10_dataset_with_augmentation()
## You can instantiate two different types of networks when experimenting with
## the inner class ExperimentsWithCIFAR. The network shown below is from the
## PyTorch tutorial
##
## https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
##
class Net(nn.Module):
def __init__(self):
super(DLStudio.ExperimentsWithCIFAR.Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
# self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
# x = self.pool(F.relu(self.conv1(x)))
x = nn.MaxPool2d(2,2)(F.relu(self.conv1(x)))
# x = self.pool(F.relu(self.conv2(x)))
x = nn.MaxPool2d(2,2)(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
## Instead of using the network shown above, you can also use the network shown below.
## if you are playing with the ExperimentsWithCIFAR inner class. If that's what you
## want to do, in the script "playing_with_cifar10.py" in the Examples directory,
## you will need to replace the statement
## model = exp_cifar.Net()
## by the statement
## model = exp_cifar.Net2()
##
class Net2(nn.Module):
def __init__(self):
"""
I created this network class just to see if it was possible to simply calculate
the size of the first of the fully connected layers from strides in the convo
layers up to that point and from the out_channels used in the top-most convo
layer. In what you see below, I am keeping track of all the strides by pushing
them into the array 'strides'. Subsequently, in the formula shown in line (A),
I use the product of all strides and the number of out_channels for the topmost
layer to compute the size of the first fully-connected layer.
"""
super(DLStudio.ExperimentsWithCIFAR.Net2, self).__init__()
self.relu = nn.ReLU()
strides = []
patch_size = 2
## conv1:
out_ch, ker_size, conv_stride, pool_stride = 128,5,1,2
self.conv1 = nn.Conv2d(3, out_ch, (ker_size,ker_size), padding=(ker_size-1)//2)
self.pool1 = nn.MaxPool2d(patch_size, pool_stride)
strides += (conv_stride, pool_stride)
## conv2:
in_ch = out_ch
out_ch, ker_size, conv_stride, pool_stride = 128,3,1,2
self.conv2 = nn.Conv2d(in_ch, out_ch, ker_size, padding=(ker_size-1)//2)
self.pool2 = nn.MaxPool2d(patch_size, pool_stride)
strides += (conv_stride, pool_stride)
## conv3:
## meant for repeated invocation, must have same in_ch, out_ch and strides of 1
in_ch = out_ch
out_ch, ker_size, conv_stride, pool_stride = in_ch,2,1,1
self.conv3 = nn.Conv2d(in_ch, out_ch, ker_size, padding=1)
self.pool3 = nn.MaxPool2d(patch_size, pool_stride)
# strides += (conv_stride, pool_stride)
## figure out the number of nodes needed for entry into fc:
in_size_for_fc = out_ch * (32 // np.prod(strides)) ** 2 ## (A)
self.in_size_for_fc = in_size_for_fc
self.fc1 = nn.Linear(in_size_for_fc, 150)
self.fc2 = nn.Linear(150, 100)
self.fc3 = nn.Linear(100, 10)
def forward(self, x):
## We know that forward() begins its with work x shaped as (4,3,32,32) where
## 4 is the batch size, 3 in_channels, and where the input image size is 32x32.
x = self.relu(self.conv1(x))
x = self.pool1(x)
x = self.relu(self.conv2(x))
x = self.pool2(x)
for _ in range(5):
x = self.pool3(self.relu(self.conv3(x)))
x = x.view(-1, self.in_size_for_fc)
x = self.relu(self.fc1( x ))
x = self.relu(self.fc2( x ))
x = self.fc3(x)
return x
def run_code_for_training(self, net, display_images=False):
self.dl_studio.run_code_for_training(net, display_images)
def save_model(self, model):
'''
Save the trained model to a disk file
'''
torch.save(model.state_dict(), self.dl_studio.path_saved_model)
def run_code_for_testing(self, model, display_images=False):
self.dl_studio.run_code_for_testing(model, display_images)
###%%%
########################################################################################
################### Start Definition of Inner Class SkipConnections ##################
class SkipConnections(nn.Module):
"""
This educational class is meant for illustrating the concepts related to the
use of skip connections in neural network. It is now well known that deep
networks are difficult to train because of the vanishing gradients problem.
What that means is that as the depth of network increases, the loss gradients
calculated for the early layers become more and more muted, which suppresses
the learning of the parameters in those layers. An important mitigation
strategy for addressing this problem consists of creating a CNN using blocks
with skip connections.
With the code shown in this inner class of the module, you can now experiment
with skip connections in a CNN to see how a deep network with this feature
might improve the classification results. As you will see in the code shown
below, the network that allows you to construct a CNN with skip connections
is named BMEnet. As shown in the script playing_with_skip_connections.py in
the Examples directory of the distribution, you can easily create a CNN with
arbitrary depth just by using the "depth" constructor option for the BMEnet
class. The basic block of the network constructed by BMEnet is called
SkipBlock which, very much like the BasicBlock in ResNet-18, has a couple of
convolutional layers whose output is combined with the input to the block.
Note that the value given to the "depth" constructor option for the
BMEnet class does NOT translate directly into the actual depth of the
CNN. [Again, see the script playing_with_skip_connections.py in the Examples
directory for how to use this option.] The value of "depth" is translated
into how many instances of SkipBlock to use for constructing the CNN.
"""
def load_cifar_10_dataset(self):
self.dl_studio.load_cifar_10_dataset()
def load_cifar_10_dataset_with_augmentation(self):
self.dl_studio.load_cifar_10_dataset_with_augmentation()
def __init__(self, dl_studio):
super(DLStudio.SkipConnections, self).__init__()
self.dl_studio = dl_studio
class SkipBlock(nn.Module):
"""
in inner class of DLStudio: SkipConnections
"""
def __init__(self, in_ch, out_ch, downsample=False, skip_connections=True):
super(DLStudio.SkipConnections.SkipBlock, self).__init__()
self.downsample = downsample
self.skip_connections = skip_connections
self.in_ch = in_ch
self.out_ch = out_ch
self.convo1 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
self.convo2 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
norm_layer1 = nn.BatchNorm2d
norm_layer2 = nn.BatchNorm2d
self.bn1 = norm_layer1(out_ch)
self.bn2 = norm_layer2(out_ch)
if downsample:
self.downsampler = nn.Conv2d(in_ch, out_ch, 1, stride=2)
def forward(self, x):
identity = x
out = self.convo1(x)
out = self.bn1(out)
out = torch.nn.functional.relu(out)
if self.in_ch == self.out_ch:
out = self.convo2(out)
out = self.bn2(out)
out = torch.nn.functional.relu(out)
if self.downsample:
out = self.downsampler(out)
identity = self.downsampler(identity)
if self.skip_connections:
if self.in_ch == self.out_ch:
out += identity
else:
out[:,:self.in_ch,:,:] += identity
out[:,self.in_ch:,:,:] += identity
return out
class BMEnet(nn.Module):
"""
in inner class of DLStudio: SkipConnections
"""
def __init__(self, skip_connections=True, depth=32):
super(DLStudio.SkipConnections.BMEnet, self).__init__()
if depth not in [8, 16, 32, 64]:
sys.exit("BMEnet has been tested for depth for only 8, 16, 32, and 64")
self.depth = depth // 8
self.conv = nn.Conv2d(3, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.skip64_arr = nn.ModuleList()
for i in range(self.depth):
self.skip64_arr.append(DLStudio.SkipConnections.SkipBlock(64, 64,
skip_connections=skip_connections))
self.skip64ds = DLStudio.SkipConnections.SkipBlock(64, 64,
downsample=True, skip_connections=skip_connections)
self.skip64to128 = DLStudio.SkipConnections.SkipBlock(64, 128,
skip_connections=skip_connections )
self.skip128_arr = nn.ModuleList()
for i in range(self.depth):
self.skip128_arr.append(DLStudio.SkipConnections.SkipBlock(128, 128,
skip_connections=skip_connections))
self.skip128ds = DLStudio.SkipConnections.SkipBlock(128,128,
downsample=True, skip_connections=skip_connections)
self.fc1 = nn.Linear(2048, 1000)
self.fc2 = nn.Linear(1000, 10)
def forward(self, x):
x = self.pool(torch.nn.functional.relu(self.conv(x)))
for i,skip64 in enumerate(self.skip64_arr[:self.depth//4]):
x = skip64(x)
x = self.skip64ds(x)
for i,skip64 in enumerate(self.skip64_arr[self.depth//4:]):
x = skip64(x)
x = self.skip64ds(x)
x = self.skip64to128(x)
for i,skip128 in enumerate(self.skip128_arr[:self.depth//4]):
x = skip128(x)
for i,skip128 in enumerate(self.skip128_arr[self.depth//4:]):
x = skip128(x)
x = x.view(-1, 2048 )
x = torch.nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return x
def run_code_for_training(self, net, display_images=False):
self.dl_studio.run_code_for_training(net, display_images)
def save_model(self, model):
'''
Save the trained model to a disk file
'''
torch.save(model.state_dict(), self.dl_studio.path_saved_model)
def run_code_for_testing(self, model, display_images=False):
self.dl_studio.run_code_for_testing(model, display_images=False)
########################################################################################
################# Start Definition of Inner Class CustomDataLoading ##################
class CustomDataLoading(nn.Module):
"""This is a testbed for experimenting with a completely grounds-up attempt at
designing a custom data loader. Ordinarily, if the basic format of how the
dataset is stored is similar to one of the datasets that the Torchvision
module knows about, you can go ahead and use that for your own dataset. At
worst, you may need to carry out some light customizations depending on the
number of classes involved, etc.
However, if the underlying dataset is stored in a manner that does not look
like anything in Torchvision, you have no choice but to supply yourself all
of the data loading infrastructure. That is what this inner class of the
DLStudio module is all about.
The custom data loading exercise here is related to a dataset called
PurdueShapes5 that contains 32x32 images of binary shapes belonging to the
following five classes:
1. rectangle
2. triangle
3. disk
4. oval
5. star
The dataset was generated by randomizing the sizes and the orientations
of these five patterns. Since the patterns are rotated with a very simple
non-interpolating transform, just the act of random rotations can introduce
boundary and even interior noise in the patterns.
Each 32x32 image is stored in the dataset as the following list:
[R, G, B, Bbox, Label]
where
R : is a 1024 element list of the values for the red component
of the color at all the pixels
B : the same as above but for the green component of the color
G : the same as above but for the blue component of the color
Bbox : a list like [x1,y1,x2,y2] that defines the bounding box
for the object in the image
Label : the shape of the object
I serialize the dataset with Python's pickle module and then compress it with
the gzip module.
You will find the following dataset directories in the "data" subdirectory
of Examples in the DLStudio distro:
PurdueShapes5-10000-train.gz
PurdueShapes5-1000-test.gz
PurdueShapes5-20-train.gz
PurdueShapes5-20-test.gz
The number that follows the main name string "PurdueShapes5-" is for the
number of images in the dataset.
You will find the last two datasets, with 20 images each, useful for debugging
your logic for object detection and bounding-box regression.
"""
def __init__(self, dl_studio, dataserver_train=None, dataserver_test=None, dataset_file_train=None, dataset_file_test=None):
super(DLStudio.CustomDataLoading, self).__init__()
self.dl_studio = dl_studio
self.dataserver_train = dataserver_train
self.dataserver_test = dataserver_test
class PurdueShapes5Dataset(torch.utils.data.Dataset):
def __init__(self, dl_studio, train_or_test, dataset_file, transform=None):
super(DLStudio.CustomDataLoading.PurdueShapes5Dataset, self).__init__()
if train_or_test == 'train' and dataset_file == "PurdueShapes5-10000-train.gz":
if os.path.exists("torch_saved_PurdueShapes5-10000_dataset.pt") and \
os.path.exists("torch_saved_PurdueShapes5_label_map.pt"):
print("\nLoading training data from the torch-saved archive")
self.dataset = torch.load("torch_saved_PurdueShapes5-10000_dataset.pt")
self.label_map = torch.load("torch_saved_PurdueShapes5_label_map.pt")
# reverse the key-value pairs in the label dictionary:
self.class_labels = dict(map(reversed, self.label_map.items()))
self.transform = transform
else:
print("""\n\n\nLooks like this is the first time you will be loading in\n"""
"""the dataset for this script. First time loading could take\n"""
"""a minute or so. Any subsequent attempts will only take\n"""
"""a few seconds.\n\n\n""")
root_dir = dl_studio.dataroot
f = gzip.open(root_dir + dataset_file, 'rb')
dataset = f.read()
# self.dataset, self.label_map = pickle.loads(dataset)
self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
torch.save(self.dataset, "torch_saved_PurdueShapes5-10000_dataset.pt")
torch.save(self.label_map, "torch_saved_PurdueShapes5_label_map.pt")
# reverse the key-value pairs in the label dictionary:
self.class_labels = dict(map(reversed, self.label_map.items()))
self.transform = transform
else:
root_dir = dl_studio.dataroot
f = gzip.open(root_dir + dataset_file, 'rb')
dataset = f.read()
# self.dataset, self.label_map = pickle.loads(dataset)
self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
# reverse the key-value pairs in the label dictionary:
self.class_labels = dict(map(reversed, self.label_map.items()))
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
r = np.array( self.dataset[idx][0] )
g = np.array( self.dataset[idx][1] )
b = np.array( self.dataset[idx][2] )
R,G,B = r.reshape(32,32), g.reshape(32,32), b.reshape(32,32)
im_tensor = torch.zeros(3,32,32, dtype=torch.float)
im_tensor[0,:,:] = torch.from_numpy(R)
im_tensor[1,:,:] = torch.from_numpy(G)
im_tensor[2,:,:] = torch.from_numpy(B)
sample = {'image' : im_tensor,
'bbox' : self.dataset[idx][3],
'label' : self.dataset[idx][4] }
if self.transform:
sample = self.transform(sample)
return sample
def load_PurdueShapes5_dataset(self, dataserver_train, dataserver_test ):
transform = tvt.Compose([tvt.ToTensor(),
tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
self.train_dataloader = torch.utils.data.DataLoader(dataserver_train,
batch_size=self.dl_studio.batch_size,shuffle=True, num_workers=4)
self.test_dataloader = torch.utils.data.DataLoader(dataserver_test,
batch_size=self.dl_studio.batch_size,shuffle=False, num_workers=4)
class SkipBlock(nn.Module):
def __init__(self, in_ch, out_ch, downsample=False, skip_connections=True):
super(DLStudio.SkipConnections.SkipBlock, self).__init__()
self.downsample = downsample
self.skip_connections = skip_connections
self.in_ch = in_ch
self.out_ch = out_ch
self.convo1 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
self.convo2 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
norm_layer1 = nn.BatchNorm2d
norm_layer2 = nn.BatchNorm2d
self.bn1 = norm_layer1(out_ch)
self.bn2 = norm_layer2(out_ch)
if downsample:
self.downsampler = nn.Conv2d(in_ch, out_ch, 1, stride=2)
def forward(self, x):
identity = x
out = self.convo1(x)
out = self.bn1(out)
out = torch.nn.functional.relu(out)
if self.in_ch == self.out_ch:
out = self.convo2(out)
out = self.bn2(out)
out = torch.nn.functional.relu(out)
if self.downsample:
out = self.downsampler(out)
identity = self.downsampler(identity)
if self.skip_connections:
if self.in_ch == self.out_ch:
out += identity
else:
out[:,:self.in_ch,:,:] += identity
out[:,self.in_ch:,:,:] += identity
return out
class BMEnet(nn.Module):
"""
in inner class of DLStudio: CustomeDataloading
"""
def __init__(self, skip_connections=True, depth=32):
super(DLStudio.SkipConnections.BMEnet, self).__init__()
if depth not in [6, 16, 32, 64]:
sys.exit("BMEnet has been tested for depth for only 16, 32, and 64")
self.depth = depth // 8
self.conv = nn.Conv2d(3, 64, 3, padding=1)
# self.pool = nn.MaxPool2d(2, 2)
self.skip64_arr = nn.ModuleList()
for i in range(self.depth):
self.skip64_arr.append(DLStudio.SkipConnections.SkipBlock(64, 64,
skip_connections=skip_connections))
self.skip64ds = DLStudio.SkipConnections.SkipBlock(64, 64,
downsample=True, skip_connections=skip_connections)
self.skip64to128 = DLStudio.SkipConnections.SkipBlock(64, 128,
skip_connections=skip_connections )
self.skip128_arr = nn.ModuleList()
for i in range(self.depth):
self.skip128_arr.append(DLStudio.SkipConnections.SkipBlock(128, 128,
skip_connections=skip_connections))
self.skip128ds = DLStudio.SkipConnections.SkipBlock(128,128,
downsample=True, skip_connections=skip_connections)
self.fc1 = nn.Linear(2048, 1000)
self.fc2 = nn.Linear(1000, 10)
def forward(self, x):
# x = self.pool(torch.nn.functional.relu(self.conv(x)))
x = nn.MaxPool2d(2,2)(torch.nn.functional.relu(self.conv(x)))
for i,skip64 in enumerate(self.skip64_arr[:self.depth//4]):
x = skip64(x)
x = self.skip64ds(x)
for i,skip64 in enumerate(self.skip64_arr[self.depth//4:]):
x = skip64(x)
x = self.skip64ds(x)
x = self.skip64to128(x)
for i,skip128 in enumerate(self.skip128_arr[:self.depth//4]):
x = skip128(x)
for i,skip128 in enumerate(self.skip128_arr[self.depth//4:]):
x = skip128(x)
x = x.view(-1, 2048 )
x = torch.nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return x
def run_code_for_training_with_custom_loading(self, net):
filename_for_out = "performance_numbers_" + str(self.dl_studio.epochs) + ".txt"
FILE = open(filename_for_out, 'w')
net = copy.deepcopy(net)
net = net.to(self.dl_studio.device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),
lr=self.dl_studio.learning_rate, momentum=self.dl_studio.momentum)
for epoch in range(self.dl_studio.epochs):
# print("\n")
running_loss = 0.0
for i, data in enumerate(self.train_dataloader):
inputs, bounding_box, labels = data['image'], data['bbox'], data['label']
if self.dl_studio.debug_train and i % 1000 == 999:
print("\n\n\nlabels: %s" % str(labels))
print("\n\n\ntype of labels: %s" % type(labels))
print("\n\n[iter=%d:] Ground Truth: " % (i+1) +
' '.join('%5s' % self.dataserver_train.class_labels[labels[j].item()] for j in range(self.dl_studio.batch_size)))
inputs = inputs.to(self.dl_studio.device)
labels = labels.to(self.dl_studio.device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
if self.dl_studio.debug_train and i % 1000 == 999:
_, predicted = torch.max(outputs.data, 1)
print("[iter=%d:] Predicted Labels: " % (i+1) +
' '.join('%5s' % self.dataserver.class_labels[predicted[j]]
for j in range(self.dl_studio.batch_size)))
self.dl_studio.display_tensor_as_image(torchvision.utils.make_grid(
inputs, normalize=True), "see terminal for TRAINING results at iter=%d" % (i+1))
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 1000 == 999:
avg_loss = running_loss / float(1000)
print("[epoch:%d, batch:%5d] loss: %.3f" % (epoch + 1, i + 1, avg_loss))
FILE.write("%.3f\n" % avg_loss)
FILE.flush()
running_loss = 0.0
print("\nFinished Training\n")
self.save_model(net)
def save_model(self, model):
'''
Save the trained model to a disk file
'''
torch.save(model.state_dict(), self.dl_studio.path_saved_model)
def run_code_for_testing_with_custom_loading(self, net):
net.load_state_dict(torch.load(self.dl_studio.path_saved_model))
correct = 0
total = 0
confusion_matrix = torch.zeros(len(self.dataserver_train.class_labels),
len(self.dataserver_train.class_labels))
class_correct = [0] * len(self.dataserver_train.class_labels)
class_total = [0] * len(self.dataserver_train.class_labels)
with torch.no_grad():
for i, data in enumerate(self.test_dataloader):
images, bounding_box, labels = data['image'], data['bbox'], data['label']
labels = labels.tolist()
if self.dl_studio.debug_test and i % 1000 == 0:
print("\n\n[i=%d:] Ground Truth: " %i + ' '.join('%10s' %
self.dataserver_train.class_labels[labels[j]] for j in range(self.dl_studio.batch_size)))
outputs = net(images)
## max() returns two things: the max value and its index in the 10 element
## output vector. We are only interested in the index --- since that is
## essentially the predicted class label:
_, predicted = torch.max(outputs.data, 1)
predicted = predicted.tolist()
if self.dl_studio.debug_test and i % 1000 == 0:
print("[i=%d:] Predicted Labels: " %i + ' '.join('%10s' %
self.dataserver_train.class_labels[predicted[j]] for j in range(self.dl_studio.batch_size)))
self.dl_studio.display_tensor_as_image(
torchvision.utils.make_grid(images, normalize=True),
"see terminal for test results at i=%d" % i)
for label,prediction in zip(labels,predicted):
confusion_matrix[label][prediction] += 1
total += len(labels)
correct += [predicted[ele] == labels[ele] for ele in range(len(predicted))].count(True)
comp = [predicted[ele] == labels[ele] for ele in range(len(predicted))]
for j in range(self.dl_studio.batch_size):
label = labels[j]
class_correct[label] += comp[j]
class_total[label] += 1
print("\n")
for j in range(len(self.dataserver_train.class_labels)):
print('Prediction accuracy for %5s : %2d %%' % (
self.dataserver_train.class_labels[j], 100 * class_correct[j] / class_total[j]))
print("\n\n\nOverall accuracy of the network on the 10000 test images: %d %%" %
(100 * correct / float(total)))
print("\n\nDisplaying the confusion matrix:\n")
out_str = " "
for j in range(len(self.dataserver_train.class_labels)):
out_str += "%15s" % self.dataserver_train.class_labels[j]
print(out_str + "\n")
for i,label in enumerate(self.dataserver_train.class_labels):
out_percents = [100 * confusion_matrix[i,j] / float(class_total[i])
for j in range(len(self.dataserver_train.class_labels))]
out_percents = ["%.2f" % item.item() for item in out_percents]
out_str = "%12s: " % self.dataserver_train.class_labels[i]
for j in range(len(self.dataserver_train.class_labels)):
out_str += "%15s" % out_percents[j]
print(out_str)
########################################################################################
################### Start Definition of Inner Class DetectAndLocalize ################
class DetectAndLocalize(nn.Module):
"""
The purpose of this inner class is to focus on object detection in images --- as
opposed to image classification. Most people would say that object detection
is a more challenging problem than image classification because, in general,
the former also requires localization. The simplest interpretation of what
is meant by localization is that the code that carries out object detection
must also output a bounding-box rectangle for the object that was detected.
You will find in this inner class some examples of LOADnet classes meant
for solving the object detection and localization problem. The acronym
"LOAD" in "LOADnet" stands for
"LOcalization And Detection"
The different network examples included here are LOADnet1, LOADnet2, and
LOADnet3. For now, only pay attention to LOADnet2 since that's the class I
have worked with the most for the 1.0.7 distribution.
"""
def __init__(self, dl_studio, dataserver_train=None, dataserver_test=None, dataset_file_train=None, dataset_file_test=None):
super(DLStudio.DetectAndLocalize, self).__init__()
self.dl_studio = dl_studio
self.dataserver_train = dataserver_train
self.dataserver_test = dataserver_test
self.debug = False
class PurdueShapes5Dataset(torch.utils.data.Dataset):
def __init__(self, dl_studio, train_or_test, dataset_file, transform=None):
super(DLStudio.DetectAndLocalize.PurdueShapes5Dataset, self).__init__()
if train_or_test == 'train' and dataset_file == "PurdueShapes5-10000-train.gz":
if os.path.exists("torch-saved-PurdueShapes5-10000-dataset.pt") and \
os.path.exists("torch-saved-PurdueShapes5-label-map.pt"):
print("\nLoading training data from the torch-saved archive")
self.dataset = torch.load("torch-saved-PurdueShapes5-10000-dataset.pt")
self.label_map = torch.load("torch-saved-PurdueShapes5-label-map.pt")
# reverse the key-value pairs in the label dictionary:
self.class_labels = dict(map(reversed, self.label_map.items()))
self.transform = transform
else:
print("""\n\n\nLooks like this is the first time you will be loading in\n"""
"""the dataset for this script. First time loading could take\n"""
"""a minute or so. Any subsequent attempts will only take\n"""
"""a few seconds.\n\n\n""")
root_dir = dl_studio.dataroot
f = gzip.open(root_dir + dataset_file, 'rb')
dataset = f.read()
if sys.version_info[0] == 3:
self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
else:
self.dataset, self.label_map = pickle.loads(dataset)
torch.save(self.dataset, "torch-saved-PurdueShapes5-10000-dataset.pt")
torch.save(self.label_map, "torch-saved-PurdueShapes5-label-map.pt")
# reverse the key-value pairs in the label dictionary:
self.class_labels = dict(map(reversed, self.label_map.items()))
self.transform = transform
elif train_or_test == 'train' and dataset_file == "PurdueShapes5-10000-train-noise-20.gz":
if os.path.exists("torch-saved-PurdueShapes5-10000-dataset-noise-20.pt") and \
os.path.exists("torch-saved-PurdueShapes5-label-map.pt"):
print("\nLoading training data from the torch-saved archive")
self.dataset = torch.load("torch-saved-PurdueShapes5-10000-dataset-noise-20.pt")
self.label_map = torch.load("torch-saved-PurdueShapes5-label-map.pt")
# reverse the key-value pairs in the label dictionary:
self.class_labels = dict(map(reversed, self.label_map.items()))
self.transform = transform
else:
print("""\n\n\nLooks like this is the first time you will be loading in\n"""
"""the dataset for this script. First time loading could take\n"""
"""a minute or so. Any subsequent attempts will only take\n"""
"""a few seconds.\n\n\n""")
root_dir = dl_studio.dataroot
f = gzip.open(root_dir + dataset_file, 'rb')
dataset = f.read()
if sys.version_info[0] == 3:
self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
else:
self.dataset, self.label_map = pickle.loads(dataset)
torch.save(self.dataset, "torch-saved-PurdueShapes5-10000-dataset-noise-20.pt")
torch.save(self.label_map, "torch-saved-PurdueShapes5-label-map.pt")
# reverse the key-value pairs in the label dictionary:
self.class_labels = dict(map(reversed, self.label_map.items()))
self.transform = transform
elif train_or_test == 'train' and dataset_file == "PurdueShapes5-10000-train-noise-50.gz":
if os.path.exists("torch-saved-PurdueShapes5-10000-dataset-noise-50.pt") and \
os.path.exists("torch-saved-PurdueShapes5-label-map.pt"):
print("\nLoading training data from the torch-saved archive")
self.dataset = torch.load("torch-saved-PurdueShapes5-10000-dataset-noise-50.pt")
self.label_map = torch.load("torch-saved-PurdueShapes5-label-map.pt")
# reverse the key-value pairs in the label dictionary:
self.class_labels = dict(map(reversed, self.label_map.items()))
self.transform = transform
else:
print("""\n\n\nLooks like this is the first time you will be loading in\n"""
"""the dataset for this script. First time loading could take\n"""
"""a minute or so. Any subsequent attempts will only take\n"""
"""a few seconds.\n\n\n""")
root_dir = dl_studio.dataroot
f = gzip.open(root_dir + dataset_file, 'rb')
dataset = f.read()
if sys.version_info[0] == 3:
self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
else:
self.dataset, self.label_map = pickle.loads(dataset)
torch.save(self.dataset, "torch-saved-PurdueShapes5-10000-dataset-noise-50.pt")
torch.save(self.label_map, "torch-saved-PurdueShapes5-label-map.pt")
# reverse the key-value pairs in the label dictionary:
self.class_labels = dict(map(reversed, self.label_map.items()))
self.transform = transform
elif train_or_test == 'train' and dataset_file == "PurdueShapes5-10000-train-noise-80.gz":
if os.path.exists("torch-saved-PurdueShapes5-10000-dataset-noise-80.pt") and \
os.path.exists("torch-saved-PurdueShapes5-label-map.pt"):
print("\nLoading training data from the torch-saved archive")
self.dataset = torch.load("torch-saved-PurdueShapes5-10000-dataset-noise-80.pt")
self.label_map = torch.load("torch-saved-PurdueShapes5-label-map.pt")
# reverse the key-value pairs in the label dictionary:
self.class_labels = dict(map(reversed, self.label_map.items()))
self.transform = transform
else:
print("""\n\n\nLooks like this is the first time you will be loading in\n"""
"""the dataset for this script. First time loading could take\n"""
"""a minute or so. Any subsequent attempts will only take\n"""
"""a few seconds.\n\n\n""")
root_dir = dl_studio.dataroot
f = gzip.open(root_dir + dataset_file, 'rb')
dataset = f.read()
if sys.version_info[0] == 3:
self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
else:
self.dataset, self.label_map = pickle.loads(dataset)
torch.save(self.dataset, "torch-saved-PurdueShapes5-10000-dataset-noise-80.pt")
torch.save(self.label_map, "torch-saved-PurdueShapes5-label-map.pt")
# reverse the key-value pairs in the label dictionary:
self.class_labels = dict(map(reversed, self.label_map.items()))
self.transform = transform
else:
root_dir = dl_studio.dataroot
f = gzip.open(root_dir + dataset_file, 'rb')
dataset = f.read()
if sys.version_info[0] == 3:
self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
else:
self.dataset, self.label_map = pickle.loads(dataset)
# reverse the key-value pairs in the label dictionary:
self.class_labels = dict(map(reversed, self.label_map.items()))
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
r = np.array( self.dataset[idx][0] )
g = np.array( self.dataset[idx][1] )
b = np.array( self.dataset[idx][2] )
R,G,B = r.reshape(32,32), g.reshape(32,32), b.reshape(32,32)
im_tensor = torch.zeros(3,32,32, dtype=torch.float)
im_tensor[0,:,:] = torch.from_numpy(R)
im_tensor[1,:,:] = torch.from_numpy(G)
im_tensor[2,:,:] = torch.from_numpy(B)
bb_tensor = torch.tensor(self.dataset[idx][3], dtype=torch.float)
sample = {'image' : im_tensor,
'bbox' : bb_tensor,
'label' : self.dataset[idx][4] }
if self.transform:
sample = self.transform(sample)
return sample
def load_PurdueShapes5_dataset(self, dataserver_train, dataserver_test ):
# transform = tvt.Compose([tvt.ToTensor(),
# tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
self.train_dataloader = torch.utils.data.DataLoader(dataserver_train,
batch_size=self.dl_studio.batch_size,shuffle=True, num_workers=4)
self.test_dataloader = torch.utils.data.DataLoader(dataserver_test,
batch_size=self.dl_studio.batch_size,shuffle=False, num_workers=4)
class SkipBlock(nn.Module):
"""
Inner class is DetectAndLocalize
"""
def __init__(self, in_ch, out_ch, downsample=False, skip_connections=True):
super(DLStudio.DetectAndLocalize.SkipBlock, self).__init__()
self.downsample = downsample
self.skip_connections = skip_connections
self.in_ch = in_ch
self.out_ch = out_ch
self.convo1 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
self.convo2 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
norm_layer1 = nn.BatchNorm2d
norm_layer2 = nn.BatchNorm2d
self.bn1 = norm_layer1(out_ch)
self.bn2 = norm_layer2(out_ch)
if downsample:
self.downsampler = nn.Conv2d(in_ch, out_ch, 1, stride=2)
def forward(self, x):
identity = x
out = self.convo1(x)
out = self.bn1(out)
out = torch.nn.functional.relu(out)
if self.in_ch == self.out_ch:
out = self.convo2(out)
out = self.bn2(out)
out = torch.nn.functional.relu(out)
if self.downsample:
out = self.downsampler(out)
identity = self.downsampler(identity)
if self.skip_connections:
if self.in_ch == self.out_ch:
out += identity
else:
out[:,:self.in_ch,:,:] += identity
out[:,self.in_ch:,:,:] += identity
return out
class LOADnet1(nn.Module):
"""
The acronym 'LOAD' stands for 'LOcalization And Detection'.
LOADnet1 only uses fully-connected layers for the regression
"""
def __init__(self, skip_connections=True, depth=32):
super(DLStudio.DetectAndLocalize.LOADnet1, self).__init__()
self.pool_count = 3
self.depth = depth // 2
self.conv = nn.Conv2d(3, 64, 3, padding=1)
# self.pool = nn.MaxPool2d(2, 2)
self.skip64 = DLStudio.DetectAndLocalize.SkipBlock(64, 64,
skip_connections=skip_connections)
self.skip64ds = DLStudio.DetectAndLocalize.SkipBlock(64, 64,
downsample=True, skip_connections=skip_connections)
self.skip64to128 = DLStudio.DetectAndLocalize.SkipBlock(64, 128,
skip_connections=skip_connections )
self.skip128 = DLStudio.DetectAndLocalize.SkipBlock(128, 128,
skip_connections=skip_connections)
self.skip128ds = DLStudio.DetectAndLocalize.SkipBlock(128,128,
downsample=True, skip_connections=skip_connections)
self.fc1 = nn.Linear(128 * (32 // 2**self.pool_count)**2, 1000)
self.fc2 = nn.Linear(1000, 5)
self.fc3 = nn.Linear(32768, 1000)
self.fc4 = nn.Linear(1000, 4)
def forward(self, x):
# x = self.pool(torch.nn.functional.relu(self.conv(x)))
x = nn.MaxPool2d(2,2)(torch.nn.functional.relu(self.conv(x)))
## The labeling section:
for _ in range(self.depth // 4):
x1 = self.skip64(x)
x1 = self.skip64ds(x1)
for _ in range(self.depth // 4):
x1 = self.skip64(x1)
x1 = self.skip64to128(x1)
for _ in range(self.depth // 4):
x1 = self.skip128(x1)
x1 = self.skip128ds(x1)
for _ in range(self.depth // 4):
x1 = self.skip128(x1)
x1 = x1.view(-1, 128 * (32 // 2**self.pool_count)**2 )
x1 = torch.nn.functional.relu(self.fc1(x1))
x1 = self.fc2(x1)
## The Bounding Box regression:
x2 = x.view(-1, 32768 )
x2 = torch.nn.functional.relu(self.fc3(x2))
x2 = self.fc4(x2)
return x1,x2
class LOADnet2(nn.Module):
"""
The acronym 'LOAD' stands for 'LOcalization And Detection'.
LOADnet2 uses both convo and linear layers for regression
"""
def __init__(self, skip_connections=True, depth=8):
super(DLStudio.DetectAndLocalize.LOADnet2, self).__init__()
if depth not in [8,10,12,14,16]:
sys.exit("LOADnet2 has only been tested for 'depth' values 8, 10, 12, 14, and 16")
self.depth = depth // 2
self.conv = nn.Conv2d(3, 64, 3, padding=1)
# self.pool = nn.MaxPool2d(2, 2)
self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(128)
self.skip64_arr = nn.ModuleList()
for i in range(self.depth):
self.skip64_arr.append(DLStudio.DetectAndLocalize.SkipBlock(64, 64,
skip_connections=skip_connections))
self.skip64ds = DLStudio.DetectAndLocalize.SkipBlock(64, 64,
downsample=True, skip_connections=skip_connections)
self.skip64to128 = DLStudio.DetectAndLocalize.SkipBlock(64, 128,
skip_connections=skip_connections )
self.skip128_arr = nn.ModuleList()
for i in range(self.depth):
self.skip128_arr.append(DLStudio.DetectAndLocalize.SkipBlock(128, 128,
skip_connections=skip_connections))
self.skip128ds = DLStudio.DetectAndLocalize.SkipBlock(128,128,
downsample=True, skip_connections=skip_connections)
self.fc1 = nn.Linear(2048, 1000)
self.fc2 = nn.Linear(1000, 10)
## for regression
self.conv_seqn = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.fc_seqn = nn.Sequential(
nn.Linear(16384, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 4)
)
def forward(self, x):
x = nn.MaxPool2d(2,2)(torch.nn.functional.relu(self.conv(x)))
## The labeling section:
x1 = x.clone()
for i,skip64 in enumerate(self.skip64_arr[:self.depth//4]):
x1 = skip64(x1)
x1 = self.skip64ds(x1)
for i,skip64 in enumerate(self.skip64_arr[self.depth//4:]):
x1 = skip64(x1)
x1 = self.bn1(x1)
x1 = self.skip64to128(x1)
for i,skip128 in enumerate(self.skip128_arr[:self.depth//4]):
x1 = skip128(x1)
x1 = self.bn2(x1)
x1 = self.skip128ds(x1)
for i,skip128 in enumerate(self.skip128_arr[self.depth//4:]):
x1 = skip128(x1)
x1 = x1.view(-1, 2048 )
x1 = torch.nn.functional.relu(self.fc1(x1))
x1 = self.fc2(x1)
## The Bounding Box regression:
x2 = self.conv_seqn(x)
x2 = self.conv_seqn(x2)
# flatten
x2 = x2.view(x.size(0), -1)
x2 = self.fc_seqn(x2)
return x1,x2
class LOADnet3(nn.Module):
"""
The acronym 'LOAD' stands for 'LOcalization And Detection'.
LOADnet3 uses both convo and linear layers for regression
"""
def __init__(self, skip_connections=True, depth=8):
super(DLStudio.DetectAndLocalize.LOADnet3, self).__init__()
if depth not in [4, 8, 16]:
sys.exit("LOADnet2 has been tested for 'depth' for only 4, 8, and 16")
self.depth = depth // 4
self.conv = nn.Conv2d(3, 64, 3, padding=1)
self.skip64_arr = nn.ModuleList()
for i in range(self.depth):
self.skip64_arr.append(DLStudio.DetectAndLocalize.SkipBlock(64, 64,
skip_connections=skip_connections))
self.skip64ds = DLStudio.DetectAndLocalize.SkipBlock(64, 64,
downsample=True, skip_connections=skip_connections)
self.skip64to128 = DLStudio.DetectAndLocalize.SkipBlock(64, 128,
skip_connections=skip_connections )
self.skip128_arr = nn.ModuleList()
for i in range(self.depth):
self.skip128_arr.append(DLStudio.DetectAndLocalize.SkipBlock(128, 128,
skip_connections=skip_connections))
self.skip128ds = DLStudio.DetectAndLocalize.SkipBlock(128,128,
downsample=True, skip_connections=skip_connections)
self.fc1 = nn.Linear(2048, 1000)
self.fc2 = nn.Linear(1000, 10)
## for regression
self.conv_seqn = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.fc_seqn = nn.Sequential(
nn.Linear(16384, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 4)
)
def forward(self, x):
# x = self.pool(torch.nn.functional.relu(self.conv(x)))
x = nn.MaxPool2d(2,2)(torch.nn.functional.relu(self.conv(x)))
## The labeling section:
x1 = x.clone()
for i,skip64 in enumerate(self.skip64_arr[:self.depth//4]):
x1 = skip64(x1)
x1 = self.skip64ds(x1)
for i,skip64 in enumerate(self.skip64_arr[self.depth//4:]):
x1 = skip64(x1)
x1 = self.skip64ds(x1)
x1 = self.skip64to128(x1)
for i,skip128 in enumerate(self.skip128_arr[:self.depth//4]):
x1 = skip128(x1)
for i,skip128 in enumerate(self.skip128_arr[self.depth//4:]):
x1 = skip128(x1)
x1 = x1.view(-1, 2048 )
x1 = torch.nn.functional.relu(self.fc1(x1))
x1 = self.fc2(x1)
## The Bounding Box regression:
for _ in range(4):
x2 = self.skip64(x)
x2 = self.skip64to128(x2)
for _ in range(4):
x2 = self.skip128(x2)
x2 = x.view(-1, 128 * (32 // 2**self.pool_count)**2 )
x2 = torch.nn.functional.relu(self.fc3(x2))
x2 = self.fc4(x2)
return x1,x2
class IOULoss(nn.Module):
def __init__(self, batch_size):
super(DLStudio.DetectAndLocalize.IOULoss, self).__init__()
self.batch_size = batch_size
def forward(self, input, target):
composite_loss = []
for idx in range(self.batch_size):
union = intersection = 0.0
for i in range(32):
for j in range(32):
inp = input[idx,i,j]
tap = target[idx,i,j]
if (inp == tap) and (inp==1):
intersection += 1
union += 1
elif (inp != tap) and ((inp==1) or (tap==1)):
union += 1
if union == 0.0:
raise Exception("something_wrong")
batch_sample_iou = intersection / float(union)
composite_loss.append(batch_sample_iou)
total_iou_for_batch = sum(composite_loss)
return 1 - torch.tensor([total_iou_for_batch / self.batch_size])
def run_code_for_training_with_CrossEntropy_and_MSE_Losses(self, net):
filename_for_out1 = "performance_numbers_" + str(self.dl_studio.epochs) + "label.txt"
filename_for_out2 = "performance_numbers_" + str(self.dl_studio.epochs) + "regres.txt"
FILE1 = open(filename_for_out1, 'w')
FILE2 = open(filename_for_out2, 'w')
net = copy.deepcopy(net)
net = net.to(self.dl_studio.device)
criterion1 = nn.CrossEntropyLoss()
criterion2 = nn.MSELoss()
optimizer = optim.SGD(net.parameters(),
lr=self.dl_studio.learning_rate, momentum=self.dl_studio.momentum)
for epoch in range(self.dl_studio.epochs):
running_loss_labeling = 0.0
running_loss_regression = 0.0
for i, data in enumerate(self.train_dataloader):
gt_too_small = False
inputs, bbox_gt, labels = data['image'], data['bbox'], data['label']
if self.dl_studio.debug_train and i % 500 == 499:
print("\n\n[epoch=%d iter=%d:] Ground Truth: " % (epoch+1, i+1) +
' '.join('%10s' % self.dataserver_train.class_labels[labels[j].item()] for j in range(self.dl_studio.batch_size)))
inputs = inputs.to(self.dl_studio.device)
labels = labels.to(self.dl_studio.device)
bbox_gt = bbox_gt.to(self.dl_studio.device)
optimizer.zero_grad()
if self.debug:
self.dl_studio.display_tensor_as_image(
torchvision.utils.make_grid(inputs.cpu(), nrow=4, normalize=True, padding=2, pad_value=10))
outputs = net(inputs)
outputs_label = outputs[0]
bbox_pred = outputs[1]
if self.dl_studio.debug_train and i % 500 == 499:
inputs_copy = inputs.detach().clone()
inputs_copy = inputs_copy.cpu()
bbox_pc = bbox_pred.detach().clone()
bbox_pc[bbox_pc<0] = 0
bbox_pc[bbox_pc>31] = 31
bbox_pc[torch.isnan(bbox_pc)] = 0
_, predicted = torch.max(outputs_label.data, 1)
print("[epoch=%d iter=%d:] Predicted Labels: " % (epoch+1, i+1) +
' '.join('%10s' % self.dataserver_train.class_labels[predicted[j].item()]
for j in range(self.dl_studio.batch_size)))
for idx in range(self.dl_studio.batch_size):
i1 = int(bbox_gt[idx][1])
i2 = int(bbox_gt[idx][3])
j1 = int(bbox_gt[idx][0])
j2 = int(bbox_gt[idx][2])
k1 = int(bbox_pc[idx][1])
k2 = int(bbox_pc[idx][3])
l1 = int(bbox_pc[idx][0])
l2 = int(bbox_pc[idx][2])
print(" gt_bb: [%d,%d,%d,%d]"%(j1,i1,j2,i2))
print(" pred_bb: [%d,%d,%d,%d]"%(l1,k1,l2,k2))
inputs_copy[idx,0,i1:i2,j1] = 255
inputs_copy[idx,0,i1:i2,j2] = 255
inputs_copy[idx,0,i1,j1:j2] = 255
inputs_copy[idx,0,i2,j1:j2] = 255
inputs_copy[idx,2,k1:k2,l1] = 255
inputs_copy[idx,2,k1:k2,l2] = 255
inputs_copy[idx,2,k1,l1:l2] = 255
inputs_copy[idx,2,k2,l1:l2] = 255
loss_labeling = criterion1(outputs_label, labels)
loss_labeling.backward(retain_graph=True)
loss_regression = criterion2(bbox_pred, bbox_gt)
loss_regression.backward()
optimizer.step()
running_loss_labeling += loss_labeling.item()
running_loss_regression += loss_regression.item()
if i % 500 == 499:
avg_loss_labeling = running_loss_labeling / float(500)
avg_loss_regression = running_loss_regression / float(500)
print("\n[epoch:%d, iteration:%5d] loss_labeling: %.3f loss_regression: %.3f " % (epoch + 1, i + 1, avg_loss_labeling, avg_loss_regression))
FILE1.write("%.3f\n" % avg_loss_labeling)
FILE1.flush()
FILE2.write("%.3f\n" % avg_loss_regression)
FILE2.flush()
running_loss_labeling = 0.0
running_loss_regression = 0.0
if self.dl_studio.debug_train and i%500==499:
# if self.dl_studio.debug_train and ((epoch==0 and (i==0 or i==9 or i==99)) or i%500==499):
self.dl_studio.display_tensor_as_image(
torchvision.utils.make_grid(inputs_copy, normalize=True),
"see terminal for TRAINING results at iter=%d" % (i+1))
print("\nFinished Training\n")
self.save_model(net)
def save_model(self, model):
'''
Save the trained model to a disk file
'''
torch.save(model.state_dict(), self.dl_studio.path_saved_model)
def run_code_for_testing_detection_and_localization(self, net):
net.load_state_dict(torch.load(self.dl_studio.path_saved_model))
correct = 0
total = 0
confusion_matrix = torch.zeros(len(self.dataserver_train.class_labels),
len(self.dataserver_train.class_labels))
class_correct = [0] * len(self.dataserver_train.class_labels)
class_total = [0] * len(self.dataserver_train.class_labels)
with torch.no_grad():
for i, data in enumerate(self.test_dataloader):
images, bounding_box, labels = data['image'], data['bbox'], data['label']
labels = labels.tolist()
if self.dl_studio.debug_test and i % 50 == 0:
print("\n\n[i=%d:] Ground Truth: " %i + ' '.join('%10s' %
self.dataserver_train.class_labels[labels[j]] for j in range(self.dl_studio.batch_size)))
outputs = net(images)
outputs_label = outputs[0]
outputs_regression = outputs[1]
outputs_regression[outputs_regression < 0] = 0
outputs_regression[outputs_regression > 31] = 31
outputs_regression[torch.isnan(outputs_regression)] = 0
output_bb = outputs_regression.tolist()
_, predicted = torch.max(outputs_label.data, 1)
predicted = predicted.tolist()
if self.dl_studio.debug_test and i % 50 == 0:
print("[i=%d:] Predicted Labels: " %i + ' '.join('%10s' %
self.dataserver_train.class_labels[predicted[j]] for j in range(self.dl_studio.batch_size)))
for idx in range(self.dl_studio.batch_size):
i1 = int(bounding_box[idx][1])
i2 = int(bounding_box[idx][3])
j1 = int(bounding_box[idx][0])
j2 = int(bounding_box[idx][2])
k1 = int(output_bb[idx][1])
k2 = int(output_bb[idx][3])
l1 = int(output_bb[idx][0])
l2 = int(output_bb[idx][2])
print(" gt_bb: [%d,%d,%d,%d]"%(j1,i1,j2,i2))
print(" pred_bb: [%d,%d,%d,%d]"%(l1,k1,l2,k2))
images[idx,0,i1:i2,j1] = 255
images[idx,0,i1:i2,j2] = 255
images[idx,0,i1,j1:j2] = 255
images[idx,0,i2,j1:j2] = 255
images[idx,2,k1:k2,l1] = 255
images[idx,2,k1:k2,l2] = 255
images[idx,2,k1,l1:l2] = 255
images[idx,2,k2,l1:l2] = 255
self.dl_studio.display_tensor_as_image(
torchvision.utils.make_grid(images, normalize=True),
"see terminal for test results at i=%d" % i)
for label,prediction in zip(labels,predicted):
confusion_matrix[label][prediction] += 1
total += len(labels)
correct += [predicted[ele] == labels[ele] for ele in range(len(predicted))].count(True)
comp = [predicted[ele] == labels[ele] for ele in range(len(predicted))]
for j in range(self.dl_studio.batch_size):
label = labels[j]
class_correct[label] += comp[j]
class_total[label] += 1
print("\n")
for j in range(len(self.dataserver_train.class_labels)):
print('Prediction accuracy for %5s : %2d %%' % (
self.dataserver_train.class_labels[j], 100 * class_correct[j] / class_total[j]))
print("\n\n\nOverall accuracy of the network on the 1000 test images: %d %%" %
(100 * correct / float(total)))
print("\n\nDisplaying the confusion matrix:\n")
out_str = " "
for j in range(len(self.dataserver_train.class_labels)):
out_str += "%15s" % self.dataserver_train.class_labels[j]
print(out_str + "\n")
for i,label in enumerate(self.dataserver_train.class_labels):
out_percents = [100 * confusion_matrix[i,j] / float(class_total[i])
for j in range(len(self.dataserver_train.class_labels))]
out_percents = ["%.2f" % item.item() for item in out_percents]
out_str = "%12s: " % self.dataserver_train.class_labels[i]
for j in range(len(self.dataserver_train.class_labels)):
out_str += "%15s" % out_percents[j]
print(out_str)
########################################################################################
################## Start Definition of Inner Class SemanticSegmentation ##############
class SemanticSegmentation(nn.Module):
"""The purpose of this inner class is to be able to use the DLStudio module for
experiments with semantic segmentation. At its simplest level, the
purpose of semantic segmentation is to assign correct labels to the
different objects in a scene, while localizing them at the same time. At
a more sophisticated level, a system that carries out semantic
segmentation should also output a symbolic expression based on the objects
found in the image and their spatial relationships with one another.
The workhorse of this inner class is the mUnet network that is based
on the UNET network that was first proposed by Ronneberger, Fischer and
Brox in the paper "U-Net: Convolutional Networks for Biomedical Image
Segmentation". Their Unet extracts binary masks for the cell pixel blobs
of interest in biomedical images. The output of their Unet can
therefore be treated as a pixel-wise binary classifier at each pixel
position. The mUnet class, on the other hand, is intended for
segmenting out multiple objects simultaneously form an image. [A weaker
reason for "Multi" in the name of the class is that it uses skip
connections not only across the two arms of the "U", but also also along
the arms. The skip connections in the original Unet are only between the
two arms of the U. In mUnet, each object type is assigned a separate
channel in the output of the network.
This version of DLStudio also comes with a new dataset,
PurdueShapes5MultiObject, for experimenting with mUnet. Each image in
this dataset contains a random number of selections from five different
shapes, with the shapes being randomly scaled, oriented, and located in
each image. The five different shapes are: rectangle, triangle, disk,
oval, and star.
"""
def __init__(self, dl_studio, dataserver_train=None, dataserver_test=None, dataset_file_train=None, dataset_file_test=None):
super(DLStudio.SemanticSegmentation, self).__init__()
self.dl_studio = dl_studio
self.dataserver_train = dataserver_train
self.dataserver_test = dataserver_test
class PurdueShapes5MultiObjectDataset(torch.utils.data.Dataset):
"""
The very first thing to note is that the images in the dataset
PurdueShapes5MultiObjectDataset are of size 64x64. Each image has a
random number (up to five) of the objects drawn from the following five
shapes: rectangle, triangle, disk, oval, and star. Each shape is
randomized with respect to all its parameters, including those for its
scale and location in the image.
Each image in the dataset is represented by two data objects, one a list
and the other a dictionary. The list data objects consists of the
following items:
[R, G, B, mask_array, mask_val_to_bbox_map] ## (A)
and the other data object is a dictionary that is set to:
label_map = {'rectangle':50,
'triangle' :100,
'disk' :150,
'oval' :200,
'star' :250} ## (B)
Note that that second data object for each image is the same, as shown
above.
In the rest of this comment block, I'll explain in greater detail the
elements of the list in line (A) above.
R,G,B:
------
Each of these is a 4096-element array whose elements store the
corresponding color values at each of the 4096 pixels in a 64x64 image.
That is, R is a list of 4096 integers, each between 0 and 255, for the
value of the red component of the color at each pixel. Similarly, for G
and B.
mask_array:
----------
The fourth item in the list shown in line (A) above is for the mask which is
a numpy array of shape:
(5, 64, 64)
It is initialized by the command:
mask_array = np.zeros((5,64,64), dtype=np.uint8)
In essence, the mask_array consists of five planes, each of size 64x64.
Each plane of the mask array represents an object type according to the
following shape_index
shape_index = (label_map[shape] - 50) // 50
where the label_map is as shown in line (B) above. In other words, the
shape_index values for the different shapes are:
rectangle: 0
triangle: 1
disk: 2
oval: 3
star: 4
Therefore, the first layer (of index 0) of the mask is where the pixel
values of 50 are stored at all those pixels that belong to the rectangle
shapes. Similarly, the second mask layer (of index 1) is where the pixel
values of 100 are stored at all those pixel coordinates that belong to
the triangle shapes in an image; and so on.
It is in the manner described above that we define five different masks
for an image in the dataset. Each mask is for a different shape and the
pixel values at the nonzero pixels in each mask layer are keyed to the
shapes also.
A reader is likely to wonder as to the need for this redundancy in the
dataset representation of the shapes in each image. Such a reader is
likely to ask: Why can't we just use the binary values 1s and 0s in each
mask layer where the corresponding pixels are in the image? Setting
these mask values to 50, 100, etc., was done merely for convenience. I
went with the intuition that the learning needed for multi-object
segmentation would become easier if each shape was represented by a
different pixels value in the corresponding mask. So I went ahead
incorporated that in the dataset generation program itself.
The mask values for the shapes are not to be confused with the actual RGB
values of the pixels that belong to the shapes. The RGB values at the
pixels in a shape are randomly generated. Yes, all the pixels in a shape
instance in an image have the same RGB values (but that value has nothing
to do with the values given to the mask pixels for that shape).
mask_val_to_bbox_map:
--------------------
The fifth item in the list in line (A) above is a dictionary that tells us
what bounding-box rectangle to associate with each shape in the image. To
illustrate what this dictionary looks like, assume that an image contains
only one rectangle and only one disk, the dictionary in this case will look
like:
mask values to bbox mappings: {200: [],
250: [],
100: [],
50: [[56, 20, 63, 25]],
150: [[37, 41, 55, 59]]}
Should there happen to be two rectangles in the same image, the dictionary
would then be like:
mask values to bbox mappings: {200: [],
250: [],
100: [],
50: [[56, 20, 63, 25], [18, 16, 32, 36]],
150: [[37, 41, 55, 59]]}
Therefore, it is not a problem even if all the objects in an image are of
the same type. Remember, the object that are selected for an image are
shown randomly from the different shapes. By the way, an entry like '[56,
20, 63, 25]' for the bounding box means that the upper-left corner of the
BBox for the 'rectangle' shape is at (56,20) and the lower-right corner of
the same is at the pixel coordinates (63,25).
As far as the BBox quadruples are concerned, in the definition
[min_x,min_y,max_x,max_y]
note that x is the horizontal coordinate, increasing to the right on your
screen, and y is the vertical coordinate increasing downwards.
"""
def __init__(self, dl_studio, train_or_test, dataset_file, transform=None):
super(DLStudio.SemanticSegmentation.PurdueShapes5MultiObjectDataset, self).__init__()
if train_or_test == 'train' and dataset_file == "PurdueShapes5MultiObject-10000-train.gz":
if os.path.exists("torch_saved_PurdueShapes5MultiObject-10000_dataset.pt") and \
os.path.exists("torch_saved_PurdueShapes5MultiObject_label_map.pt"):
print("\nLoading training data from torch saved file")
self.dataset = torch.load("torch_saved_PurdueShapes5MultiObject-10000_dataset.pt")
self.label_map = torch.load("torch_saved_PurdueShapes5MultiObject_label_map.pt")
self.transform = transform
else:
print("""\n\n\nLooks like this is the first time you will be loading in\n"""
"""the dataset for this script. First time loading could take\n"""
"""up to 3 minutes. Any subsequent attempts will only take\n"""
"""a few seconds.\n\n\n""")
root_dir = dl_studio.dataroot
f = gzip.open(root_dir + dataset_file, 'rb')
dataset = f.read()
self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
torch.save(self.dataset, "torch_saved_PurdueShapes5MultiObject-10000_dataset.pt")
torch.save(self.label_map, "torch_saved_PurdueShapes5MultiObject_label_map.pt")
# reverse the key-value pairs in the label dictionary:
self.class_labels = dict(map(reversed, self.label_map.items()))
self.transform = transform
else:
root_dir = dl_studio.dataroot
f = gzip.open(root_dir + dataset_file, 'rb')
dataset = f.read()
if sys.version_info[0] == 3:
self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
else:
self.dataset, self.label_map = pickle.loads(dataset)
# self.dataset, self.label_map = pickle.loads(dataset, encoding='latin1')
# reverse the key-value pairs in the label dictionary:
self.class_labels = dict(map(reversed, self.label_map.items()))
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
r = np.array( self.dataset[idx][0] )
g = np.array( self.dataset[idx][1] )
b = np.array( self.dataset[idx][2] )
R,G,B = r.reshape(64,64), g.reshape(64,64), b.reshape(64,64)
im_tensor = torch.zeros(3,64,64, dtype=torch.float)
im_tensor[0,:,:] = torch.from_numpy(R)
im_tensor[1,:,:] = torch.from_numpy(G)
im_tensor[2,:,:] = torch.from_numpy(B)
# mask_array = self.dataset[idx][3]
mask_array = np.array(self.dataset[idx][3])
mask_tensor = torch.from_numpy(mask_array)
mask_val_to_bbox_map = self.dataset[idx][4]
max_bboxes_per_entry_in_map = max([ len(mask_val_to_bbox_map[key]) for key in mask_val_to_bbox_map ])
## The first arg 5 is for the number of bboxes we are going to need. If all the
## shapes are exactly the same, you are going to need five different bbox'es.
## The second arg is the index reserved for each shape in a single bbox
bbox_tensor = torch.zeros(5,5,4, dtype=torch.float)
for bbox_idx in range(max_bboxes_per_entry_in_map):
for key in mask_val_to_bbox_map:
if len(mask_val_to_bbox_map[key]) == 1:
if bbox_idx == 0:
bbox_tensor[bbox_idx,key,:] = torch.from_numpy(np.array(mask_val_to_bbox_map[key][bbox_idx]))
elif len(mask_val_to_bbox_map[key]) > 1 and bbox_idx < len(mask_val_to_bbox_map[key]):
bbox_tensor[bbox_idx,key,:] = torch.from_numpy(np.array(mask_val_to_bbox_map[key][bbox_idx]))
sample = {'image' : im_tensor,
'mask_tensor' : mask_tensor,
'bbox_tensor' : bbox_tensor }
return sample
def load_PurdueShapes5MultiObject_dataset(self, dataserver_train, dataserver_test ):
# transform = tvt.Compose([tvt.ToTensor(),
# tvt.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
self.train_dataloader = torch.utils.data.DataLoader(dataserver_train,
batch_size=self.dl_studio.batch_size,shuffle=True, num_workers=4)
self.test_dataloader = torch.utils.data.DataLoader(dataserver_test,
batch_size=self.dl_studio.batch_size,shuffle=False, num_workers=4)
class SkipBlockDN(nn.Module):
"""
Inner class: SemanticSegmentation
This class for the skip connections in the downward leg of the "U"
"""
def __init__(self, in_ch, out_ch, downsample=False, skip_connections=True):
super(DLStudio.SemanticSegmentation.SkipBlockDN, self).__init__()
self.downsample = downsample
self.skip_connections = skip_connections
self.in_ch = in_ch
self.out_ch = out_ch
self.convo1 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
self.convo2 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(out_ch)
self.bn2 = nn.BatchNorm2d(out_ch)
if downsample:
self.downsampler = nn.Conv2d(in_ch, out_ch, 1, stride=2)
def forward(self, x):
identity = x
out = self.convo1(x)
out = self.bn1(out)
out = torch.nn.functional.relu(out)
if self.in_ch == self.out_ch:
out = self.convo2(out)
out = self.bn2(out)
out = torch.nn.functional.relu(out)
if self.downsample:
out = self.downsampler(out)
identity = self.downsampler(identity)
if self.skip_connections:
if self.in_ch == self.out_ch:
out += identity
else:
out[:,:self.in_ch,:,:] += identity
out[:,self.in_ch:,:,:] += identity
return out
class SkipBlockUP(nn.Module):
"""
This class is for the skip connections in the upward leg of the "U"
"""
def __init__(self, in_ch, out_ch, upsample=False, skip_connections=True):
super(DLStudio.SemanticSegmentation.SkipBlockUP, self).__init__()
self.upsample = upsample
self.skip_connections = skip_connections
self.in_ch = in_ch
self.out_ch = out_ch
self.convoT1 = nn.ConvTranspose2d(in_ch, out_ch, 3, padding=1)
self.convoT2 = nn.ConvTranspose2d(in_ch, out_ch, 3, padding=1)
self.bn1 = nn.BatchNorm2d(out_ch)
self.bn2 = nn.BatchNorm2d(out_ch)
if upsample:
self.upsampler = nn.ConvTranspose2d(in_ch, out_ch, 1, stride=2, dilation=2, output_padding=1, padding=0)
def forward(self, x):
identity = x
out = self.convoT1(x)
out = self.bn1(out)
out = torch.nn.functional.relu(out)
if self.in_ch == self.out_ch:
out = self.convoT2(out)
out = self.bn2(out)
out = torch.nn.functional.relu(out)
if self.upsample:
out = self.upsampler(out)
identity = self.upsampler(identity)
if self.skip_connections:
if self.in_ch == self.out_ch:
out += identity
else:
out += identity[:,self.out_ch:,:,:]
return out
class mUnet(nn.Module):
"""
This network is called mUnet because it is intended for segmenting
out multiple objects simultaneously form an image. [A weaker reason for
"Multi" in the name of the class is that it uses skip connections not
only across the two arms of the "U", but also also along the arms.] The
classic UNET was first proposed by Ronneberger, Fischer and Brox in the
paper "U-Net: Convolutional Networks for Biomedical Image Segmentation".
Their UNET extracts binary masks for the cell pixel blobs of interest
in biomedical images. The output of their UNET therefore can therefore
be treated as a pixel-wise binary classifier at each pixel position.
The mUnet presented here, on the other hand, is meant specifically
for simultaneously identifying and localizing multiple objects in a
given image. Each object type is assigned a separate channel in the
output of the network.
I have created a dataset, PurdueShapes5MultiObject, for experimenting
with mUnet. Each image in this dataset contains a random number of
selections from five different shapes, with the shapes being randomly
scaled, oriented, and located in each image. The five different shapes
are: rectangle, triangle, disk, oval, and star.
"""
def __init__(self, skip_connections=True, depth=16):
super(DLStudio.SemanticSegmentation.mUnet, self).__init__()
self.depth = depth // 2
self.conv_in = nn.Conv2d(3, 64, 3, padding=1)
# self.pool = nn.MaxPool2d(2, 2)
## For the DN arm of the U:
self.bn1DN = nn.BatchNorm2d(64)
self.bn2DN = nn.BatchNorm2d(128)
self.skip64DN_arr = nn.ModuleList()
for i in range(self.depth):
self.skip64DN_arr.append(DLStudio.SemanticSegmentation.SkipBlockDN(64, 64,
skip_connections=skip_connections))
self.skip64dsDN = DLStudio.SemanticSegmentation.SkipBlockDN(64, 64,
downsample=True, skip_connections=skip_connections)
self.skip64to128DN = DLStudio.SemanticSegmentation.SkipBlockDN(64, 128,
skip_connections=skip_connections )
self.skip128DN_arr = nn.ModuleList()
for i in range(self.depth):
self.skip128DN_arr.append(DLStudio.SemanticSegmentation.SkipBlockDN(128, 128,
skip_connections=skip_connections))
self.skip128dsDN = DLStudio.SemanticSegmentation.SkipBlockDN(128,128,
downsample=True, skip_connections=skip_connections)
## For the UP arm of the U:
self.bn1UP = nn.BatchNorm2d(128)
self.bn2UP = nn.BatchNorm2d(64)
self.skip64UP_arr = nn.ModuleList()
for i in range(self.depth):
self.skip64UP_arr.append(DLStudio.SemanticSegmentation.SkipBlockUP(64, 64,
skip_connections=skip_connections))
self.skip64usUP = DLStudio.SemanticSegmentation.SkipBlockUP(64, 64,
upsample=True, skip_connections=skip_connections)
self.skip128to64UP = DLStudio.SemanticSegmentation.SkipBlockUP(128, 64,
skip_connections=skip_connections )
self.skip128UP_arr = nn.ModuleList()
for i in range(self.depth):
self.skip128UP_arr.append(DLStudio.SemanticSegmentation.SkipBlockUP(128, 128,
skip_connections=skip_connections))
self.skip128usUP = DLStudio.SemanticSegmentation.SkipBlockUP(128,128,
upsample=True, skip_connections=skip_connections)
self.conv_out = nn.ConvTranspose2d(64, 5, 3, stride=2,dilation=2,output_padding=1,padding=2)
def forward(self, x):
## Going down to the bottom of the U:
# x = self.pool(torch.nn.functional.relu(self.conv_in(x)))
x = nn.MaxPool2d(2,2)(torch.nn.functional.relu(self.conv_in(x)))
for i,skip64 in enumerate(self.skip64DN_arr[:self.depth//4]):
x = skip64(x)
num_channels_to_save1 = x.shape[1] // 2
save_for_upside_1 = x[:,:num_channels_to_save1,:,:].clone()
x = self.skip64dsDN(x)
for i,skip64 in enumerate(self.skip64DN_arr[self.depth//4:]):
x = skip64(x)
x = self.bn1DN(x)
num_channels_to_save2 = x.shape[1] // 2
save_for_upside_2 = x[:,:num_channels_to_save2,:,:].clone()
x = self.skip64to128DN(x)
for i,skip128 in enumerate(self.skip128DN_arr[:self.depth//4]):
x = skip128(x)
x = self.bn2DN(x)
num_channels_to_save3 = x.shape[1] // 2
save_for_upside_3 = x[:,:num_channels_to_save3,:,:].clone()
for i,skip128 in enumerate(self.skip128DN_arr[self.depth//4:]):
x = skip128(x)
x = self.skip128dsDN(x)
## Coming up from the bottom of U on the other side:
x = self.skip128usUP(x)
for i,skip128 in enumerate(self.skip128UP_arr[:self.depth//4]):
x = skip128(x)
x[:,:num_channels_to_save3,:,:] = save_for_upside_3
x = self.bn1UP(x)
for i,skip128 in enumerate(self.skip128UP_arr[:self.depth//4]):
x = skip128(x)
x = self.skip128to64UP(x)
for i,skip64 in enumerate(self.skip64UP_arr[self.depth//4:]):
x = skip64(x)
x[:,:num_channels_to_save2,:,:] = save_for_upside_2
x = self.bn2UP(x)
x = self.skip64usUP(x)
for i,skip64 in enumerate(self.skip64UP_arr[:self.depth//4]):
x = skip64(x)
x[:,:num_channels_to_save1,:,:] = save_for_upside_1
x = self.conv_out(x)
return x
class SegmentationLoss(nn.Module):
"""
I wrote this class before I switched to MSE loss. I am leaving it here
in case I need to get back to it in the future.
"""
def __init__(self, batch_size):
super(DLStudio.SemanticSegmentation.SegmentationLoss, self).__init__()
self.batch_size = batch_size
def forward(self, output, mask_tensor):
composite_loss = torch.zeros(1,self.batch_size)
mask_based_loss = torch.zeros(1,5)
for idx in range(self.batch_size):
outputh = output[idx,0,:,:]
for mask_layer_idx in range(mask_tensor.shape[0]):
mask = mask_tensor[idx,mask_layer_idx,:,:]
element_wise = (outputh - mask)**2
mask_based_loss[0,mask_layer_idx] = torch.mean(element_wise)
composite_loss[0,idx] = torch.sum(mask_based_loss)
return torch.sum(composite_loss) / self.batch_size
def run_code_for_training_for_semantic_segmentation(self, net):
filename_for_out1 = "performance_numbers_" + str(self.dl_studio.epochs) + ".txt"
FILE1 = open(filename_for_out1, 'w')
net = copy.deepcopy(net)
net = net.to(self.dl_studio.device)
criterion1 = nn.MSELoss()
optimizer = optim.SGD(net.parameters(),
lr=self.dl_studio.learning_rate, momentum=self.dl_studio.momentum)
for epoch in range(self.dl_studio.epochs):
print("")
running_loss_segmentation = 0.0
for i, data in enumerate(self.train_dataloader):
im_tensor,mask_tensor,bbox_tensor =data['image'],data['mask_tensor'],data['bbox_tensor']
im_tensor = im_tensor.to(self.dl_studio.device)
mask_tensor = mask_tensor.type(torch.FloatTensor)
mask_tensor = mask_tensor.to(self.dl_studio.device)
bbox_tensor = bbox_tensor.to(self.dl_studio.device)
optimizer.zero_grad()
output = net(im_tensor)
segmentation_loss = criterion1(output, mask_tensor)
segmentation_loss.backward()
optimizer.step()
running_loss_segmentation += segmentation_loss.item()
# if (epoch==0 and (i==99 or i==499)) or (i%1000==999):
if i%1000==999:
avg_loss_segmentation = running_loss_segmentation / float(1000)
print("[epoch:%d,batch:%5d] MSE loss: %.3f" % (epoch+1,i+1,avg_loss_segmentation))
FILE1.write("%.3f\n" % avg_loss_segmentation)
FILE1.flush()
running_loss_segmentation = 0.0
print("\nFinished Training\n")
self.save_model(net)
def save_model(self, model):
'''
Save the trained model to a disk file
'''
torch.save(model.state_dict(), self.dl_studio.path_saved_model)
def run_code_for_testing_semantic_segmentation(self, net):
net.load_state_dict(torch.load(self.dl_studio.path_saved_model))
with torch.no_grad():
for i, data in enumerate(self.test_dataloader):
im_tensor,mask_tensor,bbox_tensor =data['image'],data['mask_tensor'],data['bbox_tensor']
if self.dl_studio.debug_test and i % 50 == 0:
print("\n\n\n\nShowing output for test batch %d: " % (i+1))
outputs = net(im_tensor)
## In the statement below: 1st arg for batch items, 2nd for channels,
## 3rd and 4th for image size
output_bw_tensor = torch.zeros(4,1,64,64, dtype=float)
for image_idx in range(self.dl_studio.batch_size):
for layer_idx in range(5):
for m in range(64):
for n in range(64):
output_bw_tensor[image_idx,0,m,n] = \
torch.max( outputs[image_idx,:,m,n] )
# display_tensor = torch.zeros(8,3,64,64, dtype=float)
display_tensor = torch.zeros(28,3,64,64, dtype=float)
for idx in range(self.dl_studio.batch_size):
for bbox_idx in range(5): ## 5 for the five different types of obj
bb_tensor = bbox_tensor[idx,bbox_idx]
for k in range(5):
i1 = int(bb_tensor[k][1])
i2 = int(bb_tensor[k][3])
j1 = int(bb_tensor[k][0])
j2 = int(bb_tensor[k][2])
output_bw_tensor[idx,0,i1:i2,j1] = 255
output_bw_tensor[idx,0,i1:i2,j2] = 255
output_bw_tensor[idx,0,i1,j1:j2] = 255
output_bw_tensor[idx,0,i2,j1:j2] = 255
im_tensor[idx,0,i1:i2,j1] = 255
im_tensor[idx,0,i1:i2,j2] = 255
im_tensor[idx,0,i1,j1:j2] = 255
im_tensor[idx,0,i2,j1:j2] = 255
display_tensor[:4,:,:,:] = output_bw_tensor
display_tensor[4:8,:,:,:] = im_tensor
for batch_im_idx in range(self.dl_studio.batch_size):
for mask_layer_idx in range(5):
for i in range(64):
for j in range(64):
if mask_layer_idx == 0:
if 25 < outputs[batch_im_idx,mask_layer_idx,i,j] < 85:
outputs[batch_im_idx,mask_layer_idx,i,j] = 255
else:
outputs[batch_im_idx,mask_layer_idx,i,j] = 50
elif mask_layer_idx == 1:
if 65 < outputs[batch_im_idx,mask_layer_idx,i,j] < 135:
outputs[batch_im_idx,mask_layer_idx,i,j] = 255
else:
outputs[batch_im_idx,mask_layer_idx,i,j] = 50
elif mask_layer_idx == 2:
if 115 < outputs[batch_im_idx,mask_layer_idx,i,j] < 185:
outputs[batch_im_idx,mask_layer_idx,i,j] = 255
else:
outputs[batch_im_idx,mask_layer_idx,i,j] = 50
elif mask_layer_idx == 3:
if 165 < outputs[batch_im_idx,mask_layer_idx,i,j] < 230:
outputs[batch_im_idx,mask_layer_idx,i,j] = 255
else:
outputs[batch_im_idx,mask_layer_idx,i,j] = 50
elif mask_layer_idx == 4:
if outputs[batch_im_idx,mask_layer_idx,i,j] > 210:
outputs[batch_im_idx,mask_layer_idx,i,j] = 255
else:
outputs[batch_im_idx,mask_layer_idx,i,j] = 50
display_tensor[8+4*mask_layer_idx+batch_im_idx,:,:,:]= \
outputs[batch_im_idx,mask_layer_idx,:,:]
self.dl_studio.display_tensor_as_image(
torchvision.utils.make_grid(display_tensor, nrow=4, normalize=True, padding=2, pad_value=10))
# plt.imshow(torchvision.utils.make_grid(display_tensor, nrow=4, normalize=True,
# padding=4, pad_value=255))
########################################################################################
################## Start Definition of Inner Class TextClassification ################
class TextClassification(nn.Module):
"""
The purpose of this inner class is to be able to use the DLStudio module for simple
experiments in text classification. Consider, for example, the problem of automatic
classification of variable-length user feedback: you want to create a neural network
that can label an uploaded product review of arbitrary length as positive or negative.
One way to solve this problem is with a recurrent neural network in which you use a
hidden state for characterizing a variable-length product review with a fixed-length
state vector. This inner class allows you to carry out such experiments.
"""
def __init__(self, dl_studio, dataserver_train=None, dataserver_test=None, dataset_file_train=None, dataset_file_test=None):
super(DLStudio.TextClassification, self).__init__()
self.dl_studio = dl_studio
self.dataserver_train = dataserver_train
self.dataserver_test = dataserver_test
class SentimentAnalysisDataset(torch.utils.data.Dataset):
"""
The sentiment analysis datasets that I have made available were extracted from
an archive of user feedback comments as made available by Amazon for the year
2007. The original archive contains user feedback on 25 product categories.
For each product category, there are two files named 'positive.reviews' and
'negative.reviews', with each file containing 1000 reviews. I believe that
characterizing the reviews as 'positive' or 'negative' was carried out by
human annotators. Regardless, the reviews in these two files can be used to
train a neural network whose purpose would be to automatically characterize
a product as being positive or negative.
I have extracted the following datasets extracted from the Amazon archive:
sentiment_dataset_train_200.tar.gz vocab_size = 43,285
sentiment_dataset_test_200.tar.gz
sentiment_dataset_train_40.tar.gz vocab_size = 17,001
sentiment_dataset_test_40.tar.gz
sentiment_dataset_train_3.tar.gz vocab_size = 3,402
sentiment_dataset_test_3.tar.gz
The integer in the name of each dataset is the number of reviews collected
from the 'positive.reviews' and the 'negative.reviews' files for each product
category. Therefore, the dataset with 200 in its name has a total of 400
reviews for each product category.
As to why I am presenting these three different datasets, note that, as shown
above, the size of the vocabulary depends on the number of reviews selected
and the size of the vocabulary has a strong bearing on how long it takes to
train an algorithm for text classification. For one simple reason for that:
the size of the one-hot representation for the words equals the size of the
vocabulary. Therefore, the one-hot representation for the words for the
dataset with 200 in its name will be a one-axis tensor of size 43,285.
For a purely feedforward network, it is not a big deal for the input tensors
to be size Nx43285 where N is the number of words in a review. And even for
RNNs with simple feedback, that does not slow things down. However, when
using GRUs, it's an entirely different matter if you are tying to run your
experiments on, say, a laptop with a Quadro GPU. Hence the reason for providing
the datasets with 200 and 40 reviews. The dataset with just 3 reviews is for
debugging your code.
"""
def __init__(self, dl_studio, train_or_test, dataset_file):
super(DLStudio.TextClassification.SentimentAnalysisDataset, self).__init__()
self.train_or_test = train_or_test
root_dir = dl_studio.dataroot
f = gzip.open(root_dir + dataset_file, 'rb')
dataset = f.read()
# if train_or_test is 'train':
if train_or_test == 'train':
if sys.version_info[0] == 3:
self.positive_reviews_train, self.negative_reviews_train, self.vocab = pickle.loads(dataset, encoding='latin1')
else:
self.positive_reviews_train, self.negative_reviews_train, self.vocab = pickle.loads(dataset)
self.categories = sorted(list(self.positive_reviews_train.keys()))
self.category_sizes_train_pos = {category : len(self.positive_reviews_train[category]) for category in self.categories}
self.category_sizes_train_neg = {category : len(self.negative_reviews_train[category]) for category in self.categories}
self.indexed_dataset_train = []
for category in self.positive_reviews_train:
for review in self.positive_reviews_train[category]:
self.indexed_dataset_train.append([review, category, 1])
for category in self.negative_reviews_train:
for review in self.negative_reviews_train[category]:
self.indexed_dataset_train.append([review, category, 0])
random.shuffle(self.indexed_dataset_train)
# elif train_or_test is 'test':
elif train_or_test == 'test':
if sys.version_info[0] == 3:
self.positive_reviews_test, self.negative_reviews_test, self.vocab = pickle.loads(dataset, encoding='latin1')
else:
self.positive_reviews_test, self.negative_reviews_test, self.vocab = pickle.loads(dataset)
self.vocab = sorted(self.vocab)
self.categories = sorted(list(self.positive_reviews_test.keys()))
self.category_sizes_test_pos = {category : len(self.positive_reviews_test[category]) for category in self.categories}
self.category_sizes_test_neg = {category : len(self.negative_reviews_test[category]) for category in self.categories}
self.indexed_dataset_test = []
for category in self.positive_reviews_test:
for review in self.positive_reviews_test[category]:
self.indexed_dataset_test.append([review, category, 1])
for category in self.negative_reviews_test:
for review in self.negative_reviews_test[category]:
self.indexed_dataset_test.append([review, category, 0])
random.shuffle(self.indexed_dataset_test)
def get_vocab_size(self):
return len(self.vocab)
def one_hotvec_for_word(self, word):
word_index = self.vocab.index(word)
hotvec = torch.zeros(1, len(self.vocab))
hotvec[0, word_index] = 1
return hotvec
def review_to_tensor(self, review):
review_tensor = torch.zeros(len(review), len(self.vocab))
for i,word in enumerate(review):
review_tensor[i,:] = self.one_hotvec_for_word(word)
return review_tensor
def sentiment_to_tensor(self, sentiment):
"""
Sentiment is ordinarily just a binary valued thing. It is 0 for negative
sentiment and 1 for positive sentiment. We need to pack this value in a
two-element tensor.
"""
sentiment_tensor = torch.zeros(2)
# if sentiment is 1:
if sentiment == 1:
sentiment_tensor[1] = 1
# elif sentiment is 0:
elif sentiment == 0:
sentiment_tensor[0] = 1
sentiment_tensor = sentiment_tensor.type(torch.long)
return sentiment_tensor
def __len__(self):
# if self.train_or_test is 'train':
if self.train_or_test == 'train':
return len(self.indexed_dataset_train)
# elif self.train_or_test is 'test':
elif self.train_or_test == 'test':
return len(self.indexed_dataset_test)
def __getitem__(self, idx):
# sample = self.indexed_dataset_train[idx] if self.train_or_test is 'train' else self.indexed_dataset_test[idx]
sample = self.indexed_dataset_train[idx] if self.train_or_test == 'train' else self.indexed_dataset_test[idx]
review = sample[0]
review_category = sample[1]
review_sentiment = sample[2]
review_sentiment = self.sentiment_to_tensor(review_sentiment)
review_tensor = self.review_to_tensor(review)
category_index = self.categories.index(review_category)
sample = {'review' : review_tensor,
'category' : category_index, # should be converted to tensor, but not yet used
'sentiment' : review_sentiment }
return sample
def load_SentimentAnalysisDataset(self, dataserver_train, dataserver_test ):
self.train_dataloader = torch.utils.data.DataLoader(dataserver_train,
batch_size=self.dl_studio.batch_size,shuffle=True, num_workers=1)
self.test_dataloader = torch.utils.data.DataLoader(dataserver_test,
batch_size=self.dl_studio.batch_size,shuffle=False, num_workers=1)
class TEXTnet(nn.Module):
"""
This network is meant for semantic classification of variable-length sentiment
data. Based on my limited testing, the performance of this network is very
poor because it has no protection against vanishing gradients when used in an
RNN.
Location: Inner class TextClassification
"""
def __init__(self, input_size, hidden_size, output_size):
super(DLStudio.TextClassification.TEXTnet, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.combined_to_hidden = nn.Linear(input_size + hidden_size, hidden_size)
self.combined_to_middle = nn.Linear(input_size + hidden_size, 100)
self.middle_to_out = nn.Linear(100, output_size)
self.logsoftmax = nn.LogSoftmax(dim=1)
self.dropout = nn.Dropout(p=0.1)
def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
hidden = self.combined_to_hidden(combined)
out = self.combined_to_middle(combined)
out = torch.nn.functional.relu(out)
out = self.dropout(out)
out = self.middle_to_out(out)
out = self.logsoftmax(out)
return out,hidden
class TEXTnetOrder2(nn.Module):
"""
In this variant of the TEXTnet network, the value of hidden as used at
each time step also includes its value at the previous time step. This
fact, not directly apparent by the definition of the class shown below,
is made possible by the last parameter, cell, in the header of forward().
All you can see here, at the end of forward(), is that the value of cell
goes through a linear layer and through a sigmoid nonlinearity. By the way,
since the sigmoid saturates at 0 and 1, it can act like a switch. Later
when I use this class in the training function, you will see the cell
values being used in such a manner that the hidden state at each time
step is mixed with the hidden state at the previous time step.
Location: Inner class TextClassification
"""
def __init__(self, input_size, hidden_size, output_size, dls):
super(DLStudio.TextClassification.TEXTnetOrder2, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.combined_to_hidden = nn.Linear(input_size + 2*hidden_size, hidden_size)
self.combined_to_middle = nn.Linear(input_size + 2*hidden_size, 100)
self.middle_to_out = nn.Linear(100, output_size)
self.logsoftmax = nn.LogSoftmax(dim=1)
self.dropout = nn.Dropout(p=0.1)
# for the cell
self.linear_for_cell = nn.Linear(hidden_size, hidden_size)
def forward(self, input, hidden, cell):
combined = torch.cat((input, hidden, cell), 1)
hidden = self.combined_to_hidden(combined)
out = self.combined_to_middle(combined)
out = torch.nn.functional.relu(out)
out = self.dropout(out)
out = self.middle_to_out(out)
out = self.logsoftmax(out)
hidden_clone = hidden.clone()
# cell = torch.tanh(self.linear_for_cell(hidden_clone))
cell = torch.sigmoid(self.linear_for_cell(hidden_clone))
return out,hidden,cell
def initialize_cell(self, batch_size):
weight = next(self.linear_for_cell.parameters()).data
cell = weight.new(1, self.hidden_size).zero_()
return cell
class GRUnet(nn.Module):
"""
Source: https://blog.floydhub.com/gru-with-pytorch/
with the only modification that the final output of forward() is now
routed through LogSoftmax activation.
Location: Inner class TextClassification
"""
def __init__(self, input_size, hidden_size, output_size, n_layers, drop_prob=0.2):
super(DLStudio.TextClassification.GRUnet, self).__init__()
self.hidden_size = hidden_size
self.n_layers = n_layers
self.gru = nn.GRU(input_size, hidden_size, n_layers, batch_first=True, dropout=drop_prob)
self.fc = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, x, h):
out, h = self.gru(x, h)
out = self.fc(self.relu(out[:,-1]))
out = self.logsoftmax(out)
return out, h
def init_hidden(self, batch_size):
weight = next(self.parameters()).data
hidden = weight.new(self.n_layers, batch_size, self.hidden_size).zero_()
return hidden
def save_model(self, model):
"Save the trained model to a disk file"
torch.save(model.state_dict(), self.dl_studio.path_saved_model)
def run_code_for_training_with_TEXTnet_no_gru(self, net, hidden_size):
filename_for_out = "performance_numbers_" + str(self.dl_studio.epochs) + ".txt"
FILE = open(filename_for_out, 'w')
net = copy.deepcopy(net)
net = net.to(self.dl_studio.device)
## Note that the TEXTnet and TEXTnetOrder2 both produce LogSoftmax output:
criterion = nn.NLLLoss()
# criterion = nn.MSELoss()
# criterion = nn.CrossEntropyLoss()
accum_times = []
optimizer = optim.SGD(net.parameters(),
lr=self.dl_studio.learning_rate, momentum=self.dl_studio.momentum)
start_time = time.clock()
for epoch in range(self.dl_studio.epochs):
print("")
running_loss = 0.0
for i, data in enumerate(self.train_dataloader):
hidden = torch.zeros(1, hidden_size)
hidden = hidden.to(self.dl_studio.device)
review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
review_tensor = review_tensor.to(self.dl_studio.device)
sentiment = sentiment.to(self.dl_studio.device)
optimizer.zero_grad()
input = torch.zeros(1,review_tensor.shape[2])
input = input.to(self.dl_studio.device)
for k in range(review_tensor.shape[1]):
input[0,:] = review_tensor[0,k]
output, hidden = net(input, hidden)
loss = criterion(output, torch.argmax(sentiment,1))
running_loss += loss.item()
loss.backward(retain_graph=True)
optimizer.step()
if i % 100 == 99:
avg_loss = running_loss / float(100)
current_time = time.clock()
time_elapsed = current_time-start_time
print("[epoch:%d iter:%4d elapsed_time: %4d secs] loss: %.3f" % (epoch+1,i+1, time_elapsed,avg_loss))
accum_times.append(current_time-start_time)
FILE.write("%.3f\n" % avg_loss)
FILE.flush()
running_loss = 0.0
print("\nFinished Training\n")
self.save_model(net)
def run_code_for_training_with_TEXTnetOrder2_no_gru(self, net, hidden_size):
filename_for_out = "performance_numbers_" + str(self.dl_studio.epochs) + ".txt"
FILE = open(filename_for_out, 'w')
net = copy.deepcopy(net)
net = net.to(self.dl_studio.device)
## Note that the TEXTnet and TEXTnetOrder2 both produce LogSoftmax output:
criterion = nn.NLLLoss()
accum_times = []
optimizer = optim.SGD(net.parameters(),
lr=self.dl_studio.learning_rate, momentum=self.dl_studio.momentum)
start_time = time.clock()
for epoch in range(self.dl_studio.epochs):
print("")
running_loss = 0.0
for i, data in enumerate(self.train_dataloader):
cell_prev = net.initialize_cell(1).to(self.dl_studio.device)
cell_prev_2_prev = net.initialize_cell(1).to(self.dl_studio.device)
hidden = torch.zeros(1, hidden_size)
hidden = hidden.to(self.dl_studio.device)
review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
review_tensor = review_tensor.to(self.dl_studio.device)
sentiment = sentiment.to(self.dl_studio.device)
optimizer.zero_grad()
input = torch.zeros(1,review_tensor.shape[2])
input = input.to(self.dl_studio.device)
for k in range(review_tensor.shape[1]):
input[0,:] = review_tensor[0,k]
output, hidden, cell = net(input, hidden, cell_prev_2_prev)
if k == 0:
cell_prev = cell
else:
cell_prev_2_prev = cell_prev
cell_prev = cell
loss = criterion(output, torch.argmax(sentiment,1))
running_loss += loss.item()
loss.backward()
optimizer.step()
if i % 100 == 99:
avg_loss = running_loss / float(100)
current_time = time.clock()
time_elapsed = current_time-start_time
print("[epoch:%d iter:%4d elapsed_time: %4d secs] loss: %.3f" % (epoch+1,i+1, time_elapsed,avg_loss))
accum_times.append(current_time-start_time)
FILE.write("%.3f\n" % avg_loss)
FILE.flush()
running_loss = 0.0
print("\nFinished Training\n")
self.save_model(net)
def run_code_for_training_for_text_classification_with_gru(self, net, hidden_size):
filename_for_out = "performance_numbers_" + str(self.dl_studio.epochs) + ".txt"
FILE = open(filename_for_out, 'w')
net = copy.deepcopy(net)
net = net.to(self.dl_studio.device)
## Note that the GREnet now produces the LogSoftmax output:
criterion = nn.NLLLoss()
# criterion = nn.MSELoss()
# criterion = nn.CrossEntropyLoss()
accum_times = []
optimizer = optim.SGD(net.parameters(),
lr=self.dl_studio.learning_rate, momentum=self.dl_studio.momentum)
for epoch in range(self.dl_studio.epochs):
print("")
running_loss = 0.0
start_time = time.clock()
for i, data in enumerate(self.train_dataloader):
review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
review_tensor = review_tensor.to(self.dl_studio.device)
sentiment = sentiment.to(self.dl_studio.device)
## The following type conversion needed for MSELoss:
##sentiment = sentiment.float()
optimizer.zero_grad()
hidden = net.init_hidden(1).to(self.dl_studio.device)
for k in range(review_tensor.shape[1]):
output, hidden = net(torch.unsqueeze(torch.unsqueeze(review_tensor[0,k],0),0), hidden)
## If using NLLLoss, CrossEntropyLoss
loss = criterion(output, torch.argmax(sentiment, 1))
## If using MSELoss:
## loss = criterion(output, sentiment)
running_loss += loss.item()
loss.backward()
optimizer.step()
if i % 100 == 99:
avg_loss = running_loss / float(100)
current_time = time.clock()
time_elapsed = current_time-start_time
print("[epoch:%d iter:%4d elapsed_time:%4d secs] loss: %.3f" % (epoch+1,i+1, time_elapsed,avg_loss))
accum_times.append(current_time-start_time)
FILE.write("%.3f\n" % avg_loss)
FILE.flush()
running_loss = 0.0
print("Total Training Time: {}".format(str(sum(accum_times))))
print("\nFinished Training\n")
self.save_model(net)
def run_code_for_testing_with_TEXTnet_no_gru(self, net, hidden_size):
net.load_state_dict(torch.load(self.dl_studio.path_saved_model))
classification_accuracy = 0.0
negative_total = 0
positive_total = 0
confusion_matrix = torch.zeros(2,2)
with torch.no_grad():
for i, data in enumerate(self.test_dataloader):
review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
input = torch.zeros(1,review_tensor.shape[2])
hidden = torch.zeros(1, hidden_size)
for k in range(review_tensor.shape[1]):
input[0,:] = review_tensor[0,k]
output, hidden = net(input, hidden)
predicted_idx = torch.argmax(output).item()
gt_idx = torch.argmax(sentiment).item()
if i % 100 == 99:
print(" [i=%4d] predicted_label=%d gt_label=%d" % (i+1, predicted_idx,gt_idx))
if predicted_idx == gt_idx:
classification_accuracy += 1
# if gt_idx is 0:
if gt_idx == 0:
negative_total += 1
# elif gt_idx is 1:
elif gt_idx == 1:
positive_total += 1
confusion_matrix[gt_idx,predicted_idx] += 1
out_percent = np.zeros((2,2), dtype='float')
out_percent[0,0] = "%.3f" % (100 * confusion_matrix[0,0] / float(negative_total))
out_percent[0,1] = "%.3f" % (100 * confusion_matrix[0,1] / float(negative_total))
out_percent[1,0] = "%.3f" % (100 * confusion_matrix[1,0] / float(positive_total))
out_percent[1,1] = "%.3f" % (100 * confusion_matrix[1,1] / float(positive_total))
print("\n\nNumber of positive reviews tested: %d" % positive_total)
print("\n\nNumber of negative reviews tested: %d" % negative_total)
print("\n\nDisplaying the confusion matrix:\n")
out_str = " "
out_str += "%18s %18s" % ('predicted negative', 'predicted positive')
print(out_str + "\n")
for i,label in enumerate(['true negative', 'true positive']):
out_str = "%12s: " % label
for j in range(2):
out_str += "%18s" % out_percent[i,j]
print(out_str)
def run_code_for_testing_with_TEXTnetOrder2_no_gru(self, net, hidden_size):
net.load_state_dict(torch.load(self.dl_studio.path_saved_model))
classification_accuracy = 0.0
negative_total = 0
positive_total = 0
confusion_matrix = torch.zeros(2,2)
with torch.no_grad():
for i, data in enumerate(self.test_dataloader):
cell_prev = net.initialize_cell(1)
cell_prev_2_prev = net.initialize_cell(1)
review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
input = torch.zeros(1,review_tensor.shape[2])
hidden = torch.zeros(1, hidden_size)
for k in range(review_tensor.shape[1]):
input[0,:] = review_tensor[0,k]
output, hidden, cell = net(input, hidden, cell_prev_2_prev)
if k == 0:
cell_prev = cell
else:
cell_prev_2_prev = cell_prev
cell_prev = cell
predicted_idx = torch.argmax(output).item()
gt_idx = torch.argmax(sentiment).item()
if i % 100 == 99:
print(" [i=%4d] predicted_label=%d gt_label=%d" % (i+1, predicted_idx,gt_idx))
if predicted_idx == gt_idx:
classification_accuracy += 1
# if gt_idx is 0:
if gt_idx == 0:
negative_total += 1
# elif gt_idx is 1:
elif gt_idx == 1:
positive_total += 1
confusion_matrix[gt_idx,predicted_idx] += 1
out_percent = np.zeros((2,2), dtype='float')
out_percent[0,0] = "%.3f" % (100 * confusion_matrix[0,0] / float(negative_total))
out_percent[0,1] = "%.3f" % (100 * confusion_matrix[0,1] / float(negative_total))
out_percent[1,0] = "%.3f" % (100 * confusion_matrix[1,0] / float(positive_total))
out_percent[1,1] = "%.3f" % (100 * confusion_matrix[1,1] / float(positive_total))
print("\n\nNumber of positive reviews tested: %d" % positive_total)
print("\n\nNumber of negative reviews tested: %d" % negative_total)
print("\n\nDisplaying the confusion matrix:\n")
out_str = " "
out_str += "%18s %18s" % ('predicted negative', 'predicted positive')
print(out_str + "\n")
for i,label in enumerate(['true negative', 'true positive']):
out_str = "%12s: " % label
for j in range(2):
out_str += "%18s" % out_percent[i,j]
print(out_str)
def run_code_for_testing_text_classification_with_gru(self, net, hidden_size):
net.load_state_dict(torch.load(self.dl_studio.path_saved_model))
classification_accuracy = 0.0
negative_total = 0
positive_total = 0
confusion_matrix = torch.zeros(2,2)
with torch.no_grad():
for i, data in enumerate(self.test_dataloader):
review_tensor,category,sentiment = data['review'], data['category'], data['sentiment']
hidden = net.init_hidden(1)
for k in range(review_tensor.shape[1]):
output, hidden = net(torch.unsqueeze(torch.unsqueeze(review_tensor[0,k],0),0), hidden)
predicted_idx = torch.argmax(output).item()
gt_idx = torch.argmax(sentiment).item()
if i % 100 == 99:
print(" [i=%d] predicted_label=%d gt_label=%d\n\n" % (i+1, predicted_idx,gt_idx))
if predicted_idx == gt_idx:
classification_accuracy += 1
# if gt_idx is 0:
if gt_idx == 0:
negative_total += 1
# elif gt_idx is 1:
elif gt_idx == 1:
positive_total += 1
confusion_matrix[gt_idx,predicted_idx] += 1
out_percent = np.zeros((2,2), dtype='float')
out_percent[0,0] = "%.3f" % (100 * confusion_matrix[0,0] / float(negative_total))
out_percent[0,1] = "%.3f" % (100 * confusion_matrix[0,1] / float(negative_total))
out_percent[1,0] = "%.3f" % (100 * confusion_matrix[1,0] / float(positive_total))
out_percent[1,1] = "%.3f" % (100 * confusion_matrix[1,1] / float(positive_total))
print("\n\nNumber of positive reviews tested: %d" % positive_total)
print("\n\nNumber of negative reviews tested: %d" % negative_total)
print("\n\nDisplaying the confusion matrix:\n")
out_str = " "
out_str += "%18s %18s" % ('predicted negative', 'predicted positive')
print(out_str + "\n")
for i,label in enumerate(['true negative', 'true positive']):
out_str = "%12s: " % label
for j in range(2):
out_str += "%18s" % out_percent[i,j]
print(out_str)
def plot_loss(self):
plt.figure()
plt.plot(self.LOSS)
plt.show()
#_________________________ End of DLStudio Class Definition ___________________________
#______________________________ Test code follows _________________________________
if __name__ == '__main__':
pass