Skip to content

Commit 98e831d

Browse files
xuanyuankingHyukjinKwon
authored andcommitted
[SPARK-25921][FOLLOW UP][PYSPARK] Fix barrier task run without BarrierTaskContext while python worker reuse
## What changes were proposed in this pull request? It's the follow-up PR for apache#22962, contains the following works: - Remove `__init__` in TaskContext and BarrierTaskContext. - Add more comments to explain the fix. - Rewrite UT in a new class. ## How was this patch tested? New UT in test_taskcontext.py Closes apache#23435 from xuanyuanking/SPARK-25921-follow. Authored-by: Yuanjian Li <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 270916f commit 98e831d

File tree

2 files changed

+46
-25
lines changed

2 files changed

+46
-25
lines changed

python/pyspark/taskcontext.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,6 @@ def __new__(cls):
4848
cls._taskContext = taskContext = object.__new__(cls)
4949
return taskContext
5050

51-
def __init__(self):
52-
"""Construct a TaskContext, use get instead"""
53-
pass
54-
5551
@classmethod
5652
def _getOrCreate(cls):
5753
"""Internal function to get or create global TaskContext."""
@@ -140,13 +136,13 @@ class BarrierTaskContext(TaskContext):
140136
_port = None
141137
_secret = None
142138

143-
def __init__(self):
144-
"""Construct a BarrierTaskContext, use get instead"""
145-
pass
146-
147139
@classmethod
148140
def _getOrCreate(cls):
149-
"""Internal function to get or create global BarrierTaskContext."""
141+
"""
142+
Internal function to get or create global BarrierTaskContext. We need to make sure
143+
BarrierTaskContext is returned from here because it is needed in python worker reuse
144+
scenario, see SPARK-25921 for more details.
145+
"""
150146
if not isinstance(cls._taskContext, BarrierTaskContext):
151147
cls._taskContext = object.__new__(cls)
152148
return cls._taskContext

python/pyspark/tests/test_taskcontext.py

+41-16
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
import os
1718
import random
1819
import sys
1920
import time
21+
import unittest
2022

21-
from pyspark import SparkContext, TaskContext, BarrierTaskContext
23+
from pyspark import SparkConf, SparkContext, TaskContext, BarrierTaskContext
2224
from pyspark.testing.utils import PySparkTestCase
2325

2426

@@ -118,21 +120,6 @@ def context_barrier(x):
118120
times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
119121
self.assertTrue(max(times) - min(times) < 1)
120122

121-
def test_barrier_with_python_worker_reuse(self):
122-
"""
123-
Verify that BarrierTaskContext.barrier() with reused python worker.
124-
"""
125-
self.sc._conf.set("spark.python.work.reuse", "true")
126-
rdd = self.sc.parallelize(range(4), 4)
127-
# start a normal job first to start all worker
128-
result = rdd.map(lambda x: x ** 2).collect()
129-
self.assertEqual([0, 1, 4, 9], result)
130-
# make sure `spark.python.work.reuse=true`
131-
self.assertEqual(self.sc._conf.get("spark.python.work.reuse"), "true")
132-
133-
# worker will be reused in this barrier job
134-
self.test_barrier()
135-
136123
def test_barrier_infos(self):
137124
"""
138125
Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the
@@ -149,6 +136,44 @@ def f(iterator):
149136
self.assertTrue(len(taskInfos[0]) == 4)
150137

151138

139+
class TaskContextTestsWithWorkerReuse(unittest.TestCase):
140+
141+
def setUp(self):
142+
class_name = self.__class__.__name__
143+
conf = SparkConf().set("spark.python.worker.reuse", "true")
144+
self.sc = SparkContext('local[2]', class_name, conf=conf)
145+
146+
def test_barrier_with_python_worker_reuse(self):
147+
"""
148+
Regression test for SPARK-25921: verify that BarrierTaskContext.barrier() with
149+
reused python worker.
150+
"""
151+
# start a normal job first to start all workers and get all worker pids
152+
worker_pids = self.sc.parallelize(range(2), 2).map(lambda x: os.getpid()).collect()
153+
# the worker will reuse in this barrier job
154+
rdd = self.sc.parallelize(range(10), 2)
155+
156+
def f(iterator):
157+
yield sum(iterator)
158+
159+
def context_barrier(x):
160+
tc = BarrierTaskContext.get()
161+
time.sleep(random.randint(1, 10))
162+
tc.barrier()
163+
return (time.time(), os.getpid())
164+
165+
result = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
166+
times = list(map(lambda x: x[0], result))
167+
pids = list(map(lambda x: x[1], result))
168+
# check both barrier and worker reuse effect
169+
self.assertTrue(max(times) - min(times) < 1)
170+
for pid in pids:
171+
self.assertTrue(pid in worker_pids)
172+
173+
def tearDown(self):
174+
self.sc.stop()
175+
176+
152177
if __name__ == "__main__":
153178
import unittest
154179
from pyspark.tests.test_taskcontext import *

0 commit comments

Comments
 (0)