function f = set_loss_multiclass(loss, samples, labels, minibatch, flag)

% problem sizes
K = length( unique(labels) );  % number of classes
M = size(samples,2);           % number of features
L = size(samples,1);           % number of samples

% opeator norm
f.dir_op = @(w) loss_dir(w, samples, labels);
f.adj_op = @(v) loss_adj(v, samples, labels);
op_norm = compute_linop_norm([M K], f);

% normalization
N = 1;
if nargin == 5 && strcmpi(flag, 'normalize')
    N = op_norm;
    samples = samples / N;
    op_norm = 1;
end

% linear operator
f.dir_op = @(w) loss_dir(w, samples, labels);
f.adj_op = @(v) loss_adj(v, samples, labels);

% block linear operator
f.dir_blk_op = @(w,idx) block_loss_dir(w, idx, samples, labels);
f.adj_blk_op = @(v,idx) block_loss_adj(v, idx, samples, labels);

% minibatch selector
f.select = @(it,args) random_selection(it, args, L, minibatch, 'pooling');

% proximity operator
if strcmpi(loss, 'logit')
    %----------------------------------------------------
    f.fun  = @(v)         fun_softmax(N*v, 1, 1);
    f.grad = @(v)        grad_softmax(N*v, N, 1);
    f.prox = @(v, gamma) prox_softmax(N*v, gamma*N^2, 1)/N;
    %----------------------------------------------------
elseif strcmpi(loss, 'hinge')
    %----------------------------------------------------
    r = - ones(K,L);
    i = sub2ind( [K L], labels, (1:L)' );
    r(i) = 0;
    %----------------------------------------------------
    f.fun  = @(v)         fun_max(N*v - r, 1, 1);
    f.grad = @(v)      grad_hinge(N*v - r, N);
    f.prox = @(v, gamma) prox_max(N*v - r, gamma*N^2, 1)/N + r/N;
    %----------------------------------------------------
else
    error('Loss not supported');
end

% symmetrix matrix (required by primal algorithms)
%----------------------------------------------------
%f.Q = samples' * samples;         % TODO: correggere
%----------------------------------------------------


% operator norm (required by primal-dual algorithms)
f.beta = op_norm^2;

% lipschitz constant (required by gradient algorithms)
f.lips = (op_norm * N)^2;





%--------------------------------------------------------------------------
function v = loss_dir(x, u, z)
%--------------------------------------------------------------------------

idx = 1:size(u,1);
v   = block_loss_dir(x, idx, u, z);


%--------------------------------------------------------------------------
function x = loss_adj(v, u, z)
%--------------------------------------------------------------------------

idx = 1:size(u,1);
x   = block_loss_adj(v, idx, u, z);


%--------------------------------------------------------------------------
function v = block_loss_dir(x, idx, u, z)
%--------------------------------------------------------------------------

% multiply with the selected rows
v = u(idx,:) * x;

% select the values to subtract
[M,K] = size(v);
i = sub2ind( [M K], (1:M)', z(idx) );

% compute the differences
v = bsxfun(@minus, v, v(i));

% transpose
v = v';


%--------------------------------------------------------------------------
function x = block_loss_adj(v, idx, u, z)
%--------------------------------------------------------------------------

% transpose
v = v';

% select the values to subtract
[L,K] = size(v);
i = sub2ind( [L K], (1:L)', z(idx) );

% compute the differences
v(i) = v(i) - sum(v,2);

% multiply with the selected rows
x = u(idx,:)' * v;