################################################################################
# This file is a part of the following submission:
#   "Linear state-space model with time-varying dynamics"
#   Luttinen, Raiko, Ilin (ECML 2014)
#
# Copyright (C) 2013-2014 Jaakko Luttinen
#
# This file is licensed under Version 3.0 of the GNU General Public
# License. See LICENSE for a text of the license.
################################################################################

"""
An experiment with 1-D signal which has changing frequency.
"""

import numpy as np
import scipy
import matplotlib.pyplot as plt

from bayespy.utils import misc
from bayespy.utils import random

from bayespy.inference.vmp.vmp import VB
from bayespy.inference.vmp import transformations

import bayespy.plot as bpplt

from bayespy.demos import (lssm,
                           lssm_sd,
                           lssm_tvd)


def simulate_data(N):
    """
    Simulate 1-D signal with changing frequency
    """

    t = np.arange(N)
    a = 0.1 * 2*np.pi  # base frequency
    b = 0.01 * 2*np.pi # frequency of the frequency change
    c = 8              # "magnitude" of the frequency change
    f = np.sin( a * (t + c*np.sin(b*t)) )
    y = f + 0.1*np.random.randn(N)

    return (y, f)


def plot(N=1000, D=5, K=4, seed=42):
    """
    Plot saved results for all the compared methods
    """

    # Seed for random number generator
    np.random.seed(seed)

    # Create data
    (y, f) = simulate_signal(N)

    # Locations of the gaps
    lines = []
    for m in range(100, N, 140):
        lines.append(m)
        lines.append(m+15)
        
    def make_plot(filename, F=None):
        # Plot data
        plt.figure(figsize=(40,3))
        if F:
            bpplt.timeseries_normal(F, scale=2)
        bpplt.timeseries(f, linestyle='-', color='r')
        plt.ylim([-1.5, 1.5])
        plt.yticks([])
        plt.xticks([])
        for line in lines:
            plt.axvline(line, color='k')
        plt.savefig(filename, bbox_inches='tight')


    # Plot data
    make_plot('fig_toy_data.pdf')

    # LSSM
    Q = lssm.model(M, N, D)
    Q.load(filename='results_toy_lssm.hdf5')
    make_plot('fig_toy_lssm.pdf', F=Q['F'])

    # DLSSM
    Q = lssm_tvd.model(M, N, D, K)
    Q.load(filename='results_toy_lssm-tvd.hdf5')
    make_plot('fig_toy_lssm-tvd.pdf', F=Q['F'])

    # SLSSM
    Q = lssm_sd.model(M=M, N=N, D=D, K=K)
    Q.load(filename='results_toy_lssm-sd.hdf5')
    make_plot('fig_toy_lssm-sd.pdf', F=Q['F'])

    return


