66from __future__ import (print_function , division , unicode_literals ,
77 absolute_import )
88from builtins import open
9+ import pytest
910
11+ from .... import config
1012from ... import engine as pe
1113from ....interfaces import base as nib
1214from ....interfaces .utility import IdentityInterface , Function , Merge
@@ -45,19 +47,15 @@ class IncrementOutputSpec(nib.TraitedSpec):
4547 output1 = nib .traits .Int (desc = 'ouput' )
4648
4749
48- class IncrementInterface (nib .BaseInterface ):
50+ class IncrementInterface (nib .SimpleInterface ):
4951 input_spec = IncrementInputSpec
5052 output_spec = IncrementOutputSpec
5153
5254 def _run_interface (self , runtime ):
5355 runtime .returncode = 0
56+ self ._results ['output1' ] = self .inputs .input1 + self .inputs .inc
5457 return runtime
5558
56- def _list_outputs (self ):
57- outputs = self ._outputs ().get ()
58- outputs ['output1' ] = self .inputs .input1 + self .inputs .inc
59- return outputs
60-
6159
6260_sums = []
6361
@@ -73,23 +71,19 @@ class SumOutputSpec(nib.TraitedSpec):
7371 operands = nib .traits .List (nib .traits .Int , desc = 'operands' )
7472
7573
76- class SumInterface (nib .BaseInterface ):
74+ class SumInterface (nib .SimpleInterface ):
7775 input_spec = SumInputSpec
7876 output_spec = SumOutputSpec
7977
8078 def _run_interface (self , runtime ):
81- runtime .returncode = 0
82- return runtime
83-
84- def _list_outputs (self ):
8579 global _sum
8680 global _sum_operands
87- outputs = self . _outputs (). get ()
88- outputs ['operands' ] = self .inputs .input1
89- _sum_operands . append ( outputs [ 'operands' ] )
90- outputs [ 'output1' ] = sum (self .inputs .input1 )
91- _sums .append (outputs [ 'output1' ] )
92- return outputs
81+ runtime . returncode = 0
82+ self . _results ['operands' ] = self .inputs .input1
83+ self . _results [ 'output1' ] = sum ( self . inputs . input1 )
84+ _sum_operands . append (self .inputs .input1 )
85+ _sums .append (sum ( self . inputs . input1 ) )
86+ return runtime
9387
9488
9589_set_len = None
@@ -148,35 +142,47 @@ def _list_outputs(self):
148142 return outputs
149143
150144
151- def test_join_expansion (tmpdir ):
145+ @pytest .mark .parametrize ('needed_outputs' , [True , False ])
146+ def test_join_expansion (tmpdir , needed_outputs ):
147+ global _sums
148+ global _sum_operands
149+ global _products
152150 tmpdir .chdir ()
153151
152+ # Clean up, just in case some other test modified them
153+ _products = []
154+ _sum_operands = []
155+ _sums = []
156+
157+ config .set ('execution' , 'remove_unnecessary_outputs' , ['false' , 'true' ][needed_outputs ])
154158 # Make the workflow.
155159 wf = pe .Workflow (name = 'test' )
156160 # the iterated input node
157161 inputspec = pe .Node (IdentityInterface (fields = ['n' ]), name = 'inputspec' )
158162 inputspec .iterables = [('n' , [1 , 2 ])]
159163 # a pre-join node in the iterated path
160164 pre_join1 = pe .Node (IncrementInterface (), name = 'pre_join1' )
161- wf .connect (inputspec , 'n' , pre_join1 , 'input1' )
162165 # another pre-join node in the iterated path
163166 pre_join2 = pe .Node (IncrementInterface (), name = 'pre_join2' )
164- wf .connect (pre_join1 , 'output1' , pre_join2 , 'input1' )
165167 # the join node
166168 join = pe .JoinNode (
167169 SumInterface (),
168170 joinsource = 'inputspec' ,
169171 joinfield = 'input1' ,
170172 name = 'join' )
171- wf .connect (pre_join2 , 'output1' , join , 'input1' )
172173 # an uniterated post-join node
173174 post_join1 = pe .Node (IncrementInterface (), name = 'post_join1' )
174- wf .connect (join , 'output1' , post_join1 , 'input1' )
175175 # a post-join node in the iterated path
176176 post_join2 = pe .Node (ProductInterface (), name = 'post_join2' )
177- wf .connect (join , 'output1' , post_join2 , 'input1' )
178- wf .connect (pre_join1 , 'output1' , post_join2 , 'input2' )
179177
178+ wf .connect ([
179+ (inputspec , pre_join1 , [('n' , 'input1' )]),
180+ (pre_join1 , pre_join2 , [('output1' , 'input1' )]),
181+ (pre_join1 , post_join2 , [('output1' , 'input2' )]),
182+ (pre_join2 , join , [('output1' , 'input1' )]),
183+ (join , post_join1 , [('output1' , 'input1' )]),
184+ (join , post_join2 , [('output1' , 'input1' )]),
185+ ])
180186 result = wf .run ()
181187
182188 # the two expanded pre-join predecessor nodes feed into one join node
@@ -185,8 +191,8 @@ def test_join_expansion(tmpdir):
185191 # the expanded graph contains 2 * 2 = 4 iteration pre-join nodes, 1 join
186192 # node, 1 non-iterated post-join node and 2 * 1 iteration post-join nodes.
187193 # Nipype factors away the IdentityInterface.
188- assert len (
189- result . nodes ()) == 8 , "The number of expanded nodes is incorrect."
194+ assert len (result . nodes ()) == 8 , "The number of expanded nodes is incorrect."
195+
190196 # the join Sum result is (1 + 1 + 1) + (2 + 1 + 1)
191197 assert len (_sums ) == 1 , "The number of join outputs is incorrect"
192198 assert _sums [
@@ -199,6 +205,7 @@ def test_join_expansion(tmpdir):
199205 "The number of iterated post-join outputs is incorrect"
200206
201207
208+
202209def test_node_joinsource (tmpdir ):
203210 """Test setting the joinsource to a Node."""
204211 tmpdir .chdir ()
0 commit comments