################################################################################
# This file is a part of the following submission:
#   "Linear state-space model with time-varying dynamics"
#   Luttinen, Raiko, Ilin (ECML 2014)
#
# Copyright (C) 2013-2014 Jaakko Luttinen
#
# This file is licensed under Version 3.0 of the GNU General Public
# License. See LICENSE for a text of the license.
################################################################################

"""
An experiment with stochastic advection-diffusion process.

The observed process is a stochastic advection-diffusion spatio-temporal
process. The velocity field changes in time so the classical linear state-space
model with constant dynamics is not able to learn the process accurately.  The
linear state-space model with switching dynamics is also unable to learn the
dynamics well.  The proposed linear state-space model with time-varying dynamics
is able to capture the dynamics almost perfectly and to find latent processes
that affect the dynamics.
"""

import numpy as np
import scipy
import matplotlib.pyplot as plt
import datetime

from bayespy.utils import misc
from bayespy.utils import random

import bayespy.plot as bpplt

from bayespy.utils.covfunc.covariance import covfunc_se as covfunc

from bayespy.demos import (lssm,
                           lssm_sd,
                           lssm_tvd)


def simulate_process(M=100, N=100, T=100, velocity=1e-3, diffusion=1e-5,
                     lengthscale=0.6, noise=1e0, decay=0.9995, verbose=True):
    """
    Simulate stochastic advection-diffusion PDE on a unit torus.

    The boundaries of the unit square are cylindrical, that is, the domain is a
    torus as a two-dimensional manifold.
    """

    # Discretization
    xh = np.arange(M) / M
    yh = np.arange(N) / N
    dy = 1.0/M
    dx = 1.0/N
    dt = 1.0/T

    # Total number of spatial grid points
    MN = M*N
    
    # The covariance of the spatial innovation noise
    Kx = covfunc(1.0,
                 lengthscale, 
                 np.array([np.sin(2*np.pi*xh),
                           np.cos(2*np.pi*xh)]).T,
                 np.array([np.sin(2*np.pi*xh),
                           np.cos(2*np.pi*xh)]).T)
    Ky = covfunc(1.0,
                 lengthscale,
                 np.array([np.sin(2*np.pi*yh),
                           np.cos(2*np.pi*yh)]).T,
                 np.array([np.sin(2*np.pi*yh),
                           np.cos(2*np.pi*yh)]).T)
    Lx = np.linalg.cholesky(Kx+1e-6*np.identity(M))
    Ly = np.linalg.cholesky(Ky+1e-6*np.identity(N))
    draw_R = lambda : noise * np.ravel(np.dot(Lx, 
                                              np.dot(np.random.randn(M,N), 
                                                     Ly.T)))
    
    # Diffusion
    D = diffusion
    # Source
    R = draw_R()

    # Initial state
    (x, y) = np.meshgrid(np.arange(N), np.arange(M))
    u = np.sin(3*2*np.pi*x/N) * np.sin(3*2*np.pi*y/M)
    u = np.ravel(u)

    # Allocate memory for the process
    U = np.empty((T, M, N))

    # Parameters of the sparse system matrix used for the partial differential
    # equation
    o = np.ones(MN)
    i = np.tile(np.arange(MN), 5)
    j = np.empty(5*MN)
    d = np.empty(5*MN)

    # Initial velocity field
    v = velocity*np.random.randn(2)

    # Brute force way to find the indices for the five different values that
    # appear in the system matrix
    j[:MN] = np.mod(np.arange(MN), MN)
    d[:MN] = 0
    j[MN:2*MN] = np.mod(np.arange(MN)+1, MN)
    d[MN:2*MN] = 1
    j[2*MN:3*MN] = np.mod(np.arange(MN)+N, MN)
    d[2*MN:3*MN] = 2
    j[3*MN:4*MN] = np.mod(np.arange(MN)-1, MN)
    d[3*MN:4*MN] = 3
    j[4*MN:5*MN] = np.mod(np.arange(MN)-N, MN)
    d[4*MN:5*MN] = 4
    A = scipy.sparse.csc_matrix((d, (i,j)))
    ind0 = (A.data == 0)
    ind1 = (A.data == 1)
    ind2 = (A.data == 2)
    ind3 = (A.data == 3)
    ind4 = (A.data == 4)
    
    for t in range(T):

        # Velocity changes
        v = decay*v + np.sqrt(1-decay**2)*velocity*np.random.randn(2)

        # Spatial innovation noise
        R = draw_R()
        
        # Form the system matrix
        A.data[ind0] = 1 + 2*D*dt/dx**2 + 2*D*dt/dy**2
        A.data[ind1] = -(D*dt/dx**2 - v[0]*dt/(2*dx))
        A.data[ind2] = -(D*dt/dy**2 - v[1]*dt/(2*dy))
        A.data[ind3] = -(D*dt/dx**2 + v[0]*dt/(2*dx))
        A.data[ind4] = -(D*dt/dy**2 + v[1]*dt/(2*dy))

        # Solve the system
        u = scipy.sparse.linalg.spsolve(A, u + R)

        # Store the solution
        U[t] = np.reshape(u, (M,N))

        if verbose:
            print('\rSimulating SPDE data... %d%%' % (int(100.0*(t+1)/T)), end="")

    if verbose:
        print('\rSimulating SPDE data... Done.')

    return U


