function c = updateCountsWithStrain(c, strainIndex, newStates, oldStates, changeOnlyTheseSites)
% The function updates COUNTS and SUMCOUNTS in "c", when "strain" has been
% updated. The old and new cluster assignments of the strain have been
% provided "newStates", "oldStates".

diffInCounts = computeDiffInCounts(strainIndex, size(c.COUNTS,1), size(c.COUNTS,2), c.data);
diffInSumCounts = sum(diffInCounts);

% Remove counts from each changed locus from each old merged cluster

if isempty(oldStates) || isempty(newStates)
    % Every site has changed.
    siteChanged = ones(1,size(c.COUNTS,2));
else
    siteChanged = (oldStates ~= newStates);
end

%siteChanged = intersect(siteChanged, changeOnlyTheseSites);
siteChanged = (siteChanged & changeOnlyTheseSites);

if ~isempty(oldStates)
    losingClusters = unique(oldStates(siteChanged));
    for losingIndex = losingClusters
        losingSites = (oldStates == losingIndex) & siteChanged;
        c.COUNTS(:, losingSites, losingIndex) = ...
            c.COUNTS(:, losingSites, losingIndex) ...
            - diffInCounts(:, losingSites);
        
        c.SUMCOUNTS(losingIndex, losingSites) = ...
            c.SUMCOUNTS(losingIndex, losingSites) ...
            - diffInSumCounts(losingSites);
    end
end


if ~isempty(newStates)
    gainingClusters = unique(newStates(siteChanged));
    for gainingIndex = gainingClusters
        
        gainingSites = (newStates == gainingIndex) & siteChanged;
        c.COUNTS(:, gainingSites, gainingIndex) = ...
            c.COUNTS(:, gainingSites, gainingIndex) ...
            + diffInCounts(:, gainingSites);
        
        c.SUMCOUNTS(gainingIndex, gainingSites) = ...
            c.SUMCOUNTS(gainingIndex, gainingSites) ...
            + diffInSumCounts(gainingSites);
    end
end
