% Qiang Qiu, Rama Chellappa, Compositional Dictionaries for Domain Adaptive Face Recognition, 
% http://arxiv.org/abs/1308.0271
%
% Qiang Qiu, qiu@cs.umd.edu


clear all;

warning off;

rand('seed', 1)

addpath(genpath('dfr')); % k-svd
load('PIE-full.mat','pie_feats', 'pie_subject', 'pie_pose', 'pie_illum');

posIdx = [8, 4, 11, 7, 1, 13, 2, 9, 10, 5, 6, 12, 3]; 

IMG_H = 64;
IMG_W = 48;

% note: resize to a smaller size, e.g., [32, 24], 
% if visual reconstruction is not required.
IMG_H2 = 64;
IMG_W2 = 48;

dim = IMG_H2*IMG_W2;

nPose = 10;
nPerson = 68; 
nLight = 21;

dictsizeA = 10;
dictsizeB = 68;
dictsizeC = 9;

TdataA = 8;
TdataB = 20;
TdataC = 9;


% Creat Y: P (x I) x S
Y = [];
for subid = 1:nPerson
    y1=[];
    for posid = 1:nPose
        idx = find(pie_pose==posid & pie_subject == subid);
        
        feature=[];
        for jj=1:length(idx)
            if (pie_illum(jj)==1 || pie_illum(jj)==23)  % illid = 1 is total dark (discard)
                continue;
            end
            
            img = pie_feats(:, idx(jj));
            img = reshape(img, IMG_H, IMG_W);
            img = imresize(img, [IMG_H2 IMG_W2], 'bicubic');
            
            feature =[feature; double(img(:))];
        end
        y1 = [y1; feature];
    end
    
    Y = [Y y1];
end

% Generate 6 forms of Y
[Y6] = gen6form(Y, dim, nPerson, nLight, nPose);

%%
% Iterative triple-sparsity 
dicIter = 5; %20 (can play with more iterations)
[DD, A, B, C] = da_dl(Y6, dicIter, dim, dictsizeA, dictsizeB, dictsizeC, TdataA, TdataB, TdataC, nPose, nPerson, nLight);


%% Domain invariant sparse coding for subjects

% Creat Y: P (x I) x S
Y = [];
nTestPerson = 68;
nPose=13;
for subid = 1:nTestPerson
    y1=[];
    for posid = 1:nPose
        idx = find(pie_pose==posid & pie_subject == subid);
        
        feature=[];
        for jj=1:length(idx)
            if (pie_illum(jj)==1 || pie_illum(jj)==23)  % illid = 1 is total dark (discard)
                continue;
            end
            
            img = pie_feats(:, idx(jj));
            img = reshape(img, IMG_H, IMG_W);
            img = imresize(img, [IMG_H2 IMG_W2], 'bicubic');
            
            feature =[feature; double(img(:))];
        end
        y1 = [y1; feature];
    end
    
    Y = [Y y1];
end


%% test alignment

% % Extract domain from i_th subject
 
ompIter = 20;
lambda  = 15;  %10

p0=10; s0=43; l0=5;
s1=3;

y0 = Y((p0-1)*nLight*dim + (l0-1)*dim+1: (p0-1)*nLight*dim + l0*dim, s0);
plotStackIm(y0, dim, IMG_H2, IMG_W2);
[a0, b0, c0, yy0] = da_omp3(y0, DD, dim, ompIter, dictsizeA, dictsizeB, dictsizeC, lambda, lambda, lambda );

plist = [7 2 9];
ilist = [11 11 11];

YY =[];
YY1 =[];
YY2 =[];
YY0 =[];

