# -*- coding: utf-8 -*-

__version__   = '1.0.7'
__author__    = "Avinash Kak (kak@purdue.edu)"
__date__      = '2025-May-29'   
__url__       = 'https://engineering.purdue.edu/kak/distGPT/babyGPT-1.0.7.html'
__copyright__ = "(C) 2025 Avinash Kak. Python Software Foundation."


import sys,os,os.path
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as tvt
import numpy as np
import math
import random
import string
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import time
import glob                                                                                                           
import json
import logging                        ## for suppressing matplotlib warning messages
import re
import itertools
import newspaper
from collections import Counter
from newspaper import Article
import blingfire as bling                     ## has the best sentence detector
from transformers import PreTrainedTokenizerFast
from tokenizers import Tokenizer                                                                                               
from tokenizers.pre_tokenizers import Whitespace 


#############################################################################################################################
################################################  Top level utility functions  ##############################################

import signal

def ctrl_c_handler( signum, frame ):             
    print("Killed by Ctrl C")                       
    os.kill( os.getpid(), signal.SIGKILL )       
signal.signal( signal.SIGINT, ctrl_c_handler )   


def dev():                                                                                                              
    if torch.cuda.is_available():                                                                                       
        return torch.device(f"cuda:0")                                                                                  
    return torch.device("cpu") 

def gen(container):
    j = 0
    while j < len(container):
        yield container[j]
        j += 1


###%%%
#############################################################################################################################
#############################################   babyGPT Class Definition   ##################################################

