function [marginals, logPosterior, states] = ...
    hmmRecombinationAnalysis(c, preCalculatedTransMatrices, strainIndex, clusterIndices)
    % The function reads given input data and carries out alpha-beta recursion.
    %
    % "marginals" are the marginal probabilities for the hidden states.
    %
    % "logPosterior" simply contains the log-likelihood given the 
    % single fixed parameter value of "priorProbOfNoBreaks".
    %
    % "states" is a (1*nSnps) realization of states simulated from the 
    % posterior distribution of the HMM.
    %
    % The input struct "c" is in  "reduced BAPS output" format (reduced means 
    % that only sequence positions with variability are retained), and the 
    % fields are as follows: "PARTITION", "npops", "data", "noalle", 
    % "SUMCOUNTS", "COUNTS", "adjprior", "priorTerm", "snpPositions".
    %
    % "preCalculatedTransMatrices" is an output from preCalculateTransMatrices
    % function.
    %
    % "strainIndex" is the index of the strain to be analysed.
    %
    % Alternatively to the "strainIndex", "clusterIndices" can be given.
    % Then the goal of the analysis is to investigate when two clusters
    % could be combined.
    
    clear strainOrClusters

    snpData = c.data;

    if any(c.COUNTS(:)<0)
        disp('error 2')
        keyboard
    end
    
    
    clusterSizes = determineClusterSizes(c.PARTITION);
    
    if ~isempty(strainIndex)
        
        % Remove observations for strain with "strainIndex" from counts and
        % sumcounts
%         diffInCounts = computeDiffInCounts(strainIndex, size(c.COUNTS,1), ...
%             size(c.COUNTS,2), c.data);
        counts = c.COUNTS;
%         counts(:,:,c.PARTITION(strainIndex)) = ...
%             counts(:,:,c.PARTITION(strainIndex))-diffInCounts;
        sumCounts = c.SUMCOUNTS;
