Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CELEBORN-1792] MemoryManager resume should use pinnedDirectMemory instead of usedDirectMemory #3018

Closed
wants to merge 18 commits into from
Original file line number Diff line number Diff line change
@@ -17,8 +17,10 @@

package org.apache.celeborn.common.network.util;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadFactory;
@@ -47,6 +49,8 @@ public class NettyUtils {
private static final ByteBufAllocator[] _sharedByteBufAllocator = new ByteBufAllocator[2];
private static final ConcurrentHashMap<String, Integer> allocatorsIndex =
JavaUtils.newConcurrentHashMap();
private static final List<PooledByteBufAllocator> pooledByteBufAllocators = new ArrayList<>();

/** Creates a new ThreadFactory which prefixes each thread with the given name. */
public static ThreadFactory createThreadFactory(String threadPoolPrefix) {
return new DefaultThreadFactory(threadPoolPrefix, true);
@@ -141,6 +145,9 @@ public static synchronized ByteBufAllocator getSharedByteBufAllocator(
_sharedByteBufAllocator[index] =
createByteBufAllocator(
conf.networkMemoryAllocatorPooled(), true, allowCache, conf.networkAllocatorArenas());
if (conf.networkMemoryAllocatorPooled()) {
pooledByteBufAllocators.add((PooledByteBufAllocator) _sharedByteBufAllocator[index]);
}
if (source != null) {
new NettyMemoryMetrics(
_sharedByteBufAllocator[index],
@@ -178,6 +185,9 @@ public static ByteBufAllocator getByteBufAllocator(
conf.preferDirectBufs(),
allowCache,
arenas);
if (conf.getCelebornConf().networkMemoryAllocatorPooled()) {
pooledByteBufAllocators.add((PooledByteBufAllocator) allocator);
}
if (source != null) {
String poolName = "default-netty-pool";
Map<String, String> labels = new HashMap<>();
@@ -196,4 +206,8 @@ public static ByteBufAllocator getByteBufAllocator(
}
return allocator;
}

public static List<PooledByteBufAllocator> getAllPooledByteBufAllocators() {
return pooledByteBufAllocators;
}
}
Original file line number Diff line number Diff line change
@@ -1278,9 +1278,12 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def workerDirectMemoryRatioToPauseReplicate: Double =
get(WORKER_DIRECT_MEMORY_RATIO_PAUSE_REPLICATE)
def workerDirectMemoryRatioToResume: Double = get(WORKER_DIRECT_MEMORY_RATIO_RESUME)
def workerPinnedMemoryRatioToResume: Double = get(WORKER_PINNED_MEMORY_RATIO_RESUME)
def workerPartitionSorterDirectMemoryRatioThreshold: Double =
get(WORKER_PARTITION_SORTER_DIRECT_MEMORY_RATIO_THRESHOLD)
def workerDirectMemoryPressureCheckIntervalMs: Long = get(WORKER_DIRECT_MEMORY_CHECK_INTERVAL)
def workerPinnedMemoryCheckEnabled: Boolean = get(WORKER_PINNED_MEMORY_CHECK_ENABLED)
def workerPinnedMemoryCheckIntervalMs: Long = get(WORKER_PINNED_MEMORY_CHECK_INTERVAL)
def workerDirectMemoryReportIntervalSecond: Long = get(WORKER_DIRECT_MEMORY_REPORT_INTERVAL)
def workerDirectMemoryTrimChannelWaitInterval: Long =
get(WORKER_DIRECT_MEMORY_TRIM_CHANNEL_WAIT_INTERVAL)
@@ -3711,6 +3714,24 @@ object CelebornConf extends Logging {
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("10ms")

val WORKER_PINNED_MEMORY_CHECK_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.worker.monitor.pinnedMemory.check.enabled")
.categories("worker")
.doc("If true, MemoryManager will check worker should resume by pinned memory used.")
.version("0.6.0")
.booleanConf
.createWithDefaultString("true")

val WORKER_PINNED_MEMORY_CHECK_INTERVAL: ConfigEntry[Long] =
buildConf("celeborn.worker.monitor.pinnedMemory.check.interval")
.categories("worker")
.doc("Interval of worker direct pinned memory checking, " +
"only takes effect when celeborn.network.memory.allocator.pooled and " +
"celeborn.worker.monitor.pinnedMemory.check.enabled are enabled.")
.version("0.6.0")
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("10s")

val WORKER_DIRECT_MEMORY_REPORT_INTERVAL: ConfigEntry[Long] =
buildConf("celeborn.worker.monitor.memory.report.interval")
.withAlternative("celeborn.worker.memory.reportInterval")
@@ -3860,6 +3881,16 @@ object CelebornConf extends Logging {
.doubleConf
.createWithDefault(0.7)

val WORKER_PINNED_MEMORY_RATIO_RESUME: ConfigEntry[Double] =
buildConf("celeborn.worker.pinnedMemoryRatioToResume")
.categories("worker")
.doc("If pinned memory usage is less than this limit, worker will resume, " +
"only takes effect when celeborn.network.memory.allocator.pooled and " +
"celeborn.worker.monitor.pinnedMemory.check.enabled are enabled")
.version("0.6.0")
.doubleConf
.createWithDefault(0.3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can add a new conf for pinnedMemoryToResume and keep exist conf for directMemoryRatioToResume


val WORKER_MEMORY_FILE_STORAGE_MAX_FILE_SIZE: ConfigEntry[Long] =
buildConf("celeborn.worker.memoryFileStorage.maxFileSize")
.categories("worker")
3 changes: 3 additions & 0 deletions docs/configuration/worker.md
Original file line number Diff line number Diff line change
@@ -144,9 +144,12 @@ license: |
| celeborn.worker.monitor.memory.report.interval | 10s | false | Interval of worker direct memory tracker reporting to log. | 0.3.0 | celeborn.worker.memory.reportInterval |
| celeborn.worker.monitor.memory.trimChannelWaitInterval | 1s | false | Wait time after worker trigger channel to trim cache. | 0.3.0 | |
| celeborn.worker.monitor.memory.trimFlushWaitInterval | 1s | false | Wait time after worker trigger StorageManger to flush data. | 0.3.0 | |
| celeborn.worker.monitor.pinnedMemory.check.enabled | true | false | If true, MemoryManager will check worker should resume by pinned memory used. | 0.6.0 | |
| celeborn.worker.monitor.pinnedMemory.check.interval | 10s | false | Interval of worker direct pinned memory checking, only takes effect when celeborn.network.memory.allocator.pooled and celeborn.worker.monitor.pinnedMemory.check.enabled are enabled. | 0.6.0 | |
| celeborn.worker.partition.initial.readBuffersMax | 1024 | false | Max number of initial read buffers | 0.3.0 | |
| celeborn.worker.partition.initial.readBuffersMin | 1 | false | Min number of initial read buffers | 0.3.0 | |
| celeborn.worker.partitionSorter.directMemoryRatioThreshold | 0.1 | false | Max ratio of partition sorter's memory for sorting, when reserved memory is higher than max partition sorter memory, partition sorter will stop sorting. If this value is set to 0, partition files sorter will skip memory check and ServingState check. | 0.2.0 | |
| celeborn.worker.pinnedMemoryRatioToResume | 0.3 | false | If pinned memory usage is less than this limit, worker will resume, only takes effect when celeborn.network.memory.allocator.pooled and celeborn.worker.monitor.pinnedMemory.check.enabled are enabled | 0.6.0 | |
| celeborn.worker.push.heartbeat.enabled | false | false | enable the heartbeat from worker to client when pushing data | 0.3.0 | |
| celeborn.worker.push.io.threads | &lt;undefined&gt; | false | Netty IO thread number of worker to handle client push data. The default threads number is the number of flush thread. | 0.2.0 | |
| celeborn.worker.push.port | 0 | false | Server port for Worker to receive push data request from ShuffleClient. | 0.2.0 | |
Original file line number Diff line number Diff line change
@@ -30,12 +30,14 @@
import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.util.internal.PlatformDependent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.metrics.source.AbstractSource;
import org.apache.celeborn.common.network.util.NettyUtils;
import org.apache.celeborn.common.protocol.TransportModuleConstants;
import org.apache.celeborn.common.util.ThreadUtils;
import org.apache.celeborn.common.util.Utils;
@@ -50,7 +52,8 @@ public class MemoryManager {
@VisibleForTesting public long maxDirectMemory;
private final long pausePushDataThreshold;
private final long pauseReplicateThreshold;
private final double resumeRatio;
private final double directMemoryResumeRatio;
private final double pinnedMemoryResumeRatio;
private final long maxSortMemory;
private final int forceAppendPauseSpentTimeThreshold;
private final List<MemoryPressureListener> memoryPressureListeners = new ArrayList<>();
@@ -93,6 +96,9 @@ public class MemoryManager {
private long memoryFileStorageThreshold;
private final LongAdder memoryFileStorageCounter = new LongAdder();
private final StorageManager storageManager;
private boolean pinnedMemoryCheckEnabled;
private long pinnedMemoryCheckInterval;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid frequently calling a pinned memory counter, I think you can cache the last pinned memory value and refresh it periodically. And exporting the pinned memory value to the metrics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is another PR introducing pinnedMemory metrics #3019

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getPinnedMemory is not called very frequently. It is called once every pinnedMemoryCheckInterval. The default is 10 seconds.

private long pinnedMemoryLastCheckTime = 0;

@VisibleForTesting
public static MemoryManager initialize(CelebornConf conf) {
@@ -120,11 +126,14 @@ public static MemoryManager instance() {
private MemoryManager(CelebornConf conf, StorageManager storageManager, AbstractSource source) {
double pausePushDataRatio = conf.workerDirectMemoryRatioToPauseReceive();
double pauseReplicateRatio = conf.workerDirectMemoryRatioToPauseReplicate();
this.resumeRatio = conf.workerDirectMemoryRatioToResume();
this.directMemoryResumeRatio = conf.workerDirectMemoryRatioToResume();
this.pinnedMemoryResumeRatio = conf.workerPinnedMemoryRatioToResume();
double maxSortMemRatio = conf.workerPartitionSorterDirectMemoryRatioThreshold();
double readBufferRatio = conf.workerDirectMemoryRatioForReadBuffer();
double memoryFileStorageRatio = conf.workerDirectMemoryRatioForMemoryFilesStorage();
long checkInterval = conf.workerDirectMemoryPressureCheckIntervalMs();
this.pinnedMemoryCheckEnabled = conf.workerPinnedMemoryCheckEnabled();
this.pinnedMemoryCheckInterval = conf.workerPinnedMemoryCheckIntervalMs();
long reportInterval = conf.workerDirectMemoryReportIntervalSecond();
double readBufferTargetRatio = conf.readBufferTargetRatio();
long readBufferTargetUpdateInterval = conf.readBufferTargetUpdateInterval();
@@ -148,9 +157,10 @@ private MemoryManager(CelebornConf conf, StorageManager storageManager, Abstract
pauseReplicateRatio,
CelebornConf.WORKER_DIRECT_MEMORY_RATIO_PAUSE_RECEIVE().key(),
pausePushDataRatio));
Preconditions.checkArgument(pausePushDataRatio > resumeRatio);
Preconditions.checkArgument(pausePushDataRatio > directMemoryResumeRatio);
if (memoryFileStorageRatio > 0) {
Preconditions.checkArgument(resumeRatio > (readBufferRatio + memoryFileStorageRatio));
Preconditions.checkArgument(
directMemoryResumeRatio > (readBufferRatio + memoryFileStorageRatio));
}

maxSortMemory = ((long) (maxDirectMemory * maxSortMemRatio));
@@ -275,14 +285,16 @@ private MemoryManager(CelebornConf conf, StorageManager storageManager, Abstract
+ "pause replication memory: {}, "
+ "read buffer memory limit: {} target: {}, "
+ "memory shuffle storage limit: {}, "
+ "resume memory ratio: {}",
+ "resume by direct memory ratio: {}, "
+ "resume by pinned memory ratio: {}",
Utils.bytesToString(maxDirectMemory),
Utils.bytesToString(pausePushDataThreshold),
Utils.bytesToString(pauseReplicateThreshold),
Utils.bytesToString(readBufferThreshold),
Utils.bytesToString(readBufferTarget),
Utils.bytesToString(memoryFileStorageThreshold),
resumeRatio);
directMemoryResumeRatio,
pinnedMemoryResumeRatio);
}

public boolean shouldEvict(boolean aggressiveMemoryFileEvictEnabled, double evictRatio) {
@@ -305,7 +317,7 @@ public ServingState currentServingState() {
return ServingState.PUSH_PAUSED;
}
// trigger resume
if (memoryUsage / (double) (maxDirectMemory) < resumeRatio) {
if (memoryUsage / (double) (maxDirectMemory) < directMemoryResumeRatio) {
isPaused = false;
return ServingState.NONE_PAUSED;
}
@@ -315,69 +327,70 @@ public ServingState currentServingState() {
}

@VisibleForTesting
protected void switchServingState() {
public void switchServingState() {
ServingState lastState = servingState;
servingState = currentServingState();
if (lastState == servingState) {
if (servingState != ServingState.NONE_PAUSED) {
logger.info("Serving state transformed from {} to {}", lastState, servingState);
switch (servingState) {
case PUSH_PAUSED:
if (canResumeByPinnedMemory()) {
resumeByPinnedMemory(servingState);
} else {
pausePushDataCounter.increment();
if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
logger.info("Trigger action: RESUME REPLICATE");
resumeReplicate();
} else {
logger.info("Trigger action: PAUSE PUSH");
pausePushDataStartTime = System.currentTimeMillis();
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
}
}
logger.debug("Trigger action: TRIM");
trimCounter += 1;
// force to append pause spent time even we are in pause state
trimAllListeners();
if (trimCounter >= forceAppendPauseSpentTimeThreshold) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lost trimCounter+=1

logger.debug(
"Trigger action: TRIM for {} times, force to append pause spent time.", trimCounter);
appendPauseSpentTime(servingState);
}
trimAllListeners();
}
return;
}
logger.info("Serving state transformed from {} to {}", lastState, servingState);
switch (servingState) {
case PUSH_PAUSED:
pausePushDataCounter.increment();
if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
logger.info("Trigger action: RESUME REPLICATE");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
} else if (lastState == ServingState.NONE_PAUSED) {
logger.info("Trigger action: PAUSE PUSH");
pausePushDataStartTime = System.currentTimeMillis();
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
}
trimAllListeners();
break;
case PUSH_AND_REPLICATE_PAUSED:
pausePushDataAndReplicateCounter.increment();
if (lastState == ServingState.NONE_PAUSED) {
if (canResumeByPinnedMemory()) {
resumeByPinnedMemory(servingState);
} else {
pausePushDataAndReplicateCounter.increment();
logger.info("Trigger action: PAUSE PUSH");
pausePushDataAndReplicateStartTime = System.currentTimeMillis();
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
logger.info("Trigger action: PAUSE REPLICATE");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE));
}
logger.info("Trigger action: PAUSE REPLICATE");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE));
logger.debug("Trigger action: TRIM");
trimCounter += 1;
trimAllListeners();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

if (trimCounter >= forceAppendPauseSpentTimeThreshold) {
logger.debug(
"Trigger action: TRIM for {} times, force to append pause spent time.", trimCounter);
appendPauseSpentTime(servingState);
}
break;
case NONE_PAUSED:
// resume from paused mode, append pause spent time
appendPauseSpentTime(lastState);
if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
logger.info("Trigger action: RESUME REPLICATE");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
resumeReplicate();
resumePush();
appendPauseSpentTime(lastState);
} else if (lastState == ServingState.PUSH_PAUSED) {
resumePush();
appendPauseSpentTime(lastState);
}
logger.info("Trigger action: RESUME PUSH");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onResume(TransportModuleConstants.PUSH_MODULE));
}
}

@@ -436,6 +449,16 @@ public long getMemoryUsage() {
return getNettyUsedDirectMemory() + sortMemoryCounter.get();
}

public long getPinnedMemory() {
return getNettyPinnedDirectMemory() + sortMemoryCounter.get();
}

public long getNettyPinnedDirectMemory() {
return NettyUtils.getAllPooledByteBufAllocators().stream()
.mapToLong(PooledByteBufAllocator::pinnedDirectMemory)
.sum();
}

public AtomicLong getSortMemoryCounter() {
return sortMemoryCounter;
}
@@ -557,6 +580,47 @@ public static void reset() {
_INSTANCE = null;
}

private void resumeByPinnedMemory(ServingState servingState) {
switch (servingState) {
case PUSH_AND_REPLICATE_PAUSED:
logger.info(
"Serving State is PUSH_AND_REPLICATE_PAUSED, but resume by lower pinned memory {}",
getNettyPinnedDirectMemory());
resumeReplicate();
resumePush();
case PUSH_PAUSED:
logger.info(
"Serving State is PUSH_PAUSED, but resume by lower pinned memory {}",
getNettyPinnedDirectMemory());
resumePush();
}
}

private boolean canResumeByPinnedMemory() {
if (pinnedMemoryCheckEnabled
&& System.currentTimeMillis() - pinnedMemoryLastCheckTime >= pinnedMemoryCheckInterval
&& getPinnedMemory() / (double) (maxDirectMemory) < pinnedMemoryResumeRatio) {
pinnedMemoryLastCheckTime = System.currentTimeMillis();
return true;
} else {
return false;
}
}

private void resumePush() {
logger.info("Trigger action: RESUME PUSH");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onResume(TransportModuleConstants.PUSH_MODULE));
}

private void resumeReplicate() {
logger.info("Trigger action: RESUME REPLICATE");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
}

public interface MemoryPressureListener {
void onPause(String moduleName);

Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ package org.apache.celeborn.service.deploy.memory

import scala.concurrent.duration.DurationInt

import org.mockito.{Mockito, MockitoSugar}
import org.scalatest.concurrent.Eventually.eventually
import org.scalatest.concurrent.Futures.{interval, timeout}

@@ -27,8 +28,8 @@ import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.CelebornConf.{WORKER_DIRECT_MEMORY_RATIO_PAUSE_RECEIVE, WORKER_DIRECT_MEMORY_RATIO_PAUSE_REPLICATE}
import org.apache.celeborn.common.protocol.TransportModuleConstants
import org.apache.celeborn.service.deploy.worker.memory.MemoryManager
import org.apache.celeborn.service.deploy.worker.memory.MemoryManager.MemoryPressureListener
import org.apache.celeborn.service.deploy.worker.memory.MemoryManager.ServingState
import org.apache.celeborn.service.deploy.worker.memory.MemoryManager.{MemoryPressureListener, ServingState}

class MemoryManagerSuite extends CelebornFunSuite {

// reset the memory manager before each test
@@ -153,6 +154,68 @@ class MemoryManagerSuite extends CelebornFunSuite {
assert(memoryManager.getPausePushDataAndReplicateTime.longValue() > 0)
}

test("[CELEBORN-1792] Test MemoryManager resume by pinned memory") {
val conf = new CelebornConf()
conf.set(CelebornConf.WORKER_DIRECT_MEMORY_CHECK_INTERVAL.key, "300s")
conf.set(CelebornConf.WORKER_PINNED_MEMORY_CHECK_INTERVAL.key, "0")
MemoryManager.reset()
val memoryManager = MockitoSugar.spy(MemoryManager.initialize(conf))
val maxDirectorMemory = memoryManager.maxDirectMemory
val pushThreshold =
(conf.workerDirectMemoryRatioToPauseReceive * maxDirectorMemory).longValue()
val replicateThreshold =
(conf.workerDirectMemoryRatioToPauseReplicate * maxDirectorMemory).longValue()

val pushListener = new MockMemoryPressureListener(TransportModuleConstants.PUSH_MODULE)
val replicateListener =
new MockMemoryPressureListener(TransportModuleConstants.REPLICATE_MODULE)
memoryManager.registerMemoryListener(pushListener)
memoryManager.registerMemoryListener(replicateListener)

// NONE PAUSED -> PAUSE PUSH
Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(0L)
Mockito.when(memoryManager.getMemoryUsage).thenReturn(pushThreshold + 1)
memoryManager.switchServingState()
assert(!pushListener.isPause)
assert(!replicateListener.isPause)
assert(memoryManager.servingState == ServingState.PUSH_PAUSED)

// KEEP PAUSE PUSH
Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(pushThreshold + 1)
memoryManager.switchServingState()
assert(pushListener.isPause)
assert(!replicateListener.isPause)
assert(memoryManager.servingState == ServingState.PUSH_PAUSED)

Mockito.when(memoryManager.getMemoryUsage).thenReturn(0L)
memoryManager.switchServingState()
assert(!pushListener.isPause)
assert(!replicateListener.isPause)
assert(memoryManager.servingState == ServingState.NONE_PAUSED)

// NONE PAUSED -> PAUSE PUSH AND REPLICATE
Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(0L)
Mockito.when(memoryManager.getMemoryUsage).thenReturn(replicateThreshold + 1)
memoryManager.switchServingState()
assert(!pushListener.isPause)
assert(!replicateListener.isPause)
assert(memoryManager.servingState == ServingState.PUSH_AND_REPLICATE_PAUSED)

// KEEP PAUSE PUSH AND REPLICATE
Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(replicateThreshold + 1)
memoryManager.switchServingState()
assert(pushListener.isPause)
assert(replicateListener.isPause)
assert(memoryManager.servingState == ServingState.PUSH_AND_REPLICATE_PAUSED)

Mockito.when(memoryManager.getMemoryUsage).thenReturn(0L)
memoryManager.switchServingState()
assert(!pushListener.isPause)
assert(!replicateListener.isPause)
assert(memoryManager.servingState == ServingState.NONE_PAUSED)
MemoryManager.reset()
}

class MockMemoryPressureListener(
val belongModuleName: String,
var isPause: Boolean = false) extends MemoryPressureListener {