infinite.brr.gibbs <- function(n.iter=500, vars.to.update=c('Psi','Gamma','Psi.local.shrinkage','Gamma.local.shrinkage','star.deltas','a3a4','brr.rank'), context, prior, genotypes, phenotypes, crossprod.genotypes=NULL) {
    #
    # A function for running the updates for the
	# the infinite reduced rank regression model.
	# It is assumed that the covariance matrix is
	# diagonal with the diagonal elements given in
	# prior$variances
    #
    # Inputs:
    #   n.iter: number of MCMC iterations
    #   vars.to.update: list of variables to update
    #   context: contains the current values of the variables
    #   prior: contains the hyperparameter and the current
    #   	values of the diagonal elements of the phenotype
    #       covariance matrix
    #   genotypes: centered 0,1,2 coded genotypes (however,
    #       this assumption is not used anywhere, so,
    #       function can be used with other kinds of
    #       regressors as well.
    #   phenotypes: scaled and centered phenotypes
    #   
    #   crossprod.genotypes: pre-computed crossprod(genotypes)
	#
	#
    # Outputs (a list with elements:
    #   updated.context: context with updated variables
    #   trace: MCMC traces of the updated variables
    #	time: total time taken by the updates
	#
	# Requires functions from bayes_lm.R (these must be loaded by
	#   the calling function, because the path i not known here)

	# functions for fitting the Bayesian multivariate linear model
	library(MASS)

	if (is.null(crossprod.genotypes)) {
		crossprod.genotypes = crossprod(genotypes)
	}

	# Unlist variables for easier access:
	Psi <- context$Psi
	Gamma <- context$Gamma
	Psi.local.shrinkage <- context$Psi.local.shrinkage
	Gamma.local.shrinkage <- context$Gamma.local.shrinkage
	star.deltas <- context$star.deltas
	star.taus <- cumprod(star.deltas)
	a3 <- context$a3a4[1]
	a4 <- context$a3a4[2]
	brr.rank <- context$brr.rank
	
	# Unlist hyperparameters and variances for easier access:
	local.shrinkage.nu <- prior$local.shrinkage.nu
	a3.shape <- prior$a3.shape
	a3.rate <- prior$a3.rate
	a3.lower.bound <- prior$a3.lower.bound
	a4.shape <- prior$a4.shape
	a4.rate <- prior$a4.rate
	a4.lower.bound <- prior$a4.lower.bound
	brr.factor.relevance.cutoff <- prior$brr.factor.relevance.cutoff
	alpha0 <- prior$alpha0
	alpha1 <- prior$alpha1
	input.clustering <- prior$input.clustering
	output.clustering <- prior$output.clustering
	n.clusters <- length(unique(output.clustering))
	Psi.Gamma.proposal <- prior$Psi.Gamma.proposal
	if (is.null(prior$step.size)) {
		step.size <- 10
	} else {
		step.size <- prior$step.size
	}

	# check that output clustering is of the form 1:n.clusters
	if(!all(sort(unique(output.clustering)) == 1:n.clusters)) {
		stop('bad output clustering')
	}


	variances <- prior$variances
	precisions <- 1/variances

	# this is used when integrating out Psi while sampling Gamma
	crossprod.var.normed.pheno.geno <- crossprod(t(t(phenotypes)/variances), genotypes)

	n.snps <- ncol(genotypes)
	n.pheno <- ncol(phenotypes)
	n.patients <- nrow(genotypes)

	traces <- list()
	for (name in names(context)) {
		traces[[name]] <- list()
	}
	cpu.times <- rep(0, length(names(context)))
	names(cpu.times) <- names(context)

	accepted <- rep(0, brr.rank)
	rejected <- rep(0, brr.rank)

	for (iter in 1:n.iter) {
		
		if (!exists('Gamma.times')) {
			Gamma.times <- rep(0,6)
		}
		
		if (any(vars.to.update=='Gamma.scale')) {
			# Update Gamma using the M-H.
			# Only the scale of Gamma is updated.
		 	t1 <- proc.time()

 			# Update the parameters related to a component

			scaling.factor <- runif(n=1, min=0.95, max=1.05)
			Gamma.star <- Gamma*scaling.factor
			 			
 			log.prior.Gamma.star <- comp.Gamma.log.prior(Gamma=Gamma.star, Gamma.local.shrinkage=Gamma.local.shrinkage, star.taus=star.taus)
	 			
 			log.proposal.old.to.new <- 0 

			log.likelihood.Gamma.star <- comp.log.likelihood.proportional.Psi.Gamma(Psi=Psi, Gamma=Gamma.star, variances=variances, genotypes=genotypes, phenotypes=phenotypes)
 			
 			log.prior.Gamma <- comp.Gamma.log.prior(Gamma=Gamma, Gamma.local.shrinkage=Gamma.local.shrinkage, star.taus=star.taus)

			log.proposal.new.to.old <- 0
 				
			log.likelihood.Gamma <- comp.log.likelihood.proportional.Psi.Gamma(Psi=Psi, Gamma=Gamma, variances=variances, genotypes=genotypes, phenotypes=phenotypes)
 				

			log.numerator <- log.likelihood.Gamma.star + log.prior.Gamma.star + log.proposal.new.to.old
 				
			log.denominator <- log.likelihood.Gamma + log.prior.Gamma + log.proposal.old.to.new
 				
			log.acceptance.ratio <- log.numerator - log.denominator
 				
			acceptance.ratio <- min(1, exp(log.acceptance.ratio))
			if (runif(n=1) < acceptance.ratio) {
				Gamma <- Gamma.star
			}

			traces$Gamma[[iter]] <- Gamma
			t2 <- proc.time()
			cpu.times['Gamma'] <- cpu.times['Gamma'] + (t2[3] - t1[3])
		}


		if (any(vars.to.update=='Gamma')) {
			t1 <- proc.time()

			
			
			
			Gamma.element.precisions <- Gamma.local.shrinkage * star.taus
			design.matrix <- genotypes %*% Psi
			
			crossprod.X = crossprod(design.matrix)
			
			# Prior variances of the elements:
			Sigma.gamma.p <- 1/Gamma.element.precisions
			fit <- fit.bayes.lm.diag(design.matrix, phenotypes, variances, Sigma.gamma.p, crossprod.X=crossprod.X)
			
			
			for (p in 1:n.pheno) {
				# Update the pth column of Gamma
				
				Gamma[,p] <- mvr.norm.own(mu=fit$posterior.mean[[p]], Sigma=fit$posterior.cov[[p]])
				
			}	
			
			
	
			traces$Gamma[[iter]] <- Gamma
			t2 <- proc.time()
			cpu.times['Gamma'] <- cpu.times['Gamma'] + (t2[3] - t1[3])

		}


	

		if (!exists('Psi.times')) {
			Psi.times <- rep(0,5)
		}

		if (any(vars.to.update=='Psi')) {
			#print('updating Psi')
	 		#browser()
			t1 <- proc.time()
			
			tA <- proc.time()[3]
			Psi.variances <- 1/as.vector(t(t(Psi.local.shrinkage) * star.taus))
			
			#browser()

			tB <- proc.time()[3]
			if (length(Psi.variances)==1) {
				Psi.aux <- 1/Psi.variances
			} else {
				Psi.aux <- diag(1/Psi.variances)
			}
			posterior.cov <- chol2inv(chol( Psi.aux + (Gamma %*% (1/variances*t(Gamma))) %x% crossprod.genotypes ))
			
			tC <- proc.time()[3]
			bb <- crossprod(genotypes,phenotypes)
			
			tD <- proc.time()[3]
			aa <-  bb %*% (1/variances*t(Gamma))
			posterior.mean <- posterior.cov %*% as.vector(aa)

			tE <- proc.time()[3]
			Psi.vec <- mvr.norm.own(mu=posterior.mean, Sigma=posterior.cov)

			tF <- proc.time()[3]

			Psi <- matrix(Psi.vec, nrow=n.snps, ncol=brr.rank)

			Psi.times[1] <- Psi.times[1] + tB-tA
			Psi.times[2] <- Psi.times[2] + tC-tB
			Psi.times[3] <- Psi.times[3] + tD-tC
			Psi.times[4] <- Psi.times[4] + tE-tD
			Psi.times[5] <- Psi.times[5] + tF-tE
		
			#browser()
			
			traces$Psi[[iter]] <- Psi
			t2 <- proc.time()
			cpu.times['Psi'] <- cpu.times['Psi'] + (t2[3] - t1[3])
			#print('Psi updated')
		}

		
		if (any(vars.to.update=='Psi.local.shrinkage')) {
			# Psi.local.shrinkage[j,h] is the same as \phi_{jh}^{\Psi} in the article.
			t1 <- proc.time()

			# Number of elements in different clusters
			num <- table(input.clustering)
			num <- matrix(rep(num, brr.rank), ncol=brr.rank)
			shape.pars <- (local.shrinkage.nu + num) / 2
			
			# Sums of squares of elements of Psi in different clusters
			ss <- by(Psi, input.clustering, function(x){colSums(x^2)})
			ss <- matrix(unlist(ss), ncol=brr.rank, byrow=TRUE)
			
			rate.pars <- (local.shrinkage.nu + t(star.taus * t(ss))) / 2

			Psi.local.shrinkage.elements <- matrix(rgamma(n=prod(dim(rate.pars)), shape=shape.pars, rate=rate.pars), nrow=nrow(rate.pars), ncol=ncol(rate.pars))
			Psi.local.shrinkage <- Psi.local.shrinkage.elements[input.clustering,, drop=FALSE]

			traces$Psi.local.shrinkage[[iter]] <- Psi.local.shrinkage
			t2 <- proc.time()
			cpu.times['Psi.local.shrinkage'] <- cpu.times['Psi.local.shrinkage'] + (t2[3] - t1[3])
		}
		
		
		if (any(vars.to.update=='Gamma.local.shrinkage')) {
			# Gamma.local.shrinkage[j,h] is the same as \phi_{jh}^{\Gamma} in the article.
			t1 <- proc.time()

			# Number of elements in different clusters
			num <- table(output.clustering)
			num <- matrix(rep(num, brr.rank), nrow=brr.rank, byrow=TRUE)
			shape.pars <- (local.shrinkage.nu + num) / 2
			
			# Sums of squares of elements of Gamma in different clusters
			ss <- by(t(Gamma), output.clustering, function(x){colSums(x^2)})
			ss <- matrix(unlist(ss), ncol=brr.rank, byrow=TRUE)
			ss <- t(ss)
			
			rate.pars <- (local.shrinkage.nu + star.taus * ss) / 2

			Gamma.local.shrinkage.elements <- matrix(rgamma(n=prod(dim(rate.pars)), shape=shape.pars, rate=rate.pars), nrow=nrow(rate.pars), ncol=ncol(rate.pars))
			Gamma.local.shrinkage <- Gamma.local.shrinkage.elements[,output.clustering, drop=FALSE]

			traces$Gamma.local.shrinkage[[iter]] <- Gamma.local.shrinkage
			t2 <- proc.time()
			cpu.times['Gamma.local.shrinkage'] <- cpu.times['Gamma.local.shrinkage'] + (t2[3] - t1[3])
		}
		
		
		## Sample star.deltas (and star.taus)
		if (any(vars.to.update=='star.deltas')) {
			# Sample star.deltas[1]
			t1 <- proc.time()

			shape <- a3 + (n.pheno + n.snps) * brr.rank / 2
			
			phi.psi.vector <- colSums(Psi^2 * Psi.local.shrinkage)
			phi.gamma.vector <- rowSums(Gamma^2 * Gamma.local.shrinkage)
			
			star.deltas[1] <- 1
			tau.vector <- cumprod(star.deltas)
			if (!is.na(a3.rate)) {
				# If a3 is not fixed, then use prior distribution 
				# Gamma(a3,1) for star.deltas[1]
				rate <- 1 + 1/2 * tau.vector %*% (phi.psi.vector + phi.gamma.vector)		
			} else {
				# If a3 is fixed, then use prior distribution
				# Gamma(shape=a3, rate=1/a3) for star.deltas[1]
				rate <- 1/a3 + 1/2 * tau.vector %*% (phi.psi.vector + phi.gamma.vector)
			}
			star.deltas[1] <- rgamma(n=1, shape=shape, rate=rate)
			
			# Sample rest of the star.deltas
			if (brr.rank>1) {
				for (h in 2:brr.rank) {
					
					shape <- a4 + (n.pheno+n.snps) * (brr.rank-h+1) / 2
					star.deltas[h] <- 1
					tau.vector <- cumprod(star.deltas)[-seq(1,h-1,by=1)]
					rate <- 1 + 1/2 * tau.vector %*% (phi.psi.vector[-seq(1,h-1,by=1)] + phi.gamma.vector[-seq(1,h-1,by=1)])
					star.deltas[h] <- rgamma(n=1, shape=shape, rate=rate)
				}
			}

			star.taus <- cumprod(star.deltas)
			traces$star.deltas[[iter]] <- star.deltas
			t2 <- proc.time()
			cpu.times['star.deltas'] <- cpu.times['star.deltas'] + (t2[3] - t1[3])
		}
		

		## Update a3 and a4
		if (any(vars.to.update=='a3a4')) {
			# Sample a3 and a4
			t1 <- proc.time()

			if (!is.na(a3.rate)) {

				# Move from (a3,a4) to (a3.star, a4.star) is propsed
				#a3.proposal.std <- a3/10
				#a4.proposal.std <- a4/10
				#a3.star <- rnorm(n=1, mean=a3, sd=a3.proposal.std)
				#a4.star <- rnorm(n=1, mean=a4, sd=a4.proposal.std)
				a3.proposal.std <- log(a3)/step.size
				a4.proposal.std <- log(a4)/step.size
				log.a3.star <- rnorm(n=1, mean=log(a3), sd=a3.proposal.std)
				log.a4.star <- rnorm(n=1, mean=log(a4), sd=a4.proposal.std)
				a3.star <- exp(log.a3.star)
				a4.star <- exp(log.a4.star)
				
				#a3.star.proposal.std <- a3.star/10
				#a4.star.proposal.std <- a4.star/10
				a3.star.proposal.std <- log(a3.star)/step.size
				a4.star.proposal.std <- log(a4.star)/step.size
				
				
				if ((a3.star>a3.lower.bound) & (a4.star>a4.lower.bound)) {
					# Otherwise the move is automatically rejected.
					#log.proposal.prob <- dnorm(a3.star, mean=a3, sd=a3.proposal.std, log=TRUE) + dnorm(a4.star, mean=a4, sd=a4.proposal.std, log=TRUE)
					#log.inverse.proposal.prob <- dnorm(a3, mean=a3.star, sd=a3.star.proposal.std, log=TRUE) + dnorm(a4, mean=a4.star, sd=a4.star.proposal.std, log=TRUE)
					log.proposal.prob <- dnorm(log(a3.star), mean=log(a3), sd=a3.proposal.std, log=TRUE) + dnorm(log(a4.star), mean=log(a4), sd=a4.proposal.std, log=TRUE)
					log.inverse.proposal.prob <- dnorm(log(a3), mean=log(a3.star), sd=a3.star.proposal.std, log=TRUE) + dnorm(log(a4), mean=log(a4.star), sd=a4.star.proposal.std, log=TRUE)

					log.prob.current.a3 <- (a3-1) * log(star.deltas[1]) + (a3.shape-1) * log(a3) - a3.rate*a3 - lgamma(a3)
					
					if (brr.rank>1) {
						log.prob.current.a4 <- (a4-1) * sum(log(star.deltas[-1])) + (a4.shape-1)*log(a4) - a4.rate*a4 - (brr.rank-1) * lgamma(a4)
						
					} else {
						# value for a4 comes from the prior
						log.prob.current.a4 <- dgamma(a4, shape=a4.shape, rate=a4.rate, log=TRUE)
					}
					log.prob.current <- log.prob.current.a3 + log.prob.current.a4

					log.prob.a3.star <- (a3.star-1)*log(star.deltas[1]) + (a3.shape-1)* log(a3.star) - a3.rate*a3.star - lgamma(a3.star)
					
					if (brr.rank>1) {
						log.prob.a4.star <- (a4.star-1) * sum(log(star.deltas[-1])) + (a4.shape-1)*log(a4.star) - a4.rate*a4.star - (brr.rank-1) * lgamma(a4.star)
						
					} else {
						# Value for a4.star comes from the prior
						log.prob.a4.star <- dgamma(a4.star, shape=a4.shape, rate=a4.rate, log=TRUE)
					}
					log.prob.proposed <- log.prob.a3.star + log.prob.a4.star
					
					log.acceptance.prob <- min(0, log.prob.proposed + log.inverse.proposal.prob - log.prob.current - log.proposal.prob)
					acceptance.prob <- exp(log.acceptance.prob)

					if (runif(n=1,min=0,max=1)<acceptance.prob) {
						# Proposal is accepted
						a3 <- a3.star
						a4 <- a4.star
					}
				}
			}
			traces$a3a4[[iter]] <- c(a3,a4)
			t2 <- proc.time()
			cpu.times['a3a4'] <- cpu.times['a3a4'] + (t2[3] - t1[3])
		}
		
		
		## Adapt rank
		if (any(vars.to.update=='brr.rank')) {
			# Perform adaptation of the rank similarly to the infinite FA model
			t1 <- proc.time()

			if (runif(n=1, min=0, max=1) < exp(alpha0 + alpha1*iter)) {
				
				relevance.score <- rep(NA, brr.rank)
				for (i in 1:brr.rank) {
					relevance.score[i] <- max(abs(Psi[,i,drop=FALSE] %*% Gamma[i,,drop=FALSE]))
				}
				col.relevant <- (relevance.score > brr.factor.relevance.cutoff)

				if (all(col.relevant)) {
					## Add another column from the prior
 					values.to.add <- simulate.new.brr.factor(star.deltas=star.deltas, local.shrinkage.nu=local.shrinkage.nu, a4=a4, n.pheno=n.pheno, n.snps=n.snps, input.clustering=input.clustering, output.clustering=output.clustering)

					
					Psi.local.shrinkage <- cbind(Psi.local.shrinkage, values.to.add$new.Psi.local.shrinkage.col)
					Psi <- cbind(Psi, values.to.add$new.Psi.col)
					
					Gamma.local.shrinkage <- rbind(Gamma.local.shrinkage, values.to.add$new.Gamma.local.shrinkage.row)
					Gamma <- rbind(Gamma, values.to.add$new.Gamma.row)
					
					star.deltas <- c(star.deltas, values.to.add$new.star.delta)
					star.taus <- c(star.taus, values.to.add$new.star.tau)
					brr.rank <- brr.rank+1

				} else {
					# Remove the non-relevant columns; however, leave at least one column.
					cols.to.remove <- which(!col.relevant)
					if (length(cols.to.remove)==ncol(Psi)) {
						# Check which column is the most relevant
						# and retain that column.
						retain.this <- which.max(relevance.score)
						cols.to.remove <- setdiff(cols.to.remove, retain.this)
					}

					Psi <- Psi[,-cols.to.remove, drop=FALSE]
					Psi.local.shrinkage <- Psi.local.shrinkage[, -cols.to.remove, drop=FALSE]

					Gamma <- Gamma[-cols.to.remove, , drop=FALSE]
					Gamma.local.shrinkage <- Gamma.local.shrinkage[-cols.to.remove, , drop=FALSE]

					star.deltas <- star.deltas[-cols.to.remove]
					star.taus <- cumprod(star.deltas)
					brr.rank <- brr.rank - length(cols.to.remove)
				}
			}
			traces$brr.rank[[iter]] <- brr.rank
			t2 <- proc.time()
			cpu.times['brr.rank'] <- cpu.times['brr.rank'] + (t2[3] - t1[3])
		}
		
		
	}
	
	for (name in names(context)){
		# Update current values of the variable to the context
		if (name=='a3a4') {
			context$a3a4 <- c(a3,a4)
		} else {
			eval(parse(text=paste('context$', name, '<-', name, sep='')))
		}
	}


	traces$accepted <- accepted
	traces$rejected <- rejected
	
	to.return <- list(updated.context=context, traces=traces, cpu.times=cpu.times)
	
	return(to.return)
}

