function f = set_loss_binary_blkdiag(loss, samples, labels, minibatch_size, num_blocks, flag)

% problem size
L = size(samples,1);       % num. of samples
N = size(samples,2);       % num. of features
B = N / num_blocks;        % size of blocks

% block-diagonalization
samples = reshape(samples, [L,B,num_blocks]);

% normalization
M = 1;
if nargin == 5 && strcmpi(flag, 'normalize')
    M = zeros(1,num_blocks);
    for b=1:num_blocks
        M(b) = norm(samples(:,:,b));
    end
    M = max(M);
    samples = samples / M;
end

% linear operator
f.dir_op = @(w) bsxfun( @times, -labels, mtimesx(samples,w) );     % -labels .* (samples  * w);
f.adj_op = @(v) mtimesx( samples, 'T', bsxfun(@times,-labels,v) ); % samples' * (-labels .* v);

% block linear operator
f.dir_blk_op = @(w,idx) bsxfun( @times, -labels(idx,:), mtimesx(samples(idx,:,:),w) );     % -labels(idx,:) .* (samples(idx,:)  * w);
f.adj_blk_op = @(v,idx) mtimesx( samples(idx,:,:), 'T', bsxfun(@times,-labels(idx,:),v) ); % samples(idx,:)' * (-labels(idx,:) .* v);

% minibatch selector
f.select = @(it,args) random_selection(it, args, L, minibatch_size, 'pooling');

% proximity operator
if strcmpi(loss, 'logit')
    f.fun  = @(v)         fun_logit(M*sum(v,3), 1);
    f.grad = @(v)        grad_logit(M*v, M);
    f.prox = @(v, gamma) vector_prox(v, gamma, @(x,gam) prox_logit(M*x, gam*M^2)/M );
elseif strcmpi(loss, 'hinge')
    f.fun  = @(v)         fun_hinge(1 + M*sum(v,3), 1);
    f.grad = @(v)        grad_hinge(1 + M*v, M);
    f.prox = @(v, gamma) vector_prox(v, gamma, @(x,gam) prox_hinge(1 + M*x, gam*M^2)/M - 1/M );
else
    error('Loss not supported');
end

% symmetrix matrix (required by primal algorithms)
f.Q = mtimesx(samples, 'T', samples); % samples' * samples;




function p = vector_prox(x, gamma, scalar_prox)

B  = size(x,3);
xs =  sum(x,3);
p  = scalar_prox(xs, B*gamma) / B;
p = bsxfun(@plus, x, p - xs/B);