#!/usr/bin/env python

__version__ = '2.1'
__author__  = "Avinash Kak (kak@purdue.edu)"
__date__    = '2011-March-22'
__url__     = 'http://RVL4.ecn.purdue.edu/~kak/dist2d/BitArray2D-2.1.html'
__copyright__ = "(C) 2011 Avinash Kak. Python Software Foundation."


from BitVector import __version__ as bitvector_version
if bitvector_version.split('.')[0] < '3':
    raise ImportError("The imported BitVector module must be of version 3.0 or higher")

import BitVector
import re

class BitArray2D( object ):                                          #(A1)

    def __init__( self, *args, **kwargs ):                           #(A2)
        if args:                                                     #(A3)
               raise ValueError(                                     #(A4)
                      '''BitArray2D constructor can only be called with
                         keyword arguments for the following keywords:
                         rows, columns, filename, bitstring)''')    
        allowed_keys = 'bitstring','filename','rows','columns'       #(A5)
        keywords_used = kwargs.keys()                                #(A6)
        for keyword in keywords_used:                                #(A7)
            if keyword not in allowed_keys:                          #(A8)
                raise ValueError("Wrong keyword used")               #(A9)
        filename = rows = columns = bitstring = None                #(A10)

        if 'filename' in kwargs  : filename  = kwargs.pop('filename')
        if 'rows' in kwargs      : rows      = kwargs.pop('rows')
        if 'columns' in kwargs   : columns   = kwargs.pop('columns')
        if 'bitstring' in kwargs : bitstring = kwargs.pop('bitstring')
                                                             #(A11 -- A14)
        self.filename  = None                                       #(A15)
        self.rows      = None                                       #(A16)
        self.columns   = None                                       #(A17)
        self.bitstring = None                                       #(A18)
        self.FILEIN    = None                                       #(A19)

        if filename:                                                #(A20)
            if rows or columns or bitstring:                        #(A21)
                raise ValueError(                                   
                  '''When filename is specified, you cannot
                     give values to any other constructor args''')  #(A22)
            self.filename = filename                                #(A23)
            self.rowVectors = []; self.rows = 0; self.columns = 0   #(A24)
            import sys                                              #(A25)
            try:                                                    #(A26)
                if sys.version_info[0] == 3:                        #(A27)
                    self.FILEIN = open( filename, encoding='utf-8' )#(A28)
                else:                                               #(A29)
                    self.FILEIN = open( filename, 'rb' )            #(A30)
            except IOError as e:                                    #(A31)
                print(e.strerror)                                   #(A32)
            return                                                  #(A33)
        elif rows is not None and rows >= 0:                        #(A34)
            if filename or bitstring:                               #(A35)
                raise ValueError(                              
                  '''When number of rows is specified, you cannot
                     give values to any other constructor args except
                     for columns''')                                #(A36)
            if not columns >= 0:                                    #(A37)
                raise ValueError(
                  '''When number of rows is specified, you must also
                     specify a value for the number of columns''')  #(A38)
            self.rows = rows; self.columns = columns                #(A39)
            self.rowVectors = [ BitVector.BitVector( size = self.columns ) \
                                   for i in range( self.rows ) ]    #(A40)
            return                                                  #(A41)
        elif bitstring or bitstring == '':                          #(A42)
            self.rowVectors = [ BitVector.BitVector( bitstring = bits ) \
                         for bits in re.split( '\n', bitstring ) ]  #(A43)
            self.rows = len( self.rowVectors )                      #(A44)
            self.columns = self.rowVectors[0].length()              #(A45)
            rowVecSizes = [ len(x) for x in self.rowVectors ]       #(A46)
            if max( rowVecSizes ) != min( rowVecSizes ):            #(A47)
                raise AttributeError("Your row sizes do not match") #(A48)

    def _add_row(self, bitvector):                                   #(B1)
        if self.columns == 0:                                        #(B2)
            self.columns = bitvector.length()                        #(B3)
        elif self.columns != bitvector.length():                     #(B4)
            raise ValueError("Size wrong for the new row")           #(B5)
        self.rowVectors.append( bitvector )                          #(B6)
        self.rows += 1                                               #(B7)
        
    def __str__( self ):                                             #(C1)
        'To create a print representation'
        if self.rows==0 and self.columns==0:                         #(C2)
            return ''                                                #(C3)
        return '\n'.join( map( str, self.rowVectors ) )              #(C4)

    def __getitem__( self, pos ):                                    #(D1)
        'Get the bit from the designated position'
        if not isinstance( pos, slice ):                             #(D2)
            row,col = ungodel(pos)                                   #(D3)
            if  row >= self.rows or row < -self.rows:                #(D4)
                raise ValueError( "row index range error" )          #(D5)
            if  col >= self.columns or col < -self.columns:          #(D6)
                raise ValueError( "column index range error" )       #(D7)
            if row < 0: row = self.rows + row                        #(D8)
            if col < 0: col = self.columns + col                     #(D9)
            return self.rowVectors[row][col]                        #(D10)
        else:                                                       #(D11)
            if pos.start is None:                                   #(D12)
                start = 0,0                                         #(D13)
            else:                                                   #(D14)
                start = ungodel(pos.start)                          #(D15)
            if pos.stop is None:                                    #(D16)
                stop = self.rows,self.columns                       #(D17)
            else:                                                   #(D18)
                stop = ungodel(pos.stop)                            #(D19)
            result = BitArray2D( rows=0, columns=0 )                #(D20)
            for i in range( start[0], stop[0] ):                    #(D21)
                result._add_row( BitVector.BitVector( bitstring = \
                      str(self.rowVectors[i][start[1]:stop[1]])) )  #(D22)
            return result                                           #(D23)

    def __setitem__(self, pos, item):                                #(E1)
        '''
        This is needed for both slice assignments and for index-based
        assignments.  It checks the type of pos and item to see if the call
        is for slice assignment.  For slice assignment, the second argument
        must be of type slice '[m:n]' whose two numbers m and n are
        produced by calling godel() on the two corners of the rectangular
        regions whose values you want to set by calling this function.  So
        for slice assignments, think of pos as consisting of
        '[(i,j):(k,l)]' where (i,j) defines one corner of the slice and
        (k,l) the other slice.  As you would expect, for slice assignments,
        the argument item must of type BitArray2D.  For index-based
        assignment, the pos consists of the tuple (i,j), the point in the
        array where you want to change the value.  Again, for index-based
        assignment, the last argument will either be 1 or 0.
        '''      
        if (not isinstance( item, BitArray2D )):                     #(E2)
            if isinstance( pos, slice ):                             #(E3)
                raise ValueError("Second arg wrong for assignment")  #(E4)
            i,j = pos                                                #(E5)
            self.rowVectors[i][j] = item                             #(E6)
        # The following section is for slice assignment:
        if isinstance(pos,slice):                                    #(E7)
            if (not isinstance( item, BitArray2D )):                 #(E8)
                raise TypeError('For slice assignment, \
                        the right hand side must be a BitArray2D')   #(E9)
            arg1, arg2 = pos.start, pos.stop                        #(E10)
            i,j = ungodel(arg1)                                     #(E11)
            k,l = ungodel(arg2)                                     #(E12)
            for m in range(i,j):                                    #(E13)
                self.rowVectors[m][k:l] = item.rowVectors[m-i]      #(E14)

    def __getslice__(self, arg1, arg2):                              #(F1)
        '''
        A slice of a 2D array is defined as a rectangular region whose one
        corner is at the (i,j) coordinates, which is represented by the
        mapped integer arg1 produced by calling godel(i,j). The other
        corner of the slice is at the coordinates (k,l) that is represented
        by the integer arg2 produced by calling godel(k,l).  The slice is
        returned as a new BitArray2D instance.
        '''
        i,j = ungodel(arg1)                                          #(F2)
        k,l = ungodel(arg2)                                          #(F3)
        sliceArray = BitArray2D( rows=0, columns=0 )                 #(F4)
        if j > self.rows: j = self.rows                              #(F5)
        if l > self.columns: l = self.columns                        #(F6)
        for x in range(i,k):                                         #(F7)
            bv = self.rowVectors[x]                                  #(F8)
            sliceArray._add_row( bv[j:l] )                           #(F9)
        return sliceArray                                           #(F10)

    def __eq__(self, other):                                         #(G1)
        if self.size() != other.size(): return False                 #(G2)   
        if self.rowVectors != other.rowVectors: return False         #(G3)
        return True                                                  #(G4)

    def __ne__(self, other):                                         #(H1)
        return not self == other                                     #(H2)

    def __and__(self, other):                                        #(I1)
        '''
        Take a bitwise 'AND' of the bit array on which the method is
        invoked with the argument bit array.  Return the result as a new
        bit array.
        '''      
        if self.rows != other.rows or self.columns != other.columns: #(I2)
            raise ValueError("Arguments to AND must be of same size")#(I3)
        resultArray = BitArray2D(rows=0,columns=0)                   #(I4)
        list(map(resultArray._add_row, \
                       [self.rowVectors[i] & other.rowVectors[i] \
                                     for i in range(self.rows)]))    #(I5)
        return resultArray                                           #(I6)

    def __or__(self, other):                                         #(J1)
        '''
        Take a bitwise 'OR' of the bit array on which the method is
        invoked with the argument bit array.  Return the result as a new
        bit array.
        '''
        if self.rows != other.rows or self.columns != other.columns: #(J2)
            raise ValueError("Arguments to OR must be of same size") #(J3)
        resultArray = BitArray2D(rows=0,columns=0)                   #(J4)
        list(map(resultArray._add_row, \
                        [self.rowVectors[i] | other.rowVectors[i] \
                                    for i in range(self.rows)]))     #(J5)
        return resultArray

    def __xor__(self, other):                                        #(K1)
        '''
        Take a bitwise 'XOR' of the bit array on which the method is
        invoked with the argument bit array.  Return the result as a new
        bit array.
        '''
        if self.rows != other.rows or self.columns != other.columns: #(K2)
            raise ValueError("Arguments to XOR must be of same size")#(K3)
        resultArray = BitArray2D(rows=0,columns=0)                   #(K4)
        list(map(resultArray._add_row, \
                     [self.rowVectors[i] ^ other.rowVectors[i] \
                                   for i in range(self.rows)]))      #(K5)
        return resultArray                                           #(K6)

    def __invert__(self):                                            #(L1)
        '''
        Invert the bits in the bit array on which the method is invoked
        and return the result as a new bit array.
        '''
        resultArray = BitArray2D(rows=0,columns=0)                   #(L2)
        list(map(resultArray._add_row, [~self.rowVectors[i] \
                                   for i in range(self.rows)]))      #(L3)
        return resultArray                                           #(L4)

    def deep_copy(self):                                             #(M1)
        'Make a deep copy of a bit array' 
        resultArray = BitArray2D(rows=0,columns=0)                   #(M2)
        list(map(resultArray._add_row, [x.deep_copy() \
                                    for x in self.rowVectors]))      #(M3)
        return resultArray                                           #(M4)

    def size(self):                                                  #(N1)
        return self.rows, self.columns                               #(N2)

    def read_bit_array_from_char_file(self):                         #(P1)
        '''
        This assumes that the bit array is stored in the form of
        ASCII characters 1 and 0 in a text file. We further assume
        that the different rows are separated by the newline character.
        '''
        error_str = "You need to first construct a BitArray2D" + \
                          "instance with a filename as argument"     #(P2)  
        if not self.FILEIN:                                          #(P3)
            raise SyntaxError( error_str )                           #(P4)
        allbits = self.FILEIN.read()                                 #(P5)



        rows = filter( None, re.split('\n', allbits) )               #(P6)
        list(map(self._add_row, [BitVector.BitVector( bitstring = x ) \
                                        for x in rows]))             #(P7)

    def write_bit_array_to_char_file(self, file_out):                #(Q1)
        '''
        Note that this write function for depositing a bit array into
        text file uses the newline as the row delimiter.
        '''
        FILEOUT = open( file_out, 'w' )                              #(Q2)
        for bitvec in self.rowVectors:                               #(Q3)
            FILEOUT.write( str(bitvec) + "\n" )                      #(Q4)
             
    def read_bit_array_from_binary_file(self, rows, columns):        #(R1)
        '''
        This assumes that the bit array is stored in the form of ASCII
        characters 1 and 0 in a text file. We further assume that the
        different rows are separated by the newline character.
        '''
        error_str = "You need to first construct a BitArray2D" + \
                          "instance with a filename as argument"     #(R2)  
        if not self.filename:                                        #(R3)
            raise SyntaxError( error_str )                           #(R4)
        import os.path                                               #(R5)
        filesize = os.path.getsize( self.filename )                  #(R6)
        if (rows * columns) % 8 != 0:                                #(R7)
            raise ValueError("In binary file input mode, rows*cols must" \
                             + " be a multiple of 8" )               #(R8)
        if filesize < int(rows*columns/8):                           #(R9)
            raise ValueError("File has insufficient bytes" )        #(R10)
        bitstring = ''                                              #(R11)
        i = 0                                                       #(R12)
        while i < rows*columns/8:                                   #(R13)
            i += 1                                                  #(R14)
            byte = self.FILEIN.read(1)                              #(R15)
            hexvalue = hex( ord( byte ) )                           #(R16)
            hexvalue = hexvalue[2:]                                 #(R17)
            if len( hexvalue ) == 1:                                #(R18)
                hexvalue = '0' + hexvalue                           #(R19)
            bitstring += BitVector._hexdict[ hexvalue[0] ]          #(R20)
            bitstring += BitVector._hexdict[ hexvalue[1] ]          #(R21)
        
        bv = BitVector.BitVector( bitstring = bitstring )           #(R22)
        list(map(self._add_row, [ bv[m*columns : m*columns+columns] \
                                     for m in range(rows) ]))       #(R23)

    def write_bit_array_to_packed_binary_file(self, file_out):       #(S1)
        '''
        This creates a packed disk storage for your bit array.  But for
        this to work, the total number of bits in your bit array must be a
        multiple of 8 since all file I/O is byte oriented.  Also note that
        now you cannot use any byte as a row delimiter.  So when reading
        such a file back into a bit array, you have to tell the read
        function how many rows and columns to create.
        '''
        err_str = '''Only a bit array whose total number of bits
            is a multiple of 8 can be written to a file.'''          #(S2)
        if self.rows * self.columns % 8:                             #(S3)
            raise ValueError( err_str )                              #(S4)
        FILEOUT = open( file_out, 'wb' )                             #(S5)
        bitstring = ''                                               #(S6)
        for bitvec in self.rowVectors:                               #(S7)
            bitstring += str(bitvec)                                 #(S8)
        compositeBitVec = BitVector.BitVector(bitstring = bitstring) #(S9)
        compositeBitVec.write_to_file( FILEOUT )                    #(S10)

    def shift( self, rowshift, colshift ):                           #(T1)
        '''
        What may make this method confusing at the beginning is the
        orientation of the positive row direction and the positive 
        column direction.  The origin of the array is at the upper
        left hand corner of your display.  Rows are positive going 
        downwards and columns are positive going rightwards:
 
                       X----->  +ve col direction
                       |
                       |
                       |
                       V
                  +ve row direction

        So a positive value for rowshift will shift the array downwards
        and a positive value for colshift will shift it rightwards.
        Just remember that if you want the shifts to seem more intuitive,
        use negative values for the rowshift argument.
        '''
        if rowshift >= 0:                                            #(T2)
            self.rowVectors[rowshift : self.rows] = \
                        self.rowVectors[: self.rows-rowshift]        #(T3)
            self.rowVectors[:rowshift] = \
                    [BitVector.BitVector(size = self.columns) \
                                         for i in range(rowshift)]   #(T4)
            if colshift >= 0:     
                for bitvec in self.rowVectors[:]: \
                                     bitvec.shift_right(colshift)    #(T5)
            else:
                for bitvec in self.rowVectors[:]:                    #(T6)
                    bitvec.shift_left(abs(colshift))                 #(T7)
        else:                                                        #(T8)
            rowshift = abs(rowshift)                                 #(T9)
            self.rowVectors[:self.rows-rowshift] = \
                          self.rowVectors[rowshift : self.rows]     #(T10)
            self.rowVectors[self.rows-rowshift:] = \
                    [BitVector.BitVector(size = self.columns) \
                                         for i in range(rowshift)]  #(T11)
            if colshift >= 0:     
                for bitvec in self.rowVectors[:]: \
                               bitvec.shift_right(colshift)         #(T12)
            else:                                                   #(T13)
                for bitvec in self.rowVectors[:]:                   #(T14)
                    bitvec.shift_left(abs(colshift))                #(T15)
        return self                                                 #(T16)
    
    def dilate( self, m ):                                           #(U1)
        accumArray = BitArray2D(rows=self.rows, columns=self.columns)#(U2)
        for i in range(-m,m+1):                                      #(U3)
            for j in range(-m,m+1):                                  #(U4)
                temp = self.deep_copy()                              #(U5)
                accumArray |=  temp.shift(i,j)                       #(U6)
        return accumArray                                            #(U7)

    def erode( self, m ):                                            #(V1)
        accumArray = BitArray2D(rows=self.rows, columns=self.columns)#(V2)
        for i in range(-m,m+1):                                      #(V3)
            for j in range(-m,m+1):                                  #(V4)
                temp = self.deep_copy()                              #(V5)
                accumArray &=  temp.shift(i,j)                       #(V6)
        return accumArray                                            #(V7)
     
