Skip to content

Commit

Permalink
fix: dispatcher for early requests (#3147)
Browse files Browse the repository at this point in the history
Co-authored-by: Aaron Pham <[email protected]>
  • Loading branch information
sauyon and aarnphm authored Oct 26, 2022
1 parent 8dec9c9 commit 2c4b191
Showing 1 changed file with 180 additions and 15 deletions.
195 changes: 180 additions & 15 deletions src/bentoml/_internal/marshal/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,28 +42,28 @@ class Optimizer:
N_SKIPPED_SAMPLE = 2 # amount of outbound info skipped after init
INTERVAL_REFRESH_PARAMS = 5 # seconds between each params refreshing

def __init__(self):
def __init__(self, max_latency: float):
"""
assume the outbound duration follows duration = o_a * n + o_b
(all in seconds)
"""
self.o_stat: collections.deque[tuple[int, float, float]] = collections.deque(
maxlen=self.N_KEPT_SAMPLE
) # to store outbound stat data
self.o_a = 2
self.o_b = 1
self.o_a = min(2, max_latency * 2.0 / 30)
self.o_b = min(1, max_latency * 1.0 / 30)

self.wait = 0.01 # the avg wait time before outbound called

self._refresh_tb = TokenBucket(2) # to limit params refresh interval
self._outbound_counter = 0
self.outbound_counter = 0

def log_outbound(self, n: int, wait: float, duration: float):
if (
self._outbound_counter <= self.N_SKIPPED_SAMPLE
): # skip inaccurate info at beginning
self._outbound_counter += 1
return
if self.outbound_counter <= self.N_SKIPPED_SAMPLE + 4:
self.outbound_counter += 1
# skip inaccurate info at beginning
if self.outbound_counter <= self.N_SKIPPED_SAMPLE:
return

self.o_stat.append((n, duration, wait))

Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(
"""
self.max_latency_in_ms = max_latency_in_ms / 1000.0
self.fallback = fallback
self.optimizer = Optimizer()
self.optimizer = Optimizer(self.max_latency_in_ms)
self.max_batch_size = int(max_batch_size)
self.tick_interval = 0.001

Expand Down Expand Up @@ -172,6 +172,174 @@ async def controller(self):
"""
A standalone coroutine to wait/dispatch calling.
"""
logger.debug("Starting dispatcher optimizer training...")
# warm up the model
while self.optimizer.outbound_counter <= self.optimizer.N_SKIPPED_SAMPLE:
try:
async with self._wake_event: # block until there's any request in queue
await self._wake_event.wait_for(self._queue.__len__)

n = len(self._queue)
now = time.time()
w0 = now - self._queue[0][0]

# only cancel requests if there are more than enough for training
if (
n
> self.optimizer.N_SKIPPED_SAMPLE
- self.optimizer.outbound_counter
+ 6
and w0 >= self.max_latency_in_ms
):
# we're being very conservative and only canceling requests if they have already timed out
self._queue.popleft()[2].cancel()
continue
# don't try to be smart here, just serve the first few requests
if self._sema.is_locked():
await asyncio.sleep(self.tick_interval)
continue

n_call_out = 1
# call
self._sema.acquire()
inputs_info = tuple(self._queue.pop() for _ in range(n_call_out))
self._loop.create_task(self.outbound_call(inputs_info))
except asyncio.CancelledError:
return
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)

logger.debug("Dispatcher finished warming up model.")

while self.optimizer.outbound_counter <= self.optimizer.N_SKIPPED_SAMPLE + 1:
try:
# step 1: attempt to serve a single request immediately
async with self._wake_event: # block until there's any request in queue
await self._wake_event.wait_for(self._queue.__len__)

n = len(self._queue)
now = time.time()
w0 = now - self._queue[0][0]

# only cancel requests if there are more than enough for training
if n > 6 and w0 >= self.max_latency_in_ms:
# we're being very conservative and only canceling requests if they have already timed out
self._queue.popleft()[2].cancel()
continue
if self._sema.is_locked():
await asyncio.sleep(self.tick_interval)
continue

n_call_out = 1
# call
self._sema.acquire()
inputs_info = tuple(self._queue.pop() for _ in range(n_call_out))
self._loop.create_task(self.outbound_call(inputs_info))
except asyncio.CancelledError:
return
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)

logger.debug("Dispatcher finished optimizer training request 1.")
self.optimizer.trigger_refresh()

if self.max_batch_size >= 2:
# we will attempt to keep the second request served within this time
step_2_wait = min(
self.max_latency_in_ms * 0.95,
5 * (self.optimizer.o_a + self.optimizer.o_b),
)

# step 2: attempt to serve 2 requests
while (
self.optimizer.outbound_counter <= self.optimizer.N_SKIPPED_SAMPLE + 2
):
try:
async with self._wake_event: # block until there's any request in queue
await self._wake_event.wait_for(self._queue.__len__)

n = len(self._queue)
dt = self.tick_interval
now = time.time()
w0 = now - self._queue[0][0]
a = self.optimizer.o_a
b = self.optimizer.o_b

# only cancel requests if there are more than enough for training
if n > 5 and w0 >= self.max_latency_in_ms:
# we're being very conservative and only canceling requests if they have already timed out
self._queue.popleft()[2].cancel()
continue
if n < 2 and (2 * a + b) + w0 <= step_2_wait:
await asyncio.sleep(self.tick_interval)
continue
if self._sema.is_locked():
await asyncio.sleep(self.tick_interval)
continue

n_call_out = min(n, 2)
# call
self._sema.acquire()
inputs_info = tuple(self._queue.pop() for _ in range(n_call_out))
self._loop.create_task(self.outbound_call(inputs_info))
except asyncio.CancelledError:
return
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)

logger.debug("Dispatcher finished optimizer training request 2.")
self.optimizer.trigger_refresh()

if self.max_batch_size >= 3:
# step 3: attempt to serve 3 requests

# we will attempt to keep the second request served within this time
step_3_wait = min(
self.max_latency_in_ms * 0.95,
7 * (self.optimizer.o_a + self.optimizer.o_b),
)
while (
self.optimizer.outbound_counter <= self.optimizer.N_SKIPPED_SAMPLE + 3
):
try:
async with self._wake_event: # block until there's any request in queue
await self._wake_event.wait_for(self._queue.__len__)

n = len(self._queue)
dt = self.tick_interval
now = time.time()
w0 = now - self._queue[0][0]
a = self.optimizer.o_a
b = self.optimizer.o_b

# only cancel requests if there are more than enough for training
if n > 3 and w0 >= self.max_latency_in_ms:
# we're being very conservative and only canceling requests if they have already timed out
self._queue.popleft()[2].cancel()
continue
if n < 3 and (3 * a + b) + w0 <= step_3_wait:
await asyncio.sleep(self.tick_interval)
continue

n_call_out = min(n, 3)
# call
self._sema.acquire()
inputs_info = tuple(self._queue.pop() for _ in range(n_call_out))
self._loop.create_task(self.outbound_call(inputs_info))
except asyncio.CancelledError:
return
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)

logger.debug("Dispatcher finished optimizer training request 3.")
self.optimizer.trigger_refresh()

if self.optimizer.o_a + self.optimizer.o_b >= self.max_latency_in_ms:
logger.warning(
"BentoML has detected that a service has a max latency that is likely too low for serving. If many 429 errors are encountered, try raising the 'runner.max_latency' in your BentoML configuration YAML file."
)
logger.debug("Dispatcher optimizer training complete.")

while True:
try:
async with self._wake_event: # block until there's any request in queue
Expand Down Expand Up @@ -199,16 +367,13 @@ async def controller(self):
await asyncio.sleep(self.tick_interval)
continue

n_call_out = min(
self.max_batch_size,
n,
)
n_call_out = min(self.max_batch_size, n)
# call
self._sema.acquire()
inputs_info = tuple(self._queue.pop() for _ in range(n_call_out))
self._loop.create_task(self.outbound_call(inputs_info))
except asyncio.CancelledError:
break
return
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)

Expand Down

0 comments on commit 2c4b191

Please sign in to comment.