Source code for pySPACE.tests.unittests.nodes.sink.test_time_series_sink

#!/usr/bin/python

"""
This module contains unit tests that test storing and loading using the sink and source nodes

:Author: Jan Hendrik Metzen (jhm@informatik.uni-bremen.de)
:Created: 2008/12/18
"""

import unittest
import os
import shutil
import numpy

if __name__ == '__main__':
    import sys
    import os
    # The root of the code
    file_path = os.path.dirname(os.path.abspath(__file__))
    sys.path.append(file_path[:file_path.rfind('pySPACE')-1])


try:
    from pySPACE.tests.utils.data.test_data_generation import SimpleTimeSeriesSourceNode
except:
    from pySPACE.missions.nodes.source.test_source_nodes import SimpleTimeSeriesSourceNode
from pySPACE.missions.nodes.source.time_series_source import TimeSeriesSourceNode
from pySPACE.missions.nodes.sink.time_series_sink import TimeSeriesSinkNode

from pySPACE.resources.dataset_defs.base import BaseDataset
        
[docs]class TimeSeriesSinkTestCase(unittest.TestCase): """ test both TimeSeries sink and source """
[docs] def test_time_series_storing(self): if os.path.exists('tmp') is False : os.makedirs('tmp') source = SimpleTimeSeriesSourceNode() sink = TimeSeriesSinkNode() sink.register_input_node(source) sink.set_run_number(0) sink.process_current_split() result_collection = sink.get_result_dataset() result_collection.store('tmp') #sink.store_results("test_time_series_storing.tmp") reloaded_collection = BaseDataset.load('tmp') reloader = TimeSeriesSourceNode() reloader.set_input_dataset(reloaded_collection) #set_permanent_attributes(time_series_file = "test_time_series_storing.tmp") orig_data = list(source.request_data_for_testing()) restored_data = list(reloader.request_data_for_testing()) # Check that the two list have the same length self.assertEqual(len(orig_data), len(restored_data), "Numbers of time series before storing and after reloading are not equal!") # Check that there is a one-to-one correspondence for orig_datapoint, orig_label in orig_data: found = False for restored_datapoint, restored_label in restored_data: found |= (orig_datapoint.view(numpy.ndarray) == restored_datapoint.view(numpy.ndarray)).all() \ and (orig_label == restored_label) if found: break self.assert_(found, "One of the original time series cannot not be found after reloading") shutil.rmtree('tmp') # Cleaning up...
if __name__ == '__main__': suite = unittest.TestLoader().loadTestsFromName('test_sink_source') unittest.TextTestRunner(verbosity=2).run(suite)