#!/usr/bin/env python

######################################################################
# Copyright (C) 2013 Jaakko Luttinen
#
# This file is licensed under Version 3.0 of the GNU General Public
# License.
######################################################################

######################################################################
# This script contains experiments for the paper
# 'Fast Variational Bayesian Linear State-Space Model'
# by Jaakko Luttinen (ECML 2013).
# The following experiments were reported in the paper:
# 'small' and 'testbed'
#
# See usage:
# ./run_experiment.py --help
#
# For instance:
# ./run_experiment.py --experiment=testbed --rotate
######################################################################

"""
Demonstrate speed-up rotations for the linear state-space model.
"""

import numpy as np
import scipy
import matplotlib.pyplot as plt

import sys, getopt, os

import datetime

from bayespy.inference.vmp.nodes.gaussian_markov_chain import GaussianMarkovChain
from bayespy.inference.vmp.nodes.gaussian import Gaussian
from bayespy.inference.vmp.nodes.gamma import Gamma
from bayespy.inference.vmp.nodes.normal import Normal
from bayespy.inference.vmp.nodes.dot import Dot
from bayespy.inference.vmp.nodes.gamma import diagonal

from bayespy.utils import utils
from bayespy.utils import random

from bayespy.inference.vmp.vmp import VB
from bayespy.inference.vmp import transformations

import bayespy.plot.plotting as bpplt

import testbed
#import mohsst5

import h5py

def linear_state_space_model(D=3, N=100, M=10):
    """
    Construct the linear state-space model using BayesPy.
    """

    # Dynamics matrix with ARD
    alpha = Gamma(1e-5,
                  1e-5,
                  plates=(D,),
                  name='alpha')
    A = Gaussian(np.zeros(D),
                 diagonal(alpha),
                 plates=(D,),
                 name='A')

    # Latent states with dynamics
    X = GaussianMarkovChain(np.zeros(D),         # mean of x0
                            1e-3*np.identity(D), # prec of x0
                            A,                   # dynamics
                            np.ones(D),          # innovation
                            n=N,                 # time instances
                            name='X',
                            initialize=False)
    X.initialize_from_value(np.zeros((N,D))) # just some empty values, X is
                                             # updated first anyway

    # Mixing matrix from latent space to observation space using ARD
    gamma = Gamma(1e-5,
                  1e-5,
                  plates=(D,),
                  name='gamma')
    C = Gaussian(np.zeros(D),
                 diagonal(gamma),
                 plates=(M,1),
                 name='C')
    # Initialize nodes (must use some randomness for C, and update X before C)
    C.initialize_from_random()

    # Observation noise
    tau = Gamma(1e-5,
                1e-5,
                name='tau')

    # Observations
    CX = Dot(C, 
             X.as_gaussian(),
             name='CX')
    Y = Normal(CX,
               tau,
               name='Y')

    Q = VB(X, C, gamma, A, alpha, tau, Y)

    return Q

def run_lssm_experiment(y, mask, Q, filename, rotate=True, maxiter=100, debug=False):
    """
    Run an experiment using a given dataset and model.
    """

    # Training and test masks, ignore nans
    nan_mask = np.isnan(y)
    test_mask = np.logical_and(~mask, ~nan_mask)
    mask = np.logical_and(mask, ~nan_mask)

    # Observe the data
    Q['Y'].observe(y, mask=mask)
    
    #
    # Run variational Bayesian inference
    #

    def compute_errors():
        # Training RMSE
        yh = Q['Y'].parents[0].get_moments()[0]
        train_rmse = np.sqrt(np.mean((yh[mask]-y[mask])**2)) 
        test_rmse = np.sqrt(np.mean((yh[test_mask]-y[test_mask])**2)) 
        return (train_rmse, test_rmse)
    
    Q.set_autosave(filename, iterations=10)
    Q.callback = compute_errors

    if not rotate:
        #
        # Run inference without rotations.
        #
        Q.update(repeat=maxiter)
    else:
        #
        # Run inference with rotations.
        #

        
        D = Q['A'].plates[0]
        # Rotator of the state dynamics matrix
        rotA = transformations.RotateGaussianARD(Q['A'], Q['alpha'])
        # Rotator of the states (includes rotation of the state dynamics matrix)
        rotX = transformations.RotateGaussianMarkovChain(Q['X'], Q['A'], rotA)
        # Rotator of the loading matrix
        rotC = transformations.RotateGaussianARD(Q['C'], Q['gamma'])
        # Rotation optimizer
        R = transformations.RotationOptimizer(rotX, rotC, D)

        if debug:
            # Parameters to use if you want to debug the rotation
            kwargs = {
                'check_bound': Q.compute_lowerbound,
                'check_bound_terms': Q.compute_lowerbound_terms,
                'check_gradient': True
                }
        else:
            kwargs = {}
            
        for ind in range(maxiter):
            # Rotate after each update
            Q.update()
            R.rotate(maxiter=10, **kwargs)

    #
    # Examine the results
    #

    plt.figure()
    plt.subplot(2,1,1)
    plt.semilogx(Q.L)
    plt.title('Lower bound')
    plt.subplot(2,1,2)
    plt.semilogx(Q.callback_output.T)
    plt.legend(['Training set', 'Test set'])
    plt.title('RMSE')