estimate.effect <- function(genotypes, phenotypes, snp, phenotype) {
	groups <- list()
	group.means <- rep(0,3)
	for (i in 1:3) {
		groups[[i]] <- which(genotypes[,snp]==i-1)
		group.means[i] <- mean(phenotypes[groups[[i]],phenotype])
	}
	return(group.means)
}

simulate.phenotypes.from.infinite.brr <- function(prior, context, genotypes, n.pheno) {
    #
    # A function for simulating the phenotypes using
	# the infinite rank reduced rank regression model.
	# It is assumed that the covariance matrix is
	# diagonal with the diagonal elements given in
	# prior$variances
    #
    # Inputs:
    #   n.patients: number of patients
    #   n.snps: number of genotypes
    #
    # Outputs:
    #   genotypes

	# Unlist variables for easier access:
	Psi <- context$Psi
	Gamma <- context$Gamma
	variances <- prior$variances
	
	n.patients <- nrow(genotypes)

	pheno.means <- genotypes %*% Psi %*% Gamma

	aux <- matrix(rnorm(n=n.pheno*n.patients, mean=t(pheno.means), sd = sqrt(variances)), nrow=n.pheno, ncol=n.patients)
	
	phenotypes <- t(aux)
	
	return(phenotypes)
}



initialize.infinite.brr <- function(local.shrinkage.nu=3, a3.shape=18, a3.rate=2, a3.lower.bound=2, a4.shape=18, a4.rate=2, a4.lower.bound=3, brr.factor.relevance.cutoff=0.01, alpha0=-1, alpha1=-5E-4, a.sigma=1, b.sigma=1, brr.rank=3, n.snps=50, n.pheno=10, input.clustering=NULL, output.clustering=NULL, step.size=10, rare.maf.threshold=0.01) {
    #
    # A function for initializing the sparse infinite
    # Bayesian reduced rank regression.
    #
    # Inputs:
    #   local.shrinkage.nu, a3.shape, a3.rate, a3.lower.bound
    #   a4.shape, a4.rate, a4.lower.bound,
    #   brr.factor.relevance.cutoff, alpha0, alpha1, a.sigma,
    #   b.sigma, brr.rank, n.snps, n.pheno, n.patients
    #
	#	If a.sigma and b.sigma are given (non-NA), variances 
	#	will be simulated using these values. Otherwise, 
	#	the variances will be set to NA.
	#
	#
    # Outputs:
    #   context: a list with fields Psi, Gamma,
    #       Psi.local.shrinkage, Gamma.local.shrinkage,
    #       star.deltas, a3a4, brr.rank
    #
    #   prior: a list fields local.shrinkage.nu,
    #       a3.shape, a3.rate, a3.lower.bound,
    #       a4.shape, a4.rate, a4.lower.bound,
    #       brr.factor.relevance.cutoff, alpha0
    #       alpha1, variances
    #   (Note: variances is actually not a hyperparameter;
    #   however, it is required when updating the low-rank
    #   regression coefficent matrix, but it is not updated
    #   itself.)
    

	## Hyperparameters:    
	if (!is.na(a.sigma)) {
		precisions <- rgamma(n=n.pheno, shape=a.sigma, rate=b.sigma)
		variances <- 1/precisions
	} else {
		variances <- NA
	}
    
	prior <- list()
	prior$local.shrinkage.nu <- local.shrinkage.nu
	prior$a3.shape <- a3.shape
	prior$a3.rate <- a3.rate
	prior$a3.lower.bound <- a3.lower.bound
	prior$a4.shape <- a4.shape
	prior$a4.rate <- a4.rate
	prior$a4.lower.bound <- a4.lower.bound
	prior$brr.factor.relevance.cutoff <- brr.factor.relevance.cutoff
	prior$alpha0 <- alpha0
	prior$alpha1 <- alpha1
	prior$variances <- variances
	prior$step.size <- step.size
	prior$rare.maf.threshold <- rare.maf.threshold

	if (!is.null(input.clustering)) {
		# Clustering for SNPs has been given
		if (length(input.clustering)==n.snps) {
			prior$input.clustering <- factor(input.clustering)
		} else {
			stop('Length of input clustering must be equal to #SNPS.')
		}
	} else {
		# Default: every SNP in its own group.
		input.clustering <- 1:n.snps
		prior$input.clustering <- factor(input.clustering)
	}
	if (!is.null(output.clustering)) {
		# Clustering of phenotypes has been given
		if (length(output.clustering)==n.pheno) {
			prior$output.clustering <- output.clustering
		} else {
			stop('Length of output clustering must equal #phenotypes.')
		}
	} else {
		# Default: every phenotype in its own group
		output.clustering <- 1:n.pheno
		prior$output.clustering <- output.clustering
	}
	
	## Variables to update:
    a3 <- -1
    while (a3 < a3.lower.bound) {
		a3 <- rgamma(n=1, shape=a3.shape, rate=a3.rate)
    }
    a4 <- -1
    while (a4 < a4.lower.bound) {
		a4 <- rgamma(n=1, shape=a4.shape, rate=a4.rate)
    }

	star.deltas <- rep(0, brr.rank)
	star.deltas[1] <- rgamma(n=1, shape=a3, rate=1)
	if (brr.rank>1) {
		star.deltas[2:brr.rank] <- rgamma(n=brr.rank-1, shape=a4, rate=1)
    }
	star.taus <- cumprod(star.deltas)
    
    n.snp.clusters <- length(unique(input.clustering))
    Psi.local.shrinkage.parameters <- matrix(rgamma(n=n.snp.clusters*brr.rank, shape=local.shrinkage.nu/2, rate=local.shrinkage.nu/2), nrow=n.snp.clusters, ncol=brr.rank)
    
    Psi.local.shrinkage <- Psi.local.shrinkage.parameters[input.clustering,, drop=FALSE]
    
    Psi.precisions <- t(t(Psi.local.shrinkage)*star.taus)
	Psi.sd <- 1/sqrt(Psi.precisions)
    Psi <- matrix(rnorm(n=n.snps*brr.rank, mean=0, sd=Psi.sd), nrow=n.snps, ncol=brr.rank)
    
    n.pheno.clusters <- length(unique(output.clustering))
    Gamma.local.shrinkage.parameters <- matrix(rgamma(n=brr.rank*n.pheno.clusters, shape=local.shrinkage.nu/2, local.shrinkage.nu/2), nrow=brr.rank, ncol=n.pheno.clusters)
    
	Gamma.local.shrinkage <- Gamma.local.shrinkage.parameters[,output.clustering, drop=FALSE]
    
   Gamma.precisions <- Gamma.local.shrinkage * star.taus
	
	

	Gamma.sd <- 1/sqrt(Gamma.precisions)
	Gamma <- matrix(rnorm(n=brr.rank*n.pheno, mean=0, sd=Gamma.sd), nrow=brr.rank, ncol=n.pheno)

	
    
    context <- list()
    context$Psi <- Psi
    context$Gamma <- Gamma
    context$Psi.local.shrinkage <- Psi.local.shrinkage
    context$Gamma.local.shrinkage <- Gamma.local.shrinkage
    context$star.deltas <- star.deltas
    context$a3a4 <- c(a3,a4)
    context$brr.rank <- brr.rank
	
	return(list(context=context, prior=prior))
}