def simulate_data(filename=None, 
                  resolution=30,
                  M=50,
                  N=1000, 
                  diffusion=1e-6,
                  velocity=4e-3,
                  innovation_noise=1e-3,
                  innovation_lengthscale=0.6,
                  noise_ratio=1e-1,
                  decay=0.9997,
                  burnin=1000,
                  verbose=True,
                  thin=20):
    """
    Generate data for the experiment.

    First, simulate stochastic advection-diffusion partial differential
    equation.  Second, discard the beginning of the simulation (burn-in) and
    thin the time series.  Also, corrupt the observations with Gaussian noise.
    """    

    # Simulate the process

    # Because simulate_process simulates a unit square, its time parameter only
    # changes the time resolution but the time length is always 1.  Thus, in
    # order to get the effect of a long time period, we need to multiply
    # parameters by the time length.  However, let the thinning parameter
    # control the temporal resolution.
    diffusion *= (N + burnin/thin)
    velocity *= (N + burnin/thin)
    decay = decay ** (1/thin)
    innovation_noise *= np.sqrt(N)
    U = simulate_process(resolution,
                         resolution,
                         burnin+N*thin, 
                         diffusion=diffusion,
                         velocity=velocity,
                         noise=innovation_noise,
                         lengthscale=innovation_lengthscale,
                         verbose=verbose,
                         decay=decay)

    # Put some stations randomly
    x1x2 = np.random.permutation(resolution*resolution)[:M]
    x1 = np.arange(resolution)[np.mod(x1x2, resolution)]
    x2 = np.arange(resolution)[(x1x2/resolution).astype(int)]

    # Get noisy observations
    U = U[burnin::thin]
    F = U[:, x1, x2].T
    std = np.std(F)
    Y = F + noise_ratio*std*np.random.randn(*np.shape(F))
    X = np.array([x1, x2]).T

    return (U, Y, F, X)


