"""Generate a "pairs plot" a.k.a "scatterplot matrix" a.k.a "matrix of bivariate scatter plots." This shows a scatterplot for each EEG channel against each other channel for a short EEG segment. It also shows a histogram illustrating the distribution of each channel and the person correlation for each channel pair. by Elliott Forney idfah@cs.colostate.edu Colorado State University 2013.04.04 """ import matplotlib.pyplot as plt import numpy as np import scipy.stats as stats import json def pairsPlot(sig, chans): if np.ndim(sig) == 1: ndim = 1 else: ndim = sig.shape[1] mx = np.max(np.abs(sig)) fig = plt.figure() for r in xrange(ndim): for c in xrange(ndim): #ax = fig.add_subplot(ndim,ndim,r*ndim+c+1) ax = fig.add_subplot(ndim,ndim,r*ndim+c+1+ndim*(ndim-(r%ndim)*2-1)) sigx = sig[:,r] sigy = sig[:,c] if (r == c): plt.hist(sigx, bins=30, normed=False) ax.set_xlim(-mx/2.0,mx/2.0) else: ax.scatter(sigx, sigy, alpha=0.5, s=10, marker='.') ax.plot((-mx,mx),(-mx,mx), color='grey', linestyle='dashed') pearsonr,pearsonp = stats.pearsonr(sigx,sigy) pearsons = ".%2d" % np.round(pearsonr*100) ax.text(0.9,0.1,pearsons, transform=ax.transAxes, horizontalalignment='right', verticalalignment='bottom', fontsize=8) ax.set_ylim(-mx,mx) ax.set_xlim(-mx,mx) if r == 0: #if r == ndim-1: ax.set_xlabel(chans[c]) ax.set_xticks([]) else: #ax.set_xlabel('') ax.get_xaxis().set_visible(False) if c == 0: ax.set_ylabel(chans[r]) ax.set_yticks([]) else: #ax.set_ylabel('') ax.get_yaxis().set_visible(False) return fig def loadData(fileName='s20-gammasys-gifford-unimpaired.json', ns=60): # load from json dataHandle = open(fileName, 'r') data = json.load(dataHandle) dataHandle.close() # pull out first session data = data[0] # get sample rate sampRate = data['sample rate'] # get channel info chans = data['channels'] nchan = len(chans) # eeg signals from first trial sig = np.array(data['eeg']['trial 1']).T # get first ns seconds of data, ommit trigger channel sig = sig[sampRate:(ns*sampRate),:nchan] # demean sig -= np.mean(sig, axis=0) return sig,chans sig,chans = loadData() pairsPlot(sig, chans) plt.show()