#------------------------  End of Class Definition -----------------------

#--------------------------- Ancillary Functions  ------------------------

def godel(i,j):                                                      #(W1)
    return 2**i*(2*j + 1)-1                                          #(W2)

def ungodel(m):                                                      #(X1)
    i,q = 0,m+1                                                      #(X2)
    while not q&1:                                                   #(X3)
        q >>= 1                                                      #(X4)
        i += 1                                                       #(X5)
    j = ((m+1)/2**i - 1)/2                                           #(X6)
    return int(i),int(j)                                             #(X7)

#------------------------     Test Code Follows    -----------------------

if __name__ == '__main__':

    print("\nConstructing an empty 2D bit array:")
    ba = BitArray2D( rows=0, columns=0 )
    print(ba)

    print("\nConstructing a bit array of size 10x10 with zero bits -- ba:")
    ba = BitArray2D( rows = 10, columns = 10 )
    print(ba)


    print("\nConstructing a bit array from a bit string -- ba2:")
    ba2 = BitArray2D( bitstring = "111\n110\n111" )
    print(ba2)                    

    print("\nPrint a specific bit in the array -- bit at 1,2 in ba2:")
    print( ba2[ godel(1,2) ] )

    print("\nSet a specific bit in the array --- set bit (0,1) of ba2:")   
    ba2[0,1] = 0
    print(ba2)

    print("\nExperiments in slice getting and setting:")
    print("Printing an array -- ba3:")
    ba3 = BitArray2D( bitstring = "111111\n110111\n111111\n111111\n111111\n111111" )
    print(ba3)
    ba4 = ba3[godel(2,3) : godel(4,5)]
    print("Printing a slice of the larger array -- slice b4 of ba3:")
    print(ba4)
    ba5 = BitArray2D( rows = 5, columns = 5 )
    print("\nPrinting an array for demonstrating slice setting:")
    print(ba5)
    ba5[godel(2, 2+ba2.rows) : godel(2,2+ba2.columns)] = ba2
    print("\nSetting a slice of the array - setting slice of ba5 to ba2:")
    print(ba5)
    print("\nConstructing a deep copy of ba, will call it ba6:")
    ba6 = ba.deep_copy()
    ba6[ godel(3,3+ba2.rows) : godel(3,3+ba2.columns) ] = ba2
    print("Setting a slice of the larger array -- set slice of ba6 to ba2:")
    print(ba6)

    print("\nExperiment in bitwise AND:")
    ba5 = ba.deep_copy()
    ba7 = ba5 & ba6
    print("Displaying bitwise AND of ba5 and ba6  --- ba7:")
    print(ba7)

    print("\nExperiment in bitwise OR:")
    ba7 = ba5 | ba6
    print("Displaying bitwise OR of ba5 and ba6  --- ba7:")
    print(ba7)

    print("\nExperiment in bitwise XOR:")
    ba7 = ba5 ^ ba6
    print("Displaying bitwise XOR of ba5 and ba6  --- ba7:")
    print(ba7)

    print("\nExperiment in bitwise negation:")
    ba7 = ~ba5
    print("Displaying bitwise negation of ba5 --- ba7:")
    print(ba7)

    print("\nSanity check (A & ~A => all zeros):")
    print(ba5 & ~ba5)

    print("\nConstruct bit array from a char file with ASCII 1's and 0's:" )
    ba8 = BitArray2D( filename = "Examples/data.txt" )
    ba8.read_bit_array_from_char_file()
    print("The bit array as read from the file -- ba8:")
    print(ba8)

    print("\nConstruct bit array from a packed binary file:")
    ba9 = BitArray2D( filename = "Examples/data_binary.dat" )
    ba9.read_bit_array_from_binary_file(rows = 5, columns = 8)
    print("The bit array as read from the file -- ba9:")
    print("size of ba9: " + str(ba9.size()))
    print(ba9)

    print("\nTest the equality and inequality operators:")
    ba10 = BitArray2D( bitstring = "111\n110" )
    ba11 = BitArray2D( bitstring = "111\n110" )
    print("ba10 is equal to ba11 is: " + str(ba10 == ba11))
    ba12 = BitArray2D( bitstring = "111\n111" )
    print("ba10 is equal to ba12 is: " + str(ba10 == ba12))

    print("\nTest shifting a bit array:")
    print("printing ba13:")
    ba13 = ba9.deep_copy()
    print(ba13)
    ba13.shift(rowshift=-2, colshift=2)
    print("The shifted version of ba9:")
    print(ba13)

    print("\nTest dilation:")
    ba14 = BitArray2D( filename = "Examples/data2.txt" )
    ba14.read_bit_array_from_char_file()
    print("Array before dilation:")
    print(ba14)
    ba15= ba14.dilate(1)
    print("Array after dilation:")
    print(ba15)

    print("\nTest erosion:")
    ba16 = BitArray2D( filename = "Examples/data2.txt" )
    ba16.read_bit_array_from_char_file()
    print("Array before erosion:")
    print(ba16)
    ba17= ba16.erode(1)
    print("Array after erosion:")
    print(ba17)

    print("\nExperiments with writing array to char file:")
    ba17.write_bit_array_to_char_file("out1.txt")

    print("\nExperiments with writing array to packed binary file:")    
    ba9.write_bit_array_to_packed_binary_file("out2.dat")