def plot(M=100, N=2000, D=30, K=5,
         filename_tvd='results_ecml2014/demo_dlssm02_seed=5_drift-A=1_drift-C=0_D=30_K=5_date=20140412012554.hdf5',
    #filename_tvd='results_ecml2014/demo_dlssm02_seed=1_drift-A=1_drift-C=0_D=30_K=5_date=20140412012551.hdf5',
         filename_sd='results_ecml2014/demo_dlssm02_seed=5_switch_D=30_K=5_date=20140412012616.hdf5'):
    #filename_sd='results_ecml2014/demo_dlssm02_seed=2_switch_D=30_K=5_date=20140412012603.hdf5'):
    #filename_sd='results_ecml2014/demo_dlssm02_seed=3_switch_D=30_K=5_date=20140412012540.hdf5'):
    #filename_sd='results_ecml2014/demo_dlssm02_seed=1_switch_D=30_K=5_date=20140412012519.hdf5'):
    """
    Plot q(S) of LSSM with time-varying dynamics
    """

    figw = 6 * 0.8
    figh = 5 * 0.8
    
    #
    # Plot S
    #
    
    Q = lssm_tvd.model(M, N, D, K)
    Q.load(filename=filename_tvd)

    plt.figure(figsize=(figw, figh))
    Q['S'].plot()

    plt.suptitle('')

    for i in range(5):
        if i == 0:
            ax0 = plt.subplot(5,1,1)
        else:
            axes = plt.subplot(5,1,i+1, sharex=ax0)
        plt.ylabel('$s_%d(t)$' % (i+1))
        plt.yticks([0])
        lims = plt.ylim()
        if i == 1 or i == 4:
            plt.ylim((-3, 3))
        else:
            plt.ylim((1.2*lims[0], 1.2*lims[1]))

    ax = plt.subplot(5,1,5, sharex=ax0)
    plt.xlabel('time $t$')
    plt.xticks([0, 500, 1000, 1500, 2000])
    #ax.set_xticklabels([0, 500, 1000, 1500, 2000])

    plt.savefig('fig_spde_s.pdf', bbox_inches='tight')

    #
    # Plot switching model Z
    #
    
    Q = lssm_sd.model(M=M, N=N, D=D, K=K)
    Q.load(filename=filename_sd)

    plt.figure(figsize=(figw, figh))
    z = Q['X'].parents[3]._message_to_child()[0]

    for i in range(K):
        plt.subplot(K,1,i+1)
        plt.plot(z[:,i], color='k')
        ## plt.fill_between(np.arange(N-1),
        ##                  0.0 + 0.5*z[:,i],
        ##                  0.5 - 0.5*z[:,i],
        ##                  facecolor=(0.3,0.3,0.3,1),
        ##                  edgecolor=(0,0,0,0),
        ##                  linewidth=0,
        ##                  interpolate=True)
        plt.yticks([0, 1])
        plt.ylim([-0.2, 1.2])
        plt.xticks([0, 500, 1000, 1500, 2000])
        
        plt.ylabel('$p(z_t=%d)$' % (i+1))
        if i < K-1:
            plt.setp(plt.gca().get_xticklabels(), visible=False)

    plt.suptitle('')
    plt.subplots_adjust(hspace=0)


    ## for i in range(5):
    ##     if i == 0:
    ##         ax0 = plt.subplot(5,1,1)
    ##     else:
    ##         axes = plt.subplot(5,1,i+1, sharex=ax0)
    ##     plt.ylabel('$z_%d(t)$' % (i+1))
    ##     ## plt.yticks([0])
    ##     ## lims = plt.ylim()
    ##     ## if i == 2 or i == 3:
    ##     ##     plt.ylim((-3, 3))
    ##     ## else:
    ##     ##     plt.ylim((1.2*lims[0], 1.2*lims[1]))

    ## ax = plt.subplot(5,1,5, sharex=ax0)
    ## plt.xlabel('time $t$')
    ## ax.set_xticklabels([0, 500, 1000, 1500, 2000])
    plt.xlabel('time $t$')

    plt.savefig('fig_spde_z.pdf', bbox_inches='tight')

    return


