Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions jupyter_scheduler/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,16 @@ def execute(self):
"""
pass

@classmethod
@abstractmethod
def supported_features(cls) -> Dict[JobFeature, bool]:
def supported_features(self) -> Dict[JobFeature, bool]:
"""Returns a configuration of supported features
by the execution engine. Implementors are expected
to override this to return a dictionary of supported
job creation features.
"""
pass

@classmethod
def validate(cls, input_path: str) -> bool:
def validate(self, input_path: str) -> bool:
"""Returns True if notebook has valid metadata to execute, False otherwise"""
return True

Expand Down Expand Up @@ -132,10 +130,16 @@ def execute(self):
nb = add_parameters(nb, job.parameters)

staging_dir = os.path.dirname(self.staging_paths["input"])

ep = ExecutePreprocessor(
kernel_name=nb.metadata.kernelspec["name"], store_widget_state=True, cwd=staging_dir
kernel_name=nb.metadata.kernelspec["name"],
store_widget_state=True,
cwd=staging_dir,
)

if self.supported_features().get(JobFeature.track_cell_execution, False):
ep.on_cell_executed = self.__update_completed_cells_hook(ep)

try:
ep.preprocess(nb, {"metadata": {"path": staging_dir}})
except CellExecutionError as e:
Expand All @@ -144,6 +148,18 @@ def execute(self):
self.add_side_effects_files(staging_dir)
self.create_output_files(job, nb)

def __update_completed_cells_hook(self, ep: ExecutePreprocessor):
"""Returns a hook that runs on every cell execution, regardless of success or failure. Updates the completed_cells for the job."""

def update_completed_cells(cell, cell_index, execute_reply):
with self.db_session() as session:
session.query(Job).filter(Job.job_id == self.job_id).update(
{"completed_cells": ep.code_cells_executed}
)
session.commit()

return update_completed_cells

def add_side_effects_files(self, staging_dir: str):
"""Scan for side effect files potentially created after input file execution and update the job's packaged_files with these files"""
input_notebook = os.path.relpath(self.staging_paths["input"])
Expand Down Expand Up @@ -173,7 +189,7 @@ def create_output_files(self, job: DescribeJob, notebook_node):
with fsspec.open(self.staging_paths[output_format], "w", encoding="utf-8") as f:
f.write(output)

def supported_features(cls) -> Dict[JobFeature, bool]:
def supported_features(self) -> Dict[JobFeature, bool]:
return {
JobFeature.job_name: True,
JobFeature.output_formats: True,
Expand All @@ -188,9 +204,10 @@ def supported_features(cls) -> Dict[JobFeature, bool]:
JobFeature.output_filename_template: False,
JobFeature.stop_job: True,
JobFeature.delete_job: True,
JobFeature.track_cell_execution: False,
}

def validate(cls, input_path: str) -> bool:
def validate(self, input_path: str) -> bool:
with open(input_path, encoding="utf-8") as f:
nb = nbformat.read(f, as_version=4)
try:
Expand Down
3 changes: 3 additions & 0 deletions jupyter_scheduler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class DescribeJob(BaseModel):
downloaded: bool = False
package_input_folder: Optional[bool] = None
packaged_files: Optional[List[str]] = []
completed_cells: Optional[int] = None

class Config:
orm_mode = True
Expand Down Expand Up @@ -193,6 +194,7 @@ class UpdateJob(BaseModel):
status: Optional[Status] = None
name: Optional[str] = None
compute_type: Optional[str] = None
completed_cells: Optional[int] = None


class DeleteJob(BaseModel):
Expand Down Expand Up @@ -295,3 +297,4 @@ class JobFeature(str, Enum):
output_filename_template = "output_filename_template"
stop_job = "stop_job"
delete_job = "delete_job"
track_cell_execution = "track_cell_execution"
1 change: 1 addition & 0 deletions jupyter_scheduler/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class Job(CommonColumns, Base):
url = Column(String(256), default=generate_jobs_url)
pid = Column(Integer)
idempotency_token = Column(String(256))
completed_cells = Column(Integer, nullable=True)
# All new columns added to this table must be nullable to ensure compatibility during database migrations.
# Any default values specified for new columns will be ignored during the migration process.

Expand Down
2 changes: 1 addition & 1 deletion jupyter_scheduler/tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def execute(self):
def process(self):
pass