def compare_experiment_results(*filenames):
    """
    Plot comparison of the results: lowerbound, training RMSE and test RMSE.
    """
    plt.figure()
    for filename in filenames:
        f = h5py.File(filename, 'r')
        try:
            plt.subplot(3,1,1)
            plt.semilogx(f['L'])
            plt.title('Lower bound')
            plt.subplot(3,1,2)
            plt.semilogx(f['callback_output'][0])
            plt.title('Training RMSE')
            plt.subplot(3,1,3)
            plt.semilogx(f['callback_output'][1])
            plt.title('Test RMSE')
        finally:
            f.close()
    
def load_results(filename):

    # Read the dimensionalities from the file
    f = h5py.File(filename, 'r')
    try:
        (M, N) = f['nodes']['Y']['phi0'].shape
        D = f['nodes']['X']['phi0'].shape[-1]
    
        Q = linear_state_space_model(D=D, N=N, M=M)

        Q.set_autosave(filename)

        Q.load()

    finally:
        f.close()

    return Q


def small_experiment(rotate=True, D=8, **kwargs):
    """
    A small artificial dataset with a noisy oscillator and random walk.
    """

    #
    # Simulate some data
    #
    
    M = 30
    N = 400
    w = 0.3
    a = np.array([[np.cos(w), -np.sin(w), 0, 0], 
                  [np.sin(w), np.cos(w),  0, 0], 
                  [0,         0,          1, 0],
                  [0,         0,          0, 0]])
    c = np.random.randn(M,4)
    x = np.empty((N,4))
    f = np.empty((M,N))
    y = np.empty((M,N))
    x[0] = 10*np.random.randn(4)
    f[:,0] = np.dot(c,x[0])
    y[:,0] = f[:,0] + 3*np.random.randn(M)
    for n in range(N-1):
        x[n+1] = np.dot(a,x[n]) + np.random.randn(4)
        f[:,n+1] = np.dot(c,x[n+1])
        y[:,n+1] = f[:,n+1] + 3*np.random.randn(M)

    # Add missing values randomly (keep only 20%)
    mask = random.mask(M, N, p=0.2)
    # Add missing values to a period of time
    #mask[:,100:200] = False

    #
    # Create the model
    #
    
    Q = linear_state_space_model(D=D, N=N, M=M)

    filename = ('results/experiment=small_rotations=%s_date=%s.hdf5' 
                % (rotate,
                   datetime.datetime.today().strftime('%Y%m%d%H%M%S')))
    
    run_lssm_experiment(y, mask, Q, filename, rotate=rotate, **kwargs)

def large_experiment(rotate=True, D=10, **kwargs):
    """
    Large artificial experiment.

    The dataset is generated by sampling the parameters of the model randomly.
    """

    #
    # Simulate some data
    #
    
    N = 1000
    M = 100
    w = 0.3
    #a = 0.9*np.diag(np.random.rand(D)) + 0.1*random.svd(np.random.rand(D))
    (U,S,V) = np.linalg.svd(np.identity(D) + 0.2*np.random.randn(D,D))
    a = np.dot(U, V.T)
    c = np.random.randn(M,D)
    x = np.empty((N,D))
    f = np.empty((M,N))
    y = np.empty((M,N))
    x[0] = 10*np.random.randn(D)
    f[:,0] = np.dot(c,x[0])
    y[:,0] = f[:,0] + 3*np.random.randn(M)
    for n in range(N-1):
        x[n+1] = np.dot(a,x[n]) + np.random.randn(D)
        f[:,n+1] = np.dot(c,x[n+1])
        y[:,n+1] = f[:,n+1] + 3*np.random.randn(M)

    # Add missing values randomly
    mask = random.mask(M, N, p=0.2)
    # Add missing values to a period of time
    mask[:,300:500] = False

    #
    # Create the model
    #
    
    Q = linear_state_space_model(D=D, N=N, M=M)

    filename = ('results/experiment=large_rotations=%s_date=%s.hdf5' 
                % (rotate,
                   datetime.datetime.today().strftime('%Y%m%d%H%M%S')))
    
    run_lssm_experiment(y, mask, Q, filename, rotate=rotate, **kwargs)

