clear; %close all;
addpath( genpath('include') )
addpath( genpath('dataset') )
addpath( genpath('export_fig') )
addpath( genpath('glmnet') )


%%%%%%%%%%%%%%%%%%
%%% parameters %%%
%%%%%%%%%%%%%%%%%%

dataset    = 'w8a';    
classifier = 'binary';  % 'binary', 'multi'
loss       = 'logit';   % loss function: 'logit', 'hinge'
minibatch  = 500;      % minibatch size
prior      = 'L1';      % regularization: 'L2', 'L1'
lambda     = 0.1;       % hyper-parameter
opt.tol    = 0;%10^-7;     % stopping criterion
opt.iter   = 5000;      % maximum number of iterations
opt.inf    = 10000;     % number of iteration for x^[inf]
opt.coef   = 1;
algo_list  = {          % list of algorithms
%     'DR'
%     'DR2'
%     'FISTA'
%     'PD'
    %-----%
    'BCDR2'
    %'SFB'
    %'RDA'
    %'BCPD'
    %'ASYNC'
    %'PFWD2'
};


%%%%%%%%%%%%%%%%%
%%% TEST BODY %%%
%%%%%%%%%%%%%%%%%


% load the data
switch lower(dataset)
    case 'w8a',     [trainset, testset] = load_W8A();     opt.gamma = 0.1;   blocks = 1;
    case 'yale',    [trainset, testset] = load_YALE();    opt.gamma = 0.001; blocks = 1;
    case 'ar',      [trainset, testset] = load_AR();      opt.gamma = 0.05;  blocks = 1;
    case 'mnist'  , [trainset, testset] = load_MNIST();   opt.gamma = 1;     blocks = 1; 
    case 'rcv1',    [trainset, testset] = load_RCV1(20);  opt.gamma = 1;     blocks = 9; % Number of blocks: 3, 9, 47
    %-----%
    case 'yaleb',   [trainset, testset] = load_YALEB();   opt.gamma = 0.007; blocks = 2;
    case 'gisette', [trainset, testset] = load_GISETTE(); opt.gamma = 0.0001; blocks = 3; % Number of blocks: 3, 4, 6
    case 'usps',    [trainset, testset] = load_USPS();    opt.gamma = 0.01;  blocks = 1;
    case 'coil20',  [trainset, testset] = load_COIL20();  opt.gamma = 0.005; blocks = 1;
    case 'bangla',  [trainset, testset] = load_BANGLA();  opt.gamma = 0.001;  blocks = 1;
    case 'covtype', [trainset, testset] = load_COVTYPE(); opt.gamma = 0.1;    blocks = 1;
    case 'sector',  [trainset, testset] = load_SECTOR(20);opt.gamma = 0.1;    blocks = 15;
    case 'news',    [trainset, testset] = load_NEWS();    opt.gamma = 0.1;    blocks = 34;
    case 'skin',    [trainset, testset] = load_SKIN();    opt.gamma = 0.0001; blocks = 1;
    case 'ijcnn1',  [trainset, testset] = load_IJCNN1();  opt.gamma = 0.001;  blocks = 1;
    %-----%
    otherwise,      error('Dataset not found');
end
fprintf('\n%6s\n-------\n', upper(dataset));

% remove zero columns
if ~strcmpi(dataset,'news') || ~strcmpi(dataset,'sector')
    idx = any( abs(trainset.samples) > 1e-7 );
    trainset.samples = trainset.samples(:,idx);
    testset.samples =  testset.samples(:,idx);
end
if strcmpi(dataset,'rcv1')  || ~strcmpi(dataset,'yaleb') %|| ~strcmpi(dataset,'sector')
    trainset.samples = [trainset.samples trainset.samples(:,end)];
     testset.samples = [ testset.samples  testset.samples(:,end)];
end

%------------
% if strcmpi(dataset, 'mnist')
%     mask1 = false;
%     mask2 = false;
%     for c = 4:10
%         mask1 = mask1 | (trainset.labels == c);
%         mask2 = mask2 | ( testset.labels == c);
%     end
%     trainset.samples(mask1,:) = [];
%     trainset.labels (mask1)   = [];
%      testset.samples(mask2,:) = [];
%      testset.labels (mask2)   = [];
% end
%-------

disp('PAUSE - To execute the algorithms, press a key !')
pause

% execute the algorithms
figure; 
ax1 = subplot(3,1,1); 
ax2 = subplot(3,1,2); 
ax3 = subplot(3,1,3); 
for idx_algo = 1:length(algo_list)

    rng('default')
    algo = algo_list{idx_algo};
       
    % perform the optimization
    if strcmpi(classifier, 'binary')
        [w,it,time,crit,conv] = train_classifier_binary(trainset, loss, minibatch, blocks, prior, lambda, algo, opt);
    else
        [w,it,time,crit] = train_classifier_multiclass(trainset, loss, minibatch, prior, lambda, algo, opt);
    end
    
    % define the predictor
    if size(w,2) == 1
        predict = @(u,samples)   sign(samples * u);
    else
        predict = @(u,samples) argmax(samples * u, 2);
    end
    
    % test the classifier
    y_train = predict(w, trainset.samples);
    y_test  = predict(w,  testset.samples);

    % asses the accuracy
    train_errors = sum(trainset.labels ~= y_train);
     test_errors = sum( testset.labels ~= y_test);
 
    fprintf('%6s (%s, loss: %s, prior: %s, lambda: %1.1e) - Training errors: %2.2f %% - Test errors: %2.2f %% - Time: %4.2f s - Iter.: %4d - Crit.: %.2f - Sparsity: %.2f %%\n', algo, classifier, loss, prior, lambda, 100 * train_errors / size(trainset.samples,1), 100 * test_errors / size(testset.samples,1), time(end), it, crit(end), 100 * sum(abs(w(:))<1e-2) / numel(w) );
       
    % visualization
    semilogy(ax1, time, conv, 'DisplayName', algo, 'linewidth', 2); hold(ax1, 'on'); ylabel(ax1, 'Distance to x^{[\infty]}'); xlabel(ax1, 'Time (sec.)');  title(ax1,classifier);
    semilogy(ax2, time, crit, 'DisplayName', algo, 'linewidth', 2); hold(ax2, 'on'); ylabel(ax2, 'Criterion');                xlabel(ax2, 'Time (sec.)');
    plot    (ax3,       w(:), 'DisplayName', algo, 'linewidth', 2); hold(ax3, 'on'); ylabel(ax3, 'Coefficients');             xlabel(ax3, 'Position');
end
legend(ax1, 'show');
legend(ax2, 'show');
legend(ax3, 'show');