def supported_features(cls) -> Dict[JobFeature, bool]:
def supported_features(self) -> Dict[JobFeature, bool]:
return {
JobFeature.job_name: True,
JobFeature.output_formats: True,
Expand Down
253 changes: 253 additions & 0 deletions jupyter_scheduler/tests/test_execution_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import shutil
from pathlib import Path
from typing import Tuple
from unittest.mock import MagicMock, patch

import pytest

Expand Down Expand Up @@ -58,3 +59,255 @@ def test_add_side_effects_files(

job = jp_scheduler_db.query(Job).filter(Job.job_id == job_id).one()
assert side_effect_file_name in job.packaged_files


def test_default_execution_manager_cell_tracking_hook_not_set_by_default():
"""Test that DefaultExecutionManager does NOT set up on_cell_executed hook when track_cell_execution is disabled by default"""
job_id = "test-job-id"

with patch.object(DefaultExecutionManager, "model") as mock_model:
with patch("jupyter_scheduler.executors.open", mock=MagicMock()):
with patch("jupyter_scheduler.executors.nbformat.read") as mock_nb_read:
with patch.object(DefaultExecutionManager, "add_side_effects_files"):
with patch.object(DefaultExecutionManager, "create_output_files"):
# Mock notebook
mock_nb = MagicMock()
mock_nb.metadata.kernelspec = {"name": "python3"}
mock_nb_read.return_value = mock_nb

# Mock model
mock_model.parameters = None
mock_model.output_formats = []

# Create manager
manager = DefaultExecutionManager(
job_id=job_id,
root_dir="/test",
db_url="sqlite:///:memory:",
staging_paths={"input": "/test/input.ipynb"},
)

# Patch ExecutePreprocessor
with patch(
"jupyter_scheduler.executors.ExecutePreprocessor"
) as mock_ep_class:
mock_ep = MagicMock()
mock_ep_class.return_value = mock_ep

# Execute
manager.execute()

# Verify ExecutePreprocessor was created
mock_ep_class.assert_called_once()

# Verify patching method was never called
mock_model.__update_completed_cells_hook.assert_not_called()


def test_update_completed_cells_hook():
"""Test the __update_completed_cells_hook method"""
job_id = "test-job-id"

# Create manager
manager = DefaultExecutionManager(
job_id=job_id,
root_dir="/test",
db_url="sqlite:///:memory:",
staging_paths={"input": "/test/input.ipynb"},
)

# Mock db_session
mock_db_session = MagicMock()
mock_session_context = MagicMock()
mock_db_session.return_value.__enter__.return_value = mock_session_context
manager._db_session = mock_db_session

# Mock ExecutePreprocessor
mock_ep = MagicMock()
mock_ep.code_cells_executed = 5

# Get the hook function
hook_func = manager._DefaultExecutionManager__update_completed_cells_hook(mock_ep)

# Call the hook
mock_cell = MagicMock()
mock_execute_reply = MagicMock()
hook_func(mock_cell, 2, mock_execute_reply)

# Verify database update was called
mock_session_context.query.assert_called_once_with(Job)
mock_session_context.query.return_value.filter.return_value.update.assert_called_once_with(
{"completed_cells": 5}
)
mock_session_context.commit.assert_called_once()


def test_update_completed_cells_hook_database_error():
"""Test that database errors in the hook are handled"""
job_id = "test-job-id"

# Create manager
manager = DefaultExecutionManager(
job_id=job_id,
root_dir="/test",
db_url="sqlite:///:memory:",
staging_paths={"input": "/test/input.ipynb"},
)

# Mock db_session with error
mock_db_session = MagicMock()
mock_session_context = MagicMock()
mock_session_context.query.return_value.filter.return_value.update.side_effect = Exception(
"DB Error"
)
mock_db_session.return_value.__enter__.return_value = mock_session_context
manager._db_session = mock_db_session

# Mock ExecutePreprocessor
mock_ep = MagicMock()
mock_ep.code_cells_executed = 3

# Get the hook function
hook_func = manager._DefaultExecutionManager__update_completed_cells_hook(mock_ep)

# Call the hook - should raise exception
mock_cell = MagicMock()
mock_execute_reply = MagicMock()

with pytest.raises(Exception, match="DB Error"):
hook_func(mock_cell, 1, mock_execute_reply)


def test_supported_features_includes_track_cell_execution():
"""Test that DefaultExecutionManager supports track_cell_execution feature"""
manager = DefaultExecutionManager(
job_id="test-job-id",
root_dir="/test",
db_url="sqlite:///:memory:",
staging_paths={"input": "/test/input.ipynb"},
)
features = manager.supported_features()

from jupyter_scheduler.models import JobFeature

assert JobFeature.track_cell_execution in features
assert features[JobFeature.track_cell_execution] is False


def test_hook_uses_correct_job_id():
"""Test that the hook uses the correct job_id in database queries"""
job_id = "specific-job-id-456"

# Create manager
manager = DefaultExecutionManager(
job_id=job_id,
root_dir="/test",
db_url="sqlite:///:memory:",
staging_paths={"input": "/test/input.ipynb"},
)

# Mock db_session
mock_db_session = MagicMock()
mock_session_context = MagicMock()
mock_db_session.return_value.__enter__.return_value = mock_session_context
manager._db_session = mock_db_session

# Mock ExecutePreprocessor
mock_ep = MagicMock()
mock_ep.code_cells_executed = 7

# Get the hook function
hook_func = manager._DefaultExecutionManager__update_completed_cells_hook(mock_ep)

# Call the hook
mock_cell = MagicMock()
mock_execute_reply = MagicMock()
hook_func(mock_cell, 3, mock_execute_reply)

# Verify the correct job_id is used in the filter
# The filter call should contain a condition that matches Job.job_id == job_id
filter_call = mock_session_context.query.return_value.filter.call_args[0][0]
# This is a SQLAlchemy comparison object, so we need to check its properties
assert hasattr(filter_call, "right")
assert filter_call.right.value == job_id


def test_cell_tracking_disabled_when_feature_false():
"""Test that cell tracking hook is not set when track_cell_execution feature is False"""
job_id = "test-job-id"

# Create a custom execution manager class with track_cell_execution = False
class DisabledTrackingExecutionManager(DefaultExecutionManager):
def supported_features(self):
features = super().supported_features()
from jupyter_scheduler.models import JobFeature

features[JobFeature.track_cell_execution] = False
return features

# Create manager with disabled tracking
manager = DisabledTrackingExecutionManager(
job_id=job_id,
root_dir="/test",
db_url="sqlite:///:memory:",
staging_paths={"input": "/test/input.ipynb"},
)

# Mock ExecutePreprocessor and track calls to __update_completed_cells_hook
with patch.object(
manager, "_DefaultExecutionManager__update_completed_cells_hook"
) as mock_hook_method:
with patch.object(DisabledTrackingExecutionManager, "model") as mock_model:
with patch("jupyter_scheduler.executors.open", mock=MagicMock()):
with patch("jupyter_scheduler.executors.nbformat.read") as mock_nb_read:
with patch.object(DisabledTrackingExecutionManager, "add_side_effects_files"):
with patch.object(DisabledTrackingExecutionManager, "create_output_files"):
with patch(
"jupyter_scheduler.executors.ExecutePreprocessor"
) as mock_ep_class:
# Mock notebook
mock_nb = MagicMock()
mock_nb.metadata.kernelspec = {"name": "python3"}
mock_nb_read.return_value = mock_nb

# Mock model
mock_model.parameters = None
mock_model.output_formats = []

mock_ep = MagicMock()
mock_ep_class.return_value = mock_ep

# Execute
manager.execute()

# Verify ExecutePreprocessor was created
mock_ep_class.assert_called_once()

# Verify the hook method was NOT called when feature is disabled
mock_hook_method.assert_not_called()


def test_disabled_tracking_feature_support():
"""Test that custom execution manager can disable track_cell_execution feature"""

# Create a custom execution manager class with track_cell_execution = False
class DisabledTrackingExecutionManager(DefaultExecutionManager):
def supported_features(self):
features = super().supported_features()
from jupyter_scheduler.models import JobFeature

features[JobFeature.track_cell_execution] = False
return features

manager = DisabledTrackingExecutionManager(
job_id="test-job-id",
root_dir="/test",
db_url="sqlite:///:memory:",
staging_paths={"input": "/test/input.ipynb"},
)
features = manager.supported_features()

from jupyter_scheduler.models import JobFeature

assert JobFeature.track_cell_execution in features
assert features[JobFeature.track_cell_execution] is False
Loading
Loading