rm(list=ls())
## UPDATE THIS PATH BASED ON WHERE YOU HAVE SAVED THE CODE
base.path <- 'C:/Work/rich_phenotype/reduced_rank_regression/gene_metabolome/'


### load all codes
library(MASS)
library(fields)
setwd(base.path)
source(paste(base.path, 'brrr/full_low_rank_brr.R', sep = ''))
source(paste(base.path, 'sparse_fa/sparse_fa.R', sep = ''))
source(paste(base.path, 'infinite_brr/infinite_brr.R', sep = ''))
source(paste(base.path, 'bayes_lm/bayes_lm.R', sep = ''))
source(paste(base.path, 'common/mvrnorm_own.R', sep = ''))
source(paste(base.path, 'common/PosDef.R', sep = ''))
source(paste(base.path, 'common/auxiliary_functions.R', sep = ''))
source(paste(base.path, 'common/preprocess_data.R', sep = ''))



## first generate simulated data
# "true" parameters for simulated data
n.patients <- 3000
n.pheno <- 25
n.snps <- 15
noise.model.rank <- 10
brrr.rank <- 2   # brrr.rank<-1 in the GWA analysis in the article.
# Assme that there are 2 known confounders.
# Parameters for their effects will be learnt.
n.confounders <- 2 

# simulate a model from the prior
true.model <- initialize.from.prior(n.pheno=n.pheno, n.snps=n.snps, n.patients=n.patients, fa.rank=noise.model.rank, brr.rank=brrr.rank, n.confounders=n.confounders)

# simulate data from the true model
data <- simulate.from.full.low.rank.brr(true.model)
data$total.unpruned.snp.variation <- sum( apply( scale(data$genotypes), 2, var))

# preprocess data as in the paper
data <- preprocess.data(raw.data=data, n.snps.to.keep=n.snps, permutation = NULL)

print('Data preprocessed')

# INFORMATIVE PRIOR for PTVE
# These values correspond to the prior expectation that the first component
# explains [0.3-0.999] of the total variation with probability 0.98.
a4.shape=4.1
a4.rate=0.31
a4.lower.bound=3

# Select parameters for the prior distribution of a3 such that the expected 
# total proportion of variation explained has the specified quantiles:
tpve.quantiles <- c(0.5, 0.99)

# The following values give the shrinkage priors that were used in the genome-wide analysis presented in the article (notice that in the code abbreviation TPVE is used for the total proportion of variation explained, contrary to the article, where abbreviation PTVE was used.)
#tpve.quantile.values <- c(0.000001, 0.001)

# The following values are obtained by fixing the median PTVE to the true PTVE, and they are meant to be used only in the present toy data illustration. (The difference in results is that with the shrinkage priors, the effects are more biased towards zero, as expected)
ptve.data <- sum(diag(var(data$genotypes %*% true.model$brr$context$Psi%*%true.model$brr$context$Gamma))) / sum(diag(var(data$phenotypes)))
tpve.quantile.values <- c(ptve.data, ptve.data^0.7)


# Limits tell where a3 must lie for mean TPVE to lie in the specified interval.
limits <- compute.a3.interval(tpve=tpve.quantile.values, a4.shape=a4.shape, a4.rate=a4.rate, a4.lower.bound=a4.lower.bound, a3.lower.bound=2, local.shrinkage.nu=3, snp.total.variation=data$total.unpruned.snp.variation, phenotype.avg.variance=mean(apply(data$phenotypes,2,var)))
a3.quantiles <- 1-tpve.quantiles[c(2,1)]
ans <- determine.gamma.pars(quantiles=a3.quantiles, values=c(limits$a3.min, limits$a3.max))
a3.shape <- ans$shape
a3.rate <- ans$rate
a3.lower.bound <- 2
a3.init.value <- 3000
a4.init.value <- 4.5

				
# Simulate from the prior distribution of TPVE have a look 
# at the distribution:
# tpve.prior.samples <- simulate.tpve.from.prior(a3.shape=a3.shape, a3.rate=a3.rate, a3.lower.bound=a3.lower.bound, a4.shape=a4.shape, a4.rate=a4.rate, a4.lower.bound=a4.lower.bound, n.simulations=100, brr.rank=brrr.rank, data)


# Informative initialization of the Bayesian reduced rank regression model.
# rare.maf.threshold is used for bookkeeping: keep track of the proportion of variance explained by SNPs with MAF less than 'rare.maf.threshold'
init.brr.model <- initialize.informative.brr(data=data, a3.init.value=a3.init.value, a4.init.value=a4.init.value, local.shrinkage.nu=3, a3.shape=a3.shape, a3.rate=a3.rate, a3.lower.bound=a3.lower.bound, a4.shape=a4.shape, a4.rate=a4.rate, a4.lower.bound=a4.lower.bound, brr.factor.relevance.cutoff=0.01, alpha0=-1, alpha1=-5E-4, a.sigma=1, b.sigma=1, brr.rank=brrr.rank, input.clustering=NULL, initialize.Psi.Gamma.proposal=F, adaptation.interval=100, adaptation.intervals.total=10, rare.maf.threshold=0.01)


# Informative initialization of the FA-part
# (the factor analysis model used as the noise model)
fa <- initialize.fa.from.prior(a1.shape=18, a1.rate=2, a1.lower.bound=2, a2.shape=18, a2.rate=2, a2.lower.bound=3, a.sigma=2.2, b.sigma=0.3, local.shrinkage.nu=3, factor.relevance.cutoff=0.001, alpha0=-1, alpha1=-0.005, rank=10, n.patients=nrow(data$genotypes), n.pheno=ncol(data$phenotypes))
# Update FA-model until convergence
Y <- data$phenotypes - data$genotypes %*% init.brr.model$context$Psi %*% init.brr.model$context$Gamma
init.fa.model <- sparse.fa.gibbs(n.iter=200, fa$context, fa$prior, Y)
init.brr.model$prior$variances <- init.fa.model$context$variances


# Initialize full Bayesian reduced rank regression model (combining the noise
# model to the reduced rank regression model) for MCMC
init.model <- list()
init.model$fa$context <- init.fa.model$context
init.model$fa$prior <- fa$prior
init.model$brr$context <- init.brr.model$context
init.model$brr$prior <- init.brr.model$prior
init.model$A <- NA
print('model initialized, tpve prior sampled')


# Run MCMC to learn the full BRRR model from data
tX <- proc.time()[3]
n.iterations <- 1000
thinning <- 5
mcmc.output <- gibbs.full.low.rank.brr(model=init.model, data=data, n.iter=n.iterations, thin=thinning, fixed.brr.rank=brrr.rank, brr.vars.to.record=c('Gamma', 'Psi', 'a3a4', 'brr.rank', 'maf.group.tpve'), simple.brr.update = TRUE, fix.gamma.iteration = n.iterations)

tY <- proc.time()[3]
print(paste('Gibbs run time: ', tY-tX, sep=''))

# remove burnin
mcmc.output <- remove.burnin(mcmc.output=mcmc.output, burnin=round(n.iterations/thinning/2))

# plot comparison of true and estimated coefficient matrix
comp <- check.mcmc.result(true.model$brr$context, mcmc.output=mcmc.output, name='coefMat')


# Trace of PTVE
ptve.samples <- unlist(mcmc.output$traces$tpve)

# Mean proportion of total variation explained by rare variants
ptve.rare <- mean(unlist(lapply(mcmc.output$traces$maf.group.tpve, function(x){x[['rare']]})))