diff --git a/omlx/engine_core.py b/omlx/engine_core.py index e1fdf3eba..c94945231 100644 --- a/omlx/engine_core.py +++ b/omlx/engine_core.py @@ -78,6 +78,22 @@ def get_mlx_executor() -> concurrent.futures.ThreadPoolExecutor: return _global_mlx_executor +_wired_limit_set = False + + +def _ensure_wired_limit() -> None: + """Set Metal wired memory limit once at first engine creation. + + BatchGenerator normally calls mx.set_wired_limit() per-instance, which + races when multiple engines init concurrently (process-global setting). + We call it once here instead. + """ + global _wired_limit_set + if not _wired_limit_set and mx.metal.is_available(): + mx.set_wired_limit(mx.device_info()["max_recommended_working_set_size"]) + _wired_limit_set = True + + @dataclass class EngineConfig: """Configuration for the engine.""" @@ -133,12 +149,23 @@ def __init__( ) self._owns_model = True - # Create scheduler + # Per-engine executor with dedicated mx.Stream (#1248). + # Each EngineCore gets its own thread + GPU stream so different + # models can run scheduler.step() concurrently. + _ensure_wired_limit() + self._mlx_stream = mx.new_thread_local_stream(mx.default_device()) + self._mlx_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, + thread_name_prefix=f"mlx-engine-{self._engine_id[:8]}", + ) + + # Create scheduler with per-engine stream scheduler_config = self.config.scheduler_config or SchedulerConfig() self.scheduler = Scheduler( model=model, tokenizer=tokenizer, config=scheduler_config, + stream=self._mlx_stream, ) # Output collectors for low-latency streaming (vLLM pattern) @@ -152,11 +179,6 @@ def __init__( self._start_time: Optional[float] = None self._steps_executed = 0 - # Global single-thread executor shared across ALL engines. - # mlx-lm uses a module-level Metal stream, so concurrent MLX calls - # from different engine threads cause segfaults. See issue #85. - self._mlx_executor = get_mlx_executor() - logger.debug(f"Engine {self._engine_id} initialized") async def start(self) -> None: @@ -698,9 +720,9 @@ def close(self) -> None: self._closed = True - # Both shutdown() and deep_reset() touch generation_stream (directly + # Both shutdown() and deep_reset() touch the engine stream (directly # or via _drain_pending_async_removes / _do_abort_request). The - # stream is bound to the MLX executor thread, so dispatch both + # stream is bound to the engine's executor thread, so dispatch both # through the executor; fall back to a direct call if the executor # is already shut down. for fn in (self.scheduler.shutdown, self.scheduler.deep_reset): @@ -712,6 +734,10 @@ def close(self) -> None: except RuntimeError: pass + if self._mlx_executor is not None: + self._mlx_executor.shutdown(wait=True) + self._mlx_executor = None + # Clear output collectors for collector in self._output_collectors.values(): collector.clear() diff --git a/omlx/patches/mlx_lm_mtp/batch_generator.py b/omlx/patches/mlx_lm_mtp/batch_generator.py index d4a1da54e..301a87a85 100644 --- a/omlx/patches/mlx_lm_mtp/batch_generator.py +++ b/omlx/patches/mlx_lm_mtp/batch_generator.py @@ -308,24 +308,6 @@ class _MtpState: # Helpers # --------------------------------------------------------------------------- -def _get_generation_stream(): - """Return the ``mlx_lm.generate`` module-level generation stream. - - The standard ``GenerationBatch._step`` runs all forward passes inside - ``mx.stream(generation_stream)``; the MTP cycle does the same so the - paged cache writes land on the same stream and ordering is preserved. - The stream lives on the *outer* ``BatchGenerator``, not on - ``GenerationBatch``, so we read it from the module. - - Note: ``mlx_lm.__init__`` re-exports a ``generate`` *function*, so - ``import mlx_lm.generate as mlg`` resolves to the function, not the - module. We use ``sys.modules`` to grab the actual module. - """ - import sys - - return sys.modules["mlx_lm.generate"].generation_stream - - def _resolve_sampler(gen_batch: Any): """Match ``GenerationBatch._step``'s per-sequence sampler resolution (batch=1).""" if gen_batch.samplers and gen_batch.samplers[0] is not None: @@ -507,9 +489,9 @@ def _reconcile_mtp_to_standard(gen_batch: Any, state: _MtpState) -> bool: procs = _proc_list(gen_batch) _set_singleton_mrope_delta(gen_batch) tok_arr = _ensure_uint32(mx.array(list(tokens))) - with mx.stream(_get_generation_stream()): - logits, _, _ = _call_backbone(gen_batch.model, tok_arr[None, :], new_cache) - last_logits = logits[:, -1, :] # (1, vocab) — dist after tokens[-1] + # Inherits the per-engine stream from the enclosing BatchGenerator context. + logits, _, _ = _call_backbone(gen_batch.model, tok_arr[None, :], new_cache) + last_logits = logits[:, -1, :] # (1, vocab) — dist after tokens[-1] if state.queue: next_id, next_lp_1d, _src = state.queue[0] @@ -760,10 +742,10 @@ def _post_init_mtp(gen_batch: Any) -> None: # 1-token backbone forward at main_tok with hidden state. No draft yet, # so no rollback is possible — discard gdn_states. - with mx.stream(_get_generation_stream()): - logits, hidden, _ = _call_backbone( - gen_batch.model, main_tok[:, None], gen_batch.prompt_cache - ) + # Inherits the per-engine stream from the enclosing BatchGenerator context. + logits, hidden, _ = _call_backbone( + gen_batch.model, main_tok[:, None], gen_batch.prompt_cache + ) next_main_logits = logits[:, -1, :] # (1, vocab) — distribution after main_tok next_main_logits = _apply_processors(procs, prev_buf, next_main_logits) @@ -775,8 +757,7 @@ def _post_init_mtp(gen_batch: Any) -> None: mtp_cache = gen_batch.model.make_mtp_cache() hidden_at_main = hidden[:, -1:, :] # (1, 1, H) next_ids = next_main_tok.reshape(1, 1) - with mx.stream(_get_generation_stream()): - mtp_logits = gen_batch.model.mtp_forward(hidden_at_main, next_ids, mtp_cache) + mtp_logits = gen_batch.model.mtp_forward(hidden_at_main, next_ids, mtp_cache) mtp_logits_2d = mtp_logits[:, -1, :] if procs is not None: prev_with_main_and_next = mx.concatenate( @@ -923,15 +904,14 @@ def _run_verify_cycle(gen_batch: Any, state: _MtpState) -> None: # per cycle (negligible vs the forward compute) and keeps the # backbone_ms / sample_ms split accurate. t0 = time.perf_counter() - with mx.stream(_get_generation_stream()): - logits, hidden, gdn_states = _call_backbone( - gen_batch.model, - inputs[None, :], - gen_batch.prompt_cache, - n_confirmed=1, - ) - verify_logits = logits[:, 0, :] - bonus_logits = logits[:, 1, :] + logits, hidden, gdn_states = _call_backbone( + gen_batch.model, + inputs[None, :], + gen_batch.prompt_cache, + n_confirmed=1, + ) + verify_logits = logits[:, 0, :] + bonus_logits = logits[:, 1, :] mx.eval(logits) state.stats.backbone_ms += (time.perf_counter() - t0) * 1000 @@ -1086,11 +1066,10 @@ def _step_mtp( t0 = time.perf_counter() next_ids = next_main_tok.reshape(1, 1) - with mx.stream(_get_generation_stream()): - mtp_logits = gen_batch.model.mtp_forward( - hidden_at_position, next_ids, state.mtp_cache - ) - mtp_logits_2d = mtp_logits[:, -1, :] + mtp_logits = gen_batch.model.mtp_forward( + hidden_at_position, next_ids, state.mtp_cache + ) + mtp_logits_2d = mtp_logits[:, -1, :] if procs is not None and prev_buf is not None: prev_with_next = mx.concatenate( [prev_buf, _ensure_uint32(next_main_tok)] diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 251072350..e260790e4 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -48,6 +48,10 @@ from .utils.proc_memory import get_phys_footprint from .utils.sampling import make_sampler as omlx_make_sampler +# Module-level alias so Scheduler.__init__ can fall back to mlx-lm's default +# stream when no per-engine stream is provided. +_default_generation_stream = generation_stream + @dataclass class _VLMMTPDecodeState: @@ -98,7 +102,7 @@ class _VLMMTPResponse: _mx_buffer_access_lock = threading.RLock() -def _sync_and_clear_cache(): +def _sync_and_clear_cache(stream=None): """Synchronize in-flight GPU work before clearing the Metal buffer cache. Without synchronization, mx.clear_cache() can release Metal buffers that @@ -114,28 +118,30 @@ def _sync_and_clear_cache(): See: https://github.com/jundot/omlx/issues/300, #888, #1106 """ with _mx_buffer_access_lock: - # Generation_stream may not have in-flight work on the current thread + # The engine stream may not have in-flight work on the current thread # (e.g. external prefill submits to the default stream). On some MLX # builds mx.synchronize raises "There is no Stream(gpu, 0) in current # thread" in that case; swallow it since there is nothing to drain. + target = stream if stream is not None else _default_generation_stream try: - mx.synchronize(generation_stream) + mx.synchronize(target) except RuntimeError: pass mx.synchronize() # default stream mx.clear_cache() -def _safe_sync_generation_stream(): - """mx.synchronize(generation_stream) that tolerates cross-thread calls. +def _safe_sync_stream(stream=None): + """mx.synchronize(stream) that tolerates cross-thread calls. - Generation_stream is owned by the _mlx_executor thread. Teardown paths - that run on the main thread (via EngineCore.close) hit "no Stream in + The per-engine stream is owned by the engine's executor thread. Teardown + paths that run on the main thread (via EngineCore.close) hit "no Stream in current thread" RuntimeError. Swallow that specific case so cleanup can proceed; re-raise anything else so real GPU errors stay visible. """ + target = stream if stream is not None else _default_generation_stream try: - mx.synchronize(generation_stream) + mx.synchronize(target) except RuntimeError as e: if "no Stream" not in str(e): raise @@ -717,6 +723,7 @@ def __init__( model: Any, tokenizer: Any, config: SchedulerConfig | None = None, + stream: Any | None = None, ): """ Initialize the scheduler. @@ -725,6 +732,8 @@ def __init__( model: The MLX model tokenizer: The tokenizer config: Scheduler configuration + stream: Optional mx.Stream for this engine. Falls back to the + module-level _default_generation_stream when not provided. """ self.model = model # Deep-copy the tokenizer so the scheduler owns an independent Rust @@ -735,6 +744,7 @@ def __init__( # Rust RefCell. See: https://github.com/huggingface/tokenizers/issues/537 self.tokenizer = copy.deepcopy(tokenizer) self.config = copy.copy(config) if config else SchedulerConfig() + self._stream = stream if stream is not None else _default_generation_stream # Load additional EOS tokens from generation_config.json. # Some models (e.g. GLM-4.6V) define multiple EOS tokens there @@ -1117,18 +1127,18 @@ def _async_store_cache_worker( without blocking the inference thread. async_eval completes Metal command enqueueing before returning, so all commands are submitted by the time executor.submit() runs. - - This worker calls mx.synchronize(generation_stream) via the - _safe_sync_generation_stream helper to wait on the same - stream where mx.async_eval dispatched the arrays. A bare - mx.synchronize() with no args only blocks on the default - stream (gpu:0) and would leave the dispatched gpu:2 work + - This worker calls mx.synchronize(self._stream) via the + _safe_sync_stream helper to wait on the same stream where + mx.async_eval dispatched the arrays. A bare mx.synchronize() + with no args only blocks on the default stream (gpu:0) and + would leave the dispatched per-engine stream's work unsynchronized, racing the buffer-protocol access below - (#1437). Stream objects are not thread-local in MLX - (Metal device is a global singleton), so - mx.synchronize(stream) is safe cross-thread; it just calls - waitUntilCompleted on the command buffer. + (#1437). Stream objects are not thread-local in MLX (Metal + device is a global singleton), so mx.synchronize(stream) is + safe cross-thread; it just calls waitUntilCompleted on the + command buffer. - bfloat16 view+eval inside _extract_tensor_bytes runs on this - worker's default mx stream, isolated from generation_stream; + worker's default mx stream, isolated from self._stream; the underlying buffer is read-only at this point. - batch_generator.remove(uid) is deferred until this worker completes (handled by _drain_pending_async_removes). @@ -1145,7 +1155,7 @@ def _async_store_cache_worker( # buffer pool mid-read (#1106). with _mx_buffer_access_lock: with self._phase_timer("store_cache_worker_sync"): - _safe_sync_generation_stream() + _safe_sync_stream(self._stream) block_table = self.block_aware_cache.store_cache( request_id, token_sequence_to_store, @@ -1192,7 +1202,7 @@ def _drain_pending_async_removes(self) -> None: ) # Run batch_generator.remove on the inference thread. try: - _safe_sync_generation_stream() + _safe_sync_stream(self._stream) self._remove_uid_from_active_batch(uid) if hasattr(self.model, "unregister_rope_delta"): self.model.unregister_rope_delta(uid) @@ -1607,6 +1617,7 @@ def _create_batch_generator( prefill_batch_size=1, completion_batch_size=self.config.completion_batch_size, prefill_step_size=self.config.prefill_step_size, + stream=self._stream, ) return bg @@ -1968,7 +1979,7 @@ def _do_external_prefill( raise _PrefillAbortedError(abort_uids, processed_tokens) # Reclaim Metal intermediates between prefill chunks. - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) # Emit final boundary snapshot if prompt lands exactly on boundary. if boundary_enabled: @@ -1982,7 +1993,7 @@ def _do_external_prefill( request, prompt_cache, total_tokens ) - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) # Restore _rope_deltas after cached VLM prefill (for decode capture) if vlm_embeds is not None and _saved_rope_deltas is not None: @@ -2296,7 +2307,7 @@ def _step_prefill_chunk(self, state: _PrefillState) -> bool: f"{self._memory_hard_limit_bytes / 1024**3:.1f}GB)" ) - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) return state.tokens_remaining.shape[1] == 0 def _emit_final_boundary_if_needed(self, state: _PrefillState) -> None: @@ -2434,7 +2445,7 @@ def _advance_chunked_prefills( # Prefill complete — emit final boundary snapshot and insert. self._prefill_states.pop(rid, None) self._emit_final_boundary_if_needed(state) - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) # Ensure a BatchGenerator exists (may not if all requests were # previously in chunked prefill with no running decode). @@ -2942,12 +2953,12 @@ def _extract_boundary_snapshot(self, uid: int) -> list[Any] | None: return None try: - # Synchronize pending generation_stream operations before + # Synchronize pending engine stream operations before # accessing batch cache tensors. with self._phase_timer("boundary_capture_sync"): - _safe_sync_generation_stream() + _safe_sync_stream(self._stream) with self._phase_timer("boundary_capture_extract"): - with mx.stream(generation_stream): + with mx.stream(self._stream): result = self.batch_generator.extract_cache([uid]) if uid not in result: return None @@ -3818,7 +3829,7 @@ def _route_to_vlm_mtp( last_arr = mx.array(last_tokens)[None] # (1, len_last) try: - with mx.stream(generation_stream): + with mx.stream(self._stream): out = lm( last_arr, cache=prefilled_cache, @@ -3958,7 +3969,7 @@ def _step_vlm_mtp(self) -> list[_VLMMTPResponse]: responses: list[_VLMMTPResponse] = [] for uid, state in list(self._vlm_mtp_active.items()): try: - with mx.stream(generation_stream): + with mx.stream(self._stream): token_val = next(state.generator) except StopIteration: # Round loop exited naturally — terminate with prompt cache @@ -4195,11 +4206,11 @@ def _score_progress(processed: int, total: int, phase: str) -> None: logger.debug(f"SpecPrefill: draft cache store failed: {e}") # Free draft cache from memory. Use _sync_and_clear_cache() so - # the generation_stream is drained before Metal buffers are + # the engine stream is drained before Metal buffers are # returned to the pool — a bare mx.clear_cache() here can race # with in-flight async evals and trigger a kernel panic (#557). del used_cache - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) # Mark scoring complete (auto-removes tracker entry). tracker.update(request.request_id, n_to_score, n_to_score, model_id) @@ -4346,7 +4357,7 @@ def _do_abort_request(self, request_id: str) -> bool: # that replaces references to arrays still used by in-flight # Metal command buffers. Without this barrier the Metal driver # can hit 'completeMemory() prepare count underflow'. - _safe_sync_generation_stream() + _safe_sync_stream(self._stream) self._remove_uid_from_active_batch(uid) if hasattr(self.model, "unregister_rope_delta"): self.model.unregister_rope_delta(uid) @@ -4502,7 +4513,7 @@ def fail_all_requests(self) -> list[str]: # state — mx.synchronize() or mx.clear_cache() can throw a C++ # exception that causes SIGABRT if uncaught (#435). try: - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) except Exception as e: logger.warning(f"Metal cache clear failed during error recovery: {e}") return failed_ids @@ -4841,13 +4852,13 @@ def _check_specprefill_abort(processed: int) -> None: ) sys_arr = sys_arr[step:] # Use _sync_and_clear_cache() instead of bare - # mx.clear_cache() to flush the generation_stream + # mx.clear_cache() to flush the engine stream # before releasing Metal buffers. A bare call here # can race with in-flight command buffers submitted # by the preceding mx.eval(), triggering the same # 'completeMemory() prepare count underflow' kernel # panic that #435 fixed elsewhere (#557). - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) if sys_arr.size > 0: _check_specprefill_abort(sys_processed) final_sys = int(sys_arr.size) @@ -5025,7 +5036,7 @@ def _sparse_progress(processed: int, total: int) -> None: if done: self._emit_final_boundary_if_needed(state) - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) get_prefill_tracker().remove(request.request_id) self._insert_prefilled_request(request, state, scheduled) else: @@ -5424,12 +5435,12 @@ def _process_batch_responses( def _cleanup_finished(self, finished_ids: set[str]) -> None: """Clean up finished requests and store caches for reuse.""" - # Synchronize pending generation_stream operations before cache storage. + # Synchronize pending engine stream operations before cache storage. # store_cache -> mx.save_safetensors triggers implicit mx.eval() which # can conflict with async Metal operations on the generation stream. if finished_ids: with self._phase_timer("cleanup_finished_sync"): - _safe_sync_generation_stream() + _safe_sync_stream(self._stream) # SpecPrefill: restore original RoPE if active request finished for rid in finished_ids: @@ -5482,7 +5493,7 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None: # without blocking; the worker calls # mx.synchronize() to wait before extracting # bytes. - with mx.stream(generation_stream): + with mx.stream(self._stream): with self._phase_timer("store_cache_main_boundary"): boundary_override = self._get_boundary_store_override( request_id, @@ -5623,7 +5634,7 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None: # used by in-flight Metal command buffers from the previous # batch_generator.next() call. Without this barrier the Metal # driver can hit 'completeMemory() prepare count underflow'. - _safe_sync_generation_stream() + _safe_sync_stream(self._stream) self._remove_uid_from_active_batch(uid) if hasattr(self.model, "unregister_rope_delta"): self.model.unregister_rope_delta(uid) @@ -5906,7 +5917,7 @@ def step(self) -> SchedulerOutput: + len(responses) ) if self._tokens_since_clear_cache >= 1024: - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) self._tokens_since_clear_cache = 0 except _PrefillAbortedError: @@ -5977,7 +5988,7 @@ def step(self) -> SchedulerOutput: should_clear = True self._deferred_clear_at = None if should_clear: - _sync_and_clear_cache() + _sync_and_clear_cache(self._stream) if ( self.config.gc_cleanup_interval > 0 and self._step_counter % self.config.gc_cleanup_interval == 0 diff --git a/tests/test_engine_core.py b/tests/test_engine_core.py index 48abc6df1..c13aaed61 100644 --- a/tests/test_engine_core.py +++ b/tests/test_engine_core.py @@ -978,10 +978,8 @@ def test_get_mlx_executor_returns_singleton(self): executor2 = get_mlx_executor() assert executor1 is executor2 - def test_engines_share_mlx_executor(self, mock_model, mock_tokenizer): - """Multiple EngineCore instances must share a single MLX executor (#85).""" - from omlx.engine_core import get_mlx_executor - + def test_engines_have_per_engine_executors(self, mock_model, mock_tokenizer): + """Each EngineCore must have its own executor (#1248).""" with patch("omlx.engine_core.get_registry") as mock_registry: mock_registry.return_value.acquire.return_value = True @@ -989,8 +987,7 @@ def test_engines_share_mlx_executor(self, mock_model, mock_tokenizer): engine2 = EngineCore(model=mock_model, tokenizer=mock_tokenizer) try: - assert engine1._mlx_executor is engine2._mlx_executor - assert engine1._mlx_executor is get_mlx_executor() + assert engine1._mlx_executor is not engine2._mlx_executor finally: engine1.close() engine2.close() @@ -1046,13 +1043,14 @@ def simulated_step(task_id: str, duration: float = 0.05): ) @pytest.mark.asyncio - async def test_two_engine_loops_serialize_on_shared_executor( + async def test_two_engine_loops_run_concurrently_on_separate_executors( self, mock_model, mock_tokenizer ): - """Two engines running their loops must serialize step() calls (#85). + """Two engines with per-engine executors can run step() concurrently (#1248). - Creates two EngineCore instances with mock schedulers, starts both - engine loops, and verifies their scheduler.step() calls never overlap. + Each EngineCore has its own ThreadPoolExecutor and mx.Stream, so their + scheduler.step() calls can overlap. This test verifies that two engines + actually achieve concurrent execution. """ import threading import time @@ -1107,8 +1105,9 @@ def tracked_step(): assert total_steps >= 4, ( f"Expected at least 4 steps from two engines, got {total_steps}" ) - assert max_concurrent == 1, ( - f"Expected max 1 concurrent step(), got {max_concurrent}. " - f"Two engines ran MLX operations in parallel — would cause " - f"Metal command buffer races in production." + # With per-engine executors (#1248), two engines CAN run concurrently. + # max_concurrent >= 2 means both engines overlapped at least once. + assert max_concurrent >= 2, ( + f"Expected concurrent execution (max_concurrent >= 2), got {max_concurrent}. " + f"Per-engine executors should allow parallel step() calls." ) diff --git a/tests/test_per_engine_threads.py b/tests/test_per_engine_threads.py new file mode 100644 index 000000000..284e426bd --- /dev/null +++ b/tests/test_per_engine_threads.py @@ -0,0 +1,217 @@ +"""Tests for per-engine thread isolation (issue #1248).""" + +from unittest.mock import MagicMock, patch + +import mlx.core as mx +import pytest + +from omlx.engine_core import EngineCore +from omlx.scheduler import Scheduler, SchedulerConfig + + +class TestSchedulerStreamParam: + """Scheduler must accept an explicit stream and use it instead of the + module-level generation_stream.""" + + def test_scheduler_stores_explicit_stream(self): + mock_model = MagicMock() + mock_model.model_type = "test" + mock_tokenizer = MagicMock() + mock_tokenizer.eos_token_id = 0 + + stream = mx.new_thread_local_stream(mx.default_device()) + scheduler = Scheduler( + model=mock_model, + tokenizer=mock_tokenizer, + stream=stream, + ) + assert scheduler._stream is stream + + def test_scheduler_defaults_to_generation_stream(self): + from omlx.scheduler import _default_generation_stream + + mock_model = MagicMock() + mock_model.model_type = "test" + mock_tokenizer = MagicMock() + mock_tokenizer.eos_token_id = 0 + + scheduler = Scheduler( + model=mock_model, + tokenizer=mock_tokenizer, + ) + assert scheduler._stream is _default_generation_stream + + +class TestSchedulerStreamIsolation: + """Scheduler must use self._stream in all GPU stream operations, + never the module-level generation_stream.""" + + def test_no_module_level_generation_stream_in_hot_path(self): + """After migration, scheduler.py should not reference the module-level + generation_stream anywhere in the Scheduler class body except the + __init__ default fallback and comments/docstrings.""" + import inspect + import re + + import omlx.scheduler as sched_mod + source = inspect.getsource(sched_mod.Scheduler) + + # Find bare generation_stream references that aren't: + # - _default_generation_stream (the import alias) + # - Part of a larger word + bare_refs = re.findall( + r'(?