%%%% colour image inpainting script %%%%

rng(0);                  % reset random generator for reproducible results
savefile = 'mysave.mat'; % save to savefile if defined
display = 1;             % display progress and learned dictionaries
cp2wksvd = 1;            % compare to wksvd dictionary !!! Attention slow !!!

%%%%%%%%% parameter settings %%%%%%%%%%
picname = 'mandrill.png';  %%% picture to destroy and reconstruct

corrtype = 't';         %%% 'r'... random erasures or 'c' cracks or 't' text
corrlevel = 0.5;        %%% erasure probability per pixel for 'r'
runs = 1;               %%% for 'c' and 't' runs = 1 automatically

s2 = 8;                 %%% patchwidth
s1 = s2;                %%% patchheight
s3 = 3;                 %%% patchdepth (colour)

L = 8;                  %%% maximum number of low rank atoms, the size is chosen automatically in the code (for wKSVD L=1)
maxitLR = 20;            %%% number of iterations for learning lr atoms
K = 2*s1*s2*s3;         %%% initial dictionary size (in case L = 0)
maxit = 40;              %%% number of iterations for dictionary learning

%%%% avoid wksvd for more than 40 iterations to save time
if maxit>40 
    cp2wksvd = 0;
end

%%%% load picture
pic = imread(strcat('images/',picname));
pic = im2double(pic);
d1=size(pic,1);
pic = imresize(pic,256/d1);   %%% get 256x256 or so 

[d1,d2,d3]=size(pic);
d = s3*s1*s2;

%%%% only one run for uploaded mask
if corrtype == 'c'
    runs = 1;
    if d2~=256
        disp('need square picture for crack mask, 50% erasure mask used');
        corrtype == 'r';
        corrlevel = 0.5;
    end        
end

if corrtype == 't' 
    runs = 1;     
end

psnr_noisy = zeros(runs,1);

psnr_itkrmm = zeros(runs,1);
runtime_itkrmm = zeros(runs,1);

if cp2wksvd == 1;
    psnr_wksvd = zeros(runs,1);
    runtime_wksvd = zeros(runs,1);
end

%%%% create masks and initial dicos (dl algos may use random generator)
if corrtype == 'c'
    load('images/cracks.mat')           % Grayscale mask
    mask = repmat(cracks, 1, 1, d3);    % Adding two more channels to the mask
elseif  corrtype == 'r'
    masks = ceil(rand(d1,d2,d3,runs)-corrlevel);
else
   load('images/textmask.mat')
   mask = imresize(mask,[d1 d2]);
   mask = im2double(mask);
end

inits = randn(d,K,runs);