simulate.new.brr.factor <- function(star.deltas, local.shrinkage.nu, a4, n.pheno, n.snps, input.clustering, output.clustering) {

	new.star.delta <- rgamma(n=1, shape=a4, rate=1)  # Column to be added is never the first.
	
	new.star.tau <- prod(star.deltas) * new.star.delta
	
	
	n.snp.clusters <- length(unique(input.clustering))    
	
	new.Psi.local.shrinkage.parameters <- rgamma(n=n.snp.clusters, shape=local.shrinkage.nu/2, rate=local.shrinkage.nu/2)	
	
	new.Psi.local.shrinkage.col <- new.Psi.local.shrinkage.parameters[input.clustering]
	
	new.Psi.col <- rnorm(n=n.snps, mean=0, sd=1/sqrt(new.Psi.local.shrinkage.col * new.star.tau))
	
	
	n.pheno.clusters <- length(unique(output.clustering))
    
	new.Gamma.local.shrinkage.parameters <- rgamma(n=n.pheno.clusters, shape=local.shrinkage.nu/2, rate=local.shrinkage.nu/2)
	
	new.Gamma.local.shrinkage.row <- new.Gamma.local.shrinkage.parameters[output.clustering]
	
	new.Gamma.row <- rnorm(n=n.pheno, mean=0, sd=1/sqrt(new.Gamma.local.shrinkage.row * new.star.tau))
	
	
	return(list(new.star.delta=new.star.delta, new.star.tau=new.star.tau, new.Psi.local.shrinkage.col=new.Psi.local.shrinkage.col, new.Psi.col=new.Psi.col, new.Gamma.local.shrinkage.row=new.Gamma.local.shrinkage.row, new.Gamma.row=new.Gamma.row))
	
}




