################################################################################
# This file is a part of the following submission:
#   "Linear state-space model with time-varying dynamics"
#   Luttinen, Raiko, Ilin (ECML 2014)
#
# Copyright (C) 2013-2014 Jaakko Luttinen
#
# This file is licensed under Version 3.0 of the GNU General Public
# License. See LICENSE for a text of the license.
################################################################################

"""
An experiment with NCDC's dataset GSOD using variants of the LSSM
"""

import numpy as np
import scipy
import matplotlib.pyplot as plt

import datetime

import bayespy.plot as bpplt

from bayespy.utils import misc
from bayespy.utils import random

import gsod

import sys, getopt

from bayespy.demos import lssm, lssm_sd, lssm_tvd


def get_data(seed=42, variable='temperature', remove_mean=False):
    """
    Return GSOD data from the European area in 2000-2009.
    """
    
    print("seed = ", seed)
    np.random.seed(seed)

    # Load Global Summary of Day dataset
    (y, coordinates, time) = gsod.load_data(begin_year=2000,
                                            end_year=2009,
                                            remove_unlocated=True,
                                            nan_threshold=0.2,
                                            variable=variable)

    # Pick approximately European stations
    c = coordinates
    ind = (c[:,0]>35) & (c[:,0]<72) & (c[:,1]>-13) & (c[:,1]<33)
    y = y[ind,:]
    coordinates = coordinates[ind,:]
    
    (M, N) = np.shape(y)
    print("Dataset size:", M, N)

    if remove_mean:
        # Remove mean from each station. This ignores Nans.
        y = y - misc.mean(y, axis=-1, keepdims=True)

    # Draw randomly 300 2-day gaps (these non-overlapping gaps must have at
    # least 10 days between)
    amount = 300
    length = 2
    gap = 10
    gaps = random.intervals(N, length, amount=amount, gap=gap)
    # Mark these gaps missing
    mask_gaps = misc.trues((M,N))
    for gap in gaps:
        mask_gaps[:,gap:(gap+(length))] = False

    # Randomly missing values
    mask_random = np.logical_or(random.mask(M, N, p=0.8),
                                ~mask_gaps)
    
    return (y, mask_gaps, mask_random, time, coordinates)


def run(D=80, K=6, rotate=True, maxiter=200, seed=42, debug=False,
        precompute=False, dynamics='constant', plot=True, autosave=False,
        monitor=True, variable='temperature'):
    """
    Run the experiment for one method
    """
    
    # Seed for random number generator
    if seed is not None:
        np.random.seed(seed)

    # Get data for the experiment
    (y, mask_gaps, mask_random, _, _) = get_data(seed=seed, 
                                                 variable=variable,
                                                 remove_mean=(variable!='temperature'))
    # Remove missing values from the test sets
    mask_data = ~np.isnan(y)
    mask_gaps = np.logical_or(mask_gaps, ~mask_data)
    mask_random = np.logical_or(mask_random, ~mask_data)
    mask = np.logical_and(mask_data,
                          np.logical_and(mask_gaps, mask_random))
    f = y.copy()
    y[~mask] = np.nan # BayesPy doesn't require NaNs, they're just for plotting.

    if autosave:
        if dynamics == 'varying':
            filename = ('results/experiment=gsod_variable=%s_dynamics=varying_seed=%d_D=%d_K=%d_date=%s.hdf5' 
                        % (variable,
                           seed,
                           D,
                           K,
                           datetime.datetime.today().strftime('%Y%m%d%H%M%S')))
        elif dynamics == 'switching':
            filename = ('results/experiment=gsod_variable=%s_dynamics=switching_seed=%d_D=%d_K=%d_date=%s.hdf5' 
                        % (variable,
                           seed,
                           D,
                           K,
                           datetime.datetime.today().strftime('%Y%m%d%H%M%S')))
        elif dynamics == 'constant':
            filename = ('results/experiment=gsod_variable=%s_dynamics=constant_seed=%d_D=%d_date=%s.hdf5' 
                        % (variable,
                           seed,
                           D,
                           datetime.datetime.today().strftime('%Y%m%d%H%M%S')))
        else:
            raise ValueError("Unknown dynamics requested")
    else:
        filename = None

    # Run the method
    if dynamics == 'switching':
        Q = lssm_sd.infer(y, D, K,
                          mask=mask, 
                          maxiter=maxiter,
                          rotate=rotate,
                          debug=debug,
                          update_hyper=10,
                          autosave=filename,
                          plot_C=False,
                          monitor=monitor)
    elif dynamics == 'varying':
        Q = lssm_tvd.infer(y, D, K,
                           mask=mask, 
                           maxiter=maxiter,
                           rotate=rotate,
                           debug=debug,
                           precompute=precompute,
                           update_hyper=10,
                           start_rotating_weights=30,
                           autosave=filename,
                           plot_C=False,
                           monitor=monitor)
    elif dynamics == 'constant':
        Q = lssm.infer(y, D, 
                       mask=mask, 
                       maxiter=maxiter,
                       rotate=rotate,
                       debug=debug,
                       precompute=precompute,
                       update_hyper=5,
                       autosave=filename,
                       plot_C=False,
                       monitor=monitor)
    else:
        raise ValueError("Unknown dynamics requested")

    if plot:
        # Plot posterior distributions
        Q.plot()
        plt.show()

    # Compute RMSE
    Q['Y'].update()
    rmse_gaps = misc.rmse(Q['Y'].get_moments()[0][~mask_gaps],
                           f[~mask_gaps])
    rmse_random = misc.rmse(Q['Y'].get_moments()[0][~mask_random],
                           f[~mask_random])
    print("RMSE for gap values: %f" % rmse_gaps)
    print("RMSE for random values: %f" % rmse_random)

    return


