""" xDAWN and variants for enhancing event-related potentials """
import os
import cPickle
from copy import deepcopy
import numpy
from scipy.linalg import qr
from pySPACE.missions.nodes.spatial_filtering.spatial_filtering \
import SpatialFilteringNode
from pySPACE.resources.data_types.time_series import TimeSeries
from pySPACE.resources.dataset_defs.stream import StreamDataset
from pySPACE.tools.filesystem import create_directory
import logging
[docs]class XDAWNNode(SpatialFilteringNode):
""" xDAWN spatial filter for enhancing event-related potentials.
xDAWN tries to construct spatial filters such that the
signal-to-signal plus noise ratio is maximized. This spatial filter is
particularly suited for paradigms where classification is based on
event-related potentials.
For more details on xDAWN, please refer to
http://www.icp.inpg.fr/~rivetber/Publications/references/Rivet2009a.pdf
**References**
========= ==============================================================
main source: xDAWN
========= ==============================================================
author Rivet, B. and Souloumiac, A. and Attina, V. and Gibert, G.
journal Biomedical Engineering, IEEE Transactions on
title `xDAWN Algorithm to Enhance Evoked Potentials: Application to Brain-Computer Interface <http://dx.doi.org/10.1109/TBME.2009.2012869>`_
year 2009
month aug.
volume 56
number 8
pages 2035 -2043
doi 10.1109/TBME.2009.2012869
ISSN 0018-9294
========= ==============================================================
========= ==============================================================
minor source: adaptive xDAWN
========= ==============================================================
author Woehrle, H. and Krell, M. M. and Straube, S. and Kim, S. U., Kirchner, E. A. and Kirchner, F.
title `An Adaptive Spatial Filter for User-Independent Single Trial Detection of Event-Related Potentials <http://dx.doi.org/10.1109/TBME.2015.2402252>`_
journal IEEE Transactions on Biomedical Engineering
publisher IEEE
doi 10.1109/TBME.2015.2402252
volume 62
issue 7
pages 1696 - 1705
year 2015
========= ==============================================================
**Parameters**
:erp_class_label: Label of the class for which an ERP should be evoked.
For instance "Target" for a P300 oddball paradigm.
(*recommended, default: 'Target'*)
:retained_channels: Determines how many of the pseudo channels
are retained. Default is None which means "all channels".
(*optional, default: None*)
:load_filter_path: An absolute path from which the spatial filters can
be loaded. If not specified, these filters are learned from the
training data.
(*optional, default: None*)
:visualize_pattern: If value is true, a visualization of the learned
spatial filters is stored.
The visualisation is divided into two components.
First of all each transformation is visualized separately.
Since the visualization itself may not be so meaningful,
there exists another combined visualization, which shows
the filter (u_i) with the underlying spatial distribution
(w_i, parameter names taken from paper).
The number of filters equals the number of original channels.
Normally only the first channels matter and the rest corresponds to
different noise components.
To avoid storing to many pictures, the *retained_channels*
parameter is used to restrict the number.
(*optional, default: False*)
**Exemplary Call**
.. code-block:: yaml
-
node : xDAWN
parameters:
erp_class_label : "Target"
retained_channels : 32
store : True
:Author: Jan Hendrik Metzen (jhm@informatik.uni-bremen.de)
:Created: 2011/07/05
"""
[docs] def __init__(self, erp_class_label=None, retained_channels=None,
load_filter_path=None, visualize_pattern=False, **kwargs):
# Must be set before constructor of superclass is called
self.trainable = (load_filter_path is None)
super(XDAWNNode, self).__init__(retained_channels=retained_channels,
**kwargs)
if erp_class_label is None:
erp_class_label = "Target"
self._log("No ERP class label given. Using default: 'Target'.",
level=logging.CRITICAL)
filters = None
# Load patterns from file if requested
if not load_filter_path is None:
filters_file = open(load_filter_path, 'r')
filters = cPickle.load(filters_file)
filters_file.close()
self.set_permanent_attributes(
# Label of the class for which an ERP should be evoked.
erp_class_label=erp_class_label,
# The channel names
channel_names=None,
# Matrices for storing data and stimuli
X=None,
D=None,
SNR=None,
# The number of channels that will be retained
retained_channels=int(retained_channels) if retained_channels is not None else None,
# whether this node is trainable
trainable=self.trainable,
# After training is finished, this attribute will contain
# the spatial filters that are used to project
# the data onto a lower dimensional subspace
filters=filters,
# Determines whether the filters are stored after training
visualize_pattern=visualize_pattern,
xDAWN_channel_names=None,
)
if self.visualize_pattern:
self.set_permanent_attributes(store=True)
[docs] def is_trainable(self):
""" Returns whether this node is trainable. """
return self.trainable
[docs] def is_supervised(self):
""" Returns whether this node requires supervised training """
return self.trainable
[docs] def _train(self, data, label):
""" Train node on given example *data* for class *label*. """
# If this is the first data sample we obtain
if self.channel_names is None:
self.channel_names = data.channel_names
if self.retained_channels in [None, 'None']:
self.retained_channels = len(self.channel_names)
else:
self.retained_channels = int(self.retained_channels)
if len(self.channel_names) < self.retained_channels:
self.retained_channels = len(self.channel_names)
self._log("Too many channels chosen for the retained channels! "
"Replaced by maximum number.", level=logging.CRITICAL)
elif self.retained_channels < 1:
self._log("Too little channels chosen for the retained channels! "
"Replaced by minimum number (1).", level=logging.CRITICAL)
# Iteratively construct Toeplitz matrix D and data matrix X
if label == self.erp_class_label:
D = numpy.diag(numpy.ones(data.shape[0]))
else:
D = numpy.zeros((data.shape[0], data.shape[0]))
if self.X is None:
self.X = deepcopy(data)
self.D = D
else:
self.X = numpy.vstack((self.X, data))
self.D = numpy.vstack((self.D, D))
[docs] def _stop_training(self, debug=False):
# The following if statement is needed only to account for
# different versions of scipy
if map(int, __import__("scipy").__version__.split('.')) >= [0, 9, 0]:
# NOTE: mode='economy'required since otherwise
# the memory consumption is excessive;
# QR decompositions of X
Qx, Rx = qr(self.X, overwrite_a=True, mode='economic')
# QR decompositions of D
Qd, Rd = qr(self.D, overwrite_a=True, mode='economic')
else:
# NOTE: econ=True required since otherwise
# the memory consumption is excessive
# QR decompositions of X
Qx, Rx = qr(self.X, overwrite_a=True, econ=True)
# QR decompositions of D
Qd, Rd = qr(self.D, overwrite_a=True, econ=True)
# Singular value decomposition of Qd.T Qx
# NOTE: full_matrices=True required since otherwise we do not get
# num_channels filters.
self.Phi, self.Lambda, self.Psi = \
numpy.linalg.svd(numpy.dot(Qd.T, Qx), full_matrices=True)
self.Psi = self.Psi.T
SNR = numpy.zeros(self.X.shape[1])
# Construct the spatial filters
for i in range(self.Psi.shape[1]):
# Construct spatial filter with index i as Rx^-1*Psi_i
ui = numpy.dot(numpy.linalg.inv(Rx), self.Psi[:,i])
wi = numpy.dot(Rx.T, self.Psi[:,i])
if i < self.Phi.shape[1]:
ai = numpy.dot(numpy.dot(numpy.linalg.inv(Rd), self.Phi[:,i]),
self.Lambda[i])
if i == 0:
self.filters = numpy.atleast_2d(ui).T
self.wi = numpy.atleast_2d(wi)
self.ai = numpy.atleast_2d(ai)
else:
self.filters = numpy.hstack((self.filters,
numpy.atleast_2d(ui).T))
self.wi = numpy.vstack((self.wi, numpy.atleast_2d(wi)))
if i < self.Phi.shape[1]:
self.ai = numpy.vstack((self.ai, numpy.atleast_2d(ai)))
a = numpy.dot(self.D, ai.T)
b = numpy.dot(self.X, ui)
# b.view(numpy.ndarray)
# bb = numpy.dot(b.T, b)
# aa = numpy.dot(a.T, a)
SNR[i] = numpy.dot(a.T, a)/numpy.dot(b.T, b)
self.SNR = SNR
self.D = None
self.X = None
[docs] def _execute(self, data):
""" Apply the learned spatial filters to the given data point """
if self.channel_names is None:
self.channel_names = data.channel_names
if self.retained_channels in [None, 'None']:
self.retained_channels = len(self.channel_names)
if len(self.channel_names)<self.retained_channels:
self.retained_channels = len(self.channel_names)
self._log("To many channels chosen for the retained channels! "
"Replaced by maximum number.", level=logging.CRITICAL)
data_array=data.view(numpy.ndarray)
# Project the data using the learned spatial filters
projected_data = numpy.dot(data_array,
self.filters[:, :self.retained_channels])
if self.xDAWN_channel_names is None:
self.xDAWN_channel_names = ["xDAWN%03d" % i
for i in range(self.retained_channels)]
return TimeSeries(projected_data, self.xDAWN_channel_names,
data.sampling_frequency, data.start_time,
data.end_time, data.name, data.marker_name)
[docs] def store_state(self, result_dir, index=None):
""" Stores this node in the given directory *result_dir* """
if self.store:
try:
node_dir = os.path.join(result_dir, self.__class__.__name__)
create_directory(node_dir)
# This node only stores the learned spatial filters
name = "%s_sp%s.pickle" % ("patterns", self.current_split)
result_file = open(os.path.join(node_dir, name), "wb")
result_file.write(cPickle.dumps((self.filters, self.wi,
self.ai), protocol=2))
result_file.close()
# Stores the signal to signal plus noise ratio resulted
# by the spatial filter
#fname = "SNR_sp%s.csv" % ( self.current_split)
#numpy.savetxt(os.path.join(node_dir, fname), self.SNR,
# delimiter=',', fmt='%2.5e')
# Store spatial filter plots if desired
if self.visualize_pattern:
from pySPACE.missions.nodes.spatial_filtering.csp \
import CSPNode
# Compute, accumulate and analyze signal components
# estimated by xDAWN
vmin = numpy.inf
vmax = -numpy.inf
signal_components = []
complete_signal = numpy.zeros((self.wi.shape[1],
self.ai.shape[1]))
for filter_index in range(self.retained_channels):
#self.ai.shape[0]):
signal_component = numpy.outer(self.wi[filter_index, :],
self.ai[filter_index, :])
vmin = min(signal_component.min(), vmin)
vmax = max(signal_component.max(), vmax)
signal_components.append(signal_component)
complete_signal += signal_component
# Plotting
import pylab
for index, signal_component in enumerate(signal_components):
pylab.figure(0, figsize=(18,8))
pylab.gcf().clear()
# Plot spatial distribution
ax=pylab.axes([0.0, 0.0, 0.2, 0.5])
CSPNode._plot_spatial_values(ax, self.wi[index, :],
self.channel_names,
'Spatial distribution')
# Plot spatial filter
ax=pylab.axes([0.0, 0.5, 0.2, 0.5])
CSPNode._plot_spatial_values(ax, self.filters[:, index],
self.channel_names,
'Spatial filter')
# Plot signal component in electrode coordinate system
self._plotTimeSeriesInEC(signal_component, vmin=vmin,
vmax=vmax,
bb=(0.2, 1.0, 0.0, 1.0))
pylab.savefig("%s%ssignal_component%02d.png"
% (node_dir, os.sep, index))
CSPNode._store_spatial_filter_plots(
self.filters[:, :self.retained_channels],
self.channel_names, node_dir)
# Plot entire signal
pylab.figure(0, figsize=(15, 8))
pylab.gcf().clear()
self._plotTimeSeriesInEC(
complete_signal,
file_name="%s%ssignal_complete.png" % (node_dir, os.sep)
)
pylab.savefig(
"%s%ssignal_complete.png" % (node_dir, os.sep))
except Exception as e:
print e
raise
super(XDAWNNode, self).store_state(result_dir)
[docs] def _plotTimeSeriesInEC(self, values, vmin=None, vmax=None,
bb=(0.0, 1.0, 0.0, 1.0), file_name=None):
# Plot time series in electrode coordinate system, i.e. the values of
# each channel at the position of the channel
import pylab
ec = self.get_metadata("electrode_coordinates")
if ec is None:
ec = StreamDataset.ec
ec_2d = StreamDataset.project2d(ec)
# Define x and y coordinates of electrodes in the order of the channels
# of data
x = numpy.array([ec_2d[key][0] for key in self.channel_names])
y = numpy.array([ec_2d[key][1] for key in self.channel_names])
# Determine min and max values
if vmin is None:
vmin = values.min()
if vmax is None:
vmax = values.max()
width = (bb[1] - bb[0])
height = (bb[3] - bb[2])
for channel_index, channel_name in enumerate(self.channel_names):
ax = pylab.axes([x[channel_index]/(1.2*(x.max() - x.min()))*width +
bb[0] + width/2 - 0.025,
y[channel_index]/(1.2*(y.max() - y.min()))*height +
bb[2] + height/2 - 0.0375, 0.05, 0.075])
ax.plot(values[channel_index, :], color='k', lw=1)
ax.set_xticks([])
ax.set_yticks([])
ax.set_ylim((vmin, vmax))
ax.text(values.shape[1]/2, vmax*.8, channel_name,
horizontalalignment='center', verticalalignment='center')
[docs]class SparseXDAWNNode(XDAWNNode):
""" Sparse xDAWN spatial filter for enhancing event-related potentials.
xDAWN tries to construct spatial filters such that the
signal-to-signal plus noise ratio (SSNR) is maximized. This spatial filter
is particularly suited for paradigms where classification is based on
event-related potentials. In contrast to the standard xDAWN algorithm,
this node tries to minimize the electrodes that have non-zero weights in
the spatial filters while at the same time trying to maximize the
signal-to-signal plus noise ratio. This property is used for electrode
selection, i.e. only those electrodes need to be set that obtained non-zero
weights.
For more details on Sparse xDAWN, please refer to
http://www.gipsa-lab.inpg.fr/~bertrand.rivet/references/RivetEMBC10.pdf
.. todo:: Two more sentences about Sparse_XDAWN
**Parameters**
:`lambda_`: Determines the relative influence of the two objectives
(maximization of SSNR and minimization of electrodes with non-zero
weights). If `lambda_` is 0, only the SSNR is relevant (like in
standard xDAWN). The larger `lambda_`, the weaker is the influence
of the SSNR.
:erp_class_label: Label of the class for which an ERP should be evoked.
For instance "Target" for a P300 oddball paradigm.
(*recommended, default:'Target'*)
:num_selected_electrodes: Determines how many electrodes keep a non-zero
weight.
**Exemplary Call**
.. code-block:: yaml
-
node : Sparse_xDAWN
parameters :
lambda_ : 0.1
erp_class_label : "Target"
num_selected_electrodes : 2
store : True
:Author: Jan Hendrik Metzen (jhm@informatik.uni-bremen.de)
:Created: 2011/08/22
"""
[docs] def __init__(self, lambda_, erp_class_label='Target',
num_selected_electrodes=None, **kwargs):
if 'retained_channels' in kwargs:
kwargs.pop('retained_channels')
super(SparseXDAWNNode, self).__init__(erp_class_label=erp_class_label,
retained_channels=None,
load_filter_path=None,
visualize_pattern=False,
**kwargs)
self.set_permanent_attributes(
lambda_=lambda_, num_selected_electrodes=num_selected_electrodes)
[docs] def _stop_training(self, debug=False):
if self.num_selected_electrodes is None:
self.num_selected_electrodes = self.retained_channels
# Estimate of the signal for class 1 (the erp_class_label class)
A_1 = numpy.dot(numpy.dot(numpy.linalg.inv(numpy.dot(self.D.T, self.D)),
self.D.T),
self.X)
# Estimate of Sigma 1 and Sigma X
sigma_1 = numpy.dot(numpy.dot(numpy.dot(A_1.T, self.D.T),
self.D), A_1)
sigma_X = numpy.dot(self.X.T, self.X)
def objective_function(v_1, lambda_):
""" The objective function from the paper from Rivet et al. """
a = numpy.dot(numpy.dot(v_1.T, sigma_1), v_1) # 0-d, skip trace!
b = numpy.dot(numpy.dot(v_1.T, sigma_X), v_1) # 0-d, skip trace!
c = numpy.linalg.norm(v_1, 1) / numpy.linalg.norm(v_1, 2)
return a / b - lambda_*c
# Compute the non-pruned weights
v_1 = self._gradient_optimization(
objective_function=lambda x: objective_function(x, self.lambda_),
sigma_1=sigma_1, sigma_X=sigma_X, max_evals=25000)
# Prune weight vector such that only self.num_selected_electrodes keep
# entries != 0 (those with the largest weight)
threshold = sorted(numpy.absolute(v_1))[-self.num_selected_electrodes]
v_1[numpy.absolute(v_1) < threshold] = 0
v_1 /= numpy.linalg.norm(v_1, 2)
# Determine indices and names of electrodes with non-zero weights
self.selected_indices = list(numpy.where(numpy.absolute(v_1) > 0)[0])
self.selected_channels = [self.channel_names[index]
for index in self.selected_indices]
[docs] def _gradient_optimization(self, objective_function, sigma_1, sigma_X,
max_evals=25000):
best_f_value = -numpy.inf
best_v_1 = None
evals = 0
rep = 0
# Start several repetitions at random start states
while True:
rep += 1
# Initialize electrode weight vector randomly
v_1 = numpy.random.random(self.X.shape[1])
v_1 /= numpy.linalg.norm(v_1, 2)
# Set initial learning rate
rho = 1.0
# Gradient ascent until we are very close to a local maximum
while rho > 10**-5:
# Some intermediate results
a = numpy.dot(sigma_X, v_1)
b = numpy.dot(v_1.T, a)
c = numpy.dot(sigma_1, v_1)
d = numpy.dot(v_1.T, c)
e = numpy.dot(
numpy.diag(numpy.sign(v_1)), numpy.ones(self.X.shape[1])) \
/ numpy.linalg.norm(v_1, 2)
f = numpy.dot(
numpy.linalg.norm(v_1, 1) / (numpy.dot(v_1.T, v_1)**1.5),
v_1)
# Subgradient components
sg1 = 2.0/b*(c - d/b*a)
sg2 = e - f
# Construct subgradient
subgradient = sg1 - self.lambda_ * sg2
# Search for a learning rate such that following the gradient
# does not bring us too far ahead of the optimum
v_1_old = numpy.array(v_1)
old_f_value = objective_function(v_1)
while True:
evals += 1
# Update and renormalize weight vector v
v_1 += rho * subgradient
v_1 /= numpy.linalg.norm(v_1, 2)
# Check if the current learning rate is too large
if objective_function(v_1) >= old_f_value:
# Not followed gradient too far, increase learning rate
# and break
rho /= 0.9
break
# Reduce learning rate and restore original v_1
rho *= 0.9
v_1 = numpy.array(v_1_old)
# If the learning rate becomes too low, we break
if rho < 10**-5:
break
# Break if we have spent the allowed time searching the maximum
if evals >= max_evals: break
# Check if we have found a new optimum in this repetition
if objective_function(v_1) > best_f_value:
best_f_value = objective_function(v_1)
best_v_1 = v_1
# Return if we have spent the allowed time searching the maximum
if evals >= max_evals:
return best_v_1
[docs] def _execute(self, data):
""" Project the data onto the selected channels. """
projected_data = data[:, self.selected_indices]
return TimeSeries(projected_data, self.selected_channels,
data.sampling_frequency, data.start_time,
data.end_time, data.name, data.marker_name)
[docs] def store_state(self, result_dir, index=None):
""" Stores this node in the given directory *result_dir* """
if self.store:
node_dir = os.path.join(result_dir, self.__class__.__name__)
create_directory(node_dir)
# This node only stores which electrodes have been selected
name = "%s_sp%s.txt" % ("electrode_selection", self.current_split)
result_file = open(os.path.join(node_dir, name), "wi")
result_file.write(str(self.selected_channels))
result_file.close()
[docs] def get_filters(self):
raise NotImplementedError("Sparse xDAWN is yet not fitting for ranking "
"electrode selection.")
[docs]class AXDAWNNode(XDAWNNode):
""" Adaptive xDAWN spatial filter for enhancing event-related potentials.
In general, the adaptive xDAWN algorithm works as the conventional xDAWN
algorithm, but is adapted to be able to evolve over time.
Therefore, instead of using the QR and SV decomposition, this node uses the
generalized Eigendecomposition to find the optimal filters.
The methods are based on iteratively computing the generalized
eigendecomposition (GED) with the algorithm from "Fast RLS-like algorithm
for generalized eigendecomposition and its applications" (2004)
by Yadunandana N. Rao , Jose , C. Principe , Tan F. Wong
In general, this works as follows:
- The noise and signal autocorrelation matrices are adapted with more
incoming samples.
- The inverse noise autocorrelation is updated.
- The weight vectors (i.e. general eigenvectors) are updated.
- These are used to get the actual filters.
Optionally, update coefficients can be used for adapting the filter
estimate.
For using regularization techniques, the noise autocorrelation is
initialized with the regularization matrix instead of using zeros.
**References**
========= ==============================================================
main source: axDAWN
========= ==============================================================
author Woehrle, H. and Krell, M. M. and Straube, S. and Kim, S. U. and Kirchner, E. A. and Kirchner, F.
title `An Adaptive Spatial Filter for User-Independent Single Trial Detection of Event-Related Potentials <http://dx.doi.org/10.1109/TBME.2015.2402252>`_
journal IEEE Transactions on Biomedical Engineering
publisher IEEE
doi 10.1109/TBME.2015.2402252
volume 62
issue 7
pages 1696 - 1705
year 2015
========= ==============================================================
========= ==============================================================
main source: raxDAWN
========= ==============================================================
author Krell, M. M. and Seeland, A. and Woehrle, H.
title `raxDAWN: Circumventing Overfitting of the Adaptive xDAWN`
book Proceedings of the International Congress on Neurotechnology, Electronics and Informatics
publisher SciTePress
doi 10.5220/0005657500680075
year 2015
========= ==============================================================
**Parameters**
:lambda_signal: update coefficient for weighting
old samples of the signal.
(*optional, default: 1.0*)
:lambda_noise: forgetting factor for weighting old samples of the noise.
(*optional, default: 1.0*)
:comp_type: Type of computation
Either use iterative GED (*rls*) or the eigh function from scipy
(*eig*). *eig* will not enable an iterative procedure and
is just integrated for comparison with the original method and for
testing the incremental approach. Depending on the scipy version,
the :func:`scipy.linalg.eigh` function might raise an error or
deliver unexpected results.
(*optional, default: rls*)
:delta: Factor for identity matrix in initialization of
inverse correlation matrix.
(*optional, default: 0.25*)
:w_ini: Factor for random filter initialization
(*optional, default: 0.01*)
:regularization: Currently only *Tikhonov* regularization is
implemented. By default no regularization is active using *False*.
For the regularization, the *lambda_reg* parameter should be
optimized.
(*optional, default: False*)
:lambda_reg:
Positive regularization constant to weight between
signal-plus-noise energy and chosen regularization term
(see also the *regularization* parameter).
Values between 100 and 1000 seem to be appropriate.
Values below 1 won't have a real effect.
This parameter should be roughly optimized, when used.
(*optional, default: 100*)
**Exemplary Call**
.. code-block:: yaml
-
node : axDAWN
parameters:
erp_class_label : "Target"
retained_channels : 32
store : True
lambda_signal : 0.99
lambda_noise : 0.99
lambda_reg : 100
:Author: Hendrik Woehrle (hendrik.woehrle@dfki.de)
:Created: 2012/05/25
"""
[docs] def __init__(self,
comp_type="rls",
lambda_signal=1.0,
lambda_noise=1.0,
delta=0.25,
w_ini=0.01,
regularization=False,
lambda_reg=100,
**kwargs):
super(AXDAWNNode, self).__init__(**kwargs)
delta = float(delta)
lambda_reg=float(lambda_reg)
if not delta > 0:
raise NotImplementedError("Delta < 0 is not supported.")
if not lambda_reg > 0:
raise NotImplementedError("Lambda_reg < 0 is not supported.")
self.set_permanent_attributes(
class_labels=[],
lambda_signal=lambda_signal,
lambda_noise=lambda_noise,
predict_lambda_signal=1.0,
predict_lambda_noise=1.0,
delta=delta,
w_ini=w_ini,
regularization=regularization,
lambda_reg=lambda_reg,
ai=None,
R1=None,
R2=None,
R2inv=None,
filters=None,
num_noise=0,
num_signals=0,
comp_type=comp_type,
num_train_items=0)
[docs] def initialize_filters(self, data):
""" Filter initialization which requires the first data sample """
if self.ai is None:
numpy.random.seed(self.run_number)
self.ai = numpy.zeros(data.shape)
self.R1 = numpy.zeros((self.retained_channels, data.shape[1],
data.shape[1]))
if not self.regularization:
self.R2 = numpy.zeros((data.shape[1], data.shape[1]))
self.R2inv = \
self.delta * numpy.eye(data.shape[1], data.shape[1])
elif self.regularization == "Tikhonov":
self.R2 = self.lambda_reg * \
numpy.eye(data.shape[1], data.shape[1])
self.R2inv = 1 / self.lambda_reg * \
numpy.eye(data.shape[1], data.shape[1])
else:
raise NotImplementedError(
"'%s' is not supported. Use 'Tikhonov' or False!"
% self.regularization)
self.wi = self.w_ini * numpy.random.rand(
data.shape[1], self.retained_channels)
self.filters = self.wi
[docs] def _train(self, data, class_label):
""" Incremental update procedure
This method is used for initial training and incremental training
"""
self.num_train_items += 1
if class_label not in self.class_labels:
self.class_labels.append(class_label)
data = data.view(numpy.ndarray)
self.initialize_filters(data)
# a target => signal
if class_label == self.erp_class_label:
# update signal estimation
self.num_signals += 1
self.ai = self.predict_lambda_signal * self.ai + \
(data - self.ai) / self.num_signals
self.R1[0] = numpy.dot(self.ai.T, self.ai)
# update noise estimation
self.adapt_inverse_noise_correlation(data)
else:
# only update noise
self.num_noise += 1
self.adapt_inverse_noise_correlation(data)
# we should have a first "target", before we really compute the weights
if self.num_signals == 0:
return
if self.comp_type == "eig":
D, V = scipy.linalg.eigh(self.R1[0], self.R2, right=True)
D = D.real
V = V.real
# Sorting the eigenvalues in ascending order
I = numpy.argsort(D)
# Eigenvectors are sorted in descending order
V = V[:, I[::-1]]
self.filters = V.T
elif self.comp_type == "rls":
# compute the generalized eigenvalue decomposition
# based on the RLS algorithm of Rao and Principe
I = numpy.eye(self.R1[0].shape[1])
for i in xrange(self.retained_channels):
if i > 0:
w_old = self.wi[:, i-1]
w_old = numpy.reshape(w_old,(self.R1[0].shape[1], 1))
Rold = self.R1[i-1, :]
Rold = numpy.reshape(Rold,
(self.R1[0].shape[1], self.R1[0].shape[1]))
r_num = numpy.dot(w_old, w_old.T)
r_num = numpy.dot(Rold, r_num)
r_denom = numpy.dot(w_old.T, Rold)
r_denom = numpy.dot(r_denom, w_old)
scale = r_num / r_denom
Rnew = numpy.dot(I - scale, Rold)
self.R1[i] = Rnew
else:
Rnew = self.R1[0]
w_new = self.wi[:,i]
w_new = numpy.reshape(w_new, (self.R1[0].shape[1],1))
w_num = numpy.dot(w_new.T, self.R2)
w_num = numpy.dot(w_num, w_new)
w_denom = numpy.dot(w_new.T, Rnew)
w_denom = numpy.dot(w_denom, w_new)
sol = self.R2inv
w_sol = numpy.dot(sol, Rnew)
w_sol_w = numpy.dot(w_sol, w_new)
w_sol_scale = w_num/w_denom * w_sol_w
w_norm = w_sol_scale / numpy.linalg.norm(w_sol_scale)
self.wi[:, i] = w_norm[:, 0]
denom_factors = \
numpy.diag(numpy.dot(numpy.dot(self.wi.T, self.R2), self.wi))
weights = numpy.zeros(self.wi.shape)
for i in range(self.wi.shape[1]):
weights[:, i] = numpy.sqrt(
(1/denom_factors[self.retained_channels-i-1])) * \
self.wi[:, self.retained_channels-i-1]
self.filters = weights
[docs] def adapt_inverse_noise_correlation(self, data):
# compute the inverse of the noise correlation technique
# based on the Sherman-Morrison formula
Ri = self.R2inv
for i in xrange(data.shape[0]):
self.R2 = self.predict_lambda_noise * self.R2 + \
numpy.dot(data[[i], :].T, data[[i], :])
u = data[[i], :].T
vt = u.T
Riu = numpy.dot(Ri, u)
vtRi = numpy.dot(vt, Ri)
denom = 1.0 + 1.0/self.predict_lambda_noise * numpy.dot(vt, Riu)
Ri = 1.0/self.predict_lambda_noise * Ri - \
1.0/self.predict_lambda_noise**2 * numpy.dot(Riu, vtRi) / denom
self.R2inv = Ri
[docs] def store_state(self, result_dir, index=None):
""" Stores this node in the given directory *result_dir* """
if self.store:
super(AXDAWNNode,self).store_state(result_dir)
[docs] def _stop_training(self, debug=False):
self.predict_lambda_signal = self.lambda_signal
self.predict_lambda_noise = self.lambda_noise
[docs] def _inc_train(self, data, label):
self._train(data, label)
[docs] def _execute(self, data):
""" Apply the learned spatial filters to the given data point """
self.initialize_filters(data)
if self.channel_names is None:
self.channel_names = data.channel_names
if self.retained_channels in [None, 'None']:
self.retained_channels = len(self.channel_names)
if len(self.channel_names) < self.retained_channels:
self.retained_channels = len(self.channel_names)
self._log("To many channels chosen for the retained channels! "
"Replaced by maximum number.", level=logging.CRITICAL)
data_array = data.view(numpy.ndarray)
# Project the data using the learned spatial filters
projected_data = numpy.dot(data_array,
self.filters[:, :self.retained_channels])
if self.xDAWN_channel_names is None:
self.xDAWN_channel_names = \
["xDAWN%03d" % i for i in range(self.retained_channels)]
return TimeSeries(projected_data, self.xDAWN_channel_names,
data.sampling_frequency, data.start_time,
data.end_time, data.name, data.marker_name)
_NODE_MAPPING = {"xDAWN": XDAWNNode,
"axDAWN": AXDAWNNode,
"Sparse_xDAWN": SparseXDAWNNode}