Source code for pySPACE.missions.nodes.spatial_filtering.fda
""" Fisher's Discriminant Analysis and variants for spatial filtering """
import os
import cPickle
try:
import mdp
except:
pass
import numpy
from pySPACE.missions.nodes.spatial_filtering.spatial_filtering import SpatialFilteringNode
from pySPACE.resources.data_types.time_series import TimeSeries
from pySPACE.tools.filesystem import create_directory
import logging
[docs]class FDAFilterNode(SpatialFilteringNode):
""" Reuse the implementation of Fisher's Discriminant Analysis provided by mdp
This node implements the supervised fisher's discriminant
analysis algorithm for spatial filtering.
**Parameters**
:retained_channels: Determines how many of the FDA pseudo channels
are retained. Default is None which means "all channels".
(*optional, default: None*)
:load_path: An absolute path from which the FDA filter
is loaded.
If not specified, this matrix is learned from the training data.
(*optional, default: None*)
**Exemplary Call**
.. code-block:: yaml
-
node : FDAFilter
parameters:
retained_channels : 42
:Author: Jan Hendrik Metzen (jhm@informatik.uni-bremen.de)
:Created: 2010/02/17
"""
[docs] def __init__(self, retained_channels=None, load_path=None, **kwargs):
# Must be set before constructor of superclass is set
self.trainable = (load_path == None)
super(FDAFilterNode, self).__init__(**kwargs)
# Load patterns from file if requested
filters = None
if load_path != None:
filters_file = open(load_path, 'r')
filters = cPickle.load(filters_file)
self.set_permanent_attributes(
trainable = self.trainable,
# The number of channels that will be retained
retained_channels=retained_channels,
# Gather all data instances passed during training
data=None,
# Remember the classes of the data
labels=None,
# After training is finished, this node will contain
# a projection matrix that is used to project
# the data onto a lower dimensional subspace
filters=filters,
new_channel_names = None,
channel_names = None
)
[docs] def is_trainable(self):
""" Returns whether this node is trainable. """
return self.trainable
[docs] def is_supervised(self):
""" Returns whether this node requires supervised training """
return self.trainable
[docs] def _train(self, data, label):
""" Remember *data* and *label* for later learning of filters."""
if self.channel_names is None:
self.channel_names = data.channel_names
# Simply gather all data and do the actual training in _stop_training
if self.data == None:
self.data = 1.0 * data
self.labels = [label for i in range(data.shape[0])]
self.channel_names = data.channel_names
else:
self.data = 1.0 * numpy.vstack([self.data, data])
self.labels.extend([label for i in range(data.shape[0])])
[docs] def _stop_training(self, debug=False):
# Uses collected data to learn a transformation matrix using LDA
fda_node = mdp.nodes.FDANode()
fda_node.train(self.data, numpy.array(self.labels))
fda_node.stop_training()
fda_node.train(self.data, numpy.array(self.labels))
fda_node.stop_training()
self.filters = fda_node.v
[docs] def _execute(self, data):
""" Execute learned transformation on *data*."""
# We must have computed the projection matrix
assert(self.filters != None)
if self.retained_channels==None:
self.retained_channels = data.shape[1]
if self.channel_names is None:
self.channel_names = data.channel_names
if len(self.channel_names)<self.retained_channels:
self.retained_channels = len(self.channel_names)
self._log("To many channels chosen for the retained channels! Replaced by maximum number.",level=logging.CRITICAL)
# Project the data using the learned FDA
projected_data = numpy.dot(data,
self.filters[:, :self.retained_channels])
if self.new_channel_names is None:
self.new_channel_names = ["fda%03d" % i
for i in range(self.retained_channels)]
return TimeSeries(projected_data, self.new_channel_names,
data.sampling_frequency, data.start_time,
data.end_time, data.name, data.marker_name)
[docs] def store_state(self, result_dir, index=None):
""" Stores the projection in the given directory *result_dir* """
if self.store:
node_dir = os.path.join(result_dir, self.__class__.__name__)
create_directory(node_dir)
name = "%s_sp%s.pickle" % ("projection", self.current_split)
result_file = open(os.path.join(node_dir, name), "wb")
result_file.write(cPickle.dumps(self.projection, protocol=2))
result_file.close()
_NODE_MAPPING = {"FDAFilter": FDAFilterNode,
"FDA": FDAFilterNode}