function [x,Crit,SNR,NGrad,Time] = FISTA(y,H,Hadj,a,b,nu,x,I,Nx,Ny,xmin,xmax,NbIt,TimeMax) 

% ========================================================================
% FISTA Algorithm, Version 1.0
%
% Kindly report any suggestions or corrections to
% audrey.repetti@univ-mlv.fr
%
%----------------------------------------------------------------------
%
%Input:  y: the degraded data
%        H: the linear degradation operator 
%           (function handle)
%        Hadj: the adjoint of the linear degradation operator
%           (function handle)
%        a, b: the gaussian dependant noise parameters
%        nu: the regularization parameters
%        x: Initialization
%        I: true image
%        xmin, xmax: the bounds         
%        NbIt: the max number of iterations   
%        TimeMax: the max computational time  


%Output: x: the restored image
%        Crit: the values of the criterion along minimization process
%        SNR:  the value of SNR along minimization process
%        NGrad: norm of the difference between x and xbar
%        Time: the computation time at each iteration
 

%Minimization of
%
%G(x) = GD(H(x),y,a,b) + nu TV(x) + i_[xmin;xmax](x)
%
%======================================================================== 
close all hidden
hhh = waitbar(0,'Reconstruction FISTA: Time / TimeMax','tag','wait');
runalgo = 1;

disp('****************************************'); 
disp('FISTA Algorithm');  
  
W = @(x) ComputeTVlin(x,Nx,Ny);
Wadj = @(z) ComputeTVlinAdj(z,Nx,Ny);
SNR = -Inf;


%quadratic extension parameters so that G(.) is defined on all R^N
Aquad =  ( a.* y+ b).^2 ./ (b.^3) - (1/2)*(a.^2)./(b.^2) ;
Bquad = -1/2 .*( y.* ( a.* y + 2 * b) ./ (b.^2)) + (1/2) * (a./b) ;
Cquad =  y.^2 ./ ( 2 * b) + (1/2)* log(b) ;
L = max( ( (a.*y + b).^2 ./ b.^3 ) - ( (a.^2)./(2*b.^2) )) ;
sigma_max_H = max_singular_value(@(x) Hadj(H(x)),rand(Nx*Ny,1),50);
L = L*sigma_max_H^2;
A = L.*ones(Nx*Ny,1); 

[Crit(1),Grad] = criterion(x,y,H,Hadj,a,b,Aquad,Bquad,Cquad,nu,Nx,Ny);
SNR(1) = 10*log10(sum(I(:).^2)/sum((I(:)-x).^2));
Time(1) = 0;
NGrad(1) = norm(Grad);

xbarold = x;
t = 1;

for k = 1:NbIt

    tic
 
    %Stopping test on computation time
    if(sum(Time)>TimeMax)||(~runalgo)
        break
    end 

    
      
    xtemp = x-Grad./A;
    if(k==1)
        v1 =  W(xtemp-x) ;
        v2 =  xtemp-x;
    end
    [xbar,v1,v2] = prox_tv_ind_metric(xtemp, A, nu, W, Wadj, v1,v2, x, xmin, xmax,Nx,Ny);
    DGrad = xbar-x;
       
    tnew = (1 + sqrt(1 + 4*t^2))/2 ;
    x = xbar + (t-1)/tnew * (xbar - xbarold);
    xbarold = xbar;
    t = tnew;   
    
    [Crit(k+1),Grad] = criterion(x,y,H,Hadj,a,b,Aquad,Bquad,Cquad,nu,Nx,Ny);
    Time(k+1) = toc;
    SNR(k+1) = 10*log10(sum(I(:).^2)/sum((I(:)-x).^2));
    NGrad(k+1) = norm(DGrad);
    
    zwk = findobj(allchild(0),'flat','tag','wait');
    if isempty(zwk),
        runalgo = 0;
    else
        waitbar(sum(Time)/TimeMax,hhh);
    end
end
disp(['Iteration number = ',num2str(length(Crit))]);
disp(['Computation time (cpu) = ', num2str(sum(Time))]);
disp(['Final criterion value = ',num2str(Crit(end))])
disp(['SNR (dB) = ',num2str(SNR(end))])
 
 
end


function [G,dF,Hx] = criterion(x,y,H,Hadj,a,b,Aq,Bq,Cq,nu,Nx,Ny)

Hx = H(x);
Hxneg = (Hx < 0 ) ;
Hxpos = (Hx >= 0 ) ;


F = sum(1/2 .* Aq(Hxneg).* Hx(Hxneg).^2 +Bq(Hxneg) .* Hx(Hxneg) + Cq(Hxneg));
F = F + sum( ( y(Hxpos) - Hx(Hxpos) ).^2 ./ ( 2 * ( a .* Hx(Hxpos) + b )) ...  
        + (1/2).* log(a.* Hx(Hxpos) + b) ) ;     
R = ComputeTV(x,Nx,Ny);
G = F + nu*R;

dphiHx(Hxneg) =  Aq(Hxneg) .* Hx(Hxneg) + Bq(Hxneg) ;
dphiHx(Hxpos) = ( ( Hx(Hxpos) - y(Hxpos) ) .* ( a .* ( Hx(Hxpos) + y(Hxpos) ) + 2 * b ) ) ./ ( 2 *(a .* Hx(Hxpos) + b).^2 ) ...    
        + a./(2.*(a.*Hx(Hxpos) + b)) ;                                                                                                  

dF = Hadj(dphiHx(:));

end

function TV = ComputeTV(x,nx,ny)

X = reshape(x,nx,ny);                   
U = X;
U(:,2:nx) = X(:,2:nx)-X(:,1:(nx-1));    
V = X;
V(2:ny,:) = X(2:ny,:)-X(1:(ny-1),:);    
TV = sum(sqrt(U(:).^2 + V(:).^2));

end

function Dx = ComputeTVlin(x,nx,ny)

X = reshape(x,ny,nx);
Du = X;
Du(:,2:nx) = X(:,2:nx)-X(:,1:(nx-1));   
Dv = X;
Dv(2:ny,:) = X(2:ny,:)-X(1:(ny-1),:);   
Dx = [Du(:);Dv(:)];                     

end

function Dtz = ComputeTVlinAdj(z,nx,ny)

Z = [reshape(z(1:nx*ny),ny,nx);reshape(z(nx*ny+1:end),ny,nx)];
Zv = Z(1:ny,:);        
Zh = Z(ny+1:2*ny,:);   
U = Zv;                                        
U(:,1:(nx-1)) = Zv(:,1:(nx-1))-Zv(:,2:nx);    
V = Zh;                                      
V(1:(ny-1),:) = Zh(1:(ny-1),:)-Zh(2:ny,:);    
Dtz = U + V;                                  
Dtz = Dtz(:);

end