def run(M=100, N=2000, D=30, K=5, rotate=True, maxiter=200, seed=42,
        debug=False, autosave=False, precompute=False, resolution=30,
        dynamics='constant', animation=False, plot=True, lengthscale=1.0, innovation=1e-4,
        monitor=True, verbose=True):
    
    # Seed for random number generator
    if seed is not None:
        np.random.seed(seed)

    # Velocity field changes
    decay = 1.0 - 5e-3

    # Simulate data
    (U, y, f, X) = simulate_data(M=M, 
                                 N=N,
                                 resolution=resolution,
                                 burnin=2000,
                                 thin=20,
                                 velocity=6e-2,
                                 diffusion=1e-4,
                                 decay=decay,
                                 innovation_noise=innovation,
                                 innovation_lengthscale=lengthscale,
                                 verbose=verbose,
                                 noise_ratio=5e-1)

    if animation:
        # Just show the simulated process and exit
        plt.ion()
        plt.figure(figsize=(6,6))
        plt.plot(X[:,0], X[:,1], 'kx')
        plt.yticks([])
        plt.xticks([])
        plt.tight_layout()
        if seed == 1:
            # Generate a snapshot image for the publication
            vmax = np.max(np.abs(U[1917]))
            x = plt.imshow(U[1917],
                           interpolation='nearest',
                           cmap='RdBu_r',
                           vmin=-vmax,
                           vmax=vmax)
            plt.savefig('fig_spde_snapshot.pdf', bbox_inches='tight')
            plt.ioff()

        # Save the process as a video animation
        plt.title('t = 0')
        plt.tight_layout()
        animation = bpplt.matrix_animation(U)
        bpplt.save_animation(animation, 'video_spde_%d.mp4' % seed)
        plt.ioff()
            
        plt.show()
        plt.ioff()

    # Create some gaps
    mask_gaps = misc.trues((M,N))
    gap = 15
    interval = 100
    for m in range(100, N, interval):
        start = m
        end = min(m+gap, N-1)
        mask_gaps[:,start:end] = False
    # Randomly missing values
    mask_random = np.logical_or(random.mask(M, N, p=0.8),
                                np.logical_not(mask_gaps))
    # Combine the two masks
    mask = np.logical_and(mask_gaps, mask_random)

    if autosave:
        if dynamics == 'varying':
            filename = ('results/results_spde_seed=%d_dynamics=varying_D=%d_K=%d_date=%s.hdf5' 
                        % (seed,
                           D,
                           K,
                           datetime.datetime.today().strftime('%Y%m%d%H%M%S')))
        elif dynamics == 'switching':
            filename = ('results/results_spde_seed=%d_dynamics=switching_D=%d_K=%d_date=%s.hdf5' 
                        % (seed,
                           D,
                           K,
                           datetime.datetime.today().strftime('%Y%m%d%H%M%S')))
        elif dynamics == 'constant':
            filename = ('results/results_spde_seed=%d_dynamics=constant_D=%d_date=%s.hdf5' 
                        % (seed,
                           D,
                           datetime.datetime.today().strftime('%Y%m%d%H%M%S')))
        else:
            raise ValueError("Unknown dynamics requested")

    else:
        filename = None

    # Run the method
    if dynamics == 'switching':
        Q = lssm_sd.infer(y, D, K,
                          mask=mask, 
                          maxiter=maxiter,
                          rotate=rotate,
                          debug=debug,
                          update_hyper=5,
                          autosave=filename,
                          monitor=monitor)
    elif dynamics == 'varying':
        Q = lssm_tvd.infer(y, D, K,
                           mask=mask, 
                           maxiter=maxiter,
                           rotate=rotate,
                           debug=debug,
                           precompute=precompute,
                           update_hyper=5,
                           start_rotating_weights=10,
                           autosave=filename,
                           monitor=monitor)
    elif dynamics == 'constant':
        Q = lssm.infer(y, D, 
                       mask=mask, 
                       maxiter=maxiter,
                       rotate=rotate,
                       debug=debug,
                       precompute=precompute,
                       update_hyper=5,
                       autosave=filename,
                       monitor=monitor)
    else:
        raise ValueError("Unknown dynamics requested")
        
    if plot:
        ym = y.copy()
        ym[~mask] = np.nan
        plt.figure()
        bpplt.timeseries_normal(Q['F'], scale=2)
        bpplt.timeseries(f, linestyle='-', color='b')
        bpplt.timeseries(ym, linestyle='None', color='r', marker='.')
    
        # Plot latent space
        Q.plot('X')
    
        if dynamics == 'varying':
            Q.plot('S')

    # Compute RMSE
    Q['Y'].update()
    rmse_random = misc.rmse(Q['Y'].get_moments()[0][~mask_random], 
                             f[~mask_random])
    rmse_gaps = misc.rmse(Q['Y'].get_moments()[0][~mask_gaps],
                           f[~mask_gaps])
    print("RMSE for randomly missing values: %f" % rmse_random)
    print("RMSE for gap values: %f" % rmse_gaps)

    plt.show()

    return


