function alpha = learnDirHyperPar(counts, noalle, lambda, range)
% Alpha represents the sum of Dirichlet parameters, e.g. 
% [p1 p2 p3] ~ Dir(alpha/3, alpha/3, alpha/3)
%
% "counts" is represents counts for a single cluster which containing all
% sampled strains.
%
% if "lambda" is not given, then the function returns estimated parameter
% for each SNP.
%
% "lambda" is the parameter of the exponential prior distribution for
% alpha. If "lambda" (but not "range") is given, then the function returns
% the global MAP value.
%
% "range" is the range of values to consider when estimating the global
% posterior distribution of lambda. Then the function return the posterior
% distribution at given points.

if nargin==2
    lambda = [];
    range = [];
elseif nargin==3
    range = [];
end

counts = counts(:,:,1);   % Only the first cluster is considered.
nSnps = size(counts,2);

if isempty(lambda)
    alphaArray = zeros(1,nSnps);
    for j=1:nSnps
        n = counts(1:noalle(j),j);
        if length(n)==2
            alphaArray(j) = fminsearch(@(a)-gammaln(a)+gammaln(a+sum(n))-gammaln(a/2+n(1))-gammaln(a/2+n(2)) + 2*gammaln(a/2),1);
        elseif length(n)==3
            alphaArray(j) = fminsearch(@(a)-gammaln(a)+gammaln(a+sum(n))-gammaln(a/3+n(1))-gammaln(a/3+n(2))-gammaln(a/3+n(3)) + 3*gammaln(a/3),1);
        elseif length(n)==4
            alphaArray(j) = fminsearch(@(a)-gammaln(a)+gammaln(a+sum(n))-gammaln(a/4+n(1))-gammaln(a/4+n(2))-gammaln(a/4+n(3))-gammaln(a/4+n(4)) + 4*gammaln(a/3),1);
        end
    end
    alpha = alphaArray;
else
    adjPrior = zeros(size(counts,1),size(counts,2));
    for j=1:nSnps
       adjPrior(1:noalle(j),j) = 1/noalle(j);
    end
    nonZeroElements = find(adjPrior>0);
    if isempty(range)
        alpha = fminsearch(@(a) ...
           +lambda*a ...
           -nSnps*gammaln(a) ...
           +sum(gammaln(a+sum(counts,1)),2) ...
           -sum(gammaln(a*adjPrior(nonZeroElements)+counts(nonZeroElements))) ...
           +sum(gammaln(a*adjPrior(nonZeroElements))),1);
    else
        posteriorArray = zeros(1,length(range));
        for i = 1:length(range)
            a = range(i);
            posteriorArray(i) = ...
                -lambda*a ...
                +nSnps*gammaln(a) ...
                -sum(gammaln(a+sum(counts,1)),2) ...
                +sum(gammaln(a*adjPrior(nonZeroElements)+counts(nonZeroElements))) ...
                -sum(gammaln(a*adjPrior(nonZeroElements)));
        end
        posteriorArray = posteriorArray-max(posteriorArray);
        posteriorArray = exp(posteriorArray);
        posteriorArray = posteriorArray./sum(posteriorArray);
        alpha = posteriorArray;
    end
end