|
17 | 17 |
|
18 | 18 | package org.apache.spark.shuffle.celeborn;
|
19 | 19 |
|
| 20 | +import java.io.ByteArrayInputStream; |
| 21 | +import java.io.ObjectInputStream; |
| 22 | +import java.io.ObjectOutputStream; |
20 | 23 | import java.util.HashSet;
|
21 | 24 | import java.util.List;
|
22 | 25 | import java.util.Map;
|
|
36 | 39 | import org.apache.spark.SparkContext;
|
37 | 40 | import org.apache.spark.SparkContext$;
|
38 | 41 | import org.apache.spark.TaskContext;
|
| 42 | +import org.apache.spark.broadcast.Broadcast; |
| 43 | +import org.apache.spark.io.CompressionCodec; |
39 | 44 | import org.apache.spark.scheduler.DAGScheduler;
|
40 | 45 | import org.apache.spark.scheduler.MapStatus;
|
41 | 46 | import org.apache.spark.scheduler.MapStatus$;
|
|
57 | 62 |
|
58 | 63 | import org.apache.celeborn.client.ShuffleClient;
|
59 | 64 | import org.apache.celeborn.common.CelebornConf;
|
| 65 | +import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse; |
60 | 66 | import org.apache.celeborn.common.util.JavaUtils;
|
| 67 | +import org.apache.celeborn.common.util.KeyLock; |
61 | 68 | import org.apache.celeborn.reflect.DynConstructors;
|
62 | 69 | import org.apache.celeborn.reflect.DynFields;
|
63 | 70 | import org.apache.celeborn.reflect.DynMethods;
|
@@ -462,4 +469,117 @@ public static void addSparkListener(SparkListener listener) {
|
462 | 469 | sparkContext.addSparkListener(listener);
|
463 | 470 | }
|
464 | 471 | }
|
| 472 | + |
| 473 | + /** |
| 474 | + * A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread accessing the |
| 475 | + * broadcast belonging to the shuffle id at a time. |
| 476 | + */ |
| 477 | + private static KeyLock<Integer> shuffleBroadcastLock = new KeyLock(); |
| 478 | + |
| 479 | + protected static Map<Integer, Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>> |
| 480 | + getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap(); |
| 481 | + |
| 482 | + public static byte[] serializeGetReducerFileGroupResponse( |
| 483 | + Integer shuffleId, GetReducerFileGroupResponse response) { |
| 484 | + SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null); |
| 485 | + if (sparkContext == null) { |
| 486 | + LOG.error("Can not get active SparkContext."); |
| 487 | + return null; |
| 488 | + } |
| 489 | + |
| 490 | + return shuffleBroadcastLock.withLock( |
| 491 | + shuffleId, |
| 492 | + () -> { |
| 493 | + Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]> |
| 494 | + cachedSerializeGetReducerFileGroupResponse = |
| 495 | + getReducerFileGroupResponseBroadcasts.get(shuffleId); |
| 496 | + if (cachedSerializeGetReducerFileGroupResponse != null) { |
| 497 | + return cachedSerializeGetReducerFileGroupResponse._2; |
| 498 | + } |
| 499 | + |
| 500 | + try { |
| 501 | + Broadcast<GetReducerFileGroupResponse> broadcast = |
| 502 | + sparkContext.broadcast( |
| 503 | + response, |
| 504 | + scala.reflect.ClassManifestFactory.fromClass( |
| 505 | + GetReducerFileGroupResponse.class)); |
| 506 | + |
| 507 | + CompressionCodec codec = CompressionCodec.createCodec(sparkContext.conf()); |
| 508 | + // Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard |
| 509 | + // one |
| 510 | + // This implementation doesn't reallocate the whole memory block but allocates |
| 511 | + // additional buffers. This way no buffers need to be garbage collected and |
| 512 | + // the contents don't have to be copied to the new buffer. |
| 513 | + org.apache.commons.io.output.ByteArrayOutputStream out = |
| 514 | + new org.apache.commons.io.output.ByteArrayOutputStream(); |
| 515 | + try (ObjectOutputStream oos = |
| 516 | + new ObjectOutputStream(codec.compressedOutputStream(out))) { |
| 517 | + oos.writeObject(broadcast); |
| 518 | + } |
| 519 | + byte[] _serializeResult = out.toByteArray(); |
| 520 | + getReducerFileGroupResponseBroadcasts.put( |
| 521 | + shuffleId, Tuple2.apply(broadcast, _serializeResult)); |
| 522 | + return _serializeResult; |
| 523 | + } catch (Throwable e) { |
| 524 | + LOG.error( |
| 525 | + "Failed to serialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e); |
| 526 | + return null; |
| 527 | + } |
| 528 | + }); |
| 529 | + } |
| 530 | + |
| 531 | + public static GetReducerFileGroupResponse deserializeGetReducerFileGroupResponse( |
| 532 | + Integer shuffleId, byte[] bytes) { |
| 533 | + SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null); |
| 534 | + if (sparkContext == null) { |
| 535 | + LOG.error("Can not get active SparkContext."); |
| 536 | + return null; |
| 537 | + } |
| 538 | + |
| 539 | + return shuffleBroadcastLock.withLock( |
| 540 | + shuffleId, |
| 541 | + () -> { |
| 542 | + GetReducerFileGroupResponse response = null; |
| 543 | + |
| 544 | + try { |
| 545 | + CompressionCodec codec = CompressionCodec.createCodec(sparkContext.conf()); |
| 546 | + try (ObjectInputStream objIn = |
| 547 | + new ObjectInputStream( |
| 548 | + codec.compressedInputStream(new ByteArrayInputStream(bytes)))) { |
| 549 | + Broadcast<GetReducerFileGroupResponse> broadcast = |
| 550 | + (Broadcast<GetReducerFileGroupResponse>) objIn.readObject(); |
| 551 | + response = broadcast.value(); |
| 552 | + LOG.info( |
| 553 | + "Shuffle {} broadcast GetReducerFileGroupResponse size={}.", |
| 554 | + shuffleId, |
| 555 | + bytes.length); |
| 556 | + } |
| 557 | + } catch (Throwable e) { |
| 558 | + LOG.error( |
| 559 | + "Failed to deserialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e); |
| 560 | + } |
| 561 | + return response; |
| 562 | + }); |
| 563 | + } |
| 564 | + |
| 565 | + public static void invalidateSerializedGetReducerFileGroupResponse(Integer shuffleId) { |
| 566 | + shuffleBroadcastLock.withLock( |
| 567 | + shuffleId, |
| 568 | + () -> { |
| 569 | + try { |
| 570 | + Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]> |
| 571 | + cachedSerializeGetReducerFileGroupResponse = |
| 572 | + getReducerFileGroupResponseBroadcasts.remove(shuffleId); |
| 573 | + if (cachedSerializeGetReducerFileGroupResponse != null) { |
| 574 | + cachedSerializeGetReducerFileGroupResponse._1().destroy(); |
| 575 | + } |
| 576 | + } catch (Throwable e) { |
| 577 | + LOG.error( |
| 578 | + "Failed to invalidate serialized GetReducerFileGroupResponse for shuffle: " |
| 579 | + + shuffleId, |
| 580 | + e); |
| 581 | + } |
| 582 | + return null; |
| 583 | + }); |
| 584 | + } |
465 | 585 | }
|
0 commit comments