import numpy as np
from scipy import signal,stats
import xarray as xr
from sklearn import linear_model, gaussian_process
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
from joblib import Parallel, delayed, parallel_backend
#from joblib import load, dump
import tempfile
import shutil
import os
#
import sys
sys.path.append('pyunicorn_timeseries')
from pyunicorn_timeseries.surrogates import Surrogates

def GPR_simple_predict(gp,x_pred):
    y_pred, sigma = gp.predict(x_pred, return_std=True)
    return y_pred, sigma

def GPR_simple_fit(lsc_min,lsc_max,x_train,y_train):
    kernel = (
          1 * RBF(length_scale=1,length_scale_bounds=(lsc_min, lsc_max))
        + WhiteKernel(noise_level=1)
        )
    gp = gaussian_process.GaussianProcessRegressor(kernel=kernel)
    gp.fit(x_train,y_train)
    #print(gp.kernel_)
    #y_pred, sigma = gp.predict(x_pred, return_std=True)
    return gp #y_pred,sigma

def combine_reconstructed(Y,X,nens=12,ensemble_length=0,return_slopes=False,percentiles=[5,25,50,75,95]):
    '''
    Build a multiple linear regression model to estimate Y
    given X. Here X is assumed to be a reconstructed Y,
    using the Green's function approach.
    
    X    : numpy.array of shape [time,predictors, maxlags]
    Y    : numpy.arary of shape [time]
    nens : integer,default is 12, if one provides the full timeseries
    ensemble_length: integer, avoid overfitting by breaking the timeseries into
                     n ensemble members of length ensemble_length
                     NOTE: ONE COULD RANDOMIZE THIS ENSEMBLE GENERATION
                     AND JUST GIVE IN THE NUMBER OF ENSEMBLE MEMBERS
                     PERHAPS THERE COULD BE A CHECK TO MAKE SURE THE 
                     NUMBE OF MEMBERS IS REASONABLE CONSIDERING THE
                     LENGTH OF THE TIMESERIES.
    '''
    
    #X=icevar_reconstructed[model+'_'+case+'_'+ens+'_'+icevar][:,:-1,-1]
    #y=icevar_anom[model+'_'+case+'_'+ens+'_'+icevar]
    #lm = linear_model.LinearRegression()
    #model = lm.fit(X,y)
    #predictions = lm.predict(X)
    ntime = Y.shape[0]
    if ensemble_length>0:
        # this will lead to 75% overlap between consecutive segments
        ens_inds = np.arange(0,ntime//nens-ensemble_length,ensemble_length//4).astype(int)
        if return_slopes:
            slopes = np.zeros((len(ens_inds),nens,X.shape[1],X.shape[2]))
        #
        Y_combined = np.zeros((Y.shape[0],len(percentiles),X.shape[-1]))
    else:
        if return_slopes:
            slopes = np.zeros((nens,X.shape[1],X.shape[2]))
        #
        Y_combined = np.zeros((Y.shape[0],X.shape[-1]))
    for j in range(X.shape[-1]):
        for m in range(nens):
            lm = linear_model.LinearRegression()
            if ensemble_length==0:
                linearmodel = lm.fit(X[m::nens,:,j],Y[m::nens])
                Y_combined[m::nens,j] = linearmodel.predict(X[m::nens,:,j])
                if return_slopes:
                    slopes[m,:,j] = linearmodel.coef_
            else:
                dumx = X[m::nens,:,j]
                dumy = Y[m::nens]
                yout = np.zeros((len(ens_inds),ntime//nens))
                for e,ei in enumerate(ens_inds):
                    linearmodel = lm.fit(dumx[ei:ei+ensemble_length,],dumy[ei:ei+ensemble_length])
                    yout[e,:]   = linearmodel.predict(dumx)
                    if return_slopes:
                        slopes[e,m,:,j] = linearmodel.coef_
                #
                Y_combined[m::nens,:,j] = np.nanpercentile(yout,percentiles,axis=0).T
            #B=np.dot(Y[m::12], np.linalg.pinv(X[m::12,:,j].T))
            #Y_combined[m::12,j] = np.dot(B,X[m::12,:,j].T)
    #
    if return_slopes:
        if ensemble_length>0:
            slopes = np.nanmedian(slopes,axis=0).squeeze()
        return Y_combined, slopes
    else:
        return Y_combined

def Y_from_X_and_G(X,G,minlag=0,maxlags=np.arange(5*12,10*12,12)):
    '''
    Reconstruct timeseries of Y based on the response function Gstep and the timeseries of predictant(s) X
    
    X:       numpy.array, a timeseries of the predictors. 
             Should be at least size (ntime,1)
    minlag:  int (default=0), possibility to limit the shortest allowed lag
             useful for assessing predictability
    maxlags: numpy.array of the longest lags allowed.
    '''
    ntime, nx = X.shape
    Y_reconstructed = np.ones((ntime,nx,len(maxlags)))*np.nan
    # loop over all the predictors
    for p,x0 in enumerate(X.T): #loop over sections
        #print(p)
        for ml,maxlag in enumerate(maxlags): #loop over maxlags - we will make an ensemble
            for m in range(12): #loop over months
                # here we construct an array with entries that start from a given time t and extend maxlags months prior to that
                #varX = np.zeros((ntime//12,maxlag))
                #for t in range(minlag,ntime//12):
                #    tinds = np.arange(t*12+m-maxlag+1,t*12+m+1).astype(np.int)[::-1]
                #    varX[t,:] = x0[tinds]
                #    varX[t,tinds[np.where(tinds<0)]]=0
                #
                # convolution
                #Y_reconstructed[m::12,p,ml] = np.dot(G[p,ml,:maxlag,m], varX.T)
                G0=G[p,ml,:maxlag,m].copy()
                G0[:minlag]=0 #this doesn't make too much sense afterall
                #Y_reconstructed[m::12,p,ml] = np.convolve(x0,G0[::-1],mode='same')[m::12]
                Y_reconstructed[m::12,p,ml] = np.convolve(x0[m:],G0[::-1],mode='same')[::12] #this is correct? G is specific for each month, and assumes that lag0 is m
                #Y_reconstructed[m::12,p,ml] = np.convolve(x0,G[p,ml,:maxlag,m][::-1],mode='same')[m::12]
    #
    return Y_reconstructed

def crf_monthly(X,Y,minlag=0,maxlags=np.arange(5*12,10*12,12),filtered=False):
    '''
    Calculate CRF given X and Y. Note that this will calculate the
    monthly CRFs, and form and ensemle given a number of maxlags.
    
    Y:       xarray.DataArray, a monthly timeseries of the target variable
    X:       numpy.array, a timeseries of the predictors. 
             Should be at least size (ntime,1)
    minlag:  int (default=0), possibility to limit the shortest allowed lag
             useful for assessing predictability
    maxlags: numpy.array of the longest lags allowed.
    '''
    ntime, nx = X.shape
    #
    varY = np.zeros(Y.shape)
    for m in range(12):
        varY[m::12]=signal.detrend(Y.values[m::12])
    #
    # initialize
    G = np.ones((nx,len(maxlags),max(maxlags),12))*np.nan
    Gstep = np.ones((nx,len(maxlags),max(maxlags),12))*np.nan
    # loop over all the predictors
    for p,x0 in enumerate(X.T): 
        print(p)
        # loop over maxlags - we will make an ensemble
        for ml,maxlag in enumerate(maxlags):
            if filtered:
                b, a = signal.butter(3, 1.0/maxlag,'highpass')
                x1   = signal.filtfilt(b,a,x0)
            else:
                x1 = x0
            # loop over months
            for m in range(12):
                # here we construct an matrix with entries that start from a given time t and extend maxlags months prior to that for each time
                varX = np.zeros((ntime//12,maxlag))
                for t in range(minlag,ntime//12):
                    tinds = np.arange(t*12+m-maxlag+1,t*12+m+1).astype(np.int)[::-1] # figure out if this is the right order??
                    varX[t,:] = x0[tinds]
                    varX[t,tinds[np.where(tinds<0)]]=0 # set to 0 if the lag is negative
                #
                # solve the linear system
                #
                G[p,ml,:maxlag,m] = np.dot(varY[m::12], np.linalg.pinv(varX.T))
                Gstep[p,ml,:maxlag,m] = np.cumsum(G[p,ml,:maxlag,m])
    #
    Y_reconstructed = Y_from_X_and_G(X,G,minlag=0,maxlags=maxlags)
    Y_combined = combine_reconstructed(varY,Y_reconstructed)
    #
    return G, Gstep, Y_reconstructed, Y_combined

def crf_monthly2(X,Y,minlag=0,maxlags=np.arange(5*12,10*12,12),percentiles=[5,25,50,75,95],filtered=False,filter_length=0,verbose=1,normalized=True):
    '''
    Calculate CRF given X and Y. Note that this will calculate the
    monthly CRFs, and form and ensemle given a number of maxlags.
    
    Y:       xarray.DataArray, a monthly timeseries of the target variable
    X:       numpy.array, a timeseries of the predictors. 
             Should be at least size (ntime,1)
    minlag:  int (default=0), possibility to limit the shortest allowed lag
             useful for assessing predictability
    maxlags: numpy.array of the longest lags allowed.
    percentiles: percentiles at which the response function G will be returned
    filtered: boolean, default to False, if True, the predictor will be filtered 
              with a highpass filter with a cutoff frequency of maxlag.
    filter_length: float, filter length in months.
    
    '''
    ntime, nx = X.shape
    #
    varY = signal.detrend(np.reshape(Y.values,(-1,12)),axis=0).flatten()
    if filtered:
        sos3 = signal.butter(3, 1.0/filter_length,'highpass',output='sos')
        y1   = signal.sosfiltfilt(sos3,varY)
    else:
        y1   = varY
    #
    if normalized:
        y1 = (np.reshape(y1,(-1,12))/np.nanstd(np.reshape(y1,(-1,12)),axis=0)).flatten()
    #
    # initialize
    G = np.ones((nx,len(maxlags),max(maxlags),12,len(percentiles)))*np.nan
    Gstep = np.ones((nx,len(maxlags),max(maxlags),12,len(percentiles)))*np.nan
    # loop over all the predictors
    for p,x0 in enumerate(X.T):
        if verbose:
            print('predictor ',p)
        # loop over months - each month is treated as one variable
        for m in range(12):
            #print(m)
            # loop over maxlags 
            for ml,maxlag in enumerate(maxlags):
                #if filtered:
                #    #sos3 = signal.butter(3, 1.0/maxlag,'highpass',output='sos')
                #    sos3 = signal.butter(3, 1.0/max(filter_length,maxlag),'highpass',output='sos')
                #    x1   = signal.sosfiltfilt(sos3,x0)
                #    y1   = signal.sosfiltfilt(sos3,varY)
                #else:
                #y1 = varY
                x1 = x0
                #
                # split up the timeseries into overlapping segments that are maxlag in length  
                #overlapinds = np.arange(maxlag/12-maxlag/24,ntime/12-maxlag/24,max(1,int(maxlag/24))).astype(np.int) #what if we start here at maxlag instead
                overlapinds = np.arange(maxlag/12,ntime/12-maxlag/12,max(1,int(maxlag/24))).astype(np.int)
                Gdum = np.ones((maxlag,len(overlapinds)))*np.nan
                for ii,oi in enumerate(overlapinds):
                    # here we construct a matrix with entries that start from a given time t and extend maxlags months prior to that for each time
                    varX  = np.zeros((maxlag,maxlag//12))
                    jinds = range(oi,oi+maxlag//12) #this is correct, ensures that overlap is maxlag/2
                    for tt, t in enumerate(jinds):
                        tinds = np.arange(t*12+m+1-maxlag,t*12+m+1).astype(np.int)[::-1] # does it matter which way these are ordered, doesn't seem to??
                        varX[:len(tinds),tt] = x1[tinds]
                        #varX[tinds[np.where(tinds<0)],tt] = 0 # set to 0 if the lag is negative
                    #
                    # solve the linear system
                    #
                    #Gdum[:,ii] = np.dot(varY[m::12][jinds], np.linalg.pinv(varX[:,:len(jinds)]))
                    Gdum[:,ii] = np.dot(y1[m::12][jinds], np.linalg.pinv(varX[:,:len(jinds)]))
                #
                # we will only save given percentiles of the Gdum distribution
                G[p,ml,:maxlag,m,:] = np.nanpercentile(Gdum,percentiles,axis=-1).T #np.nanmedian(Gdum,-1) #np.dot(varY[m::12], np.linalg.pinv(varX.T))
                Gstep[p,ml,:maxlag,m,:] = np.cumsum(G[p,ml,:maxlag,m,:],axis=0) #np.nanpercentile(np.nancumsum(Gdum,axis=0),percentiles,axis=-1).T #np.cumsum(G[p,ml,:maxlag,m]) #step function
    # percentiles needs to have 50 in it
    j50 = np.where(np.array(percentiles)==50)[0][0]
    # use the median to reconstruct Y
    Y_reconstructed = Y_from_X_and_G(X,G[:,:,:,:,j50],minlag=0,maxlags=maxlags)
    Y_combined = combine_reconstructed(varY,Y_reconstructed)
    #
    return G, Gstep, Y_reconstructed, Y_combined, varY

def surrogate_loop(nn,X,Y,minlag,maxlags,percentiles,filtered,niters):
    '''
    Parallel loop to do the surrogate calculation
    '''
    #print('Surrogate:',nn)
    predictor_surrogate    = Surrogates(X.T)
    predictor_surrogate_ts = predictor_surrogate.refined_AAFT_surrogates(X.T,niters)
    #
    G, Gstep, Y_reconstructed, Y_combined, Y_anom = crf_monthly2(predictor_surrogate_ts.T,Y,minlag=minlag,maxlags=maxlags,percentiles=percentiles,filtered=filtered,verbose=0)
    #
    slopes,rvals,pvals,intercepts,Y_combined2 = correlate_Y_Y_combined_Y_reconstructed(Y_anom,Y_combined,Y_reconstructed)
    #
    #return G, Gstep, Y_reconstructed, Y_combined,slopes,rvals,pvals,intercepts,Y_anom,Y_combined2
    return G, Gstep, rvals

def surrogate_crf_and_correlation(X,Y,minlag=0,maxlags=np.arange(5*12,10*12,12),percentiles=[5,25,50,75,95],filtered=False,n=100,niters=5,parallel='loky',n_jobs=16):
    '''
    
    '''
    #
    G_all               = []
    Gstep_all           = []
    #Y_reconstructed_all = []
    #Y_combined_all      = []
    #Y_anom_all          = []
    #Y_combined2         = []
    rvals_all           = []
    #
    if parallel is not None:
        with parallel_backend(parallel): #,scatter=[X,Y]):
            results = Parallel(n_jobs=n_jobs,verbose=10)(delayed(surrogate_loop)(nn,X,Y,minlag,maxlags,percentiles,filtered,niters) for nn in range(n))
        # results is a list of lists so just using a zipd doesn't work
        for res in results:
            G_all.append(res[0])
            Gstep_all.append(res[1])
            rvals_all.append(res[2])
            #Y_reconstructed_all.append(res[2])
            #Y_combined_all.append(res[3])
            #rvals_all.append(res[5])
    #else:
    #    for nn in range(n):
    #        print('Surrogate:',nn)
    #        predictor_surrogate    = Surrogates(X.T)
    #        predictor_surrogate_ts = predictor_surrogate.refined_AAFT_surrogates(X.T,niters)
    #        #
    #        G, Gstep, Y_reconstructed, Y_combined = crf_monthly2(predictor_surrogate_ts.T,Y,minlag=0,maxlags=np.arange(5*12,10*12,12),percentiles=[5,25,50,75,95],filtered=False)
    #        #
    #        slopes,rvals,pvals,intercepts,BSice_anom,BSice_combined2 = correlate_Y_Y_combined_Y_reconstructed(Y,Y_combined,Y_reconstructed)
    #        #
    #        rvals_all.append(rvals)
    #        G_all.append(G)
    #        Gstep_all.append(Gstep)
    #        Y_reconstructed_all.append(Y_reconstructed)
    #        Y_combined_all.append(Y_combined)
    #
    print('Surrogate done!')
    return np.array(G_all), np.array(Gstep_all), np.array(rvals_all)
    #return np.array(G_all), np.array(Gstep_all), np.array(Y_reconstructed_all), np.array(Y_combined_all), np.array(rvals_all)

def correlate_Y_Y_combined_Y_reconstructed(Y_anom,Y_combined,Y_reconstructed,ensemble_length=0,return_slopes=False,percentiles=[5,25,50,75,95]):
    '''
    Calculate correlation between the target variable (Y_anom) and
    reconstructed variable. Changed now so that the correlation is based on
    stats.lingregress, not its combination with theilslopes
    
    Y_anom :: numpy.array shape [ntime]
    Y_combined :: numpy.array shape [ntime,len(maxlags)] - predictors combined
    Y_reconstructed :: numpy.array shape [ntime,len(predictors),len(maxlags)] - individual predictors
    
    outputs
    
    slopes, rvals, pvals, intercepts:: output of the stats.lingregress. Will have the following shape
                                       [months, predictors+2,len(maxlags)]. The second dimension is 
                                       constructed as follows: 0:number_of_predictors gives the
                                       correlation between individual predictor and the target.
                                       number_of_predictors+1 gives a correlation between the target and
                                       a combination of predictors when they are combined at a given lag.
                                       number_of_predictors+2 gives a correlation between the target and
                                       combination of predictors when the 
    Y_combined2 ::                     Prediction when one uses a combination of all predictors at their
                                       maximum correlation 
    ensemble_length ::                     integer passed to combine_reconstructed in order to avoid overfitting
    '''
    #initialize variables
    ntime,nX,ml = Y_reconstructed.shape
    #
    if return_slopes:
        slopes2=np.zeros((12,nX*ml))
    #
    slopes = np.zeros((12,nX+2,ml))
    rvals = np.zeros((12,nX+2,ml))
    pvals = np.zeros((12,nX+2,ml))
    intercepts = np.zeros((12,nX+2,ml))
    if ensemble_length>0:
        Y_combined2 = np.zeros((ntime,len(percentiles)))
    else:
        Y_combined2 = np.zeros(ntime)
    #
    for m in range(12):
        #dumY = signal.detrend(Y[m::12])
        #Y_anom[m::12]=dumY
        for l in range(ml):
            if np.unique(Y_combined[m::12,:].flatten()).size>2:
                # predictors combined
                #s,ii,s0,s1 = stats.theilslopes(Y_anom[m::12],Y_combined[m::12,l])
                #slope, intercept, r_value, p_value, std_err = stats.linregress(s*Y_combined[m::12,l]+ii, Y_anom[m::12])
                slope, intercept, r_value, p_value, std_err = stats.linregress(Y_combined[m::12,l], Y_anom[m::12])
                slopes[m,-2,l] = slope
                rvals[m,-2,l] = r_value
                pvals[m,-2,l] = p_value
                intercepts[m,-2,l] = intercept
                for p in range(nX):
                    # for individual predictors
                    if False:
                        dum = combine_reconstructed(Y_anom[m::12],Y_reconstructed[m::12,p,l][:,np.newaxis,np.newaxis],nens=1,ensemble_length=ensemble_length,percentiles=percentiles).squeeze()
                        if ensemble_length>0:
                            slope, intercept, r_value, p_value, std_err = stats.linregress(dum[:,percentiles.index(50)], Y_anom[m::12])
                        else:
                            slope, intercept, r_value, p_value, std_err = stats.linregress(dum[:,percentiles.index(50)], Y_anom[m::12])
                    else:
                        #s,ii,s0,s1 = stats.theilslopes(Y_anom[m::12],Y_reconstructed[m::12,p,l])
                        #slope, intercept, r_value, p_value, std_err = stats.linregress(s*Y_reconstructed[m::12,p,l]+ii, Y_anom[m::12])
                        slope, intercept, r_value, p_value, std_err = stats.linregress(Y_reconstructed[m::12,p,l], Y_anom[m::12])
                    slopes[m,p,l]     = slope
                    rvals[m,p,l]      = r_value
                    pvals[m,p,l]      = p_value
                    intercepts[m,p,l] = intercept
        # the above combined approach is rather naiive - it just combines predictors at the same maxlags
        # however, they might have higher correlation at some other maxlag depending on the integration time.
        # We can find out that lag of max correlation and use it.
        #if False:
        #    ninds=[]
        #    for p in range(nX):
        #        ninds.append(np.where(max(rvals[m,p,:])==rvals[m,p,:])[0][0])
        #    #
        #    Y_combined2[m::12] = combine_reconstructed(Y_anom[m::12],Y_reconstructed[m::12,range(nX),ninds][:,:,np.newaxis],nens=1,return_slopes=return_slopes).squeeze()
        #else:
        #    # here we actually just use all of the sections and all of the lags
        if return_slopes:
            Y_combined2_dum,slopes2[m::12,:] = combine_reconstructed(Y_anom[m::12],np.reshape(Y_reconstructed[m::12,:,:],(ntime//12,-1))[:,:,np.newaxis],nens=1,ensemble_length=ensemble_length,return_slopes=return_slopes,percentiles=percentiles)
            Y_combined2[m::12,]=Y_combined2_dum.squeeze()
        else:
            rjinds,riinds = np.where(rvals[m,:-2,:]**2>=min(0.1,np.max(rvals[m,:-2,:]**2))) #at least 10% variance explained, if nothing above 10%, then take just max
            Y_combined2[m::12,] = combine_reconstructed(Y_anom[m::12],Y_reconstructed[m::12,rjinds,riinds][:,:,np.newaxis],nens=1,ensemble_length=ensemble_length,percentiles=percentiles).squeeze()
            #Y_combined2[m::12] = combine_reconstructed(Y_anom[m::12],np.reshape(Y_reconstructed[m::12,:,:],(ntime//12,-1))[:,:,np.newaxis],nens=1,ensemble_length=ensemble_length).squeeze()
        #
        if ensemble_length>0:
            s,ii,s0,s1 = stats.theilslopes(Y_anom[m::12],Y_combined2[m::12,percentiles.index(50)])
            slope, intercept, r_value, p_value, std_err = stats.linregress(s*Y_combined2[m::12,percentiles.index(50)]+ii, Y_anom[m::12])
            #slope, intercept, r_value, p_value, std_err = stats.linregress(Y_combined2[m::12,percentiles.index(50)], Y_anom[m::12])
        else:
            s,ii,s0,s1 = stats.theilslopes(Y_anom[m::12],Y_combined2[m::12])
            slope, intercept, r_value, p_value, std_err = stats.linregress(s*Y_combined2[m::12]+ii, Y_anom[m::12])
            #slope, intercept, r_value, p_value, std_err = stats.linregress(Y_combined2[m::12], Y_anom[m::12])
        slopes[m,-1,0]     = slope
        rvals[m,-1,0]      = r_value
        pvals[m,-1,0]      = p_value
        intercepts[m,-1,0] = intercept
    
    if return_slopes:
        return slopes,rvals,pvals,intercepts,Y_combined2,slopes2
    else:
        return slopes,rvals,pvals,intercepts,Y_combined2

def crf_spatial(X,Y,optlag):
    '''
    #THIS MIGHT ACTUALLY WORK, JUST HAVE TO CHECK THE RECONSTRUCTION FUNCTION
    #THERE IS ALSO NO NEED TO RUN THE 'COMBINED' FUNCTION
    #
    Calculate CRF given X and Y. Note that this will calculate the
    monthly CRFs, and form and ensemle given a number of maxlags.
    
    Y:       xarray.DataArray, a monthly timeseries of the target variable
    X:       numpy.array, a timeseries of the predictors. 
             Should be at least size (ntime,1)
    minlag:  int (default=0), possibility to limit the shortest allowed lag
             useful for assessing predictability
    maxlags: numpy.array of the longest lags allowed.
    '''
    ntime, nx = X.shape
    #
    varY = np.zeros(Y.shape)
    for m in range(12):
        varY[m::12]=signal.detrend(Y.values[m::12])
    #
    # initialize
    G = np.ones((int(np.max(optlag)),12))*np.nan
    Gstep = np.ones((12,nx))*np.nan
    for m in range(12):
        maxlag = int(np.max(optlag[m,:]))
        varX  = np.zeros((ntime//12,nx))
        for l,lag in enumerate(optlag[m,:].astype(int)): # this could be probably done without the loop
            # here we construct a matrix with entries that start from a given time t and extend maxlags months prior to that for each time
            dum=X[m+maxlag-lag::12,l]
            varX[:len(dum),l] = dum #X[m+maxlag-lag::12,l]
        # solve the linear system
        #
        G[:,m] = np.dot(varY[m+maxlag::12], np.linalg.pinv(varX[:ntime//12-(m+maxlag)//12,:]))
        #
        Gstep[:maxlag,m] = np.cumsum(G[:,m]) #step function
    #
    Y_reconstructed = Y_from_X_and_G(X,G,minlag=0,maxlags=maxlags)
    #
    return G, Gstep, Y_reconstructed, Y_combined


def main_loop(lags,m,corr_matrix2_mm,X,Y):
    """ 
    
    """
    nt,nx = X.shape
    for jj in range(nx):
        if np.all(np.isfinite(X[:,jj])):
            for l,lag in enumerate(lags):
                r,p = stats.pearsonr(Y[m+(12-lag%12)+lag:-m-1:12],X[m+(12-lag%12):-m-1-lag:12,jj])
                corr_matrix2_mm[m,l,jj] = r


def lagged_correlation(X,Y,lags=np.arange(120).astype(int)):
    '''
    Calculate lagged correlation between a Y (predictable) and predictor (X)
    '''
    nt,nyx  = X.shape
    months  = np.arange(12).astype(int)
    folder1 = tempfile.mkdtemp()
    path2   = os.path.join(folder1, 'corr_matrix2.mmap')
    corr_matrix2_mm    = np.memmap(path2, dtype=float, shape=(len(months),len(lags),nyx), mode='w+')
    corr_matrix2_mm[:] = np.zeros((len(months),len(lags),nyx))    #
    num_cores = 6
    Parallel(n_jobs=num_cores)(delayed(main_loop)(lags,m,corr_matrix2_mm,X,Y) for m,mon in enumerate(months))
    corr_matrix2 = np.asarray(corr_matrix2_mm)
    try:
        shutil.rmtree(folder1)
    except OSError:
        pass
    #
    return corr_matrix2

def p_cor(x,y):
    """
    Uses the scipy stats module to calculate the Pearson correlation coefficient
    :x vector: Input timeseries model for all grid points
    :y vector: Input timeseries obs for all grid points
    """
    #r, p_val = stats.pearsonr(x, y)
    if np.all(np.isfinite(x)):
        return stats.pearsonr(signal.detrend(x), y)[0]
    else:
        return 0
    
# The function we are going to use for applying the correlation per pixel (lon-lat point)
def pearson_correlation(x, y, dim='time'):
    # x = Pixel value, y = a vector containing the date, dim == dimension
    return xr.apply_ufunc(
        p_cor, x , y,
        input_core_dims=[[dim], [dim]],   # dim indicates the dimension over which we don't want to broadcast for x and y, respectively
        vectorize=True, # !Important!
        output_dtypes=[x.dtype],
        dask='parallelized'
        )

def return_grid_in(lon_in,lat_in,lon_b_in,lat_b_in):
    '''                                                                                                                                                 
    A quick helper function to create a grid                                                                                                                                                    
    '''
    #if np.min(lon_in)
    lon_in[np.where(lon_in>180)] = lon_in[np.where(lon_in>180)]-360
    lon_b_in[np.where(lon_b_in>180)] = lon_b_in[np.where(lon_b_in>180)]-360
    #  
    #
    if len(lon_in.shape)>1:
        lon = xr.DataArray(lon_in,dims=('j1','i1'),name='lon')
        lat = xr.DataArray(lat_in,dims=('j1','i1'),name='lat')
        lon_b = xr.DataArray(lon_b_in,dims=('j2','i2'),name='lon_b')
        lat_b = xr.DataArray(lat_b_in,dims=('j2','i2'),name='lat_b')
        ds = xr.merge([lat,lon,lat_b,lon_b])
    else:
        ds = xr.Dataset({'lat': (['j1'], lat_in),'lon': (['i1'], lon_in),'lat_b':(['j2'], lat_b_in), 'lon_b': (['i2'], lon_b_in), })

    return ds

def perform_regrid(regridder,datain,mask=None):
    '''                                                                                                                                                                                 
    do the regridding, take care of the land points                                                                                                                                     
    '''
    if np.any(mask==None):
        out = regridder(datain)
    else:
        mask = regridder(mask)
        out  = regridder(datain)/mask
        #
        jinds,iinds = np.where(mask<0.5)
        if len(out.shape)==3:
            out[:,jinds,iinds] = np.nan
        else:
            out[jinds,iinds] = np.nan

    return out