class babyGPT(object):

    def __init__(self, *args, **kwargs ):
        if args:
            raise ValueError(  
                   '''babyGPT constructor can only be called with keyword arguments for 
                      the following keywords: urls, max_seq_length, batch_size, embedding_size, num_atten_heads, beta1, beta2, epsilon,
                      num_warmup_steps, masking, use_gpu, verify_text_corpus, path_saved_model''')
        max_seq_length=batch_size=embedding_size=num_atten_heads=beta1=beta2=epsilon=num_warmup_steps=masking=use_gpu=verify_text_corpus=None
        urls=path_saved_model=None

        if 'urls' in kwargs                          :   urls = kwargs.pop('urls')
        if 'max_seq_length' in kwargs                :   max_seq_length = kwargs.pop('max_seq_length')
        if 'batch_size' in kwargs                    :   batch_size = kwargs.pop('batch_size')
        if 'embedding_size' in kwargs                :   embedding_size = kwargs.pop('embedding_size')
        if 'num_atten_heads' in kwargs               :   num_atten_heads = kwargs.pop('num_atten_heads')
        if 'beta1' in kwargs                         :   beta1 = kwargs.pop('beta1')
        if 'beta2' in kwargs                         :   beta2 = kwargs.pop('beta2')
        if 'epsilon' in kwargs                       :   epsilon = kwargs.pop('epsilon')
        if 'num_warmup_steps' in kwargs              :   num_warmup_steps = kwargs.pop('num_warmup_steps')
        if 'masking' in kwargs                       :   masking = kwargs.pop('masking')
        if 'use_gpu' in kwargs                       :   use_gpu = kwargs.pop('use_gpu')
        if 'verify_text_corpus' in kwargs            :   verify_text_corpus = kwargs.pop('verify_text_corpus')
        if 'path_saved_model' in kwargs              :   path_saved_model = kwargs.pop('path_saved_model')

        if urls:
            self.urls = urls
        else:
            self.urls = None 
        if max_seq_length:                         
            self.max_seq_length = max_seq_length    
        if batch_size:
            self.batch_size = batch_size
        if embedding_size:
            self.embedding_size = embedding_size
        if num_atten_heads:
            self.num_atten_heads = num_atten_heads
        if beta1:
            self.beta1 = beta1
        if beta2:
            self.beta2 = beta2
        if epsilon:
            self.epsilon = epsilon
        if num_warmup_steps:
            self.num_warmup_steps = num_warmup_steps
        if masking:
            self.masking = masking     
        if verify_text_corpus:
            self.verify_text_corpus = verify_text_corpus
        else:
            self.verify_text_corpus = False
        if path_saved_model:
            self.path_saved_model = path_saved_model


    ###%%%
    #############################################################################################################################
    ######################################  Start Definition of Inner Class ArticleGatherer  ###################################

    class ArticleGatherer:
        """
        This script is for collecting data for experimenting with the Transformer based
        unsupervised learning code in baby_gpt.py.  

        The articles are downloaded from the URLs that are specified by the argument 'urls' in the
        constructor shown below.  See the script "create_base_model.py" in the Examples directory
        for how to set the URL strings for this argument.  Here are some examples:

            urls = ['https://finance.yahoo.com','http://cnn.com',
                     'https://timesofindia.indiatimes.com',
                     'https://purdueexponent.org','https://slate.com', 
                     'https://sports.yahoo.com']
    
            urls = ['http://cnn.com']
    
            urls = ['https://slate.com']
    
            urls = ['https://timesofindia.indiatimes.com']
    
        """
        def __init__(self, gpt, urls, articles_dir = 'saved_articles_dir'):
            ##  'urls' is a local array in which we store all the article URLs from where we want to 
            ##   download the news articles:
            self.urls = gpt.urls
            self.articles_dir = articles_dir

        def download_articles(self):
            if os.path.exists(self.articles_dir): 
                articles = glob.glob(self.articles_dir + "/*") 
                for file in articles:        
                    if os.path.isfile(file):       
                        os.remove(file)      
                    else:       
                        files = glob.glob(file + "/*")         
                        list(map(lambda x: os.remove(x), files))
            else:       
                os.mkdir(self.articles_dir)      
            master_list_article_links =  []
            for url in self.urls:
                print("\n\nDownloading from URL: %s\n\n" % url)
                scraped = newspaper.build( url, memoize_articles=False )
                for article_link in scraped.articles:
                    master_list_article_links.append( article_link.url )
                    print(article_link.url)
                print("\n\nThe number of available articles: ", scraped.size())
            print("\n\n\nHere is a dump of the article url's from all the news websites: ", master_list_article_links)
            print("\n\n\nTotal number of articles in the dump: ", len(master_list_article_links) )

            article_index = 0
            for item_url in master_list_article_links:
                if not item_url.endswith(".html"):
                     continue
                article_file_name =  self.articles_dir + "/" +  "article_" + str(article_index) + ".txt"
                FILE = open(article_file_name, 'w')
                try:
                    article = Article(item_url)
                    article.download()
                    article.parse()
                except:
                    continue
                print("downloaded ", article_file_name)
                text = article.text
                FILE.write(text)
                FILE.flush()
                FILE.close()
                article_index += 1
    


    ###%%%
    #############################################################################################################################
    ######################################  Start Definition of Inner Class TrainTokenizer  #####################################

    class TrainTokenizer:
        """Tokenizers play a critical role in language modeling because they create a
        fixed-sized vocabulary for the corpus you are working with --- regardless of
        the size of the corpus itself.  Unless your text corpus is based on a set of
        documents frozen in time, ordinarily, as the size of a text corpus goes up,
        so does the size of the vocabulary --- despite the illusion to the contrary
        created by the fixed sizes of the language dictionaries you have seen all
        your life.  How we express ourselves is a living thing.  We are constantly
        inventing new words and new expressions; these form important components of
        what's referred to as the zeitgeist.

        Having a fixed-sized vocab is important because the loss functions used in
        deep-learning network used for language processing are based on
        maximum-likelihood prediction of the next token given the tokens seen
        previously.  That requires estimating the probabilities associated with all
        possible tokens at the next position.  As you can imagine, it would be
        impossible to engage in such probabilistic reasoning if you did not know in
        advance the size of the vocabulary.

        """
        def __init__(self, corpus_directory, target_vocab_size=50000):
            import babyGPT
            version_str =  babyGPT.__version__
            version_str = version_str.replace(".", "")
            self.tokenizer_json_stem   =   version_str + "_babygpt_tokenizer_"
            self.corpus_dir = corpus_directory
            self.unk_token = "[UNK]"                    # token for undecipherable bytes
            self.spl_tokens = ["<UNK>", "<SEP>", "<MASK>", "<CLS>"] 
            self.target_vocab_size = target_vocab_size

            ##  Since we are assuming utf-8 encoding of the text corpus, we already have
            ##  the mappings between the numbers 0 through 255 and their corresponding
            ##  tokens as would be yielded by calling the Python function chr() on the
            ##  integers between 0 and 255.  [For example, chr(255) returns the character
            ##  'ΓΏ'. What that means is that 255 is the Unicode code point for this symbol.]
            ##  So the first available index for a new token produced by the merge rule 
            ##  would be 256:
            self.next_index_available = 256
            ##  I use "testing_iter" for producing the intermediate results during training:
            self.testing_iter = 0                         

      
        def train_tokenizer(self):
            """
            Tokenization steps: 

            -- Start with a base vocabulary for the tokens that consists of all 256 integer
               values that can be taken by a byte.

            -- Search through consecutively occurring numeric codes for the Unicode bytes to 
               find the pair that is the most frequent

            -- Replace all such pairs of the more elementary tokens with the new token for all
               the words
           
            -- Apply the logic described above iteratively until the size of the tokenizer
               vocab has reached the prescribed value.

            Note that the size of the tokenizer vocabulary is sum of the size of the Base Vocab
            and the number of merges.  The Base Vocab consists of the unique individual
            characters in the training dataset
            """

            def word_as_num_seq(word):
                for char in list(word):
                    if char not in char_to_num_dict:
                        char_to_num_dict[char] = self.next_index_available
                        merge_rules_dict[ self.next_index_available ] = char
                        self.next_index_available += 1       
                return [char_to_num_dict[char] for char in list(word) ]
            
            def get_str_token( num ):
                """
                Note that merge_rules_dict is what becomes the vocab eventually.  We make the
                conversion by reversing the <key,num> pairs in merge_rules_dict.
                """
                if num in num_to_char_dict:
                    return num_to_char_dict[num]
                elif num in merge_rules_dict:
                    return merge_rules_dict[num]
                else:
                    sys.exit("\n\n[get_str_token]  merge_rules_dict has no merge rule for the int token %d\n\n" % num)
            
            def subword_for_num_seq( num_seq ):
                subword = ""
                for num in num_seq:
                    if num in num_to_char_dict:
                        subword += chr(num)
                    elif num in merge_rules_dict:
                        subword += merge_rules_dict[num]
                    else:
                        sys.exit("\n\n[subword_for_num_seq] merge_rules_dict has no merge rule for the int token %d\n\n" % num)
                return subword
            
            def update_tokenizer_dict( tokenizer_dict, most_frequent_pair, new_token_as_num ):
                new_tokenizer_dict = {word : [] for word in tokenizer_dict}
                for word in tokenizer_dict:
                    str_rep = ",".join(str(i) for i in tokenizer_dict[word])
                    to_be_replaced_pair =  r"\b" +  ",".join(str(i) for i in most_frequent_pair) + r"\b"
                    replacement = str(new_token_as_num) 
                    output_str= re.sub(to_be_replaced_pair, replacement, str_rep)
                    new_tokenizer_dict[word]  =  [int(i) for i in output_str.split(",")]
                return new_tokenizer_dict
            
            
            def find_best_ngram_and_update_word_tokens_dict(tokenizer_dict):
                all_consec_pairs_dict = { word : list( zip( tokenizer_dict[word], tokenizer_dict[word][1:] ) ) for word in tokenizer_dict }
                all_consec_triples_dict =   { word : list( zip( tokenizer_dict[word], tokenizer_dict[word][1:],  tokenizer_dict[word][2:] ) ) 
                                                                                                                     for word in tokenizer_dict }
                all_consec_quads_dict   =   { word : list( zip( tokenizer_dict[word], tokenizer_dict[word][1:],  tokenizer_dict[word][2:], 
                                                                                        tokenizer_dict[word][3:] ) ) for word in tokenizer_dict }   
                all_consec_all_ngrams_dict = {}
                for word in all_consec_pairs_dict:
                    if word in all_consec_triples_dict and  word in all_consec_quads_dict:
                        all_consec_all_ngrams_dict[word]  =  all_consec_pairs_dict[word] + all_consec_triples_dict[word] + all_consec_quads_dict[word]
                    elif word in all_consec_triples_dict:
                        all_consec_all_ngrams_dict[word]  =  all_consec_pairs_dict[word] + all_consec_triples_dict[word]
                    else:
                        all_consec_all_ngrams_dict[word]  =  all_consec_pairs_dict[word]
                all_consec_all_ngrams_dict  =   {word : all_consec_all_ngrams_dict[word] for word in all_consec_all_ngrams_dict 
                                                                                                      if len(all_consec_all_ngrams_dict[word]) > 0}
                most_frequent_ngram = list(Counter( list( itertools.chain(*all_consec_all_ngrams_dict.values()) ) ).keys()) [0]
                string_for_merges_array = "%s %s" % (get_str_token(most_frequent_ngram[0]), get_str_token(most_frequent_ngram[1]))
                merges.append( string_for_merges_array )
                subword_for_most_frequent_ngram  =  subword_for_num_seq( most_frequent_ngram )
                if self.testing_iter % 100 == 0:
                    print("\n\n[testing_iter: %d] Will merge the following subwords for the new most frequently occurring subword:" % self.testing_iter)
                    if len(most_frequent_ngram) == 2:
                        print("%s    %s" % (get_str_token(most_frequent_ngram[0]), get_str_token(most_frequent_ngram[1])))
                    elif len(most_frequent_ngram) == 3:
                        print("%s    %s    %s" % (get_str_token(most_frequent_ngram[0]), get_str_token(most_frequent_ngram[1]),  
                                                                                         get_str_token(most_frequent_ngram[2] )))        
                    else:
                        print("%s    %s    %s    %s" % (get_str_token(most_frequent_ngram[0]), get_str_token(most_frequent_ngram[1]),  
                                                        get_str_token(most_frequent_ngram[2]), get_str_token(most_frequent_ngram[3]) ))
                    print("\n\nAdding to tokenizer vocab: ",  subword_for_most_frequent_ngram)
                merge_rules_dict[self.next_index_available] = subword_for_most_frequent_ngram
                new_tokenizer_dict = update_tokenizer_dict( tokenizer_dict, most_frequent_ngram, self.next_index_available )
                if self.testing_iter % 100 == 0:
                    print("\n\n[testing_iter: %d] UPDATED tokenizer dict:\n" % self.testing_iter)
                    for word in new_tokenizer_dict:
                        print("%s  =>  %s" % (word, str( [get_str_token(i) for i in new_tokenizer_dict[word]] )))
                self.next_index_available += 1
                return new_tokenizer_dict
            seed_value = 0
            random.seed(seed_value)
            os.environ['PYTHONHASHSEED'] = str(seed_value)
            dir_textfiles =  self.corpus_dir
            ##  The dict defined in the next statement stores the mappings from the symbolic tokens to integers that represent 
            ##  them. For the number range 0 through 255, the mappings stored are those that are returned by calling chr() on 
            ##  the Unicode numbers between 0 and 255. Subsequently, as larger tokens are constructed by merging the "sub-word" 
            ##  tokens, we add those tokens and their associated numbers to this dict.   
            char_to_num_dict = { chr(num) :  num for num in range(256) }
            num_to_char_dict = { num : chr(num) for num in range(256) }
            merge_rules_dict = { i : "" for i in range(256, self.target_vocab_size) }
            ##  I store all pairwise merges in the following array.  Each element of this array is a string 
            ##  that looks like  "str1 str2" where str1 and str2 are the two subwords that are to be merged together.
            merges = []                            
            text = ""
            ##  Data text data from file. Note that using errors='ignore' may NOT be the right option for opening a file:  
            ##  https://stackoverflow.com/questions/45529507/unicodedecodeerror-utf-8-codec-cant-decode-byte-0x96-in-position-35-invalid
            if os.path.exists(dir_textfiles):
                    textfiles = glob.glob(dir_textfiles + "/*")
                    print("\n\nNumber of text files: ", len(textfiles))
                    for filedoc in textfiles:
                        if os.path.isfile(filedoc):
                            with open( filedoc, encoding='utf8', errors='ignore' ) as f:
                                text += f.read()
            print("\n\nlength of the text string: ", len(text))
            ##  We will store the merged char mappings for the new tokens in this dictionary
            merged_symbols_dict = {num : None for num in range(256, self.target_vocab_size) } 
            
            all_words = text.split()
            print("\n\nNumber of words in the list 'all_words': ", len(all_words))
            print("\n\nfirst 100 entries in all_words: ", all_words[:100])
            ##  We need the word frequencies BECAUSE we need to find the most frequently occurring token pair 
            ##  in the corpus.  That is, for a given token pair, we need to know the number of words in which 
            ##  that pair occurs.
            words_with_counts = Counter(all_words)
            unique_words = list(set( all_words ))
            print("\n\nnumber of UNIQUE words: ", len(unique_words))
            print("\nfirst 100 UNIQUE words: ", unique_words[:100])
            word_tokens_dict =  { word : word_as_num_seq(word) for word in unique_words }                     ##  Initialization of word_tokens_dict
            print("\n\nIterative learning of the merge rules:\n\n")
            for i in range(256): 
                merge_rules_dict[i] = chr(i)           ## the char returned by the function chr(i) is the char under utf-8 encoding
            while self.next_index_available <= self.target_vocab_size:
                self.testing_iter += 1
                new_word_tokens_dict = find_best_ngram_and_update_word_tokens_dict( word_tokens_dict )
                if self.testing_iter % 100 == 0:
                    print("\n\n[testing_iter = %d] Size of the tokenizer vocab: " % self.testing_iter,  self.next_index_available-1) 
                word_tokens_dict = new_word_tokens_dict