for run = 1 : runs
    if corrtype == 'r'
        mask = masks(:,:,:,run);
    end
        
    corrpic = mask.*pic;   %%% corrupted picture
    
    psnr_noisy(run)=psnr(corrpic,pic);
        
    %%% get patches of picture and mask with locations
    [corrpatches, ploc] = pic2patches(corrpic,s1,s2,s3);
    [maskpatches, mloc] = pic2patches(mask,s1,s2,s3);
    

    %%% learning with mask info
    %%% learn low rank atoms 
    tic
    lrc = [];
    if display == 1 && L>0;
        disp('learning low rank component using mask info');
    end
    s = Inf;
    data = corrpatches.*maskpatches;

  % adaptive choice of the low rank component
    for ll = 1 : L
        % initialise low rank atom
        inatoml = inits(:,ll,run);
        if ll > 1
            inatoml = inatoml - lrc*lrc'*inatoml;
        end
        inatoml = inatoml/norm(inatoml);
        atoml = rec_lratom(corrpatches,maskpatches,lrc,maxitLR,inatoml);
        
        temp2 = sum(sum(data.^2,1));
        temp2 = 1/(K - ll).*temp2;   % energy captured by a dictionary atom
        
        atom_copy = repmat(atoml,1,size(maskpatches,2));
        maskedatom = atom_copy.*maskpatches;
        
        atom_data = dot(atom_copy,data);
        
        temp = sum (atom_data.^2 ./ sum(maskedatom.^2,1));  % energy captured by the last low-rank component
 % checking the relation between energies, as soon as it stabilises, we stop adding low-rank atoms
        if (temp/temp2) < 0.8*s
            lrc = [lrc, atoml];   
            s = temp/temp2;
        else
            disp(['Low rank component is ' ,num2str( ll-1)]);
            K = K - ll + 1;              %%% dictionary size K - L
            S = s1- ll + 1;              %%% sparsity level in the learning step
            L = ll-1;
            break;
        end          
    end
    
    %%% learn dictionary with itkrmm
    if display == 1;
        disp('learning dictionary using mask info (itkrmm)');
    end
    %%% initialise
    
    dico = inits(:,L+1:K,run);
    if L>0
        dico = dico - lrc*lrc'*dico;
    end
    dico = dico*diag(1./sqrt(sum(dico.*dico)));
    dico = itkrmm(corrpatches, maskpatches, K, S, lrc, maxit, dico);
    runtime_itkrmm(run) = toc;
     if display == 1;
         imagesc(showdico([lrc,dico]));
         title('itkrmm dictionary (and low rank component)');
         drawnow;
     end
    
    %%% learn dictionary with wksvd
    if cp2wksvd == 1;
        if display == 1;
            disp('learning wksvd dictionary');
            param.displayProgress = 1;
        else
            param.displayProgress = 0;
        end
        param.K = K+L;              % dictionary size
        param.dSparsity = S+L; 
        param.numIteration = maxit;    
        param.InitializationMethod = 'GivenMatrix';
        tic
        indico = inits(:,L+1:K+L,run);
        if L == 1;
            param.preserveDCAtom = 1; %%% also ensures orthogonality of initialisation to lrc         
        elseif L > 1;
            param.preserveDCAtom = 1; 
            indico = indico - lrc*lrc'*indico; 
        end
        indico = indico*diag(1./sqrt(sum(indico.*indico)));
        param.initialDictionary = [lrc,indico]; 
        [dico_wksvd, output] = wKSVD(corrpatches, maskpatches, param); 
        runtime_wksvd(run) = toc;
        if display == 1;
            imagesc(showdico(dico_wksvd));
            title('wksvd dictionary (and low rank component)');
            drawnow;
        end
    end
    
    %%%% inpainting
    if display == 1;
        disp('inpainting with itkrmm dictionary');
    end
    lrcdico = [lrc,dico];    
    
    coeff = aOMPm(lrcdico, corrpatches, maskpatches);
    inppatches = lrcdico*coeff;
    inppic = patches2pic(inppatches, ploc, s1, s2, s3);
    psnr_itkrmm(run) = psnr(inppic,pic);
    if display == 1;
        imagesc(inppic)
        title(strcat('psnr using itkrmm: ',num2str(psnr_itkrmm(run))));
        drawnow;
    end
    
    if cp2wksvd == 1
        if display == 1;
            disp('inpainting with wksvd dictionary');
        end
        coeff_wksvd = aOMPm(dico_wksvd, corrpatches, maskpatches);
        inppatches_wksvd = dico_wksvd*coeff_wksvd;
        inppic_wksvd = patches2pic(inppatches_wksvd, ploc, s1, s2,s3);
        psnr_wksvd(run) = psnr(inppic_wksvd,pic);
        if display == 1;
            imagesc(inppic_wksvd)
            title(strcat('psnr using wksvd: ',num2str(psnr_wksvd(run))));
            drawnow;
        end
    end
    
    if exist('savefile','var')
        if corrtype == 'r'
            save(savefile,'lrc*','dico*','psnr*','inppic*','masks','runtime*');
        else
            save(savefile,'lrc*','dico*','psnr*','inppic*','runtime*');
        end
    end         
end