def run(N=1000, D=5, K=4, seed=42, maxiter=100, rotate=True, debug=False,
        autosave=False, precompute=False, dynamics='constant', plot=True,
        monitor=True): 
    """ Run experiment for one method """

    # Seed for random number generator
    np.random.seed(seed)

    # Create data
    (y, f) = simulate_data(N)

    # Create some gaps
    mask_gaps = misc.trues(N)
    for m in range(100, N, 140):
        start = m
        end = min(m+15, N-1)
        mask_gaps[start:end] = False
    # Missing values
    mask_random = np.logical_or(random.mask(N, p=0.8),
                                np.logical_not(mask_gaps))
    # Remove the observations
    mask = np.logical_and(mask_gaps, mask_random)
    # Remove the observations
    y[~mask] = np.nan # BayesPy doesn't require NaNs, they're just for plotting.

    # Add row axes
    y = y[None,...]
    f = f[None,...]
    mask = mask[None,...]
    mask_gaps = mask_gaps[None,...]
    mask_random = mask_random[None,...]
    
    # Plot observations
    if plot:
        plt.figure()
        bpplt.timeseries(f, linestyle='-', color='b')
        bpplt.timeseries(y, linestyle='None', color='r', marker='.')
        plt.ylim([-2, 2])
    
    # Run the method
    if dynamics == 'switching':
        if autosave:
            filename = 'results_toy_lssm-sd.hdf5'
        else:
            filename = None
        Q = lssm_sd.infer(y, D, K,
                          mask=mask, 
                          maxiter=maxiter,
                          rotate=rotate,
                          debug=debug,
                          update_hyper=10,
                          autosave=filename,
                          monitor=monitor)
    elif dynamics == 'varying':
        if autosave:
            filename = 'results_toy_lssm-tvd.hdf5'
        else:
            filename = None
        Q = lssm_tvd.infer(y, D, K,
                           mask=mask, 
                           maxiter=maxiter,
                           rotate=rotate,
                           debug=debug,
                           precompute=precompute,
                           update_hyper=10,
                           start_rotating_weights=10,
                           autosave=filename,
                           monitor=monitor)
    elif dynamics == 'constant':
        if autosave:
            filename = 'results_toy_lssm.hdf5'
        else:
            filename = None
        Q = lssm.infer(y, D, 
                       mask=mask, 
                       maxiter=maxiter,
                       rotate=rotate,
                       debug=debug,
                       precompute=precompute,
                       update_hyper=10,
                       autosave=filename,
                       monitor=monitor)
    else:
        raise ValueError("Unknown dynamics requested")
            

    # Compute RMSE
    rmse_random = misc.rmse(Q['Y'].get_moments()[0][~mask_random], 
                             f[~mask_random])
    rmse_gaps = misc.rmse(Q['Y'].get_moments()[0][~mask_gaps],
                           f[~mask_gaps])
    print("RMSE for randomly missing values: %f" % rmse_random)
    print("RMSE for gap values: %f" % rmse_gaps)

    if plot:
        # Plot observations
        plt.figure()
        bpplt.timeseries_normal(Q['F'], scale=2)
        bpplt.timeseries(f, linestyle='-', color='b')
        bpplt.timeseries(y, linestyle='None', color='r', marker='.')
        plt.ylim([-2, 2])
    
        # Plot distributions
        Q.plot()

        plt.show()


if __name__ == '__main__':
    import sys, getopt, os
    try:
        opts, args = getopt.getopt(sys.argv[1:],
                                   "",
                                   [
                                       "n=",
                                       "d=",
                                       "k=",
                                       "seed=",
                                       "autosave",
                                       "dynamics=",
                                       "maxiter=",
                                       "debug",
                                       "precompute",
                                       "plot",
                                       "no-monitor",
                                       "no-rotation"
                                   ])
    except getopt.GetoptError:
        print('python toy_experiment.py <options>')
        print('--n=<INT>        Number of time instance')
        print('--d=<INT>        Dimensionality of the latent space')
        print('--k=<INT>        Number of dynamics matrices')
        print('--seed=<INT>     Seed for the random number generator')
        print('--dynamics=...   [constant] / switching / varying')
        print('--no-rotation    Do not apply speed-up rotations')
        print('--maxiter=<INT>  Maximum number of VB iterations')
        print('--autosave       Save the results')
        print('--debug          Debug the computation of the rotation')
        print('--plot           Only plot existing results')
        print('--no-monitor     Do not monitor variables during VB learning')
        print('--precompute     Precompute some moments when rotating. May '
              'speed up or slow down.')
        sys.exit(2)

    print("By default, this function uses constant dynamics and speed-up "
          "rotations")
    print("You may also choose --dynamics=switching or --dynamics=varying")
    print("See --help for more help")

    kwargs = {}
    do_plotting = False
    for opt, arg in opts:
        if opt == "--no-rotation":
            kwargs["rotate"] = False
        elif opt == "--maxiter":
            kwargs["maxiter"] = int(arg)
        elif opt == "--debug":
            kwargs["debug"] = True
        elif opt == "--autosave":
            kwargs["autosave"] = True
        elif opt == "--precompute":
            kwargs["precompute"] = True
        elif opt == "--dynamics":
            kwargs["dynamics"] = arg
        elif opt == "--plot":
            do_plotting = True
        elif opt == "--no-monitor":
            kwargs["monitor"] = False
        elif opt == "--n":
            kwargs["N"] = int(arg)
        elif opt == "--d":
            kwargs["D"] = int(arg)
        elif opt == "--k":
            kwargs["K"] = int(arg)
        elif opt == "--seed":
            kwargs["seed"] = int(arg)
        else:
            raise ValueError("Unhandled argument given")

    if do_plotting:
        plot()
    else:
        run(**kwargs)
