 function p = prox_softmax(x, gamma, dir)
%function p = prox_softmax(x, gamma, dir)
%
% This procedure computes the proximity operator of the function:
%
%                   f(x) = gamma * log(sum(exp(x)))
%
% When the input 'x' is an array, the computation can vary as follows:
%  - dir = 0 --> 'x' is processed as a single vector [DEFAULT]
%  - dir > 0 --> 'x' is processed block-wise along the specified direction
%
%  INPUTS
% ========
%  x     - ND array
%  gamma - positive, scalar or ND array compatible with the blocks of 'x'
%  dir   - integer, direction of block-wise processing
%
%  DEPENDENCIES
% ==============
%  Lambert_W.m - located in the folder 'utils'

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Version : 1.0 (27-04-2017)
% Authors : Giovanni Chierchia, Emilie Chouzenoux
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Copyright (C) 2017
%
% This file is part of the codes provided at http://proximity-operator.net
%
% By downloading and/or using any of these files, you implicitly agree to 
% all the terms of the license CeCill-B (available online).
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


% default inputs
if nargin < 3 || (~isempty(dir) && dir == 0)
    dir = [];
end

% check input
sz = size(x); sz(dir) = 1;
if any( gamma(:) <= 0 ) || ~isscalar(gamma) && (isempty(dir) || any(size(gamma)~=sz))
    error('''gamma'' must be positive and either scalar or compatible with the blocks of ''x''')
end
%------%


% linearize
sz = size(x);
if isempty(dir)
    x   = x(:);
    dir = 1;
end

% compute the prox
p = x - find_softmax_root(x, gamma, dir);

% revert back
p = reshape(p, sz);



%--------------------------------------------------------------------------
function p = find_softmax_root(x, gamma, dir)
%--------------------------------------------------------------------------

TOL = 1e-7;
MAX = 100;

% init
sz = size(x); 
sz(dir) = 1;
lambda = ones(sz);

% Newton
for n = 1:MAX

    % evaluation step
    p   = scaled_prox_entropy(x, gamma, lambda);
    fun = sum(p,dir) - gamma;
    der = gamma .* sum(1./(p+1) - 1);
    
    % newton step
    lambda_old = lambda;
    lambda = lambda - fun ./ der;
    
    % stopping rule
    err = abs(lambda - lambda_old) ./ abs(lambda_old);    
    if all(err(:) <= TOL) || all(der(:) == 0)
        break;
    end
    
end

if n == MAX
    warning('Reached the max number of iterations');
end


%--------------------------------------------------------------------------
function p = scaled_prox_entropy(x, gamma, lambda)
%--------------------------------------------------------------------------

% preliminaries
t = bsxfun(@minus, x-1, lambda.*gamma);
y = exp(t);

% approximation of W(gamma*exp(t))
t = bsxfun(@plus, log(gamma), t);
p = t - log(t);

% exact computation of W(gamma*exp(t))
mask = y < realmax / 1e4;
y = bsxfun(@times, y, gamma);
p(mask) = Lambert_W(y(mask));