for ii=1:3
    disp(['ii: ' num2str(ii)]);
        p1=plist(ii); l1=ilist(ii);
        y1 = Y((p1-1)*nLight*dim + (l1-1)*dim+1: (p1-1)*nLight*dim + l1*dim, s1);
        YY = [YY y1];
        [a1, b1, c1, yy1] = da_omp3(y1, DD, dim, ompIter, dictsizeA, dictsizeB, dictsizeC, lambda, lambda, lambda );
        
        [yy] = da_trans(DD, a0, b1, c0, dictsizeA, dictsizeC, dim);
        YY1 = [YY1 yy];
        
        [yy] = da_trans(DD, a1, b0, c1, dictsizeA, dictsizeC, dim);
        YY2 = [YY2 yy];
        
        y0 = Y((p1-1)*nLight*dim + (l1-1)*dim+1: (p1-1)*nLight*dim + l1*dim, s0);
        YY0 = [YY0 y0];
end


 plotStackIm2(YY, dim, IMG_H2, IMG_W2);
 plotStackIm2(YY1, dim, IMG_H2, IMG_W2);
 plotStackIm2(YY2, dim, IMG_H2, IMG_W2);
 plotStackIm2(YY0, dim, IMG_H2, IMG_W2);
 
 
 
%% Test Classification
 nTestPerson = 68;
ompIter = 20;
lambda  = 5;  %10
SS = [];
PP = [];
LL = [];
Yda = [];
Pda = [];
Ida = [];
YY = [];

for subid = 1:68
     disp(['subid: ' num2str(subid)]);
    for posid = 1:nPose
        disp(['posid: ' num2str(posid)]);
        for illid = 1:nLight
            
        s1=subid; p1=posid; l1 = illid;
        SS = [SS s1]; PP=[PP p1]; LL=[LL l1];
        
        y1 = Y((p1-1)*nLight*dim + (l1-1)*dim+1: (p1-1)*nLight*dim + l1*dim, s1);
        [a1, b1, c1, yy1] = da_omp3(y1, DD, dim, ompIter, dictsizeA, dictsizeB, dictsizeC, lambda, lambda, lambda );

        Yda = [Yda b1];
        Pda = [Pda a1];
        Ida = [Ida c1];
        YY = [YY y1];
        end
    end
end



for subid = 1:68
     disp(['subid: ' num2str(subid)]);
    for posid = 1:nPose
        disp(['posid: ' num2str(posid)]);
        for illid = 1:nLight
            
        s1=subid; p1=posid; l1 = illid;
        SS = [SS s1]; PP=[PP p1]; LL=[LL l1];
        
        y1 = Y((p1-1)*nLight*dim + (l1-1)*dim+1: (p1-1)*nLight*dim + l1*dim, s1);
        [a1, b1, c1, yy1] = da_omp3(y1, DD, dim, ompIter, dictsizeA, dictsizeB, dictsizeC, lambda, lambda, lambda );

        Yda = [Yda b1];
        Pda = [Pda a1];
        Ida = [Ida c1];
        YY = [YY y1];
        end
    end
end


%% subject Classification
[dataN] = normcol(Yda);
label = SS;

probPose = 2;
idx = find(PP==probPose);
probe = dataN(:,idx)';
plabel = label(idx);

idx2 = find(PP==9);
gallery = dataN(:,idx2)';
glabel = label(idx2);

[IDXpred, nn_accuracy, IDXdist] = knnprobe(probe, plabel, gallery, glabel, 1);
fprintf(1,'DA NN acc: %f \n', nn_accuracy);


%% Pose esitimation
[dataN] = normcol(Pda);
label = PP;
probSub = 37;

idx = find(SS==probSub);
probe = dataN(:,idx)';
plabel = label(idx);

idx2 = find(SS~=probSub);
gallery = dataN(:,idx2)';
glabel = label(idx2);

[IDXpred, nn_accuracy, IDXdist] = knnprobe(probe, plabel, gallery, glabel, 1);
fprintf(1,'DA NN acc: %f \n', nn_accuracy);

%% illumination esitimation
[dataN] = normcol(Ida);
label = LL;
probSub = 37;

idx = find(SS==probSub);
probe = dataN(:,idx)';
plabel = label(idx);

idx2 = find(SS~=probSub);
gallery = dataN(:,idx2)';
glabel = label(idx2);

[IDXpred, nn_accuracy, IDXdist] = knnprobe(probe, plabel, gallery, glabel, 1);
fprintf(1,'DA NN acc: %f \n', nn_accuracy);

