Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CELEBORN-1792] MemoryManager resume should use pinnedDirectMemory instead of usedDirectMemory #3018

Closed
wants to merge 18 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.celeborn.common.network.util;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadFactory;
Expand Down Expand Up @@ -47,6 +49,8 @@ public class NettyUtils {
private static final ByteBufAllocator[] _sharedByteBufAllocator = new ByteBufAllocator[2];
private static final ConcurrentHashMap<String, Integer> allocatorsIndex =
JavaUtils.newConcurrentHashMap();
private static final List<PooledByteBufAllocator> pooledByteBufAllocators = new ArrayList<>();

/** Creates a new ThreadFactory which prefixes each thread with the given name. */
public static ThreadFactory createThreadFactory(String threadPoolPrefix) {
return new DefaultThreadFactory(threadPoolPrefix, true);
Expand Down Expand Up @@ -141,6 +145,9 @@ public static synchronized ByteBufAllocator getSharedByteBufAllocator(
_sharedByteBufAllocator[index] =
createByteBufAllocator(
conf.networkMemoryAllocatorPooled(), true, allowCache, conf.networkAllocatorArenas());
if (conf.networkMemoryAllocatorPooled()) {
pooledByteBufAllocators.add((PooledByteBufAllocator) _sharedByteBufAllocator[index]);
}
if (source != null) {
new NettyMemoryMetrics(
_sharedByteBufAllocator[index],
Expand Down Expand Up @@ -178,6 +185,9 @@ public static ByteBufAllocator getByteBufAllocator(
conf.preferDirectBufs(),
allowCache,
arenas);
if (conf.getCelebornConf().networkMemoryAllocatorPooled()) {
pooledByteBufAllocators.add((PooledByteBufAllocator) allocator);
}
if (source != null) {
String poolName = "default-netty-pool";
Map<String, String> labels = new HashMap<>();
Expand All @@ -196,4 +206,8 @@ public static ByteBufAllocator getByteBufAllocator(
}
return allocator;
}

public static List<PooledByteBufAllocator> getAllPooledByteBufAllocators() {
return pooledByteBufAllocators;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1278,9 +1278,12 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def workerDirectMemoryRatioToPauseReplicate: Double =
get(WORKER_DIRECT_MEMORY_RATIO_PAUSE_REPLICATE)
def workerDirectMemoryRatioToResume: Double = get(WORKER_DIRECT_MEMORY_RATIO_RESUME)
def workerPinnedMemoryRatioToResume: Double = get(WORKER_PINNED_MEMORY_RATIO_RESUME)
def workerPartitionSorterDirectMemoryRatioThreshold: Double =
get(WORKER_PARTITION_SORTER_DIRECT_MEMORY_RATIO_THRESHOLD)
def workerDirectMemoryPressureCheckIntervalMs: Long = get(WORKER_DIRECT_MEMORY_CHECK_INTERVAL)
def workerPinnedMemoryCheckEnabled: Boolean = get(WORKER_PINNED_MEMORY_CHECK_ENABLED)
def workerPinnedMemoryCheckIntervalMs: Long = get(WORKER_PINNED_MEMORY_CHECK_INTERVAL)
def workerDirectMemoryReportIntervalSecond: Long = get(WORKER_DIRECT_MEMORY_REPORT_INTERVAL)
def workerDirectMemoryTrimChannelWaitInterval: Long =
get(WORKER_DIRECT_MEMORY_TRIM_CHANNEL_WAIT_INTERVAL)
Expand Down Expand Up @@ -3711,6 +3714,24 @@ object CelebornConf extends Logging {
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("10ms")

val WORKER_PINNED_MEMORY_CHECK_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.worker.monitor.pinnedMemory.check.enabled")
.categories("worker")
.doc("If true, MemoryManager will check worker should resume by pinned memory used.")
.version("0.6.0")
.booleanConf
.createWithDefaultString("true")

val WORKER_PINNED_MEMORY_CHECK_INTERVAL: ConfigEntry[Long] =
buildConf("celeborn.worker.monitor.pinnedMemory.check.interval")
.categories("worker")
.doc("Interval of worker direct pinned memory checking, " +
"only takes effect when celeborn.network.memory.allocator.pooled and " +
"celeborn.worker.monitor.pinnedMemory.check.enabled are enabled.")
.version("0.6.0")
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("10s")

val WORKER_DIRECT_MEMORY_REPORT_INTERVAL: ConfigEntry[Long] =
buildConf("celeborn.worker.monitor.memory.report.interval")
.withAlternative("celeborn.worker.memory.reportInterval")
Expand Down Expand Up @@ -3860,6 +3881,16 @@ object CelebornConf extends Logging {
.doubleConf
.createWithDefault(0.7)

val WORKER_PINNED_MEMORY_RATIO_RESUME: ConfigEntry[Double] =
buildConf("celeborn.worker.pinnedMemoryRatioToResume")
.categories("worker")
.doc("If pinned memory usage is less than this limit, worker will resume, " +
"only takes effect when celeborn.network.memory.allocator.pooled and " +
"celeborn.worker.monitor.pinnedMemory.check.enabled are enabled")
.version("0.6.0")
.doubleConf
.createWithDefault(0.3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can add a new conf for pinnedMemoryToResume and keep exist conf for directMemoryRatioToResume


val WORKER_MEMORY_FILE_STORAGE_MAX_FILE_SIZE: ConfigEntry[Long] =
buildConf("celeborn.worker.memoryFileStorage.maxFileSize")
.categories("worker")
Expand Down
3 changes: 3 additions & 0 deletions docs/configuration/worker.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,12 @@ license: |
| celeborn.worker.monitor.memory.report.interval | 10s | false | Interval of worker direct memory tracker reporting to log. | 0.3.0 | celeborn.worker.memory.reportInterval |
| celeborn.worker.monitor.memory.trimChannelWaitInterval | 1s | false | Wait time after worker trigger channel to trim cache. | 0.3.0 | |
| celeborn.worker.monitor.memory.trimFlushWaitInterval | 1s | false | Wait time after worker trigger StorageManger to flush data. | 0.3.0 | |
| celeborn.worker.monitor.pinnedMemory.check.enabled | true | false | If true, MemoryManager will check worker should resume by pinned memory used. | 0.6.0 | |
| celeborn.worker.monitor.pinnedMemory.check.interval | 10s | false | Interval of worker direct pinned memory checking, only takes effect when celeborn.network.memory.allocator.pooled and celeborn.worker.monitor.pinnedMemory.check.enabled are enabled. | 0.6.0 | |
| celeborn.worker.partition.initial.readBuffersMax | 1024 | false | Max number of initial read buffers | 0.3.0 | |
| celeborn.worker.partition.initial.readBuffersMin | 1 | false | Min number of initial read buffers | 0.3.0 | |
| celeborn.worker.partitionSorter.directMemoryRatioThreshold | 0.1 | false | Max ratio of partition sorter's memory for sorting, when reserved memory is higher than max partition sorter memory, partition sorter will stop sorting. If this value is set to 0, partition files sorter will skip memory check and ServingState check. | 0.2.0 | |
| celeborn.worker.pinnedMemoryRatioToResume | 0.3 | false | If pinned memory usage is less than this limit, worker will resume, only takes effect when celeborn.network.memory.allocator.pooled and celeborn.worker.monitor.pinnedMemory.check.enabled are enabled | 0.6.0 | |
| celeborn.worker.push.heartbeat.enabled | false | false | enable the heartbeat from worker to client when pushing data | 0.3.0 | |
| celeborn.worker.push.io.threads | &lt;undefined&gt; | false | Netty IO thread number of worker to handle client push data. The default threads number is the number of flush thread. | 0.2.0 | |
| celeborn.worker.push.port | 0 | false | Server port for Worker to receive push data request from ShuffleClient. | 0.2.0 | |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@
import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.util.internal.PlatformDependent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.metrics.source.AbstractSource;
import org.apache.celeborn.common.network.util.NettyUtils;
import org.apache.celeborn.common.protocol.TransportModuleConstants;
import org.apache.celeborn.common.util.ThreadUtils;
import org.apache.celeborn.common.util.Utils;
Expand All @@ -50,7 +52,8 @@ public class MemoryManager {
@VisibleForTesting public long maxDirectMemory;
private final long pausePushDataThreshold;
private final long pauseReplicateThreshold;
private final double resumeRatio;
private final double directMemoryResumeRatio;
private final double pinnedMemoryResumeRatio;
private final long maxSortMemory;
private final int forceAppendPauseSpentTimeThreshold;
private final List<MemoryPressureListener> memoryPressureListeners = new ArrayList<>();
Expand Down Expand Up @@ -93,6 +96,9 @@ public class MemoryManager {
private long memoryFileStorageThreshold;
private final LongAdder memoryFileStorageCounter = new LongAdder();
private final StorageManager storageManager;
private boolean pinnedMemoryCheckEnabled;
private long pinnedMemoryCheckInterval;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid frequently calling a pinned memory counter, I think you can cache the last pinned memory value and refresh it periodically. And exporting the pinned memory value to the metrics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is another PR introducing pinnedMemory metrics #3019

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getPinnedMemory is not called very frequently. It is called once every pinnedMemoryCheckInterval. The default is 10 seconds.

private long pinnedMemoryLastCheckTime = 0L;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value for this is 0.


@VisibleForTesting
public static MemoryManager initialize(CelebornConf conf) {
Expand Down Expand Up @@ -120,11 +126,14 @@ public static MemoryManager instance() {
private MemoryManager(CelebornConf conf, StorageManager storageManager, AbstractSource source) {
double pausePushDataRatio = conf.workerDirectMemoryRatioToPauseReceive();
double pauseReplicateRatio = conf.workerDirectMemoryRatioToPauseReplicate();
this.resumeRatio = conf.workerDirectMemoryRatioToResume();
this.directMemoryResumeRatio = conf.workerDirectMemoryRatioToResume();
this.pinnedMemoryResumeRatio = conf.workerPinnedMemoryRatioToResume();
double maxSortMemRatio = conf.workerPartitionSorterDirectMemoryRatioThreshold();
double readBufferRatio = conf.workerDirectMemoryRatioForReadBuffer();
double memoryFileStorageRatio = conf.workerDirectMemoryRatioForMemoryFilesStorage();
long checkInterval = conf.workerDirectMemoryPressureCheckIntervalMs();
this.pinnedMemoryCheckEnabled = conf.workerPinnedMemoryCheckEnabled();
this.pinnedMemoryCheckInterval = conf.workerPinnedMemoryCheckIntervalMs();
long reportInterval = conf.workerDirectMemoryReportIntervalSecond();
double readBufferTargetRatio = conf.readBufferTargetRatio();
long readBufferTargetUpdateInterval = conf.readBufferTargetUpdateInterval();
Expand All @@ -148,9 +157,10 @@ private MemoryManager(CelebornConf conf, StorageManager storageManager, Abstract
pauseReplicateRatio,
CelebornConf.WORKER_DIRECT_MEMORY_RATIO_PAUSE_RECEIVE().key(),
pausePushDataRatio));
Preconditions.checkArgument(pausePushDataRatio > resumeRatio);
Preconditions.checkArgument(pausePushDataRatio > directMemoryResumeRatio);
if (memoryFileStorageRatio > 0) {
Preconditions.checkArgument(resumeRatio > (readBufferRatio + memoryFileStorageRatio));
Preconditions.checkArgument(
directMemoryResumeRatio > (readBufferRatio + memoryFileStorageRatio));
}

maxSortMemory = ((long) (maxDirectMemory * maxSortMemRatio));
Expand Down Expand Up @@ -282,7 +292,7 @@ private MemoryManager(CelebornConf conf, StorageManager storageManager, Abstract
Utils.bytesToString(readBufferThreshold),
Utils.bytesToString(readBufferTarget),
Utils.bytesToString(memoryFileStorageThreshold),
resumeRatio);
directMemoryResumeRatio);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add pinned memory resume ratio here. It is an important parameter for memory manager.

}

public boolean shouldEvict(boolean aggressiveMemoryFileEvictEnabled, double evictRatio) {
Expand All @@ -305,7 +315,7 @@ public ServingState currentServingState() {
return ServingState.PUSH_PAUSED;
}
// trigger resume
if (memoryUsage / (double) (maxDirectMemory) < resumeRatio) {
if (memoryUsage / (double) (maxDirectMemory) < directMemoryResumeRatio) {
isPaused = false;
return ServingState.NONE_PAUSED;
}
Expand All @@ -315,69 +325,68 @@ public ServingState currentServingState() {
}

@VisibleForTesting
protected void switchServingState() {
public void switchServingState() {
ServingState lastState = servingState;
servingState = currentServingState();
if (lastState == servingState) {
if (servingState != ServingState.NONE_PAUSED) {
logger.info("Serving state transformed from {} to {}", lastState, servingState);
switch (servingState) {
case PUSH_PAUSED:
if (canResumeByPinnedMemory()) {
resumeByPinnedMemory(servingState);
} else {
pausePushDataCounter.increment();
if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
logger.info("Trigger action: RESUME REPLICATE");
resumeReplicate();
} else {
logger.info("Trigger action: PAUSE PUSH");
pausePushDataStartTime = System.currentTimeMillis();
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
}
}
logger.debug("Trigger action: TRIM");
trimCounter += 1;
// force to append pause spent time even we are in pause state
trimAllListeners();
if (trimCounter >= forceAppendPauseSpentTimeThreshold) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lost trimCounter+=1

logger.debug(
"Trigger action: TRIM for {} times, force to append pause spent time.", trimCounter);
appendPauseSpentTime(servingState);
}
trimAllListeners();
}
return;
}
logger.info("Serving state transformed from {} to {}", lastState, servingState);
switch (servingState) {
case PUSH_PAUSED:
pausePushDataCounter.increment();
if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
logger.info("Trigger action: RESUME REPLICATE");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
} else if (lastState == ServingState.NONE_PAUSED) {
logger.info("Trigger action: PAUSE PUSH");
pausePushDataStartTime = System.currentTimeMillis();
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
}
trimAllListeners();
break;
case PUSH_AND_REPLICATE_PAUSED:
pausePushDataAndReplicateCounter.increment();
if (lastState == ServingState.NONE_PAUSED) {
if (canResumeByPinnedMemory()) {
resumeByPinnedMemory(servingState);
} else {
pausePushDataAndReplicateCounter.increment();
logger.info("Trigger action: PAUSE PUSH");
pausePushDataAndReplicateStartTime = System.currentTimeMillis();
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
logger.info("Trigger action: PAUSE REPLICATE");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE));
}
logger.info("Trigger action: PAUSE REPLICATE");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE));
logger.debug("Trigger action: TRIM");
trimAllListeners();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

if (trimCounter >= forceAppendPauseSpentTimeThreshold) {
logger.debug(
"Trigger action: TRIM for {} times, force to append pause spent time.", trimCounter);
appendPauseSpentTime(servingState);
}
break;
case NONE_PAUSED:
// resume from paused mode, append pause spent time
appendPauseSpentTime(lastState);
if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
logger.info("Trigger action: RESUME REPLICATE");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
resumeReplicate();
resumePush();
appendPauseSpentTime(lastState);
} else if (lastState == ServingState.PUSH_PAUSED) {
resumePush();
appendPauseSpentTime(lastState);
}
logger.info("Trigger action: RESUME PUSH");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onResume(TransportModuleConstants.PUSH_MODULE));
}
}

Expand Down Expand Up @@ -436,6 +445,16 @@ public long getMemoryUsage() {
return getNettyUsedDirectMemory() + sortMemoryCounter.get();
}

public long getAllocatedMemory() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method should be renamed to getPinnedMemory. The allocated memory is the netty memory counter.

return getNettyPinnedDirectMemory() + sortMemoryCounter.get();
}

public long getNettyPinnedDirectMemory() {
return NettyUtils.getAllPooledByteBufAllocators().stream()
.mapToLong(PooledByteBufAllocator::pinnedDirectMemory)
.sum();
}

public AtomicLong getSortMemoryCounter() {
return sortMemoryCounter;
}
Expand Down Expand Up @@ -557,6 +576,47 @@ public static void reset() {
_INSTANCE = null;
}

private void resumeByPinnedMemory(ServingState servingState) {
switch (servingState) {
case PUSH_AND_REPLICATE_PAUSED:
logger.info(
"Serving State is PUSH_AND_REPLICATE_PAUSED, but resume by lower pinned memory {}",
getNettyPinnedDirectMemory());
resumeReplicate();
resumePush();
case PUSH_PAUSED:
logger.info(
"Serving State is PUSH_PAUSED, but resume by lower pinned memory {}",
getNettyPinnedDirectMemory());
resumePush();
}
}

private boolean canResumeByPinnedMemory() {
if (pinnedMemoryCheckEnabled
&& System.currentTimeMillis() - pinnedMemoryLastCheckTime >= pinnedMemoryCheckInterval
&& getAllocatedMemory() / (double) (maxDirectMemory) < pinnedMemoryResumeRatio) {
pinnedMemoryLastCheckTime = System.currentTimeMillis();
return true;
} else {
return false;
}
}

private void resumePush() {
logger.info("Trigger action: RESUME PUSH");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onResume(TransportModuleConstants.PUSH_MODULE));
}

private void resumeReplicate() {
logger.info("Trigger action: RESUME REPLICATE");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
}

public interface MemoryPressureListener {
void onPause(String moduleName);

Expand Down
Loading
Loading