function collectedPopStructuresNew = reorderColorsMinEntropy(...
    collectedPopStructures, collectedPartitions, collectedGeneLengths, ...
    collectedSnpPositions)
    % "collectedPopStructures" a cell array with population structures for
    % different genes
    % "collectedSnpPositions" a cell array with snpPositions for different
    % genes
    % "collectedGeneLengths" a vector with lengths of the genes
    % "collectedPartitions" a matrix with partitions as columns
    
    % Update the color of the empty cluster at each locus, after which it
    % won't be changed anymore
    unknownClusterColor = max(collectedPartitions(:)) + 1;
    maxKnownCluster = unknownClusterColor - 1;
    
    hasUnknownClusters = find(cellfun(@(x)max(x(:)), collectedPopStructures) > max(collectedPartitions));
    
    for geneIndex = hasUnknownClusters
        popStructureNow = collectedPopStructures{geneIndex};
        popStructureNow(popStructureNow == max(popStructureNow(:))) = unknownClusterColor;
        collectedPopStructures{geneIndex} = popStructureNow;
    end
    
    nLoci = length(collectedPopStructures);
    collectedPopStructuresNew = cell(1, nLoci);
    
    collectedPopStructuresNew{1} = collectedPopStructures{1};
    
    cumulativeFrequencies = computeColorFrequenciesAtLocus(...
        collectedPopStructures{1}, maxKnownCluster, ...
        collectedSnpPositions{1}, collectedGeneLengths(1));
    
    entropies = computeEntropies(cumulativeFrequencies);
    currentMeanEntropy = mean(entropies);
    
    for locusIndex=2:nLoci
        if rem(locusIndex,10)==0
            disp([num2str(locusIndex) '/' num2str(nLoci) ' loci reordered']);
        end
            
        frequenciesNow = computeColorFrequenciesAtLocus(...
            collectedPopStructures{locusIndex}, maxKnownCluster, ...
                collectedSnpPositions{locusIndex}, ...
                collectedGeneLengths(locusIndex));
            
        % Try to identify the one column permutation of frequenciesNow
        % that, when added to the cumulativeFrequencies, results in the 
        % smallest average entropy.
        
        % Start from the most common cluster at the locus, and identify its
        % color. Then the second-most common and so on.
        
        origClusterSizes = histc(collectedPartitions(:,locusIndex),1:maxKnownCluster);
        nClustersInLocus = find(origClusterSizes > 0, 1, 'last');
        
        newClusterLabels = zeros(1,nClustersInLocus);
        % The new labels are between 1 and maxKnownCluster
        
        aux = [(1:nClustersInLocus)', origClusterSizes(1:nClustersInLocus)];
        aux = sortrows(aux, -2); % Minus -> descending
        orderOfProcessingClusters = aux(:,1)';
        
        for clusterIndex = orderOfProcessingClusters
            
            possibleLabels = setdiff(1:maxKnownCluster, newClusterLabels);
            putativeMeanEntropies = zeros(1,length(possibleLabels));
            % Contains information about what the mean entropy would be, if
            % the cluster was given one of the possible labels.
            
            for possibleLabelIndex = 1:length(possibleLabels)
                % Check what the entropy would be if the color of cluster
                % had the suggested possible label
                
                possibleLabel = possibleLabels(possibleLabelIndex);
                
                cumulativeFrequencies(:,possibleLabel) = cumulativeFrequencies(:,possibleLabel) + ...
                    frequenciesNow(:,clusterIndex);
                
                putativeEntropies = computeEntropies(cumulativeFrequencies);
                putativeMeanEntropies(possibleLabelIndex) = mean(putativeEntropies);
                
                cumulativeFrequencies(:,possibleLabel) = cumulativeFrequencies(:,possibleLabel) - ...
                    frequenciesNow(:,clusterIndex);
                
            end
            [minValue, minIndex] = min(putativeMeanEntropies);
            
            % New label was found:
            newClusterLabels(clusterIndex) = possibleLabels(minIndex);
            
            % Add the observations of the cluster to the cumulative cluster
            % frequencies.
            cumulativeFrequencies(:,possibleLabels(minIndex)) = ...
                cumulativeFrequencies(:, possibleLabels(minIndex)) + ...
                frequenciesNow(:,clusterIndex);
            
        end
        
        popStructureNow = collectedPopStructures{locusIndex};
        collectedPopStructuresNew{locusIndex} = zeros(size(popStructureNow), 'uint8');
        
        for clusterIndex = 1:length(newClusterLabels)
            elements = find(popStructureNow == clusterIndex);
            collectedPopStructuresNew{locusIndex}(elements) = newClusterLabels(clusterIndex);
        end
        outsideElements = (popStructureNow == unknownClusterColor);
        collectedPopStructuresNew{locusIndex}(outsideElements) = unknownClusterColor;
        
    end
    
end



function colorFrequencies = computeColorFrequenciesAtLocus(popStructure, maxColor, snpPositions, totalSequenceLength)
    % Returns a table nIndividuals*nPossibleColors, which tells the number
    % of sequence positions colored with different colors.
    
    % NOTE: the positions colored with maxColor+1, i.e., the empty cluster,
    % will not be recorded in the colorFrequencies table.
    
    [nStrains, nSnps] = size(popStructure);
    
    colorFrequencies = zeros(nStrains, maxColor);
    
    for strainIndex = 1:nStrains
        % Identify segments
        segments = identifySegments(popStructure, strainIndex, snpPositions, totalSequenceLength);
        segmentLengths = segments(:,2) - segments(:,1) + 1;
        
        for segmentIndex = 1:size(segments,1)
            colorIndexNow = segments(segmentIndex, 3);
            if colorIndexNow <= maxColor && colorIndexNow > 0
                % Otherwise the segment would be assigned to the outside
                % origin, and not considered in the computations.
                colorFrequencies(strainIndex, colorIndexNow) = ...
                    colorFrequencies(strainIndex, colorIndexNow) + ...
                    segmentLengths(segmentIndex);
            end
        end
    end
end



function entropies = computeEntropies(colorFrequencies)
    
    % Add prior to the frequencies
    [nStrains, nColors] = size(colorFrequencies);
    
    colorFrequencies = colorFrequencies + 1 ./ nColors;
    rowSums = sum(colorFrequencies,2);
    colorDistributions = colorFrequencies ./ repmat(rowSums, [1 nColors]);
    
    entropies = -1.*sum(colorDistributions.*log(colorDistributions),2);
    
end



function colorArray = arrangeColorsMinEntropy(colorArray, nColors, partition, priorClusterWeight)
    % Color with index "nColors" is the black color. It will not be changed.

    nInds = size(colorArray,1);
    nLoci = size(colorArray,2);
    colorDist = zeros(nInds,nColors);
    for i=1:nInds
        for j=1:nColors
            colorDist(i,j) = length(find(colorArray(i,:)==j));
        end
    end

    colorDist = colorDist + 1/nColors;  % Prior

    for col=1:length(unique(partition))
        strains = find(partition==col);
        colorDist(strains,col) = colorDist(strains,col)+priorClusterWeight;
    end

    colorDist = colorDist ./ (nLoci+1+priorClusterWeight);

    entropies = calcEntropies(colorDist);

    for i=1:nLoci
        %disp(num2str(i));
        % Go through the loci. Permute two colors if this leads to decreased
        % entropy.
        maxDecrease = -1;
        while maxDecrease<0
            maxDecrease = 0;
            for color1=1:nColors-2

                strains1 = find(colorArray(:,i)==color1);
                entropy1Current = sum(entropies(strains1));

                for color2=color1+1:nColors-1  % Don't ever change the black color!

                    strains2 = find(colorArray(:,i)==color2);
                    entropy2Current = sum(entropies(strains2));

                    colorDist1 = colorDist(strains1,:);
                    colorDist2 = colorDist(strains2,:);

                    % Update distributions of the particular rows
                    colorDist1(:,color1) = colorDist1(:,color1)-1/(nLoci+1+priorClusterWeight);
                    colorDist1(:,color2) = colorDist1(:,color2)+1/(nLoci+1+priorClusterWeight);
                    colorDist2(:,color1) = colorDist2(:,color1)+1/(nLoci+1+priorClusterWeight);
                    colorDist2(:,color2) = colorDist2(:,color2)-1/(nLoci+1+priorClusterWeight);

                    % Calculate new entropies for the new rows
                    entropies1 = calcEntropies(colorDist1);
                    entropies2 = calcEntropies(colorDist2);

                    % Compare the sum of new entropies to the old entropy sum
                    entropyDecrease = sum(entropies1)+sum(entropies2)-(entropy1Current+entropy2Current);

                    if entropyDecrease<maxDecrease
                        maxDecrease = entropyDecrease;
                        proposedSwitch = [color1 color2];
                        proposedColorDist1 = colorDist1;
                        proposedColorDist2 = colorDist2;
                        proposedEntropies1 = entropies1;
                        proposedEntropies2 = entropies2;
                    end
                end
            end
            if maxDecrease<0
                % Switch the proposed colors
                strains1 = find(colorArray(:,i)==proposedSwitch(1));
                strains2 = find(colorArray(:,i)==proposedSwitch(2));
                colorArray(strains1,i) = proposedSwitch(2);
                colorArray(strains2,i) = proposedSwitch(1);

                colorDist(strains1,:) = proposedColorDist1;
                colorDist(strains2,:) = proposedColorDist2;

                entropies(strains1) = proposedEntropies1;
                entropies(strains2) = proposedEntropies2;
            end
        end
    end
end


function entropies = calcEntropies(distMatrix)
    % The function calculates entropies for distributions given on rows.

    entropies = -1.*sum(distMatrix.*log(distMatrix),2);
end