def testbed_experiment(rotate=True, D=10, **kwargs):
    """
    An experiment using Helsinki Testbed temperature data.
    """

    #
    # Simulate some data
    #
    
    # Load data
    (y, _, _, _) = testbed.load_data()
    (M, N) = np.shape(y)

    # Create training and testing sets
    mask = random.mask(M, N, p=0.8)
    for ind in range(0, N, 6*24*10):
        mask[:,ind:(ind+6*24)] = False

    #
    # Create the model
    #

    Q = linear_state_space_model(D=D, N=N, M=M)

    filename = ('results/experiment=testbed_rotations=%s_date=%s.hdf5' 
                % (rotate,
                   datetime.datetime.today().strftime('%Y%m%d%H%M%S')))

    run_lssm_experiment(y, mask, Q, filename, 
                        rotate=rotate,
                        **kwargs)

    
def mohsst5_experiment(rotate=True, D=20, **kwargs):
    """
    An experiment using MOHSST5 sea-surface temperature data.

    Unfortunately, the data cannot be shared along this code because of the
    license.
    """

    #
    # Simulate some data
    #
    
    # Load data
    (y, time) = mohsst5.load_data()
    (M, N) = np.shape(y)

    # Create training and testing sets
    mask = random.mask(M, N, p=0.7)

    #
    # Create the model
    #

    Q = linear_state_space_model(D=D, N=N, M=M)

    filename = ('results/experiment=mohsst5_rotations=%s_date=%s.hdf5' 
                % (rotate,
                   datetime.datetime.today().strftime('%Y%m%d%H%M%S')))
    
    run_lssm_experiment(y, mask, Q, filename, 
                        rotate=rotate,
                        **kwargs)

    

if __name__ == '__main__':
    """
    This makes it possible to run the file as a script.
    """
    
    try:
        opts, args = getopt.getopt(sys.argv[1:],
                                   "erm:",
                                   ["experiment=","rotate","maxiter=","D=","seed=","debug"])
    except getopt.GetoptError:
        print('run_experiment.py <options>')
        print('--experiment=<...>\t The experiment name: small/large/testbed/mohsst5 [default=small]')
        print('--rotate\t Apply rotations')
        print('--maxiter=N\t Maximum number of VB iterations [default=1000]')
        print('--D=N\t Dimensionality of the latent space [default depends on the experiment]')
        print('--seed=N\t Seed (integer) for the random number generator [default=42]')
        print('--debug\t Check that the rotations are implemented correctly')
        sys.exit(2)
        
    rotate = False
    experiment = small_experiment
    kwargs = {'maxiter': 1000}
    seed = 42
    for opt, arg in opts:
        if opt in ("-r", "--rotate"):
            rotate = True
        elif opt in ("-m", "--maxiter"):
            kwargs['maxiter'] = int(arg)
        elif opt in ("-e", "--experiment"):
            if arg == "small":
                experiment = small_experiment
            elif arg == "large":
                experiment = large_experiment
            elif arg == "testbed":
                experiment = testbed_experiment
            elif arg == "mohsst5":
                experiment = mohsst5_experiment
            else:
                raise Exception("Unknown experiment type requested")
        elif opt in ("--D",):
            kwargs['D'] = int(arg)
        elif opt in ("--seed",):
            seed = int(arg)
        elif opt in ("--debug",):
            kwargs['debug'] = True
            
    print("Random number generator seed set to %d." % seed)
    np.random.seed(seed)

    # Create the directory for the results
    os.makedirs('results', exist_ok=True)

    # Run experiment
    experiment(rotate=rotate, **kwargs)

    # Show results
    plt.show()