initialize.informative.brr <- function(data, a3.init.value=3000, a4.init.value=30, local.shrinkage.nu=3, a3.shape=1.4, a3.rate=5.4e-5, a3.lower.bound=2, a4.shape=4.1, a4.rate=0.31, a4.lower.bound=2.7, brr.factor.relevance.cutoff=0.01, alpha0=-1, alpha1=-5E-4, a.sigma=1, b.sigma=1, brr.rank=2, input.clustering=NULL, output.clustering=NULL, long=FALSE, step.size=10, rare.maf.threshold=0.01, initialize.Psi.Gamma.proposal=T, adaptation.interval=100, adaptation.intervals.total=10, prior.effect.tpve.median = 0.0001) {
    #
    # A function for initializing the sparse infinite
    # Bayesian reduced rank regression.
    #
    # Inputs:
   	# The following three variables must be given.
    #   data: must have fields genotypes and phenotypes
    #	a3.init.value
   	#	a4.init.value
    #
    #   local.shrinkage.nu, a3.shape, a3.rate, a3.lower.bound
    #   a4.shape, a4.rate, a4.lower.bound,
    #   brr.factor.relevance.cutoff, alpha0, alpha1, a.sigma,
    #   b.sigma, brr.rank, input.clustering, output.clustering
    #
	#	If a.sigma and b.sigma are given (non-NA), variances 
	#	will be simulated using these values. Otherwise, 
	#	the variances will be set to NA.
	#
	#
    # Outputs:
    #   context: a list with fields Psi, Gamma,
    #       Psi.local.shrinkage, Gamma.local.shrinkage,
    #       star.deltas, a3a4, brr.rank
    #
    #   prior: a list fields local.shrinkage.nu,
    #       a3.shape, a3.rate, a3.lower.bound,
    #       a4.shape, a4.rate, a4.lower.bound,
    #       brr.factor.relevance.cutoff, alpha0
    #       alpha1, variances
    #   (Note: variances is actually not a hyperparameter;
    #   however, it is required when updating the low-rank
    #   regression coefficent matrix, but it is not updated
    #   itself.)

	n.snps <- ncol(data$genotypes)
	n.pheno <- ncol(data$phenotypes)
	n.inds <- nrow(data$genotypes)    

	
	## Hyperparameters:    
	if (!is.na(a.sigma)) {
		precisions <- rgamma(n=n.pheno, shape=a.sigma, rate=b.sigma)
		variances <- 1/precisions
	} else {
		variances <- NA
	}
    
	prior <- list()
	prior$local.shrinkage.nu <- local.shrinkage.nu
	prior$a3.shape <- a3.shape
	prior$a3.rate <- a3.rate
	prior$a3.lower.bound <- a3.lower.bound
	prior$a4.shape <- a4.shape
	prior$a4.rate <- a4.rate
	prior$a4.lower.bound <- a4.lower.bound
	prior$brr.factor.relevance.cutoff <- brr.factor.relevance.cutoff
	prior$alpha0 <- alpha0
	prior$alpha1 <- alpha1
	prior$variances <- variances
	prior$step.size <- step.size
	prior$rare.maf.threshold <- rare.maf.threshold

	if (!is.null(input.clustering)) {
		# Clustering for SNPs has been given
		if (length(input.clustering)==n.snps) {
			prior$input.clustering <- factor(input.clustering)
		} else {
			stop('Length of input clustering must be equal to #SNPS.')
		}
	} else {
		# Default: every SNP in its own group.
		input.clustering <- 1:n.snps
		prior$input.clustering <- factor(input.clustering)
	}
	if (!is.null(output.clustering)) {
		# Clustering of phenotypes has been given
		if (length(output.clustering)==n.pheno) {
			prior$output.clustering <- output.clustering
		} else {
			stop('Length of output clustering must equal #phenotypes.')
		}
	} else {
		# Default: every phenotype in its own group
		output.clustering <- 1:n.pheno
		prior$output.clustering <- output.clustering
	}
	
	## Variables to update:
    a3 <- a3.init.value
    a4 <- a4.init.value
	
	n.iter <- ifelse(initialize.Psi.Gamma.proposal, 100, 2)
	res <- sample.brr.simple(data$genotypes, data$phenotypes, brr.rank=brr.rank, n.iter=n.iter, prior.effect.tpve.median=prior.effect.tpve.median)
	
	if (initialize.Psi.Gamma.proposal) {
		prior$Psi.Gamma.proposal <- res$Psi.Gamma.dist
		for (comp.index in 1:brr.rank) {
			prior$Psi.Gamma.proposal[[comp.index]]$fixed.step.size = FALSE
			prior$Psi.Gamma.proposal[[comp.index]]$step.size.scaling = 0.1
		}
		prior$Psi.Gamma.adaptation.pars <- list()
		prior$Psi.Gamma.adaptation.pars$interval <- adaptation.interval # See how many accepts withinin this many iterations.
		prior$Psi.Gamma.adaptation.pars$n.intervals.total <- adaptation.intervals.total # After this many intervals, stop adaptation.
		prior$Psi.Gamma.adaptation.pars$accepted <- rep(0, brr.rank)
		prior$Psi.Gamma.adaptation.pars$rejected <- rep(0, brr.rank)
	
		Psi <- res$Psi.mean
		Gamma <- res$Gamma.mean
		for (comp.index in 1:brr.rank) {
			sampled.value <- mvr.norm.own(mu=res$Psi.Gamma.dist[[comp.index]]$mean, Sigma=res$Psi.Gamma.dist[[comp.index]]$cov+diag(n.snps+n.pheno)*0.001)
			Psi[,comp.index] <- sampled.value[1:n.snps]
			Gamma[comp.index,] <- sampled.value[-(1:n.snps)]
		}

	} else {
		prior$Psi.Gamma.proposal <- NULL
		prior$Psi.Gamma.adaptation.pars <- NULL
	
		svd.res <- svd(res$mean.effect.matrix)
		
		if (brr.rank>1) {
			Psi <- svd.res$u[,1:brr.rank] %*% diag(svd.res$d[1:brr.rank]^(1/2))
			Gamma <- diag(svd.res$d[1:brr.rank]^(1/2)) %*% t(svd.res$v[,1:brr.rank])
		} else {
			Psi <- svd.res$u[,1,drop=F] * svd.res$d[1]^(1/2)
			Gamma <- t(svd.res$v[,1,drop=F]) * svd.res$d[1]^(1/2)
		}
	}

	Psi.local.shrinkage <- matrix(1, nrow=n.snps, ncol=brr.rank)    
	Gamma.local.shrinkage <- matrix(1, nrow=brr.rank, ncol=n.pheno)

	for (rank.index in 1:brr.rank) {
		# Go through each rank, make the corresponding 
		# variances in Psi and Gamma equal.

		if (nrow(Psi)==1) {
			# Special case of having only one genotypes
			v.Psi <- as.vector(Psi^2)
		} else {
			v.Psi <- var(Psi[,rank.index])
		}
		v.Gamma <- var(Gamma[rank.index,])
		Psi.multiplier <- (v.Gamma/v.Psi)^(1/4)
		Psi[,rank.index] <- Psi[,rank.index,drop=F] * Psi.multiplier
		Gamma[rank.index,] <- Gamma[rank.index,,drop=F] / Psi.multiplier
	}
	if (nrow(Psi)==1) {
		Psi.variances <- as.vector(Psi^2)
	} else {
		Psi.variances <- apply(Psi, 2, var)
	}
	star.taus <- 1/Psi.variances

	star.deltas <- rep(0, brr.rank)
	star.deltas[1] <- star.taus[1]
	if (brr.rank>1) {
		for (rank.index in 2:brr.rank) {
			star.deltas[rank.index] <- max(star.taus[rank.index] / cumprod(star.deltas[seq(1,rank.index-1)]), 10)
		}
	}

    context <- list()
    context$Psi <- Psi
    context$Gamma <- Gamma
    context$Psi.local.shrinkage <- Psi.local.shrinkage
    context$Gamma.local.shrinkage <- Gamma.local.shrinkage
    context$star.deltas <- star.deltas
    context$a3a4 <- c(a3,a4)
    context$brr.rank <- brr.rank

	
	
	
	# Update a3a4
	if (long==TRUE) {
		aa <- infinite.brr.gibbs(n.iter=20, vars.to.update=c('Psi','Gamma','Psi.local.shrinkage','Gamma.local.shrinkage'), context=context, prior=prior, genotypes=data$genotypes, phenotypes=data$phenotypes, crossprod.genotypes=data$crossprod.genotypes)
		context <- aa$updated.context
	}

	return(list(context=context, prior=prior, blm.mean.effects=res$mean.effect.matrix))
}




