function [snr,xout,Iout,F,Time] = algoritm3MG(kmax,d,jm,qmf,la,delta,Lambda,S,Sigma,psi,I,x0,Mask)
% ========================================================================
% Complex-valued 3MG Algorithm for pRMI, Version 1.0
%
% A. Florescu, E. Chouzenoux, J.-C. Pesquet, P. Ciuciu and S. Ciochina. 
% A Majorize-Minimize Memory Gradient Method for Complex-Valued Inverse 
% Problems. Signal Processing, Vol. 103, pages 285-295, 2014. 

% Kindly report any suggestions or corrections to
% emilie.chouzenoux@univ-mlv.fr
%
%----------------------------------------------------------------------
%[snr,xout,Iout,F,Time] = algoritm3MG(kmax,d,jm,qmf,la,delta,Lambda,S,Sigma,psi,I,x0,Mask)
%Input:  kmax: the max number of iterations
%        d: the degraded data
%        jm, qmf: the wavelet parameters
%        la, delta: the regularization parameters
%        Lambda: the noise covariance matrix
%        S: the sensivity matrices
%        Sigma: the sampling matrix
%        psi: the penalty function flag as indicated below
%           (1) psi(u) = (1-exp(-u.^2./(2*delta^2))); 
%           (2) psi(u) = (u.^2)./(2*delta^2 + u.^2); 
%           (3) psi(u) = log(1 + (u.^2)./(delta^2)); 
%           (4) psi(u) = sqrt(1 + u^2/delta^2)-1; 
%           (5) psi(u) = 1/2 u^2; 
%       I: the original complex-valued image to recover
%       x0: the initial estimate
%       Mask: binary image defining the background area of I
%
%Output: snr: the snr along iterations
%        xout: the restored image (vectorized)
%        Iout: the restored image (2D)
%        F: the values of the criterion along minimization process
%        Time: the values of a time counter
%========================================================================

%display frequency
modaff = 25;

%stopping criterion
prec = 1e-10;

[Nx,Ny,L] = size(S);
x0 = x0(Mask);
J = log2(Nx);

disp('*** START 3MG ALGORITHM ***');

%do not penalize approximation coefficients
mask = ones(Nx,Ny);
mask(1:Nx/2^jm,1:Nx/2^jm) = 0;
la = la.*mask(:);

x = x0;
stop = sqrt(numel(x))*prec;

%normalization
 for l = 1:L
     d(:,l) = d(:,l)/sqrt(Lambda(l,l));
 end
 
Time(1) = 0;
for k = 1:kmax
    
    tic;
    
    [F(k),dF,Iout,bxk] = critere_gradient(x,d,J,jm,qmf,Lambda,L,S,Sigma,Nx,Ny,psi,la,delta,Mask);
    NormGrad(k) = norm(dF(:)) ;
    snr(k) = SNRcalc(Iout,I) ;
    xout = Iout(:);
    
    
    if mod(k,modaff)==1
        fprintf(1,'iteration %d, criterion = %g, snr = %g, NormGrad = %g \n',k,F(k),snr(k),NormGrad(k));
    end
    
    dF2D = subsamp_adj(dF,Mask); 
    wcg = FWT2_POc(dF2D,J-jm,qmf);
    wcg = wcg(:);
    
    if(k==1) %no memory
        B_fid = 0;
        dF2D = subsamp_adj(dF,Mask); 
        for l = 1:L
            Sl = S(:,:,l);
            Hlg = Sigma.*fft2(Sl.*dF2D)/(sqrt(Nx*Ny)*sqrt(Lambda(l,l))); 
            Hlg = Hlg(:);
            B_fid = B_fid + Hlg'*Hlg;
        end
        D = -dF;
        B =  2.*B_fid + wcg'*(bxk.*wcg);
        u = (dF'*dF)/B;
        dx = D*u;
        dw = -u*wcg;
    else %memory gradient
        D = [-dF dx];
        B_fid = 0;
        for l = 1:L
            for m = 1:2
                D2D = subsamp_adj(D(:,m),Mask); 
                Sl = S(:,:,l);
                HlDm = Sigma.*fft2(Sl.*D2D)/(sqrt(Nx*Ny)*sqrt(Lambda(l,l)));
                HlD(:,m) = HlDm(:);            
            end
            B_fid = B_fid + HlD'*HlD;
        end
        
        wcD = [-wcg dw];
        B =  2.*B_fid + wcD'*(bxk(:,ones(2,1)).*wcD);
        u   = -pinv(B)*(D'*dF) ; 
        dx = D*u;
        dw = wcD*u;
    end
    
    %update
    x = x + dx;
    
    Time(k+1) = toc;
    
    if NormGrad(k) < stop
        break
    end    
end
Time(end) = [];

disp('*** END 3MG ALGORITHM ***');
disp(['Iteration number = ',num2str(k)]);
disp(['Computation time (cpu) = ', num2str(sum(Time))]);
disp(['Final criterion value = ',num2str(F(end))])
disp(['Final SNR value = ',num2str(snr(end))])
disp('****************************************');

end


function [F,dF,Iout,bxk] = critere_gradient(x,d,J,jm,qmf,Lambda,L,S,Sigma,Nx,Ny,psi,la,delta,Mask)

%DATA FIDELITY

Iout = subsamp_adj(x,Mask);

dF = zeros(size(x));
F = 0;
for l = 1:L
    Sl = S(:,:,l);
    dltemp = fft2(Sl.*Iout)/sqrt(Nx*Ny);
    Hxl = subsamp(dltemp,Sigma)/sqrt(Lambda(l,l));
    zl = Hxl - d(:,l);
    F = F + norm(zl)^2;
    
    Zl2D = subsamp_adj(zl,Sigma);
    gl = (2*ifft2(Zl2D.*Sigma).*conj(Sl)*sqrt(Nx*Ny))/sqrt(Lambda(l,l));
    dF = dF + gl(Mask);
    
end

%PENALIZATION

%wavelet transform
wc = FWT2_POc(Iout,J-jm,qmf);
wc = wc(:);

%penalty function
switch psi
    case 1
        phixk =  (1-exp(-(abs(wc).^2)./(2*delta.^2)));
        bxk = (1./(delta.^2)).*exp(-(abs(wc).^2)./(2*delta.^2));
    case 2
        phixk = (abs(wc).^2)./(2*delta.^2 + abs(wc).^2);
        bxk =  (4*delta.^2)./(2*delta.^2 + abs(wc).^2).^2;
    case 3
        phixk = log(1 + (abs(wc).^2)./(delta.^2));
        bxk = 2./(delta^2 + abs(wc).^2);
    case 4
        phixk = (1 + (abs(wc).^2)./(delta.^2)).^(1/2)-1;
        bxk = (1./(delta.^2)).*(1+(abs(wc).^2)./delta.^2).^(-1/2);
    case 5
        phixk = 1/2*abs(wc).^2;
        bxk = ones(size(wc));           
end

bxk = la.*bxk;
phixk = la.*phixk;
dphixk = wc.*bxk;

%inverse wavelet transform
dphixk = reshape(dphixk,Nx,Ny);
Wtdphixk = IWT2_POc(dphixk,J-jm,qmf);

%compute gradient and criterion
dF = dF + Wtdphixk(Mask);  
F = F +  sum(phixk);
    
end