Source code for pySPACE.tests.unittests.nodes.splitter.test_traintest_splitter
#!/usr/bin/python
"""
A module that tests the
:mod:`~pySPACE.home.aignat.pyspace.pySPACE.missions.nodes.splitter.traintest_splitter`
node
:Author: Andrei Ignat (Andrei_Cristian.Ignat@dfki.de)
:Created: 2014/06/03
"""
if __name__ == '__main__':
import sys
import os
# The root of the code
file_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(file_path[:file_path.rfind('pySPACE') - 1])
import unittest
from pySPACE.missions.nodes.splitter.traintest_splitter import *
from pySPACE.resources.data_types.time_series import TimeSeries
import pySPACE.tests.generic_unittest as gen_test
from pySPACE.missions.nodes.source.external_generator_source import *
[docs]class TrainTestSplitterTestCase(unittest.TestCase):
[docs] def setUp(self):
# set up the channels
self.channel_names = ['Target', 'Standard']
self.points = []
# fill in the data points according to a given equation
for cntr in range(100):
self.points.append((2 * cntr, 13 * cntr))
initial_data = TimeSeries(self.points, self.channel_names, 100)
# since the node was built for online analysis and splitting,
# we must fool it by giving it the input under the form of a node
# and not just a e.g. TimeSeries object
self.input_node = ExternalGeneratorSourceNode()
self.input_node.set_generator(initial_data)
[docs] def test_random_split(self):
splitter = TrainTestSplitterNode(train_ratio=0.3, random=True)
splitter.set_permanent_attributes(input_node=self.input_node)
splitter._create_split()
# we check if the split has the correct length
self.assertEqual(len(splitter.train_data), 30)
# and then we check if the split was done in a random way
self.assertNotEqual(splitter.train_data, self.points[:30])
[docs] def test_reverse_split(self):
splitter = TrainTestSplitterNode(
train_ratio=0.3,
random=False,
reverse=True)
splitter.set_permanent_attributes(input_node=self.input_node)
splitter._create_split()
# we check if the split has the correct length
self.assertEqual(len(splitter.train_data), 30)
# and then we check if the split was done in a random way
self.assertEqual(splitter.train_data, self.points[:30])
if __name__ == '__main__':
suite = unittest.TestLoader().loadTestsFromName('test_traintest_splitter')
# Test the generic initialization of the class methods
suite.addTest(gen_test.ParametrizedTestCase.parametrize(
current_testcase=gen_test.GenericTestCase, node=TrainTestSplitterNode))
unittest.TextTestRunner(verbosity=2).run(suite)