if __name__ == '__main__':
    import sys, getopt, os
    try:
        opts, args = getopt.getopt(sys.argv[1:],
                                   "",
                                   [
                                       "m=",
                                       "n=",
                                       "d=",
                                       "k=",
                                       "dynamics=",
                                       "lengthscale=",
                                       "innovation=",
                                       "resolution=",
                                       "seed=",
                                       "maxiter=",
                                       "debug",
                                       "animation",
                                       "precompute",
                                       "no-plot",
                                       "no-monitor",
                                       "no-verbose",
                                       "no-rotation",
                                       "autosave",
                                   ])
    except getopt.GetoptError:
        print('python spde_experiment.py <options>')
        print('--m=<INT>           Dimensionality of data vectors')
        print('--n=<INT>           Number of data vectors')
        print('--d=<INT>           Dimensionality of the latent vectors in the model')
        print('--k=<INT>           Dimensionality of the latent mixing weights space')
        print('--dynamics=...      [constant] / switching / varying')
        print('--resolution=<INT>  Grid resolution for the SPDE simulation')
        print('--lengthscale=2.0   Spatial innovation noise lengthscale')
        print('--innovation=...    Magnitude of the spatial innovation noise')
        print('--no-rotation       Apply speed-up rotations')
        print('--maxiter=<INT>     Maximum number of VB iterations')
        print('--seed=<INT>        Seed (integer) for the random number generator')
        print('--autosave          Save the VB results automatically')
        print('--debug             Check that the rotations are implemented correctly')
        print('--animation         Show and save the SPDE process animation')
        print('--no-plot           Plot the process and results')
        print('--no-monitor        Do not monitor variables during learning')
        print('--no-verbose        Do not print the progress of the simulation')
        print('--precompute        Precompute some moments when rotating. May '
              'speed up or slow down.')
        sys.exit(2)

    kwargs = {}
    for opt, arg in opts:
        if opt == "--no-rotation":
            kwargs["rotate"] = False
        elif opt == "--maxiter":
            kwargs["maxiter"] = int(arg)
        elif opt == "--autosave":
            kwargs["autosave"] = True
        elif opt == "--debug":
            kwargs["debug"] = True
        elif opt == "--animation":
            kwargs["animation"] = True
        elif opt == "--dynamics":
            kwargs["dynamics"] = arg
        elif opt == "--precompute":
            kwargs["precompute"] = True
        elif opt == "--seed":
            kwargs["seed"] = int(arg)
        elif opt == "--m":
            kwargs["M"] = int(arg)
        elif opt == "--n":
            kwargs["N"] = int(arg)
        elif opt == "--d":
            kwargs["D"] = int(arg)
        elif opt == "--k":
            if int(arg) == 0:
                kwargs["K"] = None
            else:
                kwargs["K"] = int(arg)
        elif opt == "--lengthscale":
            kwargs["lengthscale"] = float(arg)
        elif opt == "--innovation":
            kwargs["innovation"] = float(arg)
        elif opt == "--resolution":
            kwargs["resolution"] = int(arg)
        elif opt == "--no-plot":
            kwargs["plot"] = False
        elif opt == "--no-monitor":
            kwargs["monitor"] = False
        elif opt == "--no-verbose":
            kwargs["verbose"] = False
        else:
            raise ValueError("Unhandled argument given")

    run(**kwargs)
