############################################################################
# Copyright (C) 2013 Jaakko Luttinen (jaakko.luttinen@aalto.fi)
#
# This file is licensed under Version 3.0 of the GNU General Public License.
############################################################################

import numpy as np
import scipy
import matplotlib.pyplot as plt

import h5py

from functools import reduce

def load_data(remove_empty=True, 
              remove_altitudes=True, 
              remove_unlocated=True,
              remove_corrupted=True):

    #
    # Load temperature data
    #
    data = h5py.File('testbed_temperature.hdf5', 'r')
    #data = h5py.File('/home/jluttine/publications/drifting_lssm/testbed_temperature.hdf5', 'r')
    y = data['temperature'][...]
    coordinates = data['coordinates'][...]
    time = data['time'][...]

    #
    # Preprocess
    #

    M = np.shape(y)[0]
    mask = np.ones(M, dtype=np.bool)
    
    if remove_corrupted:
        # New foundings based on LSSM results
        ## mask[7] = False
        ## mask[13] = False
        ## mask[16] = False
        ## mask[26] = False
        ## mask[103] = False
        ## mask[136] = False
        ## mask[188] = False
        # Remove badly corrupted stations
        mask[24] = False
        mask[31] = False
        mask[45] = False
        mask[49] = False
        mask[66] = False
        mask[99] = False
        mask[100] = False
        mask[101] = False
        mask[110] = False
        mask[135] = False
        mask[141] = False
        mask[159] = False
    
    if remove_empty:
        # Remove stations with no observations
        mask = np.logical_and(mask,
                              ~np.all(np.isnan(y), axis=-1))

    if remove_altitudes:
        # Use only the lowest altitude from each location. Assume that the
        # stations are sorted by location and lower altitude is before the
        # higher.
        c = np.diff(coordinates[:,:2], axis=0)
        z = np.ones(M, dtype=np.bool)
        z[1:] = (c!=0).any(axis=1)
        mask = np.logical_and(mask, z)

    if remove_unlocated:
        # Remove stations with invalid coordinates
        mask = reduce(np.logical_and,
                      (mask,
                       coordinates[:,0] > 22.5,
                       coordinates[:,0] < 26.8,
                       coordinates[:,1] > 59.7,
                       coordinates[:,1] < 61.0))

    y = y[mask]
    coordinates = coordinates[mask]
    index = np.arange(M)[mask]
    
    (M, N) = np.shape(y)

    ## plt.figure()
    ## d = np.ceil(np.sqrt(M))
    ## for m in range(M):
    ##     plt.subplot(d,d,m+1)
    ##     plt.plot(y[m,:])
    ##     plt.title('y[%d]' % index[m])
    ##     plt.xticks(())
    ##     plt.yticks(())
    ## plt.show()
    
    return (y, time, coordinates, index)
