Source code for pySPACE.missions.nodes.splitter.transfer_splitter

# This Python file uses the following encoding: utf-8
""" Splits data into training and test data

.. todo:: Divide the node into splitting and data set filtering node.
"""
import random

from pySPACE.missions.nodes.base_node import BaseNode
from pySPACE.tools.memoize_generator import MemoizeGenerator

[docs]class TransferSplitterNode(BaseNode): """ Allow to split data into training and test data sets according to different window definitions Splits the available data into disjunct training and test sets. The transfer of different training and test window definitions is supported. The node was implemented with several use cases in mind: - The training set contains instances of 'Standard' and 'Target' stimuli but the test set of 'Target' and 'MissedTarget' stimuli. - The training set contains instances of 'LRP' with different training times and 'NoLRPs', but the test set should contain sliding windows. Cross validation should be supported to use the node together with parameter optimization node. - The use of merged data sets should be possible. **Parameters** :wdefs_train: A list with window definition names (specified in the window spec file when the raw data was segmented). All windows that belong to one of the window definition are considered when the training set(s) is(/are) determined. :wdefs_test: A list with window definition names (specified in the window spec file when the raw data was segmented). All windows that belong to one of the window definition are considered when the testing set(s) is(/are) determined. :split_method: One of the following Strings: 'all_data', 'time', 'count', 'set_flag'. - all_data : All possible data is used in every split. This results in splitting only window definitions that occur in both, *wdefs_train* AND *wdefs_test*. Window definitions that only occur in either *wdefs_train* or *wdefs_test* are retained in every split. - time : The data is sorted and split according to time. For that (*start_time* of last window - *start_time* of first window)/*nr_of_splits*) is determined. Since time in eeg data is relative for every set, ensure that each input collection consists only of one data set (is not a merge of several sets) or that the change_option has been used. - count : The data is split according to *num_split_instances*. By default only windows specified in both, *wdefs_train* and *wdefs_test*, are count. With the parameter *wdefs_split* window definition that are count can be specified. If *num_split_instances* is not specified, *splits* determines how many instances of *wdefs_split* are in one split. - set_flag : When the data has been merged with the concatenate operation before, a flag 'new_set' has been inserted to the time series specs. Splits are based on this flag, i.e. the splits behave like a inter-set cross validation. For example you merged 3 sets: 'A', 'B', 'C', then there are 3 splits generated: 'A'+'B' vs 'C', 'A'+'C' vs 'B' and 'B'+'C' vs 'A'. - set_flag_reverse: When the data has been merged with the concatenate operation before, a flag 'new_set' has been inserted to the time series specs. Splits are based on this flag, i.e. the splits behave like a reverse-inter-set cross validation. For example you merged 3 sets: 'A', 'B', 'C', then there are 3 splits generated: 'A' vs 'C'+'B', 'B' vs 'A'+'C' and 'C' vs 'A'+'B'. :random: If True, the data is randomized before splitting. .. note:: It is not guaranteed that overlapping windows will be in the same split for split methods 'time' and 'all_data'! (*optional, default: False*) :splits: The number of splits created internally and the number of train-test pairs. (*optional, default: 10*) :num_split_instances: If *split_method* is 'count', *num_split_instances* specifies how many instances will be in one split. After splitting one split is evaluated according to *wdefs_test* for the test data set and the remaining splits according to *wdefs_train*. The test split is iterated. If the total number of instances that are count is not divisible by *num_split_instances* the last split will contain the remaining instances. If in addition *splits* is set to 1, only one train-test pair is created with *num_split_instances* in the training set. (*optional, default: None*) :wdefs_split: A list with window definition names (specified in the window spec file when the raw data was segmented). All windows that belong to one of the window definition are counted when *split_method* was set to 'count'. (*optional, default: None*) :reverse: If this option is True, the data is split in reverse ordering. (*optional, default: False*) **Exemplary Call** .. code-block:: yaml - node : TransferSplitter parameters : wdefs_train : ['s2', 's1'] wdefs_test : ['s5', 's2'] split_method : "all_data" splits : 5 :Author: Anett Seeland (anett.seeland@dfki.de) :Created: 2011/04/10 :LastChange: 2011/11/14 (traintest functionality) """
[docs] def __init__(self, wdefs_train, wdefs_test, split_method, wdefs_train_test = None, splits=10, random=False, num_split_instances=None, wdefs_split=None, reverse=False, sort=False, *args, **kwargs): super(TransferSplitterNode, self).__init__(*args, **kwargs) if wdefs_train_test == None: wdefs_train_test = [wdef for wdef in \ wdefs_train if wdef in wdefs_test], self.set_permanent_attributes(wdefs_train = wdefs_train, wdefs_test = wdefs_test, split_method = split_method, splits = splits, random = random, num_split_instances = num_split_instances, wdefs_split = wdefs_split, reverse = reverse, sort = sort, current_split = 0, wdefs_train_test = wdefs_train_test, split_indices_train = None, split_indices_test = None)
[docs] def is_split_node(self): """ Returns whether this is a split node. """ return True
[docs] def use_next_split(self): """ Use the next split of the data into training and test data. Returns True if more splits are available, otherwise False. This method is useful for benchmarking """ if self.current_split + 1 < self.splits: self.current_split = self.current_split + 1 self._log("Benchmarking with split %s/%s" % (self.current_split + 1, self.splits)) return True else: return False
[docs] def train_sweep(self, use_test_data): """ Performs the actual training of the node. .. note:: Split nodes cannot be trained """ raise Exception("Split nodes cannot be trained")
[docs] def request_data_for_training(self, use_test_data): # Create split lazily when required if self.split_indices_train == None: self._create_split() # Create training data generator self.data_for_training = MemoizeGenerator( self.data[i] for i in self.split_indices_train[self.current_split]) return self.data_for_training.fresh()
[docs] def request_data_for_testing(self): # Create split lazily when required if self.split_indices_test == None: self._create_split() # Create test data generator self.data_for_testing = MemoizeGenerator( self.data[i] for i in self.split_indices_test[self.current_split]) return self.data_for_testing.fresh()
[docs] def _create_split(self): """ Create the split of the data into training and test data. """ self._log("Splitting data into train and test data") # Get training and test data # note: return the data in a list can double the memory requirements! train_data = list(self.input_node.request_data_for_training( use_test_data = False)) test_data = list(self.input_node.request_data_for_testing()) # If there is already a non-empty training set, # it means that we are not the first split node in the node chain. if len(train_data) > 0: if len(test_data)==0: # If there was an All_Train_Splitter before, filter according # to wdef_train and return all training data self.split_indices_train = \ [[ind for ind, (win, lab) in enumerate(train_data) \ if win.specs['wdef_name'] in self.wdefs_train]] self.split_indices_test = [[]] self.splits = 1 self.data = train_data self._log("Using all data for training.") return else: raise Exception("No iterated splitting of data sets allowed\n " "(Calling a splitter on a data set that is already " "splitted)") # Remember all the data and store it in memory # TODO: This might cause problems for large dataset self.data = train_data + test_data del train_data, test_data if self.reverse: self.data = self.data[::-1] # sort the data according to the start time if self.sort or self.split_method == 'time': self.data.sort(key=lambda swindow: swindow[0].start_time) # randomize the data if needed if self.random: r = random.Random(self.run_number) if self.split_method == 'set_flag': self.random = False # TODO: log this elif self.split_method == 'count': if self.wdefs_split == None: self.wdefs_split = self.wdefs_train_test # divide the data with respect to the time data_time = dict() marker = -1 last_window_endtime = 0 for ind, (win, lab) in enumerate(self.data): if win.start_time < last_window_endtime: # overlapping windows or start of a new set if win.end_time < last_window_endtime: # new set marker += 1 data_time[marker]=[(win,lab)] else: # overlapping windows data_time[marker].append((win,lab)) else: marker += 1 data_time[marker]=[(win,lab)] last_window_endtime = win.end_time # randomize order of events by simultaneously keep the order of # sliding windows in each event data_random = data_time.values() r.shuffle(data_random) self.data = [] for l in data_random: self.data.extend(l) del data_random, data_time, l else: r.shuffle(self.data) if self.split_method == 'all_data': # divide the data with respect to *wdef_train*, *wdef_test* and # *wdef_train_test* wdef_data = {'wdef_train_test':[],'wdef_train':[],'wdef_test':[]} class_labels = [] for (index, (window, label)) in enumerate(self.data): if window.specs['wdef_name'] in self.wdefs_train_test: wdef_data['wdef_train_test'].append(index) if label not in class_labels: class_labels.append(label) elif window.specs['wdef_name'] in self.wdefs_train: wdef_data['wdef_train'].append(index) elif window.specs['wdef_name'] in self.wdefs_test: wdef_data['wdef_test'].append(index) else: import warnings warnings.warn("Found window definition %s, which is " \ "neither in *wdefs_train* nor in " \ "*wdefs_test*. Window %s will be ignored!" \ % (window.specs['wdef_name'],window.tag)) # check if splitting makes sense if wdef_data['wdef_train_test']==[] and self.splits>1: raise Exception('No instances to split, i.e train-test window'\ ' definitions are disjunct!') split_indices_train = [[] for i in range(self.splits)] split_indices_test = [[] for i in range(self.splits)] # calculate splits if wdef_data['wdef_train_test']!=[]: data_size = len(wdef_data['wdef_train_test']) # ensure stratified splits if there are several classes if len(class_labels)>1: # divide the data with respect to the class_label data_labeled = dict() for index in wdef_data['wdef_train_test']: if not data_labeled.has_key(self.data[index][1]): data_labeled[self.data[index][1]] = [index] else: data_labeled[self.data[index][1]].append(index) # have not more splits than instances of every class! min_nr_per_class = min([len(data) for data in \ data_labeled.values()]) if self.splits > min_nr_per_class: self.splits = min_nr_per_class self._log("Reducing number of splits to %s since no " \ "more instances of one of the classes are " \ "available." % self.splits) # determine the splits of the data for label, indices in data_labeled.iteritems(): data_size = len(indices) for j in range(self.splits): split_start = \ int(round(float(j)*data_size/self.splits)) split_end = \ int(round(float(j+1)*data_size/self.splits)) split_indices_test[j].extend([i for i in indices[split_start: split_end]\ if self.data[i][0].specs['wdef_name'] in self.wdefs_test]) split_indices_train[j].extend([i for i in indices \ if i not in split_indices_test[j]]) else: # len(class_labels) == 1 # have not more splits than instances! if self.splits > data_size: self.splits = data_size self._log("Reducing number of splits to %s since no " \ "more instances of one of the classes are " \ "available." % self.splits) # determine the splits of the data for j in range(self.splits): split_start = \ int(round(float(j)*data_size/self.splits)) split_end = \ int(round(float(j+1)*data_size/self.splits)) # means half-open interval [split_start, split_end) split_indices_test[j].extend( wdef_data['wdef_train_test'][split_start:split_end]) split_indices_train[j].extend([i for i in \ wdef_data['wdef_train_test'] if i \ not in split_indices_test[j]]) for i in range(self.splits): split_indices_train[i].extend(wdef_data['wdef_train']) split_indices_test[i].extend(wdef_data['wdef_test']) elif self.split_method == 'time': first_window_start = self.data[0][0].start_time last_window_start = self.data[-1][0].start_time # ensure, that time can never be greater than self.splits*time! time = round((last_window_start-first_window_start)/self.splits+0.5) # divide the data according to the time data_time = {0: []} time_fold = 0 for (index, (window, label)) in enumerate(self.data): if window.start_time > time_fold*time+time: time_fold += 1 data_time[time_fold]=[index] else: data_time[time_fold].append(index) split_indices_train = [[] for i in range(self.splits)] split_indices_test = [[] for i in range(self.splits)] for i in range(self.splits): split_indices_test[i].extend([index for index in data_time[i] \ if self.data[index][0].specs['wdef_name'] \ in self.wdefs_test]) for j in range(self.splits): split_indices_train[i].extend([index for index in data_time[j] \ if j != i and self.data[index][0].specs['wdef_name'] \ in self.wdefs_train]) elif self.split_method == 'count': if self.wdefs_split == None: self.wdefs_split = self.wdefs_train_test if self.num_split_instances == None: l = len([ind for ind, (win, lab) \ in enumerate(self.data) if win.specs['wdef_name'] \ in self.wdefs_split]) self.num_split_instances = round(float(l)/self.splits) # divide the data according to *num_split_instances* data_count = {0:[]} count = -1 count_fold = 0 if self.splits==1 and len([i for i in range(len(self.data)) \ if self.data[i][0].specs['wdef_name'] in self.wdefs_split])\ == self.num_split_instances: train_end = len(self.data) else: for (ind, (win, lab)) in enumerate(self.data): #print ind, win.specs['wdef_name'], lab if win.specs['wdef_name'] in self.wdefs_split: count += 1 if self.splits == 1 and \ count == self.num_split_instances: train_end = ind break if count != 0 and count % self.num_split_instances == 0: count_fold += 1 data_count[count_fold] = [ind] else: data_count[count_fold].append(ind) else: data_count[count_fold].append(ind) if self.splits != 1: # self.num_split_instances*self.splits < l, but in the case # when only num_split_instances is specified we can not trust # self.splits if len(data_count.keys()) == self.splits+1 or \ (len(data_count.keys())-1)*self.num_split_instances > l: data_count[count_fold-1].extend(data_count[count_fold]) del data_count[count_fold] self.splits = len(data_count.keys()) split_indices_train = [[] for i in range(self.splits)] split_indices_test = [[] for i in range(self.splits)] for i in range(self.splits): split_indices_test[i].extend([ind for ind in data_count[i] \ if self.data[ind][0].specs['wdef_name'] \ in self.wdefs_test]) for j in range(self.splits): split_indices_train[i].extend([ind for ind in data_count[j]\ if j != i and self.data[ind][0].specs['wdef_name'] \ in self.wdefs_train]) else: # self.splits == 1 split_indices_train = \ [[ind for ind in range(len(self.data[:train_end])) if \ self.data[ind][0].specs['wdef_name'] in self.wdefs_train]] split_indices_test = \ [[ind for ind in range(train_end,len(self.data)) if \ self.data[ind][0].specs['wdef_name'] in self.wdefs_test]] elif self.split_method == 'set_flag': # divide the data according to *new_set* flag in time series specs data_set = {0:[]} key_fold = 0 for (ind, (win, lab)) in enumerate(self.data): if win.specs['new_set']: key_fold += 1 data_set[key_fold]=[ind] else: data_set[key_fold].append(ind) self.splits = len(data_set.keys()) split_indices_train = [[] for i in range(self.splits)] split_indices_test = [[] for i in range(self.splits)] for i in range(self.splits): split_indices_test[i].extend([ind for ind in data_set[i] \ if self.data[ind][0].specs['wdef_name'] \ in self.wdefs_test]) for j in range(self.splits): split_indices_train[i].extend([ind for ind in data_set[j] \ if j != i and self.data[ind][0].specs['wdef_name'] \ in self.wdefs_train]) elif self.split_method == 'set_flag_reverse': # divide the data according to *new_set* flag in time series specs data_set = {0:[]} key_fold = 0 for (ind, (win, lab)) in enumerate(self.data): if win.specs['new_set']: key_fold += 1 data_set[key_fold]=[ind] else: data_set[key_fold].append(ind) self.splits = len(data_set.keys()) split_indices_train = [[] for i in range(self.splits)] split_indices_test = [[] for i in range(self.splits)] for i in range(self.splits): split_indices_train[i].extend([ind for ind in data_set[i] \ if self.data[ind][0].specs['wdef_name'] \ in self.wdefs_train]) for j in range(self.splits): split_indices_test[i].extend([ind for ind in data_set[j] \ if j != i and self.data[ind][0].specs['wdef_name'] \ in self.wdefs_test]) self.split_indices_train = split_indices_train self.split_indices_test = split_indices_test self._log("Benchmarking with split %s/%s" % (self.current_split + 1, self.splits))