Source code for pySPACE.tests.unittests.nodes.splitter.test_all_train_splitter
#!/usr/bin/python
"""
A module that tests the
:mod:`~pyspace.pySPACE.missions.nodes.splitter.all_train_splitter`
node
:Author: Andrei Ignat (Andrei_Cristian.Ignat@dfki.de)
:Created: 2014/06/04
"""
if __name__ == '__main__':
import sys
import os
# The root of the code
file_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(file_path[:file_path.rfind('pySPACE') - 1])
import unittest
from pySPACE.missions.nodes.splitter.all_train_splitter import *
from pySPACE.resources.data_types.time_series import TimeSeries
import pySPACE.tests.generic_unittest as gen_test
from pySPACE.missions.nodes.source.external_generator_source import *
[docs]class AllTrainSplitterTestCase(unittest.TestCase):
"""
The test itself is an embarrassingly simple one since the only thing
that the node should do is classify all the data points as training
data points. Therefore, the scenario of the test is to create an input
node, feed it some data points and then just check if all the data points
were classified as training data.
"""
[docs] def setUp(self):
# set up the channels
self.channel_names = ['Target', 'Standard']
self.points = []
# fill in the data points according to a given equation
for counter in range(100):
self.points.append((2 * counter, 13 * counter))
initial_data = TimeSeries(self.points, self.channel_names, 100)
# since the node was built for online analysis and splitting,
# we must fool it by giving it the input under the form of a node
# and not just a e.g. TimeSeries object
self.input_node = ExternalGeneratorSourceNode()
self.input_node.set_generator(initial_data)
[docs] def test_split(self):
splitter = AllTrainSplitterNode()
splitter.set_permanent_attributes(input_node=self.input_node)
# check if the test data set is empty
self.assertEqual(list(splitter.request_data_for_testing()), [])
# check if the train data set contains all the data points
self.assertEqual(
list(splitter.request_data_for_training(True)),
self.points)
if __name__ == '__main__':
suite = unittest.TestLoader().loadTestsFromName('test_all_train_splitter')
# check the generic unittests
suite.addTest(gen_test.ParametrizedTestCase.parametrize(
current_testcase=gen_test.GenericTestCase, node=AllTrainSplitterNode))
unittest.TextTestRunner(verbosity=2).run(suite)