Skip to content

Commit 4c56ac5

Browse files
scripts: A class which gates progress on metrics satisifying a condition
1 parent 5c16fef commit 4c56ac5

File tree

1 file changed

+159
-0
lines changed

1 file changed

+159
-0
lines changed

scripts/prod/update_config_and_restart_nodes_lib.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33
from abc import ABC, abstractmethod
44
import argparse
55
import json
6+
import signal
7+
import socket
68
import subprocess
79
import sys
810
from enum import Enum
11+
from time import sleep
912
from typing import Any, Callable, Optional
1013

1114
import tempfile
1215
import urllib.error
1316
import urllib.parse
1417
import urllib.request
18+
from prometheus_client.parser import text_string_to_metric_families
1519
import yaml
1620
from difflib import unified_diff
1721

@@ -212,6 +216,161 @@ def validate_arguments(args: argparse.Namespace) -> None:
212216
sys.exit(1)
213217

214218

219+
class MetricConditionGater:
220+
"""Gates progress on a metric satisfying a condition.
221+
222+
This class was meant to be used with counter/gauge metrics. It may not work properly with histogram metrics.
223+
"""
224+
225+
class MetricCondition:
226+
def __init__(
227+
self,
228+
value_condition: Callable[[Any], bool],
229+
condition_description: Optional[str] = None,
230+
):
231+
self.value_condition = value_condition
232+
self.condition_description = condition_description
233+
234+
def __init__(
235+
self,
236+
metric_name: str,
237+
namespace: str,
238+
cluster: Optional[str],
239+
pod: str,
240+
metrics_port: int,
241+
refresh_interval_seconds: int = 3,
242+
):
243+
self.metric_name = metric_name
244+
self.local_port = self._get_free_port()
245+
self.namespace = namespace
246+
self.cluster = cluster
247+
self.pod = pod
248+
self.metrics_port = metrics_port
249+
self.refresh_interval_seconds = refresh_interval_seconds
250+
251+
@staticmethod
252+
def _get_free_port():
253+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
254+
s.bind(("", 0))
255+
return s.getsockname()[1]
256+
257+
def _get_metrics_raw_string(self) -> str:
258+
while True:
259+
try:
260+
with urllib.request.urlopen(
261+
f"http://localhost:{self.local_port}/monitoring/metrics"
262+
) as response:
263+
if response.status == 200:
264+
return response.read().decode("utf-8")
265+
else:
266+
print_colored(
267+
f"Failed to get metrics for pod {self.pod}: {response.status}"
268+
)
269+
except urllib.error.URLError as e:
270+
print_colored(f"Failed to get metrics for pod {self.pod}: {e}")
271+
print_colored(
272+
f"Waiting {self.refresh_interval_seconds} seconds to retry getting metrics...",
273+
Colors.YELLOW,
274+
)
275+
sleep(self.refresh_interval_seconds)
276+
277+
def _poll_until_condition_met(
278+
self, metric_value_condition: "MetricConditionGater.MetricCondition"
279+
):
280+
"""Poll metrics until the condition is met for the metric."""
281+
condition_description = (
282+
f"({metric_value_condition.condition_description}) "
283+
if metric_value_condition.condition_description is not None
284+
else ""
285+
)
286+
287+
while True:
288+
metrics = self._get_metrics_raw_string()
289+
assert metrics is not None, f"Failed to get metrics from for pod {self.pod}"
290+
291+
metric_families = text_string_to_metric_families(metrics)
292+
val = None
293+
for metric_family in metric_families:
294+
if metric_family.name == self.metric_name:
295+
if len(metric_family.samples) > 1:
296+
print_error(
297+
f"Multiple samples found for metric {self.metric_name}. Using the first one.",
298+
)
299+
val = metric_family.samples[0].value
300+
break
301+
302+
if val is None:
303+
print_colored(
304+
f"Metric '{self.metric_name}' not found in pod {self.pod}. Assuming the node is not ready."
305+
)
306+
elif metric_value_condition.value_condition(val):
307+
print_colored(
308+
f"Metric {self.metric_name} condition {condition_description}met (value={val})."
309+
)
310+
return
311+
else:
312+
print_colored(
313+
f"Metric {self.metric_name} condition {condition_description}not met (value={val}). Continuing to wait."
314+
)
315+
316+
sleep(self.refresh_interval_seconds)
317+
318+
@staticmethod
319+
def _terminate_port_forward_process(pf_process: subprocess.Popen):
320+
if pf_process and pf_process.poll() is None:
321+
print_colored(f"Terminating kubectl port-forward process (PID: {pf_process.pid})")
322+
pf_process.terminate()
323+
try:
324+
pf_process.wait(timeout=5)
325+
except subprocess.TimeoutExpired:
326+
print_colored("Force killing kubectl port-forward process")
327+
pf_process.kill()
328+
pf_process.wait()
329+
330+
def gate(self, metric_value_condition: "MetricConditionGater.MetricCondition"):
331+
"""Wait until the nodes metrics satisfy the condition."""
332+
# This method:
333+
# 1. Starts kubectl port forwarding to the node and keep it running in the background so we can access the metrics.
334+
# 2. Calls _poll_until_condition_met.
335+
# 3. Terminates the port forwarding process when done or when interrupted.
336+
cmd = [
337+
"kubectl",
338+
"port-forward",
339+
f"pod/{self.pod}",
340+
f"{self.local_port}:{self.metrics_port}",
341+
]
342+
cmd.extend(get_namespace_args(self.namespace, self.cluster))
343+
344+
pf_process = None
345+
346+
try:
347+
pf_process = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
348+
print("Waiting for forwarding to start")
349+
# Give the forwarding time to start.
350+
# TODO(guy.f): Consider poll until the forwarding is ready if we see any issues.
351+
sleep(3)
352+
assert (
353+
pf_process.poll() is None
354+
), f"Port forwarding process exited with code {pf_process.returncode}"
355+
356+
print(
357+
f"Forwarding started (from local port {self.local_port} to {self.pod}:{self.metrics_port})"
358+
)
359+
360+
# Set up signal handler to ensure forwarding subprocess is terminated on interruption
361+
def signal_handler(signum, frame):
362+
self._terminate_port_forward_process(pf_process)
363+
sys.exit(0)
364+
365+
signal.signal(signal.SIGINT, signal_handler)
366+
signal.signal(signal.SIGTERM, signal_handler)
367+
368+
self._poll_until_condition_met(metric_value_condition)
369+
370+
finally:
371+
self._terminate_port_forward_process(pf_process)
372+
373+
215374
class NamespaceAndInstructionArgs:
216375
def __init__(
217376
self,

0 commit comments

Comments
 (0)