|
17 | 17 |
|
18 | 18 | package org.apache.spark.shuffle.celeborn;
|
19 | 19 |
|
| 20 | +import java.io.ByteArrayInputStream; |
| 21 | +import java.io.ObjectInputStream; |
| 22 | +import java.io.ObjectOutputStream; |
20 | 23 | import java.util.HashSet;
|
21 | 24 | import java.util.List;
|
22 | 25 | import java.util.Map;
|
23 | 26 | import java.util.Set;
|
24 | 27 | import java.util.concurrent.ConcurrentHashMap;
|
| 28 | +import java.util.concurrent.atomic.AtomicInteger; |
25 | 29 | import java.util.concurrent.atomic.LongAdder;
|
26 | 30 | import java.util.stream.Collectors;
|
27 | 31 |
|
|
36 | 40 | import org.apache.spark.SparkContext;
|
37 | 41 | import org.apache.spark.SparkContext$;
|
38 | 42 | import org.apache.spark.TaskContext;
|
| 43 | +import org.apache.spark.broadcast.Broadcast; |
| 44 | +import org.apache.spark.io.CompressionCodec; |
39 | 45 | import org.apache.spark.scheduler.DAGScheduler;
|
40 | 46 | import org.apache.spark.scheduler.MapStatus;
|
41 | 47 | import org.apache.spark.scheduler.MapStatus$;
|
|
57 | 63 |
|
58 | 64 | import org.apache.celeborn.client.ShuffleClient;
|
59 | 65 | import org.apache.celeborn.common.CelebornConf;
|
| 66 | +import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse; |
60 | 67 | import org.apache.celeborn.common.util.JavaUtils;
|
| 68 | +import org.apache.celeborn.common.util.KeyLock; |
61 | 69 | import org.apache.celeborn.reflect.DynConstructors;
|
62 | 70 | import org.apache.celeborn.reflect.DynFields;
|
63 | 71 | import org.apache.celeborn.reflect.DynMethods;
|
@@ -462,4 +470,120 @@ public static void addSparkListener(SparkListener listener) {
|
462 | 470 | sparkContext.addSparkListener(listener);
|
463 | 471 | }
|
464 | 472 | }
|
| 473 | + |
| 474 | + /** |
| 475 | + * A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread accessing the |
| 476 | + * broadcast belonging to the shuffle id at a time. |
| 477 | + */ |
| 478 | + private static KeyLock<Integer> shuffleBroadcastLock = new KeyLock(); |
| 479 | + |
| 480 | + @VisibleForTesting |
| 481 | + public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new AtomicInteger(); |
| 482 | + |
| 483 | + protected static Map<Integer, Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>> |
| 484 | + getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap(); |
| 485 | + |
| 486 | + public static byte[] serializeGetReducerFileGroupResponse( |
| 487 | + Integer shuffleId, GetReducerFileGroupResponse response) { |
| 488 | + SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null); |
| 489 | + if (sparkContext == null) { |
| 490 | + LOG.error("Can not get active SparkContext."); |
| 491 | + return null; |
| 492 | + } |
| 493 | + |
| 494 | + return shuffleBroadcastLock.withLock( |
| 495 | + shuffleId, |
| 496 | + () -> { |
| 497 | + Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]> |
| 498 | + cachedSerializeGetReducerFileGroupResponse = |
| 499 | + getReducerFileGroupResponseBroadcasts.get(shuffleId); |
| 500 | + if (cachedSerializeGetReducerFileGroupResponse != null) { |
| 501 | + return cachedSerializeGetReducerFileGroupResponse._2; |
| 502 | + } |
| 503 | + |
| 504 | + try { |
| 505 | + LOG.info("Broadcasting GetReducerFileGroupResponse for shuffle: {}", shuffleId); |
| 506 | + Broadcast<GetReducerFileGroupResponse> broadcast = |
| 507 | + sparkContext.broadcast( |
| 508 | + response, |
| 509 | + scala.reflect.ClassManifestFactory.fromClass( |
| 510 | + GetReducerFileGroupResponse.class)); |
| 511 | + |
| 512 | + CompressionCodec codec = CompressionCodec.createCodec(sparkContext.conf()); |
| 513 | + // Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard |
| 514 | + // one |
| 515 | + // This implementation doesn't reallocate the whole memory block but allocates |
| 516 | + // additional buffers. This way no buffers need to be garbage collected and |
| 517 | + // the contents don't have to be copied to the new buffer. |
| 518 | + org.apache.commons.io.output.ByteArrayOutputStream out = |
| 519 | + new org.apache.commons.io.output.ByteArrayOutputStream(); |
| 520 | + try (ObjectOutputStream oos = |
| 521 | + new ObjectOutputStream(codec.compressedOutputStream(out))) { |
| 522 | + oos.writeObject(broadcast); |
| 523 | + } |
| 524 | + byte[] _serializeResult = out.toByteArray(); |
| 525 | + getReducerFileGroupResponseBroadcasts.put( |
| 526 | + shuffleId, Tuple2.apply(broadcast, _serializeResult)); |
| 527 | + getReducerFileGroupResponseBroadcastNum.incrementAndGet(); |
| 528 | + return _serializeResult; |
| 529 | + } catch (Throwable e) { |
| 530 | + LOG.error( |
| 531 | + "Failed to serialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e); |
| 532 | + return null; |
| 533 | + } |
| 534 | + }); |
| 535 | + } |
| 536 | + |
| 537 | + public static GetReducerFileGroupResponse deserializeGetReducerFileGroupResponse( |
| 538 | + Integer shuffleId, byte[] bytes) { |
| 539 | + SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null); |
| 540 | + if (sparkContext == null) { |
| 541 | + LOG.error("Can not get active SparkContext."); |
| 542 | + return null; |
| 543 | + } |
| 544 | + |
| 545 | + return shuffleBroadcastLock.withLock( |
| 546 | + shuffleId, |
| 547 | + () -> { |
| 548 | + GetReducerFileGroupResponse response = null; |
| 549 | + LOG.info( |
| 550 | + "Deserializing GetReducerFileGroupResponse broadcast for shuffle: {}", shuffleId); |
| 551 | + |
| 552 | + try { |
| 553 | + CompressionCodec codec = CompressionCodec.createCodec(sparkContext.conf()); |
| 554 | + try (ObjectInputStream objIn = |
| 555 | + new ObjectInputStream( |
| 556 | + codec.compressedInputStream(new ByteArrayInputStream(bytes)))) { |
| 557 | + Broadcast<GetReducerFileGroupResponse> broadcast = |
| 558 | + (Broadcast<GetReducerFileGroupResponse>) objIn.readObject(); |
| 559 | + response = broadcast.value(); |
| 560 | + } |
| 561 | + } catch (Throwable e) { |
| 562 | + LOG.error( |
| 563 | + "Failed to deserialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e); |
| 564 | + } |
| 565 | + return response; |
| 566 | + }); |
| 567 | + } |
| 568 | + |
| 569 | + public static void invalidateSerializedGetReducerFileGroupResponse(Integer shuffleId) { |
| 570 | + shuffleBroadcastLock.withLock( |
| 571 | + shuffleId, |
| 572 | + () -> { |
| 573 | + try { |
| 574 | + Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]> |
| 575 | + cachedSerializeGetReducerFileGroupResponse = |
| 576 | + getReducerFileGroupResponseBroadcasts.remove(shuffleId); |
| 577 | + if (cachedSerializeGetReducerFileGroupResponse != null) { |
| 578 | + cachedSerializeGetReducerFileGroupResponse._1().destroy(); |
| 579 | + } |
| 580 | + } catch (Throwable e) { |
| 581 | + LOG.error( |
| 582 | + "Failed to invalidate serialized GetReducerFileGroupResponse for shuffle: " |
| 583 | + + shuffleId, |
| 584 | + e); |
| 585 | + } |
| 586 | + return null; |
| 587 | + }); |
| 588 | + } |
465 | 589 | }
|
0 commit comments