sample.brr.simple <- function(genotypes, phenotypes, brr.rank=2, n.iter=100, prior.effect.tpve.median) {
	
	
	

	n.pheno <- ncol(phenotypes)
	n.snps <- ncol(genotypes)
	pheno.vars <- apply(phenotypes,2,var)
	geno.vars <- apply(genotypes, 2, var)
	prior.var <- prior.effect.tpve.median / sum(geno.vars)*sum(pheno.vars)
	crossprod.X <- crossprod(genotypes)

	fitted <- list()

	mean.effect.matrix <- matrix(NA, nrow=n.snps, ncol=n.pheno)

	for (pheno.index in 1:n.pheno) {
		fitted[[pheno.index]] <- fit.bayes.lm.diag(X=genotypes, y=phenotypes[,pheno.index], noise.var=pheno.vars[pheno.index], prior.var=prior.var, crossprod.X=crossprod.X)

		mean.effect.matrix[,pheno.index] <- fitted[[pheno.index]]$posterior.mean
	}
	
	Gamma.list <- list()
	Psi.list <- list()
	tpve.vec <- rep(NA, n.iter)
	total.variation.in.data <- sum(apply(phenotypes, 2, var))

	for (i in 1:n.iter) {
		effect.matrix <- matrix(NA, nrow=n.snps, ncol=n.pheno)
		for (pheno.index in 1:n.pheno) {
			# Simulate effect matrix from the posterior
			effect.matrix[,pheno.index] <- mvr.norm.own(mu=fitted[[pheno.index]]$posterior.mean, Sigma=fitted[[pheno.index]]$posterior.cov)
			
		}
	
		svd.res <- svd(effect.matrix)
		
		if (brr.rank>1) {
			Psi <- svd.res$u[,1:brr.rank] %*% diag(svd.res$d[1:brr.rank]^(1/2))
			Gamma <- diag(svd.res$d[1:brr.rank]^(1/2)) %*% t(svd.res$v[,1:brr.rank])
		} else {
			Psi <- svd.res$u[,1,drop=F] * svd.res$d[1]^(1/2)
			Gamma <- t(svd.res$v[,1,drop=F]) * svd.res$d[1]^(1/2)
		}

		for (rank.index in 1:brr.rank) {
			# Go through each rank, make the corresponding 
			# variances in Psi and Gamma equal.

			if (nrow(Psi)==1) {
				# Special case of having only one genotypes
				v.Psi <- as.vector(Psi^2)
			} else {
				v.Psi <- var(Psi[,rank.index])
			}
			v.Gamma <- var(Gamma[rank.index,])
			Psi.multiplier <- (v.Gamma/v.Psi)^(1/4)
			Psi[,rank.index] <- Psi[,rank.index,drop=F] * Psi.multiplier
			Gamma[rank.index,] <- Gamma[rank.index,,drop=F] / Psi.multiplier
		}

		Psi.list[[i]] <- Psi
		Gamma.list[[i]] <- Gamma
		
		total.variation.explained <- compute.amount.total.variance.explained(genotypes=genotypes, Psi=Psi, Gamma=Gamma)

		tpve.vec[i] <- total.variation.explained / total.variation.in.data
	}

	# Estimate the joint distribution for the
	# Psi and Gamma parameters related to the
	# i:th component.
	Gamma.array <- array(unlist(Gamma.list), dim=c(brr.rank,n.pheno,n.iter))
	Gamma.mean <- apply(Gamma.array, c(1,2), mean)
		
	Gamma.mat <- matrix(unlist(lapply(Gamma.list, function(x){t(x)})), nrow=n.iter, ncol=brr.rank*n.pheno, byrow=T)
	Gamma.cov <- cov(Gamma.mat)

	Psi.array <- array(unlist(Psi.list), dim=c(n.snps, brr.rank, n.iter))
	Psi.mean <- apply(Psi.array, c(1,2), mean)

	Psi.mat <- matrix(unlist(Psi.list), nrow=n.iter, ncol=brr.rank*n.snps, byrow=T)
	Psi.cov <- cov(Psi.mat)

	Psi.Gamma.dist <- list()

	for (comp.index in 1:brr.rank) {
		Psi.Gamma.dist[[comp.index]] <- list()
		
		# A matrix where columns represent iterations. 
		# At each column there are the values of the 
		# comp.index:th column of Psi.
		Psi.pars.in.comp <- Psi.array[,comp.index,]
		
		# In columns there are the comp.index:th row
		# of Gamma in different intereations.
		Gamma.pars.in.comp <- Gamma.array[comp.index,,]
		
		Psi.Gamma.pars <- rbind(Psi.pars.in.comp, Gamma.pars.in.comp)

		Psi.Gamma.dist[[comp.index]]$mean <- apply(Psi.Gamma.pars,1,mean)
		Psi.Gamma.dist[[comp.index]]$cov <- cov(t(Psi.Gamma.pars))
	}
	
	tpve.mean <- mean(tpve.vec)
	tpve.sd <- sd(tpve.vec)

	return(list(Gamma.mean=Gamma.mean, Gamma.cov=Gamma.cov, Psi.mean=Psi.mean, Psi.cov=Psi.cov, Psi.Gamma.dist=Psi.Gamma.dist, tpve.mean=tpve.mean, tpve.sd=tpve.sd, mean.effect.matrix=mean.effect.matrix))
	
}