#                if self.testing_iter % 10000 == 0:
                if self.testing_iter % 5000 == 0:
                    FILE = open("merge_rules_dictionary_" +  str(self.testing_iter) + ".txt", 'w')
                    for i in merge_rules_dict: 
                        FILE.write("%d       =>       %s\n" % (i, merge_rules_dict[i]))
                    merge_rules_dict[self.target_vocab_size + 1] = "<UNK>"
                    vocab = {val : key for (key,val) in merge_rules_dict.items()}
                    print("\n\n[testing_iter: %d] vocab: " % self.testing_iter, vocab)
                    print("\n\n[testing_iter: %d] merges array:" % self.testing_iter, merges)
                    vocab_and_merges =  {"version" : "1.0", 
                                         "truncation" : None,
                                         "padding" : None,
                                         "added_tokens" : [
                                              {"id" : self.target_vocab_size+1, 
                                               "content" : "<UNK>",
                                               "single_word": False,  
                                               "lstrip": False,
                                               "rstrip": False, 
                                               "normalized": False, 
                                               "special": True,
                                              },
                                         ],
                                         "normalizer": None,
                                         "pre_tokenizer": {
                                             "type": "Whitespace"
                                         },  
                                         "model" :  {"type": "BPE", "dropout" : None, "vocab" :  vocab,  "merges" : merges } }
                    with open(self.tokenizer_json_stem + str(self.testing_iter) + ".json", "w") as outfile:
                        json.dump(vocab_and_merges, outfile, indent=4)
            FILE = open("merge_rules_dictionary_" +  str(self.testing_iter) + ".txt", 'w')
            for i in merge_rules_dict: 
                FILE.write("%d       =>       %s\n" % (i, merge_rules_dict[i]))
            merge_rules_dict[self.target_vocab_size + 1] = "<UNK>"
            vocab = {val : key for (key,val) in merge_rules_dict.items()}
            print("\n\nvocab: ", vocab)
            print("\n\nmerges array:", merges)
            vocab_and_merges =  {"version" : "1.0", 
                                 "truncation" : None,
                                 "padding" : None,
                                 "added_tokens" : [
                                      {"id" : self.target_vocab_size+1, 
                                       "content" : "<UNK>",
                                       "single_word": False,  
                                       "lstrip": False,
                                       "rstrip": False, 
                                       "normalized": False, 
                                       "special": True,
                                      },
                                 ],
                                 "normalizer": None,
                                 "pre_tokenizer": {
                                     "type": "Whitespace"
                                 },  
                                 "model" :  {"type": "BPE", "dropout" : None, "vocab" :  vocab,  "merges" : merges } }
            with open(self.tokenizer_json_stem + str(self.testing_iter) + ".json", "w") as outfile:
                json.dump(vocab_and_merges, outfile, indent=4)
            


    ###%%%
    #############################################################################################################################
    #############################  Start Definition of Inner Class ArticleDatasetWithBufferedContext  ###########################

    class ArticleDatasetWithBufferedContext:    
        """
        The parameter 'context_window_size' is related to how many tokens you can feed into the
        transformer at one iteration as the training corpus is being scanned.  In my Week 14 lecture
        on Transformers, I used the notation 'max_seq_len' for this parameter.

        """
        def __init__(self, gpt, tokenizer_json, context_window_size, context_buffer_size=7, articles_dir='saved_articles_dir'):

            if os.path.exists(articles_dir): 
                num_articles = len(glob.glob(articles_dir + "/*")) 
                if gpt.verify_text_corpus:
                    if num_articles == 0:
                        sys.exit("\n\nAborting --- You have no articles in the articles directory.  You may need to first use the ArticleGatherer")
                    ans = input("\n\nYou have %d articles in the articles directory. Continue? Enter 'y' if yes: " % num_articles)
                    ans = ans.strip()
                    if ans != ('y' or 'yes'): 
                        print("\n\nPlease run the 'run_gatherer()' function to gather the news articles.\n\n")
            else:
                sys.exit("\n\nAborting --- Your articles directory %s does not exist." % articles_dir)
            print("\n\nThe Dataloader will be applied to the previously collected trove of articles in %s." % articles_dir)
            print()
            self.dir_collected_articles = articles_dir
            self.num_articles = num_articles
            self.context_buffer_size = context_buffer_size
            ## The file named below must be a json file created by a tokenizer training routine:
            self.tokenizer_json = tokenizer_json                       
            self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=self.tokenizer_json)
            FILE = open(self.tokenizer_json)    
            tokenizer_dict =  json.load( FILE ) 
            self.batch_size = gpt.batch_size
            self.context_window_size = context_window_size
            self.inverse_lookup  =  {v:k for k,v in tokenizer_dict['model']['vocab'].items()}  
            self.articles = []
            self.articles_for_batch_instances = []
            self.encoded_streams = {}              ## A dict whose keys are batch instance indexes            
            self.all_encoded_streams  =   []       ## A list of the values in the above dict
            self.iteration_index = 0               ## This value is reset to 0 at the beginning of each new epoch
            self.epoch_index = 0
            self.datastreams_initialized = False

        def generate_article_streams(self):
            debug = False
            def gen(container):
                j = 0
                while j < len(container):
                    yield container[j]
                    j += 1
            random.shuffle(self.articles)
            self.articles_for_batch_instances = [self.articles[i:i+len(self.articles)//self.batch_size] for i in range(self.batch_size)]
            self.encoded_streams =  []
            ## Create a stream of encoding for each batch instance
            for i in  range(self.batch_size):
                article_gen = gen( self.articles_for_batch_instances[i] )
                encoded_stream = [] 
                for article in article_gen:
                    if article is None: break
                    FILE = open(article)
                    text = FILE.read()
                    if debug:
                        encoded = self.tokenizer.encode(text)
                        print("\n\n\ntext in article: ", text)
                        print("after tokenization and encoding: ", encoded)
                        the_tokens = [self.inverse_lookup[code] for code in encoded]
                        print("the individual tokens: ", the_tokens)
                    encoded_stream += self.tokenizer.encode(text)
                self.encoded_streams.append( encoded_stream )
    

        def generate_article_sequences_for_batch_instances(self):
            """
            "equalization" here means that we want all the streams AS EQUAL IN LENGTH AS POSSIBLE
            based on N different attempts at article randomization.  Highly unequal stream lengths 
            can make GPT learning inefficient --- and sometimes impossible.
            """
            debug = False
            ## We need to find the total number of tokens in all the articles in our corpus.  Subsequently,
            ## when we partition the corpus into sub-corpora, with one sub-corpus for each batch instance,
            ## we want to make sure that the total number of tokens available for the token-stream created
            ## for each batch instance is roughly the same.
            article_sizes = { article : None for article in self.articles }  ## size is measured in terms of the number of tokens
            master_article_gen = gen(self.articles)
            total_num_tokens = 0 
            for article in master_article_gen:
                FILE = open(article)
                text = FILE.read()
                article_tokens = self.tokenizer.encode( text )
                article_sizes[article] = len(article_tokens) 
                total_num_tokens += len(article_tokens)

            if debug:
                print("\n\narticle sizes: ", article_sizes)
                print("\n\ntotal_num_tokens: ", total_num_tokens)
                print("\n\n\n")

            ##  Now we want to assign articles to each batch instance in such a way that the total number
            ##  of tokens assigned to a batch instance is approximately the same for batch instances. I am
            ##  going to use the followings dicts for this logic:
            num_tokens_per_batch_instance = total_num_tokens // self.batch_size
            article_sequence_for_batch_instance = {i : [] for i in range(self.batch_size)}           ## The sub-corpora of articles
            char_stream_size_for_batch_instance = {i : 0 for i in range(self.batch_size)}          ## The token stream for each sub-corpus

            ##  Now we are ready to create a sub-corpus for each batch instance. Each sub-corpus will eventually 
            ##  be turned into a token stream.  The epoch-to-epoch randomization of the input data would consist
            ##  of randomizing the sequence of articles (meaning, the order in which the articles appear) in
            ##  each sub-corpus.
            for article in article_sizes:
                ##  This is a variant of the heuristic algorithms used commonly for solving the combinatorial NP-Hard BIN 
                ##  PACKING Optimization problem in which the object are placed in unit-sized bins so as to minimize the
                ##  bins used.  The heuristic I have used here is to assign an article to that sub-corpus that currently
                ##  has the least total number of tokens in it. REMEMBER we measure the size of an article in terms of the 
                ##  number of tokens needed for that article.
                smallest_idx =  (sorted(char_stream_size_for_batch_instance, key=char_stream_size_for_batch_instance.get ))[0]
                article_sequence_for_batch_instance[smallest_idx].append(article)
                char_stream_size_for_batch_instance[smallest_idx] += article_sizes[article]
            ##  Let's now check we did a good job of roughly equalizing the number of tokens for each sub-corpus:
            for i in  range(self.batch_size):
                total_num_tokens = 0 
                article_gen = gen(article_sequence_for_batch_instance[i])
                for article in article_gen:
                    FILE = open(article)
                    text = FILE.read()
                    article_tokens = self.tokenizer.encode( text )
                    article_sizes[article] = len(article_tokens) 
                    total_num_tokens += len(article_tokens)
         
            self.article_sequence_for_batch_instance = article_sequence_for_batch_instance

        def generate_token_streams_for_batch_instances(self):
            debug = False 
            article_sequence_for_batch_instance  = self.article_sequence_for_batch_instance
            for seq_idx in article_sequence_for_batch_instance:
                random.shuffle( article_sequence_for_batch_instance[seq_idx] )          ## randomization at the beginning of each epoch
            ## Create a stream of encoding for each batch instance
            self.encoded_streams =  {i : [] for i in range(self.batch_size)}
            for i in  range(self.batch_size):
                article_gen = gen(article_sequence_for_batch_instance[i])
                for article in article_gen:
                    FILE = open(article)
                    text = FILE.read()
                    ## Change made on Jan 29, 2025. Insert underscore between the words to help out with the detokenization step:
                    all_words = text.split()  
                    all_words = [word + " _" if re.search(r'.*[\w]$', word) else word for word in all_words] 
                    text = ' '.join(all_words)
                    article_tokens = self.tokenizer.encode( text )
                    self.encoded_streams[i] += article_tokens
            ## Now let's check the difference in length between the longest batch-instance stream
            ## and the shortest batch-instance stream:
            self.all_encoded_streams = list(self.encoded_streams.values())
            shortest_encoded_stream = min(self.all_encoded_streams, key=lambda x: len(x))
            longest_encoded_stream = max(self.all_encoded_streams, key=lambda x: len(x))
            stream_len_disparity =  len(longest_encoded_stream)  -  len(shortest_encoded_stream) 
            if debug:
                print("\n\nlength of the shortest stream: ", len(shortest_encoded_stream))
                print("length of the longest stream: ", len(longest_encoded_stream))
                print("value of stream_len_disparity: ", stream_len_disparity)


        def initialize_tokenized_data_streams(self):
            if self.datastreams_initialized == False:
                self.articles = glob.glob(self.dir_collected_articles + "/*")               
                self.generate_article_sequences_for_batch_instances()
                self.generate_token_streams_for_batch_instances()
                self.datastreams_initialized = True

        def dataloader_for_buffered_context(self, how_many):
            """
            The argument "how_many" means the size of the context_window_size that is specified in the call to the 
            constructor of ArticleDatasetWithBufferedContext. 

            This function returns a batch of token sequences on each call.  A batch is constructing by pulling the token sequences 
            for each batch instance from the 'batch_size' number of token streams created in the constructor of the 'Ddataloader'
            class. When that process gets too close the end of the shortest of the 'batch_size' number of streams, the articles 
            are randomized again for assignment to the individual batch-instance streams.

            The variable   self.iteration_index  keeps track of where the downloader is in each batch-instance stream as feed data
            one batch at a time into the Transformer.
            """
            debug = False
            batch_size = self.batch_size
            context_window_size = how_many
            cws_minus_one = context_window_size - 1
            codes_for_SOS = [89, 90, 91, 92, 93, 94, 96, 97, 98]

            if any( len( self.all_encoded_streams[i][self.iteration_index*cws_minus_one : ] )  < cws_minus_one for i in range(batch_size) ):
                self.epoch_index += 1
                print("\n\nStarting epoch: %d\n" % (self.epoch_index + 1))
                self.iteration_index = 0

            ## self.iteration_index == 0  means we are starting a new epoch
            if self.datastreams_initialized and self.iteration_index == 0:
                self.articles = glob.glob(self.dir_collected_articles + "/*")               
                self.generate_article_sequences_for_batch_instances()
                self.generate_token_streams_for_batch_instances()

            out = np.zeros(shape=(batch_size, context_window_size), dtype=int)

            for i in range(batch_size):
                out[i,1:] =  self.all_encoded_streams[i][self.iteration_index*cws_minus_one :  (self.iteration_index+1) * cws_minus_one]
                out[i,0] = 89
            self.iteration_index  += 1
            return out

        def test_dataloader(self, how_many):
            data = self.dataloader_for_buffered_context(how_many)
            print("\n\n\nshape of the data returned by the dataloader: ", data.shape)
            print("\n\ndata returned by the dataloader:")
            print(data)
            tokens = [[self.inverse_lookup[code] for code in data[i]] for i in range(self.batch_size)]
            print(tokens)
            
            data = self.dataloader_for_buffered_context(how_many)
            print("\n\n\nshape of the data returned by the dataloader: ", data.shape)
            print("\n\ndata returned by the dataloader:")
            print(data)
            tokens = [[self.inverse_lookup[code] for code in data[i]]  for i in range(self.batch_size)]
            print(tokens)
            

        def display_token_vocab(self):  
            for code in self.inverse_lookup:
                print("%d        =>       %s" % (code , str( self.inverse_lookup[code] ) ) )
            

    ###%%%
    #############################################################################################################################
    ########################################  Start Definition of Inner Class TransformerFG  ####################################

    class TransformerFG(nn.Module):             
        """
        I have borrowed from the DLStudio's Transformers module.  "FG" stands for "First Generation" --- which is the Transformer
        as originally proposed by Vaswani et al.
        """
        def __init__(self, max_seq_length, embedding_size, tokenizer_json, num_warmup_steps=None, optimizer_params=None):
            super(babyGPT.TransformerFG, self).__init__()
            self.max_seq_length = max_seq_length
            self.embedding_size = embedding_size
            self.num_warmup_steps = num_warmup_steps
            self.optimizer_params = optimizer_params
            self.tokenizer_json = tokenizer_json                       
            self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=self.tokenizer_json)
            FILE = open(self.tokenizer_json)    
            tokenizer_dict =  json.load( FILE ) 
            self.inverse_lookup  =  {v:k for k,v in tokenizer_dict['model']['vocab'].items()}  
            self.vocab_size = self.tokenizer.vocab_size
    
        def sentence_with_words_to_ints(self, sentences, lang):
            sentence_to_ints = torch.ones(len(sentences), self.max_seq_length, dtype=torch.long)
            for i in range(len(sentences)):
                words = sentences[i].split(' ')
                for j,word in enumerate(words):
                    sentence_to_ints[i,j] = self.en_vocab_dict[word] if lang=="en" else self.es_vocab_dict[word]
            return sentence_to_ints
    
    class EmbeddingGenerator(nn.Module):
        def __init__(self, xformer, embedding_size):
            super(babyGPT.EmbeddingGenerator, self).__init__()
            tokenizer = PreTrainedTokenizerFast(tokenizer_file=xformer.tokenizer_json)
            self.vocab_size =  xformer.vocab_size
            self.embedding_size = embedding_size                                             
            self.max_seq_length = xformer.max_seq_length                                                     
            self.embed = nn.Embedding(self.vocab_size, embedding_size)

        def forward(self, sentence_tensor):                                                                 
            sentence_tensor = sentence_tensor
            ## Let's say your batch_size is 4 and that each sentence has a max_seq_length of 10.
            ## The sentence_tensor argument will now be of shape [4,10].  If the embedding size is
            ## is 512, the following call will return a tensor of shape [4,10,512)
            word_embeddings = self.embed(sentence_tensor)
            position_coded_word_embeddings = self.apply_positional_encoding( word_embeddings )
            return position_coded_word_embeddings

        def apply_positional_encoding(self, sentence_tensor):
            position_encodings = torch.zeros_like( sentence_tensor ).float()
            ## Calling unsqueeze() with arg 1 causes the "row tensor" to turn into a "column tensor"
            ##    which is needed in the products shown below. We create a 2D pattern by 
            ##    taking advantage of how PyTorch has overloaded the definition of the infix '*' 
            ##    tensor-tensor multiplication operator.  It in effect creates an output-product of
            ##    of what is essentially a column vector with what is essentially a row vector.
            word_positions = torch.arange(0, self.max_seq_length).unsqueeze(1)            
            div_term =  1.0 / (100.0 ** ( 2.0 * torch.arange(0, self.embedding_size, 2) / float(self.embedding_size) ))
            position_encodings[:, :, 0::2] =  torch.sin(word_positions * div_term)                             
            position_encodings[:, :, 1::2] =  torch.cos(word_positions * div_term)                             
            return sentence_tensor + position_encodings

    ###%%%
    #######################################################################################################################
    ###################################  Self Attention Code for TransformerFG  ###########################################

    class SelfAttention(nn.Module):
        """
        Borrowed from the Transformers module of DLStudio
        """  
        def __init__(self, xformer, num_atten_heads):
            super(babyGPT.SelfAttention, self).__init__()
            self.max_seq_length = xformer.max_seq_length                                                     
            self.embedding_size = xformer.embedding_size
            self.num_atten_heads = num_atten_heads
            self.qkv_size = self.embedding_size // num_atten_heads
            self.attention_heads_arr = nn.ModuleList( [babyGPT.AttentionHead(self.max_seq_length, 
                                    self.qkv_size, num_atten_heads)  for _ in range(num_atten_heads)] )           

        def forward(self, sentence_tensor):                                                                       
            concat_out_from_atten_heads = torch.zeros( sentence_tensor.shape[0], self.max_seq_length, 
                                                                  self.num_atten_heads * self.qkv_size).float()
            for i in range(self.num_atten_heads):                                                                 
                sentence_embed_slice = sentence_tensor[:, :, i * self.qkv_size : (i+1) * self.qkv_size]
                concat_out_from_atten_heads[:, :, i * self.qkv_size : (i+1) * self.qkv_size] =          \
                                                               self.attention_heads_arr[i](sentence_embed_slice)   
            return concat_out_from_atten_heads


    class AttentionHead(nn.Module):
        """
        Borrowed from the Transformers module of DLStudio
        """  
        def __init__(self,  max_seq_length, qkv_size, num_atten_heads):
            super(babyGPT.AttentionHead, self).__init__()
            self.qkv_size = qkv_size
            self.max_seq_length = max_seq_length
            self.WQ =  nn.Linear( self.qkv_size, self.qkv_size )                                                      
            self.WK =  nn.Linear( self.qkv_size, self.qkv_size )                                                      
            self.WV =  nn.Linear( self.qkv_size, self.qkv_size )                                                      
            self.softmax = nn.Softmax(dim=-1)                                                                          

        def forward(self, sent_embed_slice):           ## sent_embed_slice == sentence_embedding_slice                
            Q = self.WQ( sent_embed_slice )                                                                           
            K = self.WK( sent_embed_slice )                                                                           
            V = self.WV( sent_embed_slice )                                                                           
            A = K.transpose(2,1)                                                                                      
            QK_dot_prod = Q @ A                                                                                       
            rowwise_softmax_normalizations = self.softmax( QK_dot_prod )                                              
            Z = rowwise_softmax_normalizations @ V                                                                    
            coeff = 1.0/torch.sqrt(torch.tensor([self.qkv_size]).float()).to(dev())                
            Z = coeff * Z                                                                          
            return Z


    ###%%%
    #######################################################################################################################
    #########################################  Basic Decoder Class for TransformerFG  #####################################

    class BasicDecoderWithMasking(nn.Module):
        """
        Borrowed from the Transformers module of DLStudio
        """  
        def __init__(self, xformer, num_atten_heads, masking=True):
            super(babyGPT.BasicDecoderWithMasking, self).__init__()
            self.masking = masking
            self.embedding_size = xformer.embedding_size                                             
            self.max_seq_length = xformer.max_seq_length                                                     
            self.num_atten_heads = num_atten_heads
            self.qkv_size = self.embedding_size // num_atten_heads
            self.self_attention_layer = babyGPT.SelfAttention(xformer, num_atten_heads)
            self.norm1 = nn.LayerNorm(self.embedding_size)
            self.norm2 = nn.LayerNorm(self.embedding_size)
            ## What follows are the linear layers for the FFN (Feed Forward Network) part of a BasicDecoder
            self.W1 =  nn.Linear( self.embedding_size, 4 * self.embedding_size )
            self.W2 =  nn.Linear( 4 * self.embedding_size, self.embedding_size ) 
            self.norm3 = nn.LayerNorm(self.embedding_size)

        def forward(self, sentence_tensor, mask):   
            masked_sentence_tensor = self.apply_mask(sentence_tensor, mask)
            Z_concatenated = self.self_attention_layer(masked_sentence_tensor).to(dev())
            Z_out = self.norm1(Z_concatenated + masked_sentence_tensor)                     
            ## for FFN:
            basic_decoder_out =  nn.ReLU()(self.W1( Z_out.view( sentence_tensor.shape[0], self.max_seq_length, -1) ))                  
            basic_decoder_out =  self.W2( basic_decoder_out )                                                    
            basic_decoder_out = basic_decoder_out.view(sentence_tensor.shape[0], self.max_seq_length, self.embedding_size )
            basic_decoder_out =  basic_decoder_out  + Z_out 
            basic_decoder_out = self.norm3( basic_decoder_out )
            return basic_decoder_out

        def apply_mask(self, sentence_tensor, mask):  
            out = torch.zeros_like(sentence_tensor).float().to(dev())
            out[:,:len(mask),:] = sentence_tensor[:,:len(mask),:] 
            return out    



    ###%%%
    #######################################################################################################################
    ######################################  MasterDecoder Class for TransformerFG #########################################

    class MasterDecoderWithMasking(nn.Module):
        """
        Borrowed from the Transformers module of DLStudio
        """  
        def __init__(self, xformer, num_basic_decoders, num_atten_heads, masking=True):
            super(babyGPT.MasterDecoderWithMasking, self).__init__()
            self.masking = masking
            self.max_seq_length = xformer.max_seq_length
            self.embedding_size = xformer.embedding_size
            self.vocab_size = xformer.vocab_size                                             
            self.basic_decoder_arr = nn.ModuleList([babyGPT.BasicDecoderWithMasking( xformer,
                                                    num_atten_heads, masking) for _ in range(num_basic_decoders)])  
            ##  Need the following layer because we want the prediction of each target word to be a probability 
            ##  distribution over the target vocabulary. The conversion to probs would be done by the criterion 
            ##  nn.CrossEntropyLoss in the training loop:
            self.out = nn.Linear(self.embedding_size, self.vocab_size)                                          

        def forward(self, sentence_tensor, mask):                                                   
            out_tensor = sentence_tensor
            for i in range(len(self.basic_decoder_arr)):                                                 
                out_tensor = self.basic_decoder_arr[i](out_tensor, mask)                              
            word_index = mask.shape[0]
            last_word_tensor = out_tensor[:,word_index]                                      
            last_word_onehot = self.out(last_word_tensor)        
            output_word_logprobs = nn.LogSoftmax(dim=1)(last_word_onehot)                                     
            _, idx_max = torch.max(output_word_logprobs, 1)                
            ## the logprobs are over the entire vocabulary of the tokenizer
            return output_word_logprobs, idx_max


        def apply_mask(self, sentence_tensor, mask):  
            out = torch.zeros_like(sentence_tensor).float().to(dev())
            out[:,:len(mask),:] = sentence_tensor[:,:len(mask),:] 
            return out    


    ###%%%
    #######################################################################################################################
    ############################################### Training babyGPT  #####################################################

    def save_decoder(self, decoder):
        "Save the trained decoder to a disk file"       
        torch.save(decoder.state_dict(), self.gpt.path_saved_model["saved_decoder"])

    def save_embedding_generator(self, embedding_generator):
        torch.save(embedding_generator.state_dict(), self.gpt.path_saved_model["saved_embeddings_generator"])

    def save_checkpoint_decoder(self, decoder, dir_name, iter_index):
        "Save the decoder checkpoint"       
        torch.save(decoder.state_dict(), dir_name + "/saved_decoder_" + str(iter_index))

    def save_checkpoint_embedding_generator(self, embedding_generator, dir_name, iter_index):
        "save checkpoint for the embedding_generator"
        torch.save(embedding_generator.state_dict(), dir_name + "/saved_embedding_generator_" + str(iter_index))        


    def run_code_with_buffered_context_for_training_TransformerFG(self, xformer, master_decoder, dataloader, 
                                                           checkpoint_frequency=1000, display_train_loss=False ):
        """
        Drawn from the training routines in the Transformer module of DLStudio
        """
        def detokenizer( token_sequence_as_string ):
            regex = r'\s_\s'
            out_words = ""
            try:
                out_words = re.split(regex, token_sequence_as_string)
            except TypeError as e:
                print(e)
                return [""] * len(token_sequence)
            ## Join together the space-separated token fragments into complete words, but make sure 
            ## you do NOT cross punctuation marks:
            new_all_words = []
            for word in out_words:
                 frag = word
                 while re.search( r'\w+\s\w+', frag ):
                     frag =  re.sub(r'(\w+)\s(\w+)', r'\1\2', frag)
                 new_all_words.append(frag)
            ## If a word obtained from the previous step include a fragment that terminates in a 
            ## punctuation mark which can be any of ".?,!]+.?", break it into two or more subwords:
            cleaned_all_words = []
            for word in new_all_words:
                new_words = []   
                if any(char in string.punctuation for char in word):
                    parts = re.findall(r'[^.?,!]+.?', word)
                    cleaned_all_words += parts
                else:
                    cleaned_all_words.append(word)
            return ' '.join(cleaned_all_words)

        checkpoint_dir =  "checkpoint_dir"
        if os.path.exists(checkpoint_dir):  
            files = glob.glob(checkpoint_dir + "/*")
            for file in files: 
                if os.path.isfile(file): 
                    os.remove(file) 
                else: 
                    files = glob.glob(file + "/*") 
                    list(map(lambda x: os.remove(x), files)) 
        else: 
            os.mkdir(checkpoint_dir)   

        context_window_size = dataloader.context_window_size
        context_buffer_size = dataloader.context_buffer_size
        FILE_for_training_results = open("saved_training_with_buffered_context_results.txt",'w')
        FILE_for_training_loss = open("training_loss_vs_iterations.txt",'w')
        master_decoder.to(dev())     
        embedding_generator = self.EmbeddingGenerator(xformer, self.embedding_size).to(dev())
        beta1,beta2,epsilon = xformer.optimizer_params['beta1'], xformer.optimizer_params['beta2'], xformer.optimizer_params['epsilon']     
        master_decoder_optimizer = self.ScheduledOptim(optim.Adam(master_decoder.parameters(), betas=(beta1,beta2), eps=epsilon),
                                            lr_mul=2, d_model=self.embedding_size, n_warmup_steps=self.num_warmup_steps)    
        criterion = nn.NLLLoss()                                                                                            
        accum_times = []
        start_time = time.perf_counter()
        training_loss_tally = []
        running_loss = 0.0
        print("")
        debug = False
        iter = 0
        prev_seq_logprobs = torch.ones(self.batch_size, xformer.vocab_size, dtype=torch.float).to(dev())
        prev_iteration_data = np.zeros((self.batch_size, context_buffer_size), dtype=int)
        dataloader.initialize_tokenized_data_streams()
        while True:
            new_data_for_new_iteration = dataloader.dataloader_for_buffered_context(context_window_size)
            new_prev_iteration_data = new_data_for_new_iteration[:, -context_buffer_size:]
            token_sequences_in_batch = [[dataloader.inverse_lookup[code] for code in new_data_for_new_iteration[i][1:]] 
                                                                                           for i in range(dataloader.batch_size)]
            if new_data_for_new_iteration is None: continue
            ## The first token in each batch instance is a starter token like '<SOS>':
            first_tokens_in_batch  =  new_data_for_new_iteration[:,0]
            first_tokens_in_batch = first_tokens_in_batch[...,None]
            data = np.concatenate( (first_tokens_in_batch, prev_iteration_data, new_data_for_new_iteration[:,1:]), axis=1 )
            iter += 1
            data = torch.from_numpy( data ).to(dev())
            input_tensor = embedding_generator( data )
            master_decoder_optimizer.zero_grad()
            mask = torch.ones(1, dtype=int)                         ## initialize the mask                      
            predicted_indexes = [[] for i in range(dataloader.batch_size)]
            predicted_tokens = [[] for i in range(dataloader.batch_size)]
            detokenized_predicted_word_sequence = [[] for i in range(dataloader.batch_size)]
            predicted_logprobs = []
            LOSS = 0.0
            for word_index in range(1,input_tensor.shape[1]):
                masked_input_seq = master_decoder.apply_mask(input_tensor, mask)                                
                predicted_word_logprobs, predicted_word_index_values = master_decoder(input_tensor, mask)
                if word_index == 0:
                    predicted_word_logprobs = predicted_word_logprobs * prev_seq_logprobs
                for i in  range(dataloader.batch_size):
                    predicted_indexes[i].append(predicted_word_index_values.cpu().numpy()[i])
                loss = criterion(predicted_word_logprobs, data[:, word_index])           
                LOSS += loss
                mask = torch.cat( ( mask, torch.ones(1, dtype=int) ) )                                          

            predicted_indexes = np.array(predicted_indexes)
            ## The following accounts for the fact that the first token is the SOS token, followed by context-buffer tokens
#            predicted_indexes = predicted_indexes[:, context_buffer_size+1:]              
            predicted_indexes = predicted_indexes[:, context_buffer_size:]              
            prev_iteration_data = new_prev_iteration_data
            LOSS.backward()
            master_decoder_optimizer.step_and_update_lr()                                                       
            loss_normed = LOSS.item() / input_tensor.shape[0]
            running_loss += loss_normed
            prev_seq_logprobs  =  predicted_word_logprobs
            if iter % 100 == 99:    
                avg_loss = running_loss / float(100)
                training_loss_tally.append(avg_loss)
                FILE_for_training_loss.write("%s\n" % str(avg_loss))
                running_loss = 0.0
                current_time = time.perf_counter()
                time_elapsed = current_time-start_time
                print("\n\n\n[iter:%4d  elapsed_time: %4d secs]     loss: %.4f\n\n" % (iter+1,time_elapsed,avg_loss)) 
                FILE_for_training_results.write("\n\n\n[iter:%4d  elapsed_time: %4d secs]     loss: %.4f\n\n\n" % (iter+1,time_elapsed,avg_loss)) 
                for j in range(dataloader.batch_size):
                    predicted_tokens[j] = dataloader.tokenizer.decode( predicted_indexes[j], skip_special_tokens=True )
                for i in random.sample( range(dataloader.batch_size), 4 ): 
                    print("Ground-Truth: ", detokenizer( ' '.join(token_sequences_in_batch[i]) ))
                    print("GT Token Seq: ", ' '.join(token_sequences_in_batch[i] ))
                    print("   Predicted: ", predicted_tokens[i])
                    print(" Detokenized: ", detokenizer( dataloader.tokenizer.decode( predicted_indexes[i], skip_special_tokens=True ) ))
                    print()
                    FILE_for_training_results.write("ground-truth: %s\n" % str(detokenizer( ' '.join(token_sequences_in_batch[i]) )))
                    FILE_for_training_results.write("GT Token Seq: %s\n" % str(' '.join(token_sequences_in_batch[i]) ))
                    FILE_for_training_results.write("   predicted: %s\n" % str(predicted_tokens[i]))
                    FILE_for_training_results.write(" detokenized: %s\n" % str(detokenizer( dataloader.tokenizer.decode(predicted_indexes[i],skip_special_tokens=True))))
                    FILE_for_training_results.write("\n")
                accum_times.append(current_time-start_time)
                FILE_for_training_results.flush()
                FILE_for_training_loss.flush()

            if iter % checkpoint_frequency == checkpoint_frequency-1:    
                print("\n\nSaving checkpoint at iteration: %d\n\n"% (iter+1))
                self.save_checkpoint_decoder(master_decoder, checkpoint_dir, iter+1)
                self.save_checkpoint_embedding_generator(embedding_generator, checkpoint_dir, iter+1)


    ###%%%
    #######################################################################################################################
    ###########################################  PromptResponder for babyGPT  #############################################

    class PromptResponder(nn.Module):
        """
        Prompting a trained babyGPT models means that you supply a small number of words (as, say, the 
        beginning of a new thought) as a prompt and the model supplies the rest of the words to complete 
        the thought.  The class comes with two methods, the first for extending your prompt until it 
        reaches a period, and the second for going beyond the first period encountered.
    
        Any interaction with a trained GPT model has to deal with the following issue:  What to do with
        the context buffer that is meant to be a continuation of the last part of the previous "sentence"
        fed into the transformer. 
    
        Ideally, we should be placing in the context buffer words that create a context for the prompt.
        But there is no easy way to that without a more elaborate model. An example of more elaborate
        modeling would be to have the input to the transformer consist of, say, an SOS token, a special
        context token consisting possibly of integer index values beyond the tokenizer vocab, followed
        by a context buffer that would be the last part of the previous sentence, followed, finally, 
        by the new input tokens.
    
        babyGPT gives you two options regarding what to do with the context buffer for your prompt:
    
                --  all_zeros
    
                --  get_from_prompt
    
        With the first option, all of the integer encoding values in the context buffer are set to
        the integer zero.  And, with the second option, at this time, the context buffer contains
        a portion or all of the prompt itself.  If the tokenized version of the prompt is shorter
        than the size of the context buffer, only the context_buffer_size number of elements of the
        prompt are retained for the context buffer.  In the opposite case, just the initial 
        context_buffer_size number of elements of the prompt are retained.
        """
        def __init__(self, gpt,  xformer, master_decoder, context_window_size, context_buffer_size, tokenizer_json, checkpoint_dir, checkpoint_index):
            super(babyGPT.PromptResponder, self).__init__()
            self.gpt  =  gpt
            self.xformer = xformer
            self.master_decoder = master_decoder
            self.context_window_size = context_window_size
            self.context_buffer_size = context_buffer_size
            self.tokenizer_json = tokenizer_json                       
            self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=self.tokenizer_json)
            FILE = open(self.tokenizer_json)    
            tokenizer_dict =  json.load( FILE ) 
            self.inverse_lookup  =  {v:k for k,v in tokenizer_dict['model']['vocab'].items()}  
            self.vocab_size = self.tokenizer.vocab_size
            self.checkpoint_dir = checkpoint_dir
            self.checkpoint_index = checkpoint_index


        def generate_response_to_prompt_up_to_period(self, context_buffer_option=None, result_file=None):
            """
            This version tries to construct a more elaborate response to a single prompt by going beyond the first period that 
            is encountered.  The first part of the if-else block shown below is for extending the prompt to the first period.
            On the other hand, the else clause is for generating additional sentences beyond the first period.  I have yet to
            clean up the logic for that.
            """

            def detokenizer( token_sequence_as_string ):
                regex = r'\s_\s'
                out_words = ""
                try:
                    out_words = re.split(regex, token_sequence_as_string)
                except TypeError as e:
                    print(e)
                    return [""] * len(token_sequence)
                ## Join together the space-separated token fragments into complete words, but make sure 
                ## you do NOT cross punctuation marks:
                new_all_words = []
                for word in out_words:
                     frag = word
                     while re.search( r'\w+\s\w+', frag ):
                         frag =  re.sub(r'(\w+)\s(\w+)', r'\1\2', frag)
                     new_all_words.append(frag)
                ## If a word obtained from the previous step include a fragment that terminates in a 
                ## punctuation mark which can be any of ".?,!]+.?", break it into two or more subwords:
                cleaned_all_words = []
                for word in new_all_words:
                    new_words = []   
                    if any(char in string.punctuation for char in word):
                        parts = re.findall(r'[^.?,!]+.?', word)
                        cleaned_all_words += parts
                    else:
                        cleaned_all_words.append(word)
                return ' '.join(cleaned_all_words)
    

            if result_file is not None:
                FILE = open(result_file, 'w')

            master_decoder = self.master_decoder
            master_decoder.load_state_dict(torch.load(self.checkpoint_dir + "/" + 
                                                        self.gpt.path_saved_model['decoder'] +  '_' + str(self.checkpoint_index) ))
            master_decoder.to( dev() )     
            embedding_generator = self.gpt.EmbeddingGenerator(self.xformer, self.gpt.embedding_size).to( dev() )
            embedding_generator.load_state_dict(torch.load(self.checkpoint_dir + "/" +
                                                   self.gpt.path_saved_model['embedding_generator'] + '_' + str(self.checkpoint_index)))
            embedding_generator.to( dev() )
            debug = False
            prompt = ""
            with torch.no_grad():
                while True:
                    context_buffer = np.zeros(shape=(self.context_buffer_size), dtype=int)
                    while True:
                        prompt = input("\nEnter your prompt: ")
                        if prompt == "": continue
                        else: break
                    ##  Strip any empty space before or after the prompt:
                    prompt = prompt.strip()
                    print("\nyour prompt: ", prompt)
                    all_words = prompt.split()
                    all_words = [word + " _" if re.search(r'.*[\w]$', word) else word for word in all_words]
                    prompt_text = ' '.join(all_words)
                    print("\nprompt_text_with_underscores: ", prompt_text)
                    ## consists of int tokens for the symbolic token in prompt:
                    encoded_prompt = self.tokenizer.encode( prompt_text )
                    token_sequence_in_prompt = [self.inverse_lookup[int_code] for int_code in encoded_prompt]
                    print("\ntoken_sequence_in_prompt: ", token_sequence_in_prompt)
                    print("\nencoded_prompt: ", encoded_prompt)
                    predicted_word_index_value = torch.zeros(1, dtype=torch.int)
                    stopping_token_code = 46
                    while predicted_word_index_value.item() != stopping_token_code:
                        if len(encoded_prompt) >=  self.context_window_size: 
                            break
                        input_tensor = torch.zeros( 1, self.xformer.max_seq_length, dtype=torch.int )
                        input_tensor[0,0] = 89         ##  The SOS token             
                        if context_buffer_option == "all_zeros":
