Skip to content

Commit 0b3679e

Browse files
scripts: A class which gates progress on metrics satisifying a condition
1 parent dd55b92 commit 0b3679e

File tree

1 file changed

+159
-0
lines changed

1 file changed

+159
-0
lines changed

scripts/prod/update_config_and_restart_nodes_lib.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@
66
import sys
77
from abc import ABC, abstractmethod
88
from enum import Enum
9+
from time import sleep
910
from typing import Any, Callable, Optional
1011

12+
import signal
13+
import socket
1114
import tempfile
1215
import urllib.error
1316
import urllib.parse
1417
import urllib.request
1518
import yaml
1619
from difflib import unified_diff
20+
from prometheus_client.parser import text_string_to_metric_families
1721

1822

1923
class Colors(Enum):
@@ -212,6 +216,161 @@ def validate_arguments(args: argparse.Namespace) -> None:
212216
sys.exit(1)
213217

214218

219+
class MetricConditionGater:
220+
"""Gates progress on a metric satisfying a condition.
221+
222+
This class was meant to be used with counter/gauge metrics. It may not work properly with histogram metrics.
223+
"""
224+
225+
class MetricCondition:
226+
def __init__(
227+
self,
228+
value_condition: Callable[[Any], bool],
229+
condition_description: Optional[str] = None,
230+
):
231+
self.value_condition = value_condition
232+
self.condition_description = condition_description
233+
234+
def __init__(
235+
self,
236+
metric_name: str,
237+
namespace: str,
238+
cluster: Optional[str],
239+
pod: str,
240+
metrics_port: int,
241+
metric_value_condition: "MetricConditionGater.MetricCondition",
242+
refresh_interval_seconds: int = 3,
243+
):
244+
self.metric_name = metric_name
245+
self.local_port = self._get_free_port()
246+
self.namespace = namespace
247+
self.cluster = cluster
248+
self.pod = pod
249+
self.metrics_port = metrics_port
250+
self.metric_value_condition = metric_value_condition
251+
self.refresh_interval_seconds = refresh_interval_seconds
252+
253+
@staticmethod
254+
def _get_free_port():
255+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
256+
s.bind(("", 0))
257+
return s.getsockname()[1]
258+
259+
def _get_metrics_raw_string(self) -> str:
260+
while True:
261+
try:
262+
with urllib.request.urlopen(
263+
f"http://localhost:{self.local_port}/monitoring/metrics"
264+
) as response:
265+
if response.status == 200:
266+
return response.read().decode("utf-8")
267+
else:
268+
print_colored(
269+
f"Failed to get metrics for pod {self.pod}: {response.status}"
270+
)
271+
except urllib.error.URLError as e:
272+
print_colored(f"Failed to get metrics for pod {self.pod}: {e}")
273+
print_colored(
274+
f"Waiting {self.refresh_interval_seconds} seconds to retry getting metrics...",
275+
Colors.YELLOW,
276+
)
277+
sleep(self.refresh_interval_seconds)
278+
279+
def _poll_until_condition_met(self):
280+
"""Poll metrics until the condition is met for the metric."""
281+
condition_description = (
282+
f"({self.metric_value_condition.condition_description}) "
283+
if self.metric_value_condition.condition_description is not None
284+
else ""
285+
)
286+
287+
while True:
288+
metrics = self._get_metrics_raw_string()
289+
assert metrics is not None, f"Failed to get metrics from for pod {self.pod}"
290+
291+
metric_families = text_string_to_metric_families(metrics)
292+
val = None
293+
for metric_family in metric_families:
294+
if metric_family.name == self.metric_name:
295+
if len(metric_family.samples) > 1:
296+
print_error(
297+
f"Multiple samples found for metric {self.metric_name}. Using the first one.",
298+
)
299+
val = metric_family.samples[0].value
300+
break
301+
302+
if val is None:
303+
print_colored(
304+
f"Metric '{self.metric_name}' not found in pod {self.pod}. Assuming the node is not ready."
305+
)
306+
elif self.metric_value_condition.value_condition(val):
307+
print_colored(
308+
f"Metric {self.metric_name} condition {condition_description}met (value={val})."
309+
)
310+
return
311+
else:
312+
print_colored(
313+
f"Metric {self.metric_name} condition {condition_description}not met (value={val}). Continuing to wait."
314+
)
315+
316+
sleep(self.refresh_interval_seconds)
317+
318+
@staticmethod
319+
def _terminate_port_forward_process(pf_process: subprocess.Popen):
320+
if pf_process and pf_process.poll() is None:
321+
print_colored(f"Terminating kubectl port-forward process (PID: {pf_process.pid})")
322+
pf_process.terminate()
323+
try:
324+
pf_process.wait(timeout=5)
325+
except subprocess.TimeoutExpired:
326+
print_colored("Force killing kubectl port-forward process")
327+
pf_process.kill()
328+
pf_process.wait()
329+
330+
def gate(self):
331+
"""Wait until the nodes metrics satisfy the condition."""
332+
# This method:
333+
# 1. Starts kubectl port forwarding to the node and keep it running in the background so we can access the metrics.
334+
# 2. Calls _poll_until_condition_met.
335+
# 3. Terminates the port forwarding process when done or when interrupted.
336+
cmd = [
337+
"kubectl",
338+
"port-forward",
339+
f"pod/{self.pod}",
340+
f"{self.local_port}:{self.metrics_port}",
341+
]
342+
cmd.extend(get_namespace_args(self.namespace, self.cluster))
343+
344+
pf_process = None
345+
346+
try:
347+
pf_process = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
348+
print("Waiting for forwarding to start")
349+
# Give the forwarding time to start.
350+
# TODO(guy.f): Consider poll until the forwarding is ready if we see any issues.
351+
sleep(3)
352+
assert (
353+
pf_process.poll() is None
354+
), f"Port forwarding process exited with code {pf_process.returncode}"
355+
356+
print(
357+
f"Forwarding started (from local port {self.local_port} to {self.pod}:{self.metrics_port})"
358+
)
359+
360+
# Set up signal handler to ensure forwarding subprocess is terminated on interruption
361+
def signal_handler(signum, frame):
362+
self._terminate_port_forward_process(pf_process)
363+
sys.exit(0)
364+
365+
signal.signal(signal.SIGINT, signal_handler)
366+
signal.signal(signal.SIGTERM, signal_handler)
367+
368+
self._poll_until_condition_met()
369+
370+
finally:
371+
self._terminate_port_forward_process(pf_process)
372+
373+
215374
class NamespaceAndInstructionArgs:
216375
def __init__(
217376
self,

0 commit comments

Comments
 (0)