comp.Psi.log.prior <- function(Psi, Psi.local.shrinkage, star.taus) {

	Psi.variances <- 1/t(t(Psi.local.shrinkage) * star.taus)	
	Psi.log.prior <- sum(dnorm(x=Psi, mean=0, sd=sqrt(Psi.variances), log=T))

	return(Psi.log.prior)
}


comp.Gamma.log.prior <- function(Gamma, Gamma.local.shrinkage, star.taus) {

	Gamma.variances <-  1/(Gamma.local.shrinkage * star.taus)
	Gamma.log.prior <- sum(dnorm(x=Gamma, mean=0, sd=sqrt(Gamma.variances), log=T))
	return(Gamma.log.prior)
}

comp.log.likelihood.proportional.Psi.Gamma <- function(Psi, Gamma, variances, genotypes, phenotypes) {

	n.inds <- nrow(genotypes)

	matrix.of.means <- genotypes %*% Psi %*% Gamma
	
	matrix.of.variances <- matrix(rep(variances, each=n.inds), nrow=n.inds, byrow=FALSE)

	log.likelihood <- sum(-0.5 / matrix.of.variances * (phenotypes-matrix.of.means)^2)

	#log.likelihood <- sum(dnorm(x=t(phenotypes), mean=t(matrix.of.means), sd=sqrt(variances), log=T))
	
	return(log.likelihood)
}