%         sumCounts(c.PARTITION(strainIndex),:) = ...
%             sumCounts(c.PARTITION(strainIndex),:)-sum(diffInCounts,1);
%         
        if any(counts(:)<0)
            disp('error 2')
            keyboard
        end
        
        
        
        priorCounts = c.adjprior;    % Perks prior
        priorCounts(priorCounts==1) = 0;
        priorSumCounts = ones(1,length(c.noalle));
        
        % Add cluster-dependent over-dispersion
        nClusters = size(c.COUNTS, 3);
        priorCounts = repmat(priorCounts, [1 1 nClusters]);
        priorSumCounts = repmat(priorSumCounts, [nClusters, 1]);
        
        homeCluster = c.PARTITION(strainIndex);
        %homeClusterSize = clusterSizes(homeCluster) - 1;
        homeClusterSizeVector = sumCounts(homeCluster,:);
        
        for clusterIndex = 1:(nClusters-1) % The last cluster is the empty cluster
            if clusterIndex ~= homeCluster
                %clusterSize = clusterSizes(clusterIndex);
                %if clusterSize > homeClusterSize
                %    overDispersion = clusterSize / homeClusterSize;
                %    priorCounts(:,:,clusterIndex) = priorCounts(:,:,clusterIndex) * overDispersion;
                %    priorSumCounts(clusterIndex,:) = priorSumCounts(clusterIndex,:) * overDispersion;
                %end
                clusterSizeVector = sumCounts(clusterIndex,:);
                lociForDispersion = find(homeClusterSizeVector>0 & clusterSizeVector>homeClusterSizeVector);
                overDispersions = 1.1.*clusterSizeVector(lociForDispersion) ./ homeClusterSizeVector(lociForDispersion);
                priorCounts(:,lociForDispersion,clusterIndex) = priorCounts(:,lociForDispersion,clusterIndex) .* overDispersions(ones(1,size(c.COUNTS,1)), :);
                priorSumCounts(clusterIndex,lociForDispersion) = priorSumCounts(clusterIndex,lociForDispersion) .* overDispersions;
            end
        end
        
        % Calculate emission probabilities
        emissionProbs = calcEmissionProbs(counts, sumCounts, priorCounts, ...
            priorSumCounts, c.data(strainIndex,:));
        
        nStates = size(c.COUNTS,3);
        
    else
        % In this case, compare the two clusters, given as input argument.
        
        counts1 = c.COUNTS(:,:,clusterIndices(1));
        counts2 = c.COUNTS(:,:,clusterIndices(2));
        sumCounts1 = c.SUMCOUNTS(clusterIndices(1),:);
        sumCounts2 = c.SUMCOUNTS(clusterIndices(2),:);

        alphaHyperParameter = 1;
        
        adjustedPriorCounts = c.adjprior; % Perks
        adjustedPriorCounts(adjustedPriorCounts == 1) = nan;
        adjustedPriorCounts = adjustedPriorCounts .* alphaHyperParameter;
        adjustedPriorCounts(isnan(adjustedPriorCounts)) = 1;
        priorSumCounts = ones(1,length(c.noalle)).* alphaHyperParameter;
        
        emissionProbs = calcEmissionProbsForMergingTwoClusters(counts1, sumCounts1, counts2, sumCounts2, adjustedPriorCounts, priorSumCounts);
        
        nStates = 2;
        
    end
    
    % ALPHA RECURSION
    completeAlphaHatList = zeros(nStates, size(snpData,2));
    completeCoefList = zeros(1, size(snpData,2));
    
    
    % Carry out alpha recursion
    initTransitionMatrix = 1/nStates.*ones(nStates); % Uniform prior for the overall first state.
    initAlphaHat = [];

    [alphaHats coefs] = hmmAlphaHatRecursion(c.snpPositions, ...
        preCalculatedTransMatrices, emissionProbs, initTransitionMatrix, ...
        initAlphaHat);

    % Add calculated alpha hats and coefficients to list
    completeAlphaHatList(:,1 : length(c.snpPositions)) = alphaHats;
    completeCoefList(1 : length(c.snpPositions)) = coefs;


    % BETA RECURSION:
    completeBetaHatList = zeros(nStates, size(snpData,2));
    emissionProbs = [emissionProbs ones(nStates,1)];

    % Carry out beta recursion
    snpIndicesInList = 1:length(c.snpPositions);
    coefs = [completeCoefList(snpIndicesInList) 1];

    initTransitionMatrix = [];  % Uniform prior for the overall first state.
    initBetaHat = [];

    betaHats = hmmBetaHatRecursion(coefs, c.snpPositions, ...
        preCalculatedTransMatrices, emissionProbs, initTransitionMatrix, ...
        initBetaHat);

    % Add calculated beta hats to list
    completeBetaHatList(:, snpIndicesInList) = betaHats;

    % Calculate marginal distributions
    marginals = completeAlphaHatList.*completeBetaHatList;
    logPosterior = sum(log(completeCoefList));
    
    states = simulateRealization(marginals, completeCoefList, ...
        completeBetaHatList, emissionProbs, c.snpPositions, ...
        preCalculatedTransMatrices);
    
end




%------------------------------------------------------------


function states = simulateRealization(marginals, completeCoefList, ...
    completeBetaHatList, emissionProbs, snpPositions,...
    preCalculatedTransMatrices)

    [nStates, nSnps] = size(marginals);

    % Simulate a sequence of states
    states = zeros(1,nSnps);
    states(1) = find(cumsum(marginals(:,1))>rand,1,'first');
    
    snpDistances = snpPositions(2:end)-snpPositions(1:end-1);
    
    randNumbers = rand(nSnps,1);
    for i=2:nSnps
        % Compute the transition matrix from (i-1)th SNP to ith SNP:
        
        %transMatrix = calcTransitionProbs(snpDistances(i-1), ... TEST
        %    preCalculatedTransMatrices); TEST
        %transDist = transMatrix(states(i-1),:)';
        transDist = (preCalculatedTransMatrices{snpDistances(i-1)}(states(i-1),:))';
        
        % Compute the conditional distribution of the i:th state given the
        % (i-1)th state.
        %condDist = ...
        %    (1/completeCoefList(i)) .* emissionProbs(:,i) .* transDist .* ...
        %    completeBetaHatList(:,i) ./ completeBetaHatList(states(i-1),i-1);
        
        condDist = ...
            emissionProbs(:,i) .* transDist .* ...
            completeBetaHatList(:,i);
        
        condDist = condDist./sum(condDist);  % Make sure that always a distribution (not something else because of rounding errors).
        states(i) = find(cumsum(condDist)>randNumbers(i),1,'first');
    end
end



%-------------------------------------------------------

function clusterSizes = determineClusterSizes(partition)
    nClusters = length(unique(partition));
    clusterSizes = zeros(1,nClusters);
    for j=1:nClusters
        clusterSizes(j) = length(find(partition==j));
    end
end