#                            print("\n\n======================== Choosing all-zeros option for context initialization")
                            input_tensor[0,self.context_buffer_size:self.context_buffer_size + len(encoded_prompt)] = torch.tensor(encoded_prompt)
                        elif context_buffer_option == "get_from_prompt": 
#                            print("\n\n======================== Choosing 'get from prompt' option for context initialization")
                            if len(encoded_prompt) > self.context_buffer_size:
                                input_tensor[0,1:1+self.context_buffer_size] = torch.tensor(encoded_prompt[:self.context_buffer_size])
                                input_tensor[0,self.context_buffer_size:self.context_buffer_size + len(encoded_prompt)] = torch.tensor(encoded_prompt)    
                            else:
                                ## if prompt is too short:
                                padded_encoded_prompt =  encoded_prompt +  [0] * (self.context_buffer_size - len(encoded_prompt))
                                input_tensor[0,1:1+self.context_buffer_size] = torch.tensor(padded_encoded_prompt)
                                input_tensor[0,self.context_buffer_size:self.context_buffer_size + len(encoded_prompt)] = torch.tensor(encoded_prompt)
                        input_tensor = input_tensor.to( dev() )
                        input_tensor = embedding_generator( input_tensor )
                        mask = torch.ones( self.context_buffer_size + len(encoded_prompt), dtype=int)
                        predicted_word_index_value = torch.zeros(1, dtype=torch.int)
                        predicted_word_logprobs, predicted_word_index_value = master_decoder(input_tensor, mask)                     
                        predicted_token =  self.xformer.inverse_lookup[predicted_word_index_value.cpu().numpy()[0]]
                        encoded_prompt.append(predicted_word_index_value.item())
                        if debug: 
                            print("\npredicted token: ", predicted_token)                
                            print("\nencoded_prompt: ", encoded_prompt)                
                    if debug:
                        print("\n\nprompt and its response: ", encoded_prompt)
                    output_string = ""
                    for code in encoded_prompt:
                        output_string += " " + self.xformer.inverse_lookup[code]
                    print("\nencoding of prompt and the response: ", encoded_prompt)
                    print("\nprompt and its response: ", output_string)
                    final_output = detokenizer(output_string) 
                    ## find() returns -1 when no char is "."
                    index_period  =  final_output.find(".")
                    if index_period >= 0 and index_period < len(final_output):
                        print("\ndetokenized sentence completion: ", final_output[:final_output.find(".")+1])
                    else:
                        print("\ndetokenized sentence completion: ", final_output)


        def generate_response_to_prompt_beyond_period(self, context_buffer_option=None, result_file=None):
            """
            This version tries to construct a more elaborate response to a single prompt by going beyond the first period that 
            is encountered.  The first part of the if-else block shown below is for extending the prompt to the first period.
            On the other hand, the else clause is for generating additional sentences beyond the first period.  I have yet to
            clean up the logic for that.xs
            """
            def detokenizer( token_sequence_as_string ):
                regex = r'\s_\s'
                out_words = ""
                try:
                    out_words = re.split(regex, token_sequence_as_string)
                except TypeError as e:
                    print(e)
                    return [""] * len(token_sequence)
                ## Join together the space-separated token fragments into complete words, but make sure 
                ## you do NOT cross punctuation marks:
                new_all_words = []
                for word in out_words:
                     frag = word
                     while re.search( r'\w+\s\w+', frag ):
                         frag =  re.sub(r'(\w+)\s(\w+)', r'\1\2', frag)
                     new_all_words.append(frag)
                ## If a word obtained from the previous step include a fragment that terminates in a 
                ## punctuation mark which can be any of ".?,!]+.?", break it into two or more subwords:
                cleaned_all_words = []
                for word in new_all_words:
                    new_words = []   
                    if any(char in string.punctuation for char in word):
                        parts = re.findall(r'[^.?,!]+.?', word)
                        cleaned_all_words += parts
                    else:
                        cleaned_all_words.append(word)
                return ' '.join(cleaned_all_words)

            if result_file is not None:
                FILE = open(result_file, 'w')
            master_decoder = self.master_decoder
            master_decoder.load_state_dict(torch.load(self.checkpoint_dir + "/" + 
                                                        self.gpt.path_saved_model['decoder'] +  '_' + str(self.checkpoint_index) ))
            master_decoder.to( dev() )     
            embedding_generator = self.gpt.EmbeddingGenerator(self.xformer, self.gpt.embedding_size).to( dev() )
            embedding_generator.load_state_dict(torch.load(self.checkpoint_dir + "/" +
                                                   self.gpt.path_saved_model['embedding_generator'] + '_' + str(self.checkpoint_index)))
            embedding_generator.to( dev() )
            debug = False
            with torch.no_grad():
                interaction_index = 0
                overall_response = ""
                while True:
                    if interaction_index == 0:                                                                              ## (A)

                        context_buffer = np.zeros(shape=(self.context_buffer_size), dtype=int)
                        prompt = input("\n\nEnter your prompt: ")
                        ##  Strip any empty space before or after the prompt:
                        prompt = prompt.strip()
                        print("\n\nyour prompt: ", prompt)
                        all_words = prompt.split()
                        all_words = [word + " _" if re.search(r'.*[\w]$', word) else word for word in all_words]
                        prompt_text = ' '.join(all_words)
                        print("\n\nprompt_text_with_underscores: ", prompt_text)
                        ## consists of int tokens for the symbolic token in prompt:
                        encoded_prompt = self.tokenizer.encode( prompt_text )
                        token_sequence_in_prompt = [self.inverse_lookup[int_code] for int_code in encoded_prompt]
                        print("\n\ntoken_sequence_in_prompt: ", token_sequence_in_prompt)
                        print("\n\nencoded_prompt: ", encoded_prompt)
                        input_tensor = torch.zeros( 1, self.xformer.max_seq_length, dtype=torch.int )
                        input_tensor[0,0] = 89         ##  The SOS token             
                        if context_buffer_option == "all_zeros":
                            input_tensor[0,self.context_buffer_size:self.context_buffer_size + len(encoded_prompt)] = torch.tensor(encoded_prompt)
                        elif context_buffer_option == "get_from_prompt": 
                            if len(encoded_prompt) > self.context_buffer_size:
                                input_tensor[0,1:1+self.context_buffer_size] = torch.tensor(encoded_prompt[:self.context_buffer_size])
                                input_tensor[0,self.context_buffer_size:self.context_buffer_size + len(encoded_prompt)] = torch.tensor(encoded_prompt)    
                            else:
                                ## if prompt is too short:
                                padded_encoded_prompt =  encoded_prompt +  [0] * (self.context_buffer_size - len(encoded_prompt))
                                input_tensor[0,1:1+self.context_buffer_size] = torch.tensor(padded_encoded_prompt)
                                input_tensor[0,self.context_buffer_size:self.context_buffer_size + len(encoded_prompt)] = torch.tensor(encoded_prompt)

                    else:                                                                                                   ## (B)
                        print("\n\n\n[Interaction no. %d]  encoded prompt from PREV ITER: " % (interaction_index+1), encoded_prompt) 
                        context_buffer =  encoded_prompt[-self.context_buffer_size - self.context_buffer_size:-self.context_buffer_size]
                        encoded_prompt = encoded_prompt[-self.context_buffer_size:]
                        print("[Interaction no. %d]  context_buffer: " % (interaction_index+1), context_buffer)
                        print("[Interaction no. %d]  encoded prompt: " % (interaction_index+1), encoded_prompt) 
                        input_tensor = torch.zeros( 1, self.xformer.max_seq_length, dtype=torch.int )
                        input_tensor[0,:self.context_buffer_size] = torch.tensor( context_buffer )
                        input_tensor[0,self.context_buffer_size:self.context_buffer_size + len(encoded_prompt)] = torch.tensor(encoded_prompt)
                    interaction_index += 1               
                    if interaction_index >= 3: 
                        print("\n\nInteraction limit reached")
                        break
                    input_tensor = input_tensor.to( dev() )
                    input_tensor = embedding_generator( input_tensor )
                    mask = torch.ones( self.context_buffer_size + len(encoded_prompt), dtype=int)
                    stopping_token_code = 16
                    predicted_word_index_value = torch.zeros(1, dtype=torch.int)
                    while predicted_word_index_value.item() != stopping_token_code:
                        if len(encoded_prompt) >=  self.context_window_size: 
                            break
                        predicted_word_logprobs, predicted_word_index_value = master_decoder(input_tensor, mask)                 
                        predicted_token =  self.xformer.inverse_lookup[predicted_word_index_value.cpu().numpy()[0]]
                        encoded_prompt.append(predicted_word_index_value.item())
                        if debug: 
                            print("\npredicted token: ", predicted_token)                
                            print("\nencoded_prompt: ", encoded_prompt)                
                        input_tensor = torch.zeros( 1, self.xformer.max_seq_length, dtype=torch.int )
                        input_tensor[0,self.context_buffer_size:self.context_buffer_size + len(encoded_prompt)] = torch.tensor(encoded_prompt)
                        input_tensor = input_tensor.to( dev() )
                        input_tensor = embedding_generator( input_tensor )
                        mask = torch.cat( ( mask, torch.ones(1, dtype=int) ) )                                          
                    if debug:
                        print("\n\nprompt and its response: ", encoded_prompt)
                    output_string = ""
                    for code in encoded_prompt:
                        output_string += " " + self.xformer.inverse_lookup[code]
                    print("\n\nprompt and its response: ", output_string)
                    print("\n\ndetokenized sentence completion: ", detokenizer(output_string))

                    overall_response += output_string

            print("\n\nOverall response to your prompt: ", overall_response)
            print("\n\nOverall detokenized response: ", detokenizer(overall_response))


    ###%%%
    ###########################  ScheduledOptim Code for TransformerFG #############################

    class ScheduledOptim():
        """
        As in the Transformers module of DLStudio, for the scheduling of the learning rate
        during the warm-up phase of training TransformerFG, I have borrowed the class shown below
        from the GitHub code made available by Yu-Hsiang Huang at:

                https://github.com/jadore801120/attention-is-all-you-need-pytorch
        """
        def __init__(self, optimizer, lr_mul, d_model, n_warmup_steps):
            self._optimizer = optimizer
            self.lr_mul = lr_mul
            self.d_model = d_model
            self.n_warmup_steps = n_warmup_steps
            self.n_steps = 0

        def step_and_update_lr(self):
            "Step with the inner optimizer"
            self._update_learning_rate()
            self._optimizer.step()
    
        def zero_grad(self):
            "Zero out the gradients with the inner optimizer"
            self._optimizer.zero_grad()
    
        def _get_lr_scale(self):
            d_model = self.d_model
            n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps
            return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5))
    
        def _update_learning_rate(self):
            ''' Learning rate scheduling per step '''
            self.n_steps += 1
            lr = self.lr_mul * self._get_lr_scale()
            for param_group in self._optimizer.param_groups:
                param_group['lr'] = lr

#############################################################################################################################
##############################################   End of babyGPT Class Definition#############################################

if __name__ == '__main__': 
    pass