Source code for pySPACE.environments.live.trainer

""" The module that trains pyspace flows.
"""
import os
import glob
import multiprocessing
import time
import shutil
import logging
import yaml
import datetime
import re

from pySPACE.environments.chains.node_chain import NodeChain, NodeChainFactory
from pySPACE.resources.dataset_defs.base import BaseDataset
from pySPACE.resources.dataset_defs.time_series import TimeSeriesDataset
from pySPACE.environments.live import eeg_stream_manager

online_logger = logging.getLogger("OnlineLogger")


[docs]class LiveTrainer(object): """ The class is responsible to perform all tasks in pyspace live that are related to the training process of pyspace The trained flows are stored in the flow_storage directory. """
[docs] def __init__(self, flow_storage = "flow_storage", prewindowed_data_directory = "prewindowed_data_storage"): self.training_active_potential = {} # : path to storage location for node_chain defs and pickles self.flow_storage = flow_storage # : path to storage location for prewindowed data self.prewindowed_data_directory = prewindowed_data_directory # : stores node_chain definition as dictionary self.node_chain_definitions = {} # : stores executable node_chains self.node_chains = {} self.train_process = {} self.prewindowed_data = {} self.queue = {} self.data_stream_process = {} self.window_stream = {} self.target_shown = {} self.last_target_data = {} self.marker_windower = {} self.training_paused_potential = multiprocessing.Value('b',False) self.nullmarker_stride_ms = None
[docs] def set_controller(self,controller): """ Set reference to the calling controller """ self.controller = controller
[docs] def set_eeg_stream_manager(self, stream_manager): """ Set the stream manager that provides the training data """ self.stream_manager = stream_manager
[docs] def prepare_training(self, training_files, potentials, operation, nullmarker_stride_ms = None): """ Prepares pyspace live for training. Prepares everything for training of pyspace live, i.e. creates flows based on the dataflow specs and configures them. """ online_logger.info( "Preparing Training") self.potentials = potentials self.operation = operation self.nullmarker_stride_ms = nullmarker_stride_ms if self.nullmarker_stride_ms == None: online_logger.warn( 'Nullmarker stride interval is %s. You can specify it in your parameter file.' % self.nullmarker_stride_ms) else: online_logger.info( 'Nullmarker stride interval is set to %s ms ' % self.nullmarker_stride_ms) online_logger.info( "Creating flows..") for key in self.potentials.keys(): spec_base = self.potentials[key]["configuration"].spec_dir if self.operation == "train": self.potentials[key]["node_chain"] = os.path.join(spec_base, self.potentials[key]["node_chain"]) online_logger.info( "node_chain_spec:" + self.potentials[key]["node_chain"]) elif self.operation in ("prewindowing", "prewindowing_offline"): self.potentials[key]["prewindowing_flow"] = os.path.join(spec_base, self.potentials[key]["prewindowing_flow"]) online_logger.info( "prewindowing_dataflow_spec: " + self.potentials[key]["prewindowing_flow"]) elif self.operation == "prewindowed_train": self.potentials[key]["postprocess_flow"] = os.path.join(spec_base, self.potentials[key]["postprocess_flow"]) online_logger.info( "postprocessing_dataflow_spec: " + self.potentials[key]["postprocess_flow"]) self.training_active_potential[key] = multiprocessing.Value("b",False) online_logger.info("Path variables set for NodeChains") # check if multiple potentials are given for training if isinstance(training_files, list): self.training_data = training_files else: self.training_data = [training_files] # Training is done in separate processes, we send the time series # windows to these threads via two queues online_logger.info( "Initializing Queues") for key in self.potentials.keys(): self.queue[key] = multiprocessing.Queue() def flow_generator(key): """create a generator to yield all the abri flow windows""" # Yield all windows until a None item is found in the queue while True: window = self.queue[key].get(block = True, timeout = None) if window == None: break yield window # Create the actual data flows for key in self.potentials.keys(): if self.operation == "train": self.node_chains[key] = NodeChainFactory.flow_from_yaml(Flow_Class = NodeChain, flow_spec = file(self.potentials[key]["node_chain"])) self.node_chains[key][0].set_generator(flow_generator(key)) flow = open(self.potentials[key]["node_chain"]) elif self.operation in ("prewindowing", "prewindowing_offline"): online_logger.info("loading prewindowing flow..") online_logger.info("file: " + str(self.potentials[key]["prewindowing_flow"])) self.node_chains[key] = NodeChainFactory.flow_from_yaml(Flow_Class = NodeChain, flow_spec = file(self.potentials[key]["prewindowing_flow"])) self.node_chains[key][0].set_generator(flow_generator(key)) flow = open(self.potentials[key]["prewindowing_flow"]) elif self.operation == "prewindowed_train": self.node_chains[key] = NodeChainFactory.flow_from_yaml(Flow_Class = NodeChain, flow_spec = file(self.potentials[key]["postprocess_flow"])) replace_start_and_end_markers = False final_collection = TimeSeriesDataset() final_collection_path = os.path.join(self.prewindowed_data_directory, key, "all_train_data") # delete previous training collection if os.path.exists(final_collection_path): online_logger.info("deleting old training data collection for " + key) shutil.rmtree(final_collection_path) # load all prewindowed collections and # append data to the final collection prewindowed_sets = \ glob.glob(os.path.join(self.prewindowed_data_directory, key, "*")) if len(prewindowed_sets) == 0: online_logger.error("Couldn't find data, please do prewindowing first!") raise Exception online_logger.info("concatenating prewindowed data from " + str(prewindowed_sets)) for s,d in enumerate(prewindowed_sets): collection = BaseDataset.load(d) data = collection.get_data(0, 0, "train") for d,(sample,label) in enumerate(data): if replace_start_and_end_markers: # in case we concatenate multiple 'Window' labeled # sets we have to remove every start- and endmarker for k in sample.marker_name.keys(): # find '{S,s} 8' or '{S,s} 9' m = re.match("^s\s{0,2}[8,9]{1}$", k, re.IGNORECASE) if m is not None: online_logger.info(str("remove %s from %d %d" % (m.group(), s, d))) del(sample.marker_name[m.group()]) if s == len(prewindowed_sets)-1 and \ d == len(data)-1: # insert endmarker sample.marker_name["S 9"] = [0.0] online_logger.info("added endmarker" + str(s) + " " + str(d)) if s == 0 and d == 0: # insert startmarker sample.marker_name["S 8"] = [0.0] online_logger.info("added startmarker" + str(s) + " " + str(d)) final_collection.add_sample(sample, label, True) # save final collection (just for debugging) os.mkdir(final_collection_path) final_collection.store(final_collection_path) online_logger.info("stored final collection at " + final_collection_path) # load final collection again for training online_logger.info("loading data from " + final_collection_path) self.prewindowed_data[key] = BaseDataset.load(final_collection_path) self.node_chains[key][0].set_input_dataset(self.prewindowed_data[key]) flow = open(self.potentials[key]["postprocess_flow"]) # create window_stream for every potential if self.operation in ("prewindowing"): window_spec_file = os.path.join(spec_base,"node_chains","windower", self.potentials[key]["windower_spec_path_train"]) self.window_stream[key] = \ self.stream_manager.request_window_stream(window_spec_file, nullmarker_stride_ms = self.nullmarker_stride_ms) elif self.operation in ("prewindowing_offline"): pass elif self.operation in ("train"): pass self.node_chain_definitions[key] = yaml.load(flow) flow.close() # TODO: check if the prewindowing flow is still needed when using the stream mode! if self.operation in ("train"): online_logger.info( "Removing old flows...") try: shutil.rmtree(self.flow_storage) except: online_logger.info("Could not delete flow storage directory") os.mkdir(self.flow_storage) elif self.operation in ("prewindowing", "prewindowing_offline"): # follow this policy: # - delete prewindowed data older than 12 hours # - always delete trained/stored flows now = datetime.datetime.now() then = now - datetime.timedelta(hours=12) if not os.path.exists(self.prewindowed_data_directory): os.mkdir(self.prewindowed_data_directory) if not os.path.exists(self.flow_storage): os.mkdir(self.flow_storage) for key in self.potentials.keys(): found = self.find_files_older_than(then, \ os.path.join(self.prewindowed_data_directory, key)) if found is not None: for f in found: online_logger.info(str("recursively deleting files in \'%s\'" % f)) try: shutil.rmtree(os.path.abspath(f)) except Exception as e: # TODO: find a smart solution for this! pass # dir was probably already deleted.. if os.path.exists(os.path.join(self.prewindowed_data_directory, key, "all_train_data")): shutil.rmtree(os.path.join(self.prewindowed_data_directory, key, "all_train_data")) online_logger.info("deleted concatenated training data for " + key) online_logger.info( "Training preparations finished") return 0
[docs] def find_files_older_than(self, then, dir): # recursively find files in 'dir' which are older # then 'date' and add their basepath to 'found' found = None for r,d,f in os.walk(dir): for file in f: if file.startswith("."): continue abs_file = os.path.abspath(os.path.join(r, file)) if os.path.getmtime(abs_file) < time.mktime(then.timetuple()): if found is None: found = list() print f, " -> adding -> ", r found.append(r) online_logger.info(str("pathes to delete: %s" % found)) return found
[docs] def training_fct(self, key): """ Function that performs the real training """ self.training_active_potential[key].value= True online_logger.info( key + " " + self.operation + " started") if self.operation in ("train", "prewindowed_train"): self.node_chains[key].train() elif self.operation in ("prewindowing", "prewindowing_offline"): result_collection = {} self.node_chains[key][-1].process_current_split() result_collection[key] = self.node_chains[key][-1].get_result_dataset() save_dir = os.path.abspath(self.prewindowed_data_directory + os.path.sep + key) if not os.path.exists(save_dir): os.mkdir(os.path.abspath(self.prewindowed_data_directory + os.path.sep + key)) if result_collection[key] != None: online_logger.info("storing result collection for " + key) now = datetime.datetime.now() now_folder = str("%04d%02d%02d-%02d%02d%02d" % \ (now.year, now.month, now.day, now.hour, now.minute, now.second)) p = os.path.abspath(os.path.join(self.prewindowed_data_directory, \ key, now_folder)) if not os.path.exists(p): os.mkdir(p) result_collection[key].store(p) online_logger.info( key + " Prewindowed data stored!") else: online_logger.warn(str("result-collection for %s was None - nothing stored.." % key)) online_logger.info( key + " " + self.operation + " finished") online_logger.info( "Storing " + key +" flow model...") self.node_chains[key].save("%s/%s.pickle" % (self.flow_storage, self.operation + "_flow_"+ key)) f = open('%s/%s.yaml' % (self.flow_storage, self.operation +"_flow_"+ key),"w") yaml.dump(self.node_chain_definitions[key], f, default_flow_style=False) f.close() online_logger.info( key + " Flow Model stored!") self.training_active_potential[key].value = False
[docs] def triggered_queue_filler_training(self,data,label, key): if label in self.potentials[key]['positive_event']: self.target_shown[key] = True self.last_target_data[key] = data elif label in self.potentials[key]['trigger_event']: if self.target_shown[key] == True: self.queue[key].put((self.last_target_data[key], self.potentials[key]['positive_event'])) self.target_shown[key] = False elif label in self.potentials[key]['negative_event']: self.queue[key].put((data, label))
[docs] def classification_thread(self, key): """ Thread that processes external training commands """ window_counter = 0 active = False for data, label in self.window_stream[key]: if self.training_paused_potential.value == True: break online_logger.info("Got instance number "+ str(window_counter) + " with class %s" % label) window_counter += 1 # Skip the first few training examples since there might be no # clear distinction between standards and targets if "ignore_num_first_examples" in self.potentials[key]: if window_counter < int(self.potentials[key]["ignore_num_first_examples"]): online_logger.info("Ignoring first " + str(window_counter) + " " + key + " training samples") continue if self.potentials[key].has_key("trigger_event"): self.triggered_queue_filler_training(data, label, key) # distribution is performed only if it is activated beforehand elif self.potentials[key].has_key("activation_label"): if label in self.potentials[key]["activation_label"]: online_logger.warn("Detection of " + key + "started") active = True if label in self.potentials[key]["positive_event"] and active: self.event_queue.put(self.potentials[key]["positive_event"]) self.queue[key].put((data, label)) if label in self.potentials[key]["deactivation_label"]: online_logger.warn("Detection of " + key + "stopped") active = False else: if label in self.potentials[key]["positive_event"]: self.queue[key].put((data, label)) elif label in self.potentials[key]["negative_event"]: self.queue[key].put((data, label)) online_logger.info( "Streaming data finished") online_logger.debug("Submit stream end data item...") self.queue.put(None) online_logger.debug("Stream end data item submitted")
[docs] def stream_data(self, key): """ A function that forwards the data to the worker threads """ spec_base = self.potentials[key]["configuration"].spec_dir window_spec_file = {} if self.operation in ("prewindowing"): online_logger.info(str("streaming data for %s started" % key)) self.classification_thread(key) # all done! self.queue[key].put(None) online_logger.info(str("%s for %s finished" % (self.operation, key))) elif self.operation in ("prewindowing_offline"): data_set_count = 0 # create local stream manager local_streaming = eeg_stream_manager.LiveEegStreamManager(online_logger) for train_dataset in self.training_data: # continue if we are not supposed to train any further if self.training_paused_potential.value == True: continue # stream local file local_streaming.stream_local_file(train_dataset) # create window stream window_spec_file[key] = os.path.join(spec_base, "node_chains", "windower", self.potentials[key]["windower_spec_path_train"]) self.window_stream[key] = local_streaming.request_window_stream(window_spec_file[key], \ nullmarker_stride_ms=self.nullmarker_stride_ms) # process the data online_logger.info(str("streaming data for %s started" % key)) self.classification_thread(key) data_set_count += 1 online_logger.info(str("dataset %d completely streamed for %s" % (data_set_count, key))) local_streaming.stop() # all done! self.queue[key].put(None) online_logger.info(str("training for %s finished" % key)) elif self.operation in ("train"): data_set_count = 0 local_streaming = eeg_stream_manager.LiveEegStreamManager(online_logger) for train_dataset in self.training_data: if self.training_paused_potential.value == True: continue online_logger.info("Start streaming training dataset " + train_dataset) # Start EEG client local_streaming.stream_local_file(train_dataset) # create windower print 'window specs: ' , self.potentials[key]["windower_spec_path_train"] window_spec_file[key] = \ os.path.join(spec_base, "node_chains", "windower", self.potentials[key]["windower_spec_path_train"]) self.window_stream[key] = \ local_streaming.request_window_stream(window_spec_file[key], \ nullmarker_stride_ms=self.nullmarker_stride_ms) online_logger.info(key + " windower: " + str(self.window_stream[key])) self.classification_thread(key) data_set_count += 1 online_logger.info(str("dataset %d completely streamed for %s" % (data_set_count, key))) #local_streaming.stop() self.queue[key].put(None)
[docs] def start_training(self, operation, profiling=False): """ Trains flows on the streamed data """ for key in self.potentials.keys(): assert(not self.training_active_potential[key].value == True) # Stream the data if self.operation in ("train", "prewindowing", "prewindowing_offline"): self.data_stream_process[key] = multiprocessing.Process(target = self.stream_data, args = (key,)) self.data_stream_process[key].start() time.sleep(0.1) else: pass if not key in self.train_process.keys(): # Start multiple threads for training self.train_process[key] = multiprocessing.Process(target = self.training_fct, args = (key,)) #start all processes self.train_process[key].start() self.training_paused_potential.value = False # wait until training processes are set up and running # they should run after 30s setup_timer = 0 while not self.is_training_active(): if setup_timer > 30: online_logger.error("Training processes not started") raise RuntimeError("Training processes not started") else: time.sleep(1) setup_timer+=1 if self.operation in ("train", "prewindowing", "prewindowing_offline", "prewindowed_train"): for key in self.train_process.iterkeys(): while self.train_process[key].is_alive(): time.sleep(1)
[docs] def is_training_active(self): """ Returns whether training is finished or still running """ active = False alive = False for key in self.potentials.keys(): active |= self.training_active_potential[key].value for key in self.train_process.iterkeys(): alive |= self.train_process[key].is_alive() if not alive: active = False return active
[docs] def process_external_command(self, command): """ Process external stop command """ if command == "STOP": for key in self.data_stream_process.keys(): self.data_stream_process[key].terminate() self.pause_training()
[docs] def pause_training(self): """ Pause the training phase """ self.training_paused_potential.value = True
[docs] def stop_training(self): """ Force the end of the training """ # We stop the training by disconnecting the data stream from it def read(**kwargs): online_logger.info( "Canceling EEG transfer") return 0 online_logger.info( "Stopping training ...") # Wait until training has finished for key in self.potentials.keys(): online_logger.info("Check if training is still active ...") while self.is_training_active(): time.sleep(1) online_logger.info("Training is still active ...") self.train_process[key].join() online_logger.info("Training finished") return 0