if __name__ == '__main__':
    import sys, getopt, os
    try:
        opts, args = getopt.getopt(sys.argv[1:],
                                   "",
                                   [
                                       "d=",
                                       "k=",
                                       "dynamics=",
                                       "variable=",
                                       "seed=",
                                       "maxiter=",
                                       "autosave",
                                       "debug",
                                       "precompute",
                                       "no-plot",
                                       "no-monitor",
                                       "no-rotation"
                                   ])
    except getopt.GetoptError:
        print('python gsod_experiment.py <options>')
        print('--m=<INT>           Dimensionality of data vectors')
        print('--n=<INT>           Number of data vectors')
        print('--d=<INT>           Dimensionality of the latent vectors in the model')
        print('--k=<INT>           Dimensionality of the latent mixing weights space')
        print('--dynamics=...      [constant] / switching / varying')
        print('--variable=...      temperature / pressure / seapressure')
        print('--no-rotation       Do not apply speed-up rotations')
        print('--no-monitor        Do not plot distributions during VB learning')
        print('--autosave          Auto save results during iteration')
        print('--maxiter=<INT>     Maximum number of VB iterations')
        print('--seed=<INT>        Seed (integer) for the random number generator')
        print('--debug             Check that the rotations are implemented correctly')
        print('--no-plot           Do not plot stuff')
        print('--precompute        Precompute some moments when rotating. May '
              'speed up or slow down.')
        sys.exit(2)

    kwargs = {}
    for opt, arg in opts:
        if opt == "--no-rotation":
            kwargs["rotate"] = False
        elif opt == "--maxiter":
            kwargs["maxiter"] = int(arg)
        elif opt == "--debug":
            kwargs["debug"] = True
        elif opt == "--no-monitor":
            kwargs["monitor"] = False
        elif opt == "--autosave":
            kwargs["autosave"] = True
        elif opt == "--precompute":
            kwargs["precompute"] = True
        elif opt == "--variable":
            kwargs["variable"] = arg
        elif opt == "--seed":
            kwargs["seed"] = int(arg)
        elif opt == "--d":
            kwargs["D"] = int(arg)
        elif opt == "--k":
            if int(arg) == 0:
                kwargs["K"] = None
            else:
                kwargs["K"] = int(arg)
        elif opt == "--dynamics":
            kwargs["dynamics"] = arg
        elif opt == "--no-plot":
            kwargs["plot"] = False
        else:
            raise ValueError("Unhandled argument given")

    run(**kwargs)

