From f8150c014a386306069301169f19acccafea2b21 Mon Sep 17 00:00:00 2001 From: joshlee Date: Wed, 22 Oct 2025 21:31:10 +0000 Subject: [PATCH 01/18] Make CancelTask RPC Fault Tolerant Signed-off-by: joshlee --- .../tests/test_core_worker_fault_tolerance.py | 50 +++++++++++ src/ray/core_worker/core_worker_process.cc | 2 + .../core_worker/task_submission/BUILD.bazel | 3 + .../task_submission/actor_task_submitter.cc | 53 +++++++----- .../task_submission/actor_task_submitter.h | 13 ++- .../task_submission/normal_task_submitter.cc | 23 ++--- .../task_submission/normal_task_submitter.h | 5 +- src/ray/protobuf/core_worker.proto | 5 +- src/ray/protobuf/node_manager.proto | 23 +++++ src/ray/raylet/node_manager.cc | 85 +++++++++++++++++++ src/ray/raylet/node_manager.h | 4 + src/ray/raylet_rpc_client/raylet_client.cc | 12 +++ src/ray/raylet_rpc_client/raylet_client.h | 4 + .../raylet_client_interface.h | 4 + .../rpc/node_manager/node_manager_server.h | 7 +- 15 files changed, 256 insertions(+), 37 deletions(-) diff --git a/python/ray/tests/test_core_worker_fault_tolerance.py b/python/ray/tests/test_core_worker_fault_tolerance.py index 6ab8cf9ba5de..ebc4f25e88ad 100644 --- a/python/ray/tests/test_core_worker_fault_tolerance.py +++ b/python/ray/tests/test_core_worker_fault_tolerance.py @@ -1,3 +1,4 @@ +import os import sys import numpy as np @@ -9,6 +10,8 @@ from ray.exceptions import GetTimeoutError, TaskCancelledError from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy +import psutil + @pytest.mark.parametrize( "allow_out_of_order_execution", @@ -207,5 +210,52 @@ def remote_wait(sg): ray.get(inner, timeout=10) +@pytest.fixture +def inject_cancel_local_task_rpc_failure(monkeypatch, request): + deterministic_failure = request.param + monkeypatch.setenv( + "RAY_testing_rpc_failure", + "NodeManagerService.grpc_client.CancelLocalTask=1:" + + ("100:0" if deterministic_failure == "request" else "0:100"), + ) + + +@pytest.mark.parametrize( + "inject_cancel_local_task_rpc_failure", ["request", "response"], indirect=True +) +@pytest.mark.parametrize("force_kill", [True, False]) +def test_cancel_local_task_rpc_retry_and_idempotency( + inject_cancel_local_task_rpc_failure, force_kill, shutdown_only +): + ray.init(num_cpus=2) + signaler = SignalActor.remote() + + @ray.remote(num_cpus=1) + def get_pid(): + return os.getpid() + + @ray.remote(num_cpus=1) + def blocking_task(): + return ray.get(signaler.wait.remote()) + + worker_pid = ray.get(get_pid.remote()) + + blocking_ref = blocking_task.remote() + + with pytest.raises(GetTimeoutError): + ray.get(blocking_ref, timeout=1) + + ray.cancel(blocking_ref, force=force_kill) + + with pytest.raises(TaskCancelledError): + ray.get(blocking_ref, timeout=10) + if force_kill: + + def verify_process_killed(): + return not psutil.pid_exists(worker_pid) + + wait_for_condition(verify_process_killed, timeout=30) + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) diff --git a/src/ray/core_worker/core_worker_process.cc b/src/ray/core_worker/core_worker_process.cc index 3ebdf00877f5..d5e886662b9c 100644 --- a/src/ray/core_worker/core_worker_process.cc +++ b/src/ray/core_worker/core_worker_process.cc @@ -498,6 +498,8 @@ std::shared_ptr CoreWorkerProcessImpl::CreateCoreWorker( auto actor_task_submitter = std::make_unique( *core_worker_client_pool, + *raylet_client_pool, + gcs_client, *memory_store, *task_manager, *actor_creator, diff --git a/src/ray/core_worker/task_submission/BUILD.bazel b/src/ray/core_worker/task_submission/BUILD.bazel index 387fba21552c..8d6ac2dc9a1f 100644 --- a/src/ray/core_worker/task_submission/BUILD.bazel +++ b/src/ray/core_worker/task_submission/BUILD.bazel @@ -71,6 +71,9 @@ ray_cc_library( "//src/ray/common:protobuf_utils", "//src/ray/core_worker:actor_creator", "//src/ray/core_worker_rpc_client:core_worker_client_pool", + "//src/ray/gcs_rpc_client:gcs_client", + "//src/ray/raylet_rpc_client:raylet_client_interface", + "//src/ray/raylet_rpc_client:raylet_client_pool", "//src/ray/rpc:rpc_callback_types", "//src/ray/util:time", "@com_google_absl//absl/base:core_headers", diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.cc b/src/ray/core_worker/task_submission/actor_task_submitter.cc index e0cc2fb20d75..b8589ef28196 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.cc +++ b/src/ray/core_worker/task_submission/actor_task_submitter.cc @@ -272,6 +272,7 @@ void ActorTaskSubmitter::CancelDependencyResolution(const TaskID &task_id) { void ActorTaskSubmitter::DisconnectRpcClient(ClientQueue &queue) { queue.client_address_ = std::nullopt; + queue.raylet_address_ = std::nullopt; // If the actor on the worker is dead, the worker is also dead. core_worker_client_pool_.Disconnect(WorkerID::FromBinary(queue.worker_id_)); queue.worker_id_.clear(); @@ -336,7 +337,12 @@ void ActorTaskSubmitter::ConnectActor(const ActorID &actor_id, // So new RPCs go out with the right intended worker id to the right address. queue->second.worker_id_ = address.worker_id(); queue->second.client_address_ = address; + NodeID node_id = NodeID::FromBinary(address.node_id()); + auto node_info = gcs_client_->Nodes().GetNodeAddressAndLiveness(node_id); + RAY_CHECK(node_info != nullptr); + queue->second.raylet_address_ = rpc::RayletClientPool::GenerateRayletAddress( + node_id, node_info->node_manager_address(), node_info->node_manager_port()); SendPendingTasks(actor_id); } @@ -1008,32 +1014,35 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) return; } - rpc::CancelTaskRequest request; + rpc::CancelLocalTaskRequest request; request.set_intended_task_id(task_spec.TaskIdBinary()); request.set_force_kill(force_kill); request.set_recursive(recursive); request.set_caller_worker_id(task_spec.CallerWorkerIdBinary()); - auto client = core_worker_client_pool_.GetOrConnect(*queue->second.client_address_); - client->CancelTask(request, - [this, task_spec = std::move(task_spec), recursive, task_id]( - const Status &status, const rpc::CancelTaskReply &reply) { - RAY_LOG(DEBUG).WithField(task_spec.TaskId()) - << "CancelTask RPC response received with status " - << status.ToString(); - - // Keep retrying every 2 seconds until a task is officially - // finished. - if (!task_manager_.GetTaskSpec(task_id)) { - // Task is already finished. - RAY_LOG(DEBUG).WithField(task_spec.TaskId()) - << "Task is finished. Stop a cancel request."; - return; - } - - if (!reply.attempt_succeeded()) { - RetryCancelTask(task_spec, recursive, 2000); - } - }); + request.set_executor_worker_address( + queue->second.client_address_->SerializeAsString()); + auto raylet_client = + raylet_client_pool_.GetOrConnectByAddress(*queue->second.raylet_address_); + raylet_client->CancelLocalTask( + request, + [this, task_spec = std::move(task_spec), recursive, task_id]( + const Status &status, const rpc::CancelLocalTaskReply &reply) { + RAY_LOG(DEBUG).WithField(task_spec.TaskId()) + << "CancelTask RPC response received with status " << status.ToString(); + + // Keep retrying every 2 seconds until a task is officially + // finished. + if (!task_manager_.GetTaskSpec(task_id)) { + // Task is already finished. + RAY_LOG(DEBUG).WithField(task_spec.TaskId()) + << "Task is finished. Stop a cancel request."; + return; + } + + if (!reply.attempt_succeeded()) { + RetryCancelTask(task_spec, recursive, 2000); + } + }); } } diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.h b/src/ray/core_worker/task_submission/actor_task_submitter.h index f225397768be..a2089d5376d5 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.h +++ b/src/ray/core_worker/task_submission/actor_task_submitter.h @@ -67,6 +67,8 @@ class ActorTaskSubmitterInterface { class ActorTaskSubmitter : public ActorTaskSubmitterInterface { public: ActorTaskSubmitter(rpc::CoreWorkerClientPool &core_worker_client_pool, + rpc::RayletClientPool &raylet_client_pool, + std::shared_ptr gcs_client, CoreWorkerMemoryStore &store, TaskManagerInterface &task_manager, ActorCreatorInterface &actor_creator, @@ -76,6 +78,8 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface { instrumented_io_context &io_service, std::shared_ptr reference_counter) : core_worker_client_pool_(core_worker_client_pool), + raylet_client_pool_(raylet_client_pool), + gcs_client_(std::move(gcs_client)), actor_creator_(actor_creator), resolver_(store, task_manager, actor_creator, tensor_transport_getter), task_manager_(task_manager), @@ -300,8 +304,10 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface { int64_t num_restarts_due_to_lineage_reconstructions_ = 0; /// Whether this actor exits by spot preemption. bool preempted_ = false; - /// The RPC client address. + /// The RPC client address of the actor. std::optional client_address_; + /// The local raylet addres of the actor. + std::optional raylet_address_; /// The intended worker ID of the actor. std::string worker_id_; /// The actor is out of scope but the death info is not published @@ -411,6 +417,11 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface { /// Pool for producing new core worker clients. rpc::CoreWorkerClientPool &core_worker_client_pool_; + /// Pool for producing new raylet clients. + rpc::RayletClientPool &raylet_client_pool_; + + std::shared_ptr gcs_client_; + ActorCreatorInterface &actor_creator_; /// Mutex to protect the various maps below. diff --git a/src/ray/core_worker/task_submission/normal_task_submitter.cc b/src/ray/core_worker/task_submission/normal_task_submitter.cc index 49191d9b8077..77eb3da2c882 100644 --- a/src/ray/core_worker/task_submission/normal_task_submitter.cc +++ b/src/ray/core_worker/task_submission/normal_task_submitter.cc @@ -179,7 +179,8 @@ void NormalTaskSubmitter::OnWorkerIdle( task_spec.GetMutableMessage().set_lease_grant_timestamp_ms(current_sys_time_ms()); task_spec.EmitTaskMetrics(); - executing_tasks_.emplace(task_spec.TaskId(), addr); + executing_tasks_.emplace(task_spec.TaskId(), + std::make_pair(addr, lease_entry.addr)); PushNormalTask( addr, client, scheduling_key, std::move(task_spec), assigned_resources); } @@ -665,7 +666,8 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, SchedulingKey scheduling_key(task_spec.GetSchedulingClass(), task_spec.GetDependencyIds(), task_spec.GetRuntimeEnvHash()); - std::shared_ptr client = nullptr; + std::shared_ptr raylet_client = nullptr; + rpc::Address executor_worker_address; { absl::MutexLock lock(&mu_); generators_to_resubmit_.erase(task_id); @@ -700,9 +702,8 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, // This will get removed either when the RPC call to cancel is returned, when all // dependencies are resolved, or when dependency resolution is successfully cancelled. RAY_CHECK(cancelled_tasks_.emplace(task_id).second); - auto rpc_client = executing_tasks_.find(task_id); - - if (rpc_client == executing_tasks_.end()) { + auto rpc_client_address = executing_tasks_.find(task_id); + if (rpc_client_address == executing_tasks_.end()) { if (failed_tasks_pending_failure_cause_.contains(task_id)) { // We are waiting for the task failure cause. Do not fail it here; instead, // wait for the cause to come in and then handle it appropriately. @@ -723,22 +724,24 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, return; } // Looks for an RPC handle for the worker executing the task. - client = core_worker_client_pool_->GetOrConnect(rpc_client->second); + raylet_client = + raylet_client_pool_->GetOrConnectByAddress(rpc_client_address->second.second); + executor_worker_address = rpc_client_address->second.first; } - RAY_CHECK(client != nullptr); - auto request = rpc::CancelTaskRequest(); + auto request = rpc::CancelLocalTaskRequest(); request.set_intended_task_id(task_spec.TaskIdBinary()); request.set_force_kill(force_kill); request.set_recursive(recursive); request.set_caller_worker_id(task_spec.CallerWorkerIdBinary()); - client->CancelTask( + request.set_executor_worker_address(executor_worker_address.SerializeAsString()); + raylet_client->CancelLocalTask( request, [this, task_spec = std::move(task_spec), scheduling_key = std::move(scheduling_key), force_kill, - recursive](const Status &status, const rpc::CancelTaskReply &reply) mutable { + recursive](const Status &status, const rpc::CancelLocalTaskReply &reply) mutable { absl::MutexLock lock(&mu_); RAY_LOG(DEBUG) << "CancelTask RPC response received for " << task_spec.TaskId() << " with status " << status.ToString(); diff --git a/src/ray/core_worker/task_submission/normal_task_submitter.h b/src/ray/core_worker/task_submission/normal_task_submitter.h index 60e2c6db01cb..9ec27f9c45a8 100644 --- a/src/ray/core_worker/task_submission/normal_task_submitter.h +++ b/src/ray/core_worker/task_submission/normal_task_submitter.h @@ -348,7 +348,10 @@ class NormalTaskSubmitter { absl::flat_hash_set cancelled_tasks_ ABSL_GUARDED_BY(mu_); // Keeps track of where currently executing tasks are being run. - absl::flat_hash_map executing_tasks_ ABSL_GUARDED_BY(mu_); + // The first address is the executor, the second address is the local raylet of the + // executor. + absl::flat_hash_map> executing_tasks_ + ABSL_GUARDED_BY(mu_); // Generators that are currently running and need to be resubmitted. absl::flat_hash_set generators_to_resubmit_ ABSL_GUARDED_BY(mu_); diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 2bec20b67ffe..aa0c05ce21f5 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -529,8 +529,9 @@ service CoreWorkerService { // Failure: TODO: Never retries rpc KillActor(KillActorRequest) returns (KillActorReply); - // Request from owner worker to executor worker to cancel a task. - // Failure: Will retry, TODO: Needs tests for failure behavior. + // Request from local raylet to executor worker to cancel a task. + // Failure: Idempotent, does not retry for network failures. However requests + // should only be sent via CancelLocalTask from the raylet which does implement retries rpc CancelTask(CancelTaskRequest) returns (CancelTaskReply); // Request from a worker to the owner worker to issue a cancellation. diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto index 05a3ea9be154..83f1b85200f6 100644 --- a/src/ray/protobuf/node_manager.proto +++ b/src/ray/protobuf/node_manager.proto @@ -406,6 +406,26 @@ message GetWorkerPIDsReply { repeated int32 pids = 1; } +message CancelLocalTaskRequest { + // ID of task that should be killed. + bytes intended_task_id = 1; + // Whether to kill the worker. + bool force_kill = 2; + // Whether to recursively cancel tasks. + bool recursive = 3; + // The worker ID of the caller. + bytes caller_worker_id = 4; + // The worker address of the executor. + bytes executor_worker_address = 5; +} + +message CancelLocalTaskReply { + // Whether the requested task is the currently running task. + bool requested_task_running = 1; + // Whether the task is canceled. + bool attempt_succeeded = 2; +} + // Service for inter-node-manager communication. service NodeManagerService { // Handle the case when GCS restarted. @@ -525,4 +545,7 @@ service NodeManagerService { // Failure: Will retry with the default timeout 1000ms. If fails, reply return an empty // list. rpc GetWorkerPIDs(GetWorkerPIDsRequest) returns (GetWorkerPIDsReply); + // Forwards the CancelTask request from the caller core worker to the executor + // Failure: Retries, it's idempotent. + rpc CancelLocalTask(CancelLocalTaskRequest) returns (CancelLocalTaskReply); } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 90eaa40d4985..6537fed35a98 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -3339,4 +3339,89 @@ std::unique_ptr NodeManager::CreateRuntimeEnvAgentManager( add_process_to_system_cgroup_hook_); } +void NodeManager::HandleCancelLocalTask(rpc::CancelLocalTaskRequest request, + rpc::CancelLocalTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) { + rpc::Address executor_address; + if (!executor_address.ParseFromString(request.executor_worker_address())) { + send_reply_callback( + Status::Invalid("Failed to parse executor worker address"), nullptr, nullptr); + return; + } + + auto worker = worker_pool_.GetRegisteredWorker( + WorkerID::FromBinary(executor_address.worker_id())); + // If the worker is not registered, then it must have already been killed + if (!worker || worker->IsDead()) { + reply->set_attempt_succeeded(true); + reply->set_requested_task_running(false); + send_reply_callback(Status::OK(), nullptr, nullptr); + return; + } + + WorkerID worker_id = worker->WorkerId(); + + rpc::CancelTaskRequest cancel_task_request; + cancel_task_request.set_intended_task_id(request.intended_task_id()); + cancel_task_request.set_force_kill(request.force_kill()); + cancel_task_request.set_recursive(request.recursive()); + cancel_task_request.set_caller_worker_id(request.caller_worker_id()); + if (!request.force_kill()) { + worker->rpc_client()->CancelTask( + cancel_task_request, + [reply, send_reply_callback](const Status &status, + const rpc::CancelTaskReply &cancel_task_reply) { + if (!status.ok()) { + send_reply_callback( + Status::Invalid("Failed to cancel task"), nullptr, nullptr); + } else { + reply->set_attempt_succeeded(cancel_task_reply.attempt_succeeded()); + reply->set_requested_task_running(cancel_task_reply.requested_task_running()); + send_reply_callback(Status::OK(), nullptr, nullptr); + } + }); + return; + } + auto timer = execute_after( + io_service_, + [this, reply, send_reply_callback, worker_id]() { + auto current_worker = worker_pool_.GetRegisteredWorker(worker_id); + if (current_worker) { + // If the worker is still alive, force kill it + RAY_LOG(INFO) << "Worker with PID=" << current_worker->GetProcess().GetId() + << " did not exit after " + << RayConfig::instance().kill_worker_timeout_milliseconds() + << "ms, force killing with SIGKILL."; + DestroyWorker(current_worker, + rpc::WorkerExitType::INTENDED_SYSTEM_EXIT, + "Actor killed by GCS", + /*force=*/true); + } + reply->set_attempt_succeeded(true); + reply->set_requested_task_running(false); + send_reply_callback(Status::OK(), nullptr, nullptr); + }, + std::chrono::milliseconds( + RayConfig::instance().kill_worker_timeout_milliseconds())); + + worker->rpc_client()->CancelTask( + cancel_task_request, + [task_id = request.intended_task_id(), timer, reply, send_reply_callback]( + const ray::Status &status, const rpc::CancelTaskReply &cancel_task_reply) { + if (!status.ok()) { + std::ostringstream stream; + stream << "CancelTask RPC failed for task " << task_id << ": " + << status.ToString(); + const auto &msg = stream.str(); + RAY_LOG(DEBUG) << msg; + // NOTE: We'll escalate the graceful shutdown to SIGKILL which is done by the + // timer above + } else { + reply->set_attempt_succeeded(cancel_task_reply.attempt_succeeded()); + reply->set_requested_task_running(cancel_task_reply.requested_task_running()); + } + send_reply_callback(Status::OK(), nullptr, nullptr); + }); +} + } // namespace ray::raylet diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 701105b535a8..aae14ba64cdd 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -304,6 +304,10 @@ class NodeManager : public rpc::NodeManagerServiceHandler, rpc::DrainRayletReply *reply, rpc::SendReplyCallback send_reply_callback) override; + void HandleCancelLocalTask(rpc::CancelLocalTaskRequest request, + rpc::CancelLocalTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + private: FRIEND_TEST(NodeManagerStaticTest, TestHandleReportWorkerBacklog); diff --git a/src/ray/raylet_rpc_client/raylet_client.cc b/src/ray/raylet_rpc_client/raylet_client.cc index 1a937c51d521..73e65c5536eb 100644 --- a/src/ray/raylet_rpc_client/raylet_client.cc +++ b/src/ray/raylet_rpc_client/raylet_client.cc @@ -491,5 +491,17 @@ void RayletClient::GetWorkerPIDs( timeout_ms); } +void RayletClient::CancelLocalTask( + const rpc::CancelLocalTaskRequest &request, + const rpc::ClientCallback &callback) { + INVOKE_RETRYABLE_RPC_CALL(retryable_grpc_client_, + NodeManagerService, + CancelLocalTask, + request, + callback, + grpc_client_, + /*method_timeout_ms*/ -1); +} + } // namespace rpc } // namespace ray diff --git a/src/ray/raylet_rpc_client/raylet_client.h b/src/ray/raylet_rpc_client/raylet_client.h index e5273c8ce36c..eef1d5101c5f 100644 --- a/src/ray/raylet_rpc_client/raylet_client.h +++ b/src/ray/raylet_rpc_client/raylet_client.h @@ -170,6 +170,10 @@ class RayletClient : public RayletClientInterface { void GetWorkerPIDs(const gcs::OptionalItemCallback> &callback, int64_t timeout_ms); + void CancelLocalTask( + const rpc::CancelLocalTaskRequest &request, + const rpc::ClientCallback &callback) override; + protected: /// gRPC client to the NodeManagerService. std::shared_ptr> grpc_client_; diff --git a/src/ray/raylet_rpc_client/raylet_client_interface.h b/src/ray/raylet_rpc_client/raylet_client_interface.h index 713c69c76c9b..9130072ea616 100644 --- a/src/ray/raylet_rpc_client/raylet_client_interface.h +++ b/src/ray/raylet_rpc_client/raylet_client_interface.h @@ -213,6 +213,10 @@ class RayletClientInterface { virtual int64_t GetPinsInFlight() const = 0; + virtual void CancelLocalTask( + const rpc::CancelLocalTaskRequest &request, + const rpc::ClientCallback &callback) = 0; + virtual ~RayletClientInterface() = default; }; diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index fba7780afc69..2a38a456f770 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -59,7 +59,8 @@ class ServerCallFactory; RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(GetWorkerFailureCause) \ RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(RegisterMutableObject) \ RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(PushMutableObject) \ - RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(GetWorkerPIDs) + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(GetWorkerPIDs) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(CancelLocalTask) /// Interface of the `NodeManagerService`, see `src/ray/protobuf/node_manager.proto`. class NodeManagerServiceHandler { @@ -187,6 +188,10 @@ class NodeManagerServiceHandler { virtual void HandleGetWorkerPIDs(GetWorkerPIDsRequest request, GetWorkerPIDsReply *reply, SendReplyCallback send_reply_callback) = 0; + + virtual void HandleCancelLocalTask(CancelLocalTaskRequest request, + CancelLocalTaskReply *reply, + SendReplyCallback send_reply_callback) = 0; }; /// The `GrpcService` for `NodeManagerService`. From 0a630a71023f715ff0aeff1a6473dbe00aba9bc8 Mon Sep 17 00:00:00 2001 From: joshlee Date: Wed, 22 Oct 2025 21:53:36 +0000 Subject: [PATCH 02/18] Addressing comments Signed-off-by: joshlee --- .../task_submission/actor_task_submitter.h | 2 +- src/ray/raylet/node_manager.cc | 16 +++++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.h b/src/ray/core_worker/task_submission/actor_task_submitter.h index a2089d5376d5..9acc98a6ac2c 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.h +++ b/src/ray/core_worker/task_submission/actor_task_submitter.h @@ -306,7 +306,7 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface { bool preempted_ = false; /// The RPC client address of the actor. std::optional client_address_; - /// The local raylet addres of the actor. + /// The local raylet address of the actor. std::optional raylet_address_; /// The intended worker ID of the actor. std::string worker_id_; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 6537fed35a98..68cc634ffe66 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -3394,7 +3394,7 @@ void NodeManager::HandleCancelLocalTask(rpc::CancelLocalTaskRequest request, << "ms, force killing with SIGKILL."; DestroyWorker(current_worker, rpc::WorkerExitType::INTENDED_SYSTEM_EXIT, - "Actor killed by GCS", + "Force-killed by ray.cancel(force=True)", /*force=*/true); } reply->set_attempt_succeeded(true); @@ -3409,18 +3409,16 @@ void NodeManager::HandleCancelLocalTask(rpc::CancelLocalTaskRequest request, [task_id = request.intended_task_id(), timer, reply, send_reply_callback]( const ray::Status &status, const rpc::CancelTaskReply &cancel_task_reply) { if (!status.ok()) { - std::ostringstream stream; - stream << "CancelTask RPC failed for task " << task_id << ": " - << status.ToString(); - const auto &msg = stream.str(); - RAY_LOG(DEBUG) << msg; + RAY_LOG(DEBUG) << "CancelTask RPC failed for task " << task_id << ": " + << status.ToString(); // NOTE: We'll escalate the graceful shutdown to SIGKILL which is done by the // timer above - } else { - reply->set_attempt_succeeded(cancel_task_reply.attempt_succeeded()); - reply->set_requested_task_running(cancel_task_reply.requested_task_running()); + return; } + reply->set_attempt_succeeded(cancel_task_reply.attempt_succeeded()); + reply->set_requested_task_running(cancel_task_reply.requested_task_running()); send_reply_callback(Status::OK(), nullptr, nullptr); + timer->cancel(); }); } From 8ae4e3aa9e1589050474e680746029d29047a305 Mon Sep 17 00:00:00 2001 From: joshlee Date: Wed, 22 Oct 2025 22:12:49 +0000 Subject: [PATCH 03/18] clean up and cpp test failures Signed-off-by: joshlee --- .../tests/test_core_worker_fault_tolerance.py | 50 ---------------- .../ray/tests/test_raylet_fault_tolerance.py | 59 ++++++++++++++++++- src/mock/ray/raylet_client/raylet_client.h | 5 ++ .../raylet_rpc_client/fake_raylet_client.h | 3 + 4 files changed, 66 insertions(+), 51 deletions(-) diff --git a/python/ray/tests/test_core_worker_fault_tolerance.py b/python/ray/tests/test_core_worker_fault_tolerance.py index ebc4f25e88ad..6ab8cf9ba5de 100644 --- a/python/ray/tests/test_core_worker_fault_tolerance.py +++ b/python/ray/tests/test_core_worker_fault_tolerance.py @@ -1,4 +1,3 @@ -import os import sys import numpy as np @@ -10,8 +9,6 @@ from ray.exceptions import GetTimeoutError, TaskCancelledError from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy -import psutil - @pytest.mark.parametrize( "allow_out_of_order_execution", @@ -210,52 +207,5 @@ def remote_wait(sg): ray.get(inner, timeout=10) -@pytest.fixture -def inject_cancel_local_task_rpc_failure(monkeypatch, request): - deterministic_failure = request.param - monkeypatch.setenv( - "RAY_testing_rpc_failure", - "NodeManagerService.grpc_client.CancelLocalTask=1:" - + ("100:0" if deterministic_failure == "request" else "0:100"), - ) - - -@pytest.mark.parametrize( - "inject_cancel_local_task_rpc_failure", ["request", "response"], indirect=True -) -@pytest.mark.parametrize("force_kill", [True, False]) -def test_cancel_local_task_rpc_retry_and_idempotency( - inject_cancel_local_task_rpc_failure, force_kill, shutdown_only -): - ray.init(num_cpus=2) - signaler = SignalActor.remote() - - @ray.remote(num_cpus=1) - def get_pid(): - return os.getpid() - - @ray.remote(num_cpus=1) - def blocking_task(): - return ray.get(signaler.wait.remote()) - - worker_pid = ray.get(get_pid.remote()) - - blocking_ref = blocking_task.remote() - - with pytest.raises(GetTimeoutError): - ray.get(blocking_ref, timeout=1) - - ray.cancel(blocking_ref, force=force_kill) - - with pytest.raises(TaskCancelledError): - ray.get(blocking_ref, timeout=10) - if force_kill: - - def verify_process_killed(): - return not psutil.pid_exists(worker_pid) - - wait_for_condition(verify_process_killed, timeout=30) - - if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) diff --git a/python/ray/tests/test_raylet_fault_tolerance.py b/python/ray/tests/test_raylet_fault_tolerance.py index 5aa7d5f7f7e0..7c797673b98d 100644 --- a/python/ray/tests/test_raylet_fault_tolerance.py +++ b/python/ray/tests/test_raylet_fault_tolerance.py @@ -1,16 +1,20 @@ +import os import sys import pytest import ray -from ray._private.test_utils import wait_for_condition +from ray._common.test_utils import SignalActor, wait_for_condition from ray.core.generated import autoscaler_pb2 +from ray.exceptions import GetTimeoutError, TaskCancelledError from ray.util.placement_group import placement_group, remove_placement_group from ray.util.scheduling_strategies import ( NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy, ) +import psutil + @pytest.mark.parametrize("deterministic_failure", ["request", "response"]) def test_request_worker_lease_idempotent( @@ -138,5 +142,58 @@ def task(): assert result == "success" +@pytest.fixture +def inject_cancel_local_task_rpc_failure(monkeypatch, request): + deterministic_failure = request.param + monkeypatch.setenv( + "RAY_testing_rpc_failure", + "NodeManagerService.grpc_client.CancelLocalTask=1:" + + ("100:0" if deterministic_failure == "request" else "0:100"), + ) + + +@pytest.mark.parametrize( + "inject_cancel_local_task_rpc_failure", ["request", "response"], indirect=True +) +@pytest.mark.parametrize("force_kill", [True, False]) +def test_cancel_local_task_rpc_retry_and_idempotency( + inject_cancel_local_task_rpc_failure, force_kill, shutdown_only +): + """Test that CancelLocalTask RPC retries work correctly. + + Verify that the RPC is idempotent when network failures occur. + When force_kill=True, verify the worker process is actually killed using psutil. + """ + ray.init(num_cpus=2) + signaler = SignalActor.remote() + + @ray.remote(num_cpus=1) + def get_pid(): + return os.getpid() + + @ray.remote(num_cpus=1) + def blocking_task(): + return ray.get(signaler.wait.remote()) + + worker_pid = ray.get(get_pid.remote()) + + blocking_ref = blocking_task.remote() + + with pytest.raises(GetTimeoutError): + ray.get(blocking_ref, timeout=1) + + ray.cancel(blocking_ref, force=force_kill) + + with pytest.raises(TaskCancelledError): + ray.get(blocking_ref, timeout=10) + + if force_kill: + + def verify_process_killed(): + return not psutil.pid_exists(worker_pid) + + wait_for_condition(verify_process_killed, timeout=30) + + if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__])) diff --git a/src/mock/ray/raylet_client/raylet_client.h b/src/mock/ray/raylet_client/raylet_client.h index 9a2c2d06b8b9..af12386de43c 100644 --- a/src/mock/ray/raylet_client/raylet_client.h +++ b/src/mock/ray/raylet_client/raylet_client.h @@ -151,6 +151,11 @@ class MockRayletClientInterface : public RayletClientInterface { (const rpc::ClientCallback &callback), (override)); MOCK_METHOD(int64_t, GetPinsInFlight, (), (const, override)); + MOCK_METHOD(void, + CancelLocalTask, + (const rpc::CancelLocalTaskRequest &request, + const rpc::ClientCallback &callback), + (override)); }; } // namespace ray diff --git a/src/ray/raylet_rpc_client/fake_raylet_client.h b/src/ray/raylet_rpc_client/fake_raylet_client.h index ff4fdc2e227f..db25e4b9193d 100644 --- a/src/ray/raylet_rpc_client/fake_raylet_client.h +++ b/src/ray/raylet_rpc_client/fake_raylet_client.h @@ -285,6 +285,9 @@ class FakeRayletClient : public RayletClientInterface { int64_t GetPinsInFlight() const override { return 0; } + void CancelLocalTask(const CancelLocalTaskRequest &request, + const ClientCallback &callback) override {} + int num_workers_requested = 0; int num_workers_returned = 0; int num_workers_disconnected = 0; From a733422e7444be1cbe7abe4bdf6c9fc8a889db26 Mon Sep 17 00:00:00 2001 From: joshlee Date: Thu, 23 Oct 2025 20:39:19 +0000 Subject: [PATCH 04/18] Addressing comments Signed-off-by: joshlee --- src/ray/core_worker/core_worker_process.cc | 1 + .../task_submission/actor_task_submitter.cc | 64 +++++++-- .../task_submission/actor_task_submitter.h | 2 - .../task_submission/normal_task_submitter.cc | 132 +++++++++++------- .../task_submission/normal_task_submitter.h | 15 +- .../tests/normal_task_submitter_test.cc | 1 + src/ray/core_worker/tests/core_worker_test.cc | 1 + src/ray/protobuf/node_manager.proto | 4 +- src/ray/raylet/node_manager.cc | 21 +-- 9 files changed, 155 insertions(+), 86 deletions(-) diff --git a/src/ray/core_worker/core_worker_process.cc b/src/ray/core_worker/core_worker_process.cc index d5e886662b9c..9f155cbf9882 100644 --- a/src/ray/core_worker/core_worker_process.cc +++ b/src/ray/core_worker/core_worker_process.cc @@ -538,6 +538,7 @@ std::shared_ptr CoreWorkerProcessImpl::CreateCoreWorker( local_raylet_rpc_client, core_worker_client_pool, raylet_client_pool, + gcs_client, std::move(lease_policy), memory_store, *task_manager, diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.cc b/src/ray/core_worker/task_submission/actor_task_submitter.cc index b8589ef28196..e4d3413997ba 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.cc +++ b/src/ray/core_worker/task_submission/actor_task_submitter.cc @@ -272,7 +272,6 @@ void ActorTaskSubmitter::CancelDependencyResolution(const TaskID &task_id) { void ActorTaskSubmitter::DisconnectRpcClient(ClientQueue &queue) { queue.client_address_ = std::nullopt; - queue.raylet_address_ = std::nullopt; // If the actor on the worker is dead, the worker is also dead. core_worker_client_pool_.Disconnect(WorkerID::FromBinary(queue.worker_id_)); queue.worker_id_.clear(); @@ -337,12 +336,6 @@ void ActorTaskSubmitter::ConnectActor(const ActorID &actor_id, // So new RPCs go out with the right intended worker id to the right address. queue->second.worker_id_ = address.worker_id(); queue->second.client_address_ = address; - NodeID node_id = NodeID::FromBinary(address.node_id()); - - auto node_info = gcs_client_->Nodes().GetNodeAddressAndLiveness(node_id); - RAY_CHECK(node_info != nullptr); - queue->second.raylet_address_ = rpc::RayletClientPool::GenerateRayletAddress( - node_id, node_info->node_manager_address(), node_info->node_manager_port()); SendPendingTasks(actor_id); } @@ -1004,6 +997,7 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) // If there's no client, it means actor is not created yet. // Retry in 1 second. + rpc::Address client_address; { absl::MutexLock lock(&mu_); RAY_LOG(DEBUG).WithField(task_id) << "Task was sent to an actor. Send a cancel RPC."; @@ -1013,20 +1007,36 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) RetryCancelTask(task_spec, recursive, 1000); return; } + client_address = queue->second.client_address_.value(); + } + + const auto node_id = NodeID::FromBinary(client_address.node_id()); + const auto executor_worker_id = client_address.worker_id(); + + auto do_cancel_local_task = [this, + task_spec = std::move(task_spec), + task_id, + force_kill, + recursive, + executor_worker_id = std::move(executor_worker_id)]( + const rpc::GcsNodeInfo &node_info) mutable { + rpc::Address raylet_address; + raylet_address.set_node_id(node_info.node_id()); + raylet_address.set_ip_address(node_info.node_manager_address()); + raylet_address.set_port(node_info.node_manager_port()); rpc::CancelLocalTaskRequest request; request.set_intended_task_id(task_spec.TaskIdBinary()); request.set_force_kill(force_kill); request.set_recursive(recursive); request.set_caller_worker_id(task_spec.CallerWorkerIdBinary()); - request.set_executor_worker_address( - queue->second.client_address_->SerializeAsString()); - auto raylet_client = - raylet_client_pool_.GetOrConnectByAddress(*queue->second.raylet_address_); + request.set_executor_worker_id(executor_worker_id); + + auto raylet_client = raylet_client_pool_.GetOrConnectByAddress(raylet_address); raylet_client->CancelLocalTask( request, [this, task_spec = std::move(task_spec), recursive, task_id]( - const Status &status, const rpc::CancelLocalTaskReply &reply) { + const Status &status, const rpc::CancelLocalTaskReply &reply) mutable { RAY_LOG(DEBUG).WithField(task_spec.TaskId()) << "CancelTask RPC response received with status " << status.ToString(); @@ -1040,10 +1050,38 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) } if (!reply.attempt_succeeded()) { - RetryCancelTask(task_spec, recursive, 2000); + RetryCancelTask(std::move(task_spec), recursive, 2000); } }); + }; + + // Check GCS node cache. If node info is not in the cache, query the GCS instead. + auto *node_info = gcs_client_->Nodes().Get(node_id, /*filter_dead_nodes=*/false); + if (node_info == nullptr) { + gcs_client_->Nodes().AsyncGetAll( + [do_cancel_local_task = std::move(do_cancel_local_task), node_id]( + const Status &status, std::vector &&nodes) mutable { + if (!status.ok()) { + RAY_LOG(INFO) << "Failed to get node info from GCS"; + return; + } + if (nodes.empty() || nodes[0].state() != rpc::GcsNodeInfo::ALIVE) { + RAY_LOG(INFO).WithField(node_id) + << "Not sending CancelLocalTask because node is dead"; + return; + } + do_cancel_local_task(nodes[0]); + }, + -1, + {node_id}); + return; + } + if (node_info->state() == rpc::GcsNodeInfo::DEAD) { + RAY_LOG(INFO).WithField(node_id) + << "Not sending CancelLocalTask because node is dead"; + return; } + do_cancel_local_task(*node_info); } bool ActorTaskSubmitter::QueueGeneratorForResubmit(const TaskSpecification &spec) { diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.h b/src/ray/core_worker/task_submission/actor_task_submitter.h index 9acc98a6ac2c..300ee55d50f9 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.h +++ b/src/ray/core_worker/task_submission/actor_task_submitter.h @@ -306,8 +306,6 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface { bool preempted_ = false; /// The RPC client address of the actor. std::optional client_address_; - /// The local raylet address of the actor. - std::optional raylet_address_; /// The intended worker ID of the actor. std::string worker_id_; /// The actor is out of scope but the death info is not published diff --git a/src/ray/core_worker/task_submission/normal_task_submitter.cc b/src/ray/core_worker/task_submission/normal_task_submitter.cc index 77eb3da2c882..8ead50e0aa45 100644 --- a/src/ray/core_worker/task_submission/normal_task_submitter.cc +++ b/src/ray/core_worker/task_submission/normal_task_submitter.cc @@ -179,8 +179,7 @@ void NormalTaskSubmitter::OnWorkerIdle( task_spec.GetMutableMessage().set_lease_grant_timestamp_ms(current_sys_time_ms()); task_spec.EmitTaskMetrics(); - executing_tasks_.emplace(task_spec.TaskId(), - std::make_pair(addr, lease_entry.addr)); + executing_tasks_.emplace(task_spec.TaskId(), addr); PushNormalTask( addr, client, scheduling_key, std::move(task_spec), assigned_resources); } @@ -666,7 +665,6 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, SchedulingKey scheduling_key(task_spec.GetSchedulingClass(), task_spec.GetDependencyIds(), task_spec.GetRuntimeEnvHash()); - std::shared_ptr raylet_client = nullptr; rpc::Address executor_worker_address; { absl::MutexLock lock(&mu_); @@ -723,56 +721,92 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, } return; } - // Looks for an RPC handle for the worker executing the task. - raylet_client = - raylet_client_pool_->GetOrConnectByAddress(rpc_client_address->second.second); - executor_worker_address = rpc_client_address->second.first; + executor_worker_address = rpc_client_address->second; } - auto request = rpc::CancelLocalTaskRequest(); - request.set_intended_task_id(task_spec.TaskIdBinary()); - request.set_force_kill(force_kill); - request.set_recursive(recursive); - request.set_caller_worker_id(task_spec.CallerWorkerIdBinary()); - request.set_executor_worker_address(executor_worker_address.SerializeAsString()); - raylet_client->CancelLocalTask( - request, - [this, - task_spec = std::move(task_spec), - scheduling_key = std::move(scheduling_key), - force_kill, - recursive](const Status &status, const rpc::CancelLocalTaskReply &reply) mutable { - absl::MutexLock lock(&mu_); - RAY_LOG(DEBUG) << "CancelTask RPC response received for " << task_spec.TaskId() - << " with status " << status.ToString(); - cancelled_tasks_.erase(task_spec.TaskId()); - - // Retry is not attempted if !status.ok() because force-kill may kill the worker - // before the reply is sent. - if (!status.ok()) { - RAY_LOG(DEBUG) << "Failed to cancel a task due to " << status.ToString(); - return; - } - - if (!reply.attempt_succeeded()) { - if (reply.requested_task_running()) { - // Retry cancel request if failed. - if (cancel_retry_timer_.expiry().time_since_epoch() <= - std::chrono::high_resolution_clock::now().time_since_epoch()) { - cancel_retry_timer_.expires_after(boost::asio::chrono::milliseconds( - RayConfig::instance().cancellation_retry_ms())); + const auto node_id = NodeID::FromBinary(executor_worker_address.node_id()); + const auto executor_worker_id = + WorkerID::FromBinary(executor_worker_address.worker_id()); + + auto do_cancel_local_task = [this, + task_spec = std::move(task_spec), + scheduling_key = std::move(scheduling_key), + executor_worker_address, + executor_worker_id, + force_kill, + recursive](const rpc::GcsNodeInfo &node_info) mutable { + rpc::Address raylet_address; + raylet_address.set_node_id(node_info.node_id()); + raylet_address.set_ip_address(node_info.node_manager_address()); + raylet_address.set_port(node_info.node_manager_port()); + + rpc::CancelLocalTaskRequest request; + request.set_intended_task_id(task_spec.TaskIdBinary()); + request.set_force_kill(force_kill); + request.set_recursive(recursive); + request.set_caller_worker_id(task_spec.CallerWorkerIdBinary()); + request.set_executor_worker_id(executor_worker_id.Binary()); + + auto raylet_client = raylet_client_pool_->GetOrConnectByAddress(raylet_address); + raylet_client->CancelLocalTask( + request, + [this, + task_spec = std::move(task_spec), + scheduling_key = std::move(scheduling_key), + force_kill, + recursive](const Status &status, + const rpc::CancelLocalTaskReply &reply) mutable { + absl::MutexLock lock(&mu_); + RAY_LOG(DEBUG) << "CancelTask RPC response received for " << task_spec.TaskId() + << " with status " << status.ToString(); + cancelled_tasks_.erase(task_spec.TaskId()); + + if (!reply.attempt_succeeded()) { + if (reply.requested_task_running()) { + if (cancel_retry_timer_.expiry().time_since_epoch() <= + std::chrono::high_resolution_clock::now().time_since_epoch()) { + cancel_retry_timer_.expires_after(boost::asio::chrono::milliseconds( + RayConfig::instance().cancellation_retry_ms())); + } + cancel_retry_timer_.async_wait(boost::bind(&NormalTaskSubmitter::CancelTask, + this, + std::move(task_spec), + force_kill, + recursive)); + } else { + RAY_LOG(DEBUG) << "Attempt to cancel task " << task_spec.TaskId() + << " in a worker that doesn't have this task."; } - cancel_retry_timer_.async_wait(boost::bind(&NormalTaskSubmitter::CancelTask, - this, - std::move(task_spec), - force_kill, - recursive)); - } else { - RAY_LOG(DEBUG) << "Attempt to cancel task " << task_spec.TaskId() - << " in a worker that doesn't have this task."; } - } - }); + }); + }; + + auto *node_info = gcs_client_->Nodes().Get(node_id, /*filter_dead_nodes=*/false); + if (node_info == nullptr) { + gcs_client_->Nodes().AsyncGetAll( + [do_cancel_local_task = std::move(do_cancel_local_task), node_id]( + const Status &status, std::vector &&nodes) mutable { + if (!status.ok()) { + RAY_LOG(INFO) << "Failed to get node info from GCS"; + return; + } + if (nodes.empty() || nodes[0].state() != rpc::GcsNodeInfo::ALIVE) { + RAY_LOG(INFO).WithField(node_id) + << "Not sending CancelLocalTask because node is dead"; + return; + } + do_cancel_local_task(nodes[0]); + }, + -1, + {node_id}); + return; + } + if (node_info->state() == rpc::GcsNodeInfo::DEAD) { + RAY_LOG(INFO).WithField(node_id) + << "Not sending CancelLocalTask because node is dead"; + return; + } + do_cancel_local_task(*node_info); } void NormalTaskSubmitter::CancelRemoteTask(const ObjectID &object_id, diff --git a/src/ray/core_worker/task_submission/normal_task_submitter.h b/src/ray/core_worker/task_submission/normal_task_submitter.h index 9ec27f9c45a8..c25aca0a127e 100644 --- a/src/ray/core_worker/task_submission/normal_task_submitter.h +++ b/src/ray/core_worker/task_submission/normal_task_submitter.h @@ -34,6 +34,11 @@ #include "ray/raylet_rpc_client/raylet_client_pool.h" namespace ray { + +namespace gcs { +class GcsClient; +} // namespace gcs + namespace core { // The task queues are keyed on resource shape & function descriptor @@ -85,6 +90,7 @@ class NormalTaskSubmitter { std::shared_ptr local_raylet_client, std::shared_ptr core_worker_client_pool, std::shared_ptr raylet_client_pool, + std::shared_ptr gcs_client, std::unique_ptr lease_policy, std::shared_ptr store, TaskManagerInterface &task_manager, @@ -99,6 +105,7 @@ class NormalTaskSubmitter { : rpc_address_(std::move(rpc_address)), local_raylet_client_(std::move(local_raylet_client)), raylet_client_pool_(std::move(raylet_client_pool)), + gcs_client_(std::move(gcs_client)), lease_policy_(std::move(lease_policy)), resolver_(*store, task_manager, *actor_creator, tensor_transport_getter), task_manager_(task_manager), @@ -242,6 +249,9 @@ class NormalTaskSubmitter { /// Raylet client pool for producing new clients to request leases from remote nodes. std::shared_ptr raylet_client_pool_; + /// GCS client for checking node liveness. + std::shared_ptr gcs_client_; + /// Provider of worker leasing decisions for the first lease request (not on /// spillback). std::unique_ptr lease_policy_; @@ -348,10 +358,7 @@ class NormalTaskSubmitter { absl::flat_hash_set cancelled_tasks_ ABSL_GUARDED_BY(mu_); // Keeps track of where currently executing tasks are being run. - // The first address is the executor, the second address is the local raylet of the - // executor. - absl::flat_hash_map> executing_tasks_ - ABSL_GUARDED_BY(mu_); + absl::flat_hash_map executing_tasks_ ABSL_GUARDED_BY(mu_); // Generators that are currently running and need to be resubmitted. absl::flat_hash_set generators_to_resubmit_ ABSL_GUARDED_BY(mu_); diff --git a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc index 99c331c5ff7d..62e387284dab 100644 --- a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc +++ b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc @@ -484,6 +484,7 @@ class NormalTaskSubmitterTest : public testing::Test { raylet_client, client_pool, raylet_client_pool, + /*gcs_client=*/nullptr, std::move(lease_policy), store, *task_manager, diff --git a/src/ray/core_worker/tests/core_worker_test.cc b/src/ray/core_worker/tests/core_worker_test.cc index f07d9da02944..69753e4bd64f 100644 --- a/src/ray/core_worker/tests/core_worker_test.cc +++ b/src/ray/core_worker/tests/core_worker_test.cc @@ -212,6 +212,7 @@ class CoreWorkerTest : public ::testing::Test { fake_local_raylet_rpc_client, core_worker_client_pool, raylet_client_pool, + /*gcs_client=*/nullptr, std::move(lease_policy), memory_store_, *task_manager_, diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto index 83f1b85200f6..37c5ee63bf7d 100644 --- a/src/ray/protobuf/node_manager.proto +++ b/src/ray/protobuf/node_manager.proto @@ -415,8 +415,8 @@ message CancelLocalTaskRequest { bool recursive = 3; // The worker ID of the caller. bytes caller_worker_id = 4; - // The worker address of the executor. - bytes executor_worker_address = 5; + // The worker ID of the executor. + bytes executor_worker_id = 5; } message CancelLocalTaskReply { diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 68cc634ffe66..ba7803489c03 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -3342,15 +3342,9 @@ std::unique_ptr NodeManager::CreateRuntimeEnvAgentManager( void NodeManager::HandleCancelLocalTask(rpc::CancelLocalTaskRequest request, rpc::CancelLocalTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { - rpc::Address executor_address; - if (!executor_address.ParseFromString(request.executor_worker_address())) { - send_reply_callback( - Status::Invalid("Failed to parse executor worker address"), nullptr, nullptr); - return; - } + auto executor_worker_id = WorkerID::FromBinary(request.executor_worker_id()); - auto worker = worker_pool_.GetRegisteredWorker( - WorkerID::FromBinary(executor_address.worker_id())); + auto worker = worker_pool_.GetRegisteredWorker(executor_worker_id); // If the worker is not registered, then it must have already been killed if (!worker || worker->IsDead()) { reply->set_attempt_succeeded(true); @@ -3371,14 +3365,9 @@ void NodeManager::HandleCancelLocalTask(rpc::CancelLocalTaskRequest request, cancel_task_request, [reply, send_reply_callback](const Status &status, const rpc::CancelTaskReply &cancel_task_reply) { - if (!status.ok()) { - send_reply_callback( - Status::Invalid("Failed to cancel task"), nullptr, nullptr); - } else { - reply->set_attempt_succeeded(cancel_task_reply.attempt_succeeded()); - reply->set_requested_task_running(cancel_task_reply.requested_task_running()); - send_reply_callback(Status::OK(), nullptr, nullptr); - } + reply->set_attempt_succeeded(cancel_task_reply.attempt_succeeded()); + reply->set_requested_task_running(cancel_task_reply.requested_task_running()); + send_reply_callback(Status::OK(), nullptr, nullptr); }); return; } From 8a2e428dd91c982c4a45a05d48699718fbcf9982 Mon Sep 17 00:00:00 2001 From: joshlee Date: Thu, 23 Oct 2025 22:07:40 +0000 Subject: [PATCH 05/18] Fix broken cpp tests Signed-off-by: joshlee --- .../core_worker/task_submission/BUILD.bazel | 2 + .../task_submission/actor_task_submitter.cc | 1 + .../task_submission/tests/BUILD.bazel | 2 + .../tests/actor_task_submitter_test.cc | 11 +++ .../tests/direct_actor_transport_test.cc | 8 ++ .../tests/normal_task_submitter_test.cc | 85 ++++++++++++++++--- src/ray/core_worker/tests/core_worker_test.cc | 4 +- src/ray/protobuf/core_worker.proto | 4 +- 8 files changed, 100 insertions(+), 17 deletions(-) diff --git a/src/ray/core_worker/task_submission/BUILD.bazel b/src/ray/core_worker/task_submission/BUILD.bazel index 8d6ac2dc9a1f..3934174949ef 100644 --- a/src/ray/core_worker/task_submission/BUILD.bazel +++ b/src/ray/core_worker/task_submission/BUILD.bazel @@ -99,7 +99,9 @@ ray_cc_library( "//src/ray/core_worker:memory_store", "//src/ray/core_worker:task_manager_interface", "//src/ray/core_worker_rpc_client:core_worker_client_pool", + "//src/ray/gcs_rpc_client:gcs_client", "//src/ray/raylet_rpc_client:raylet_client_interface", + "//src/ray/raylet_rpc_client:raylet_client_pool", "//src/ray/util:time", "@com_google_absl//absl/base:core_headers", ], diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.cc b/src/ray/core_worker/task_submission/actor_task_submitter.cc index e4d3413997ba..cac3304594ed 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.cc +++ b/src/ray/core_worker/task_submission/actor_task_submitter.cc @@ -336,6 +336,7 @@ void ActorTaskSubmitter::ConnectActor(const ActorID &actor_id, // So new RPCs go out with the right intended worker id to the right address. queue->second.worker_id_ = address.worker_id(); queue->second.client_address_ = address; + SendPendingTasks(actor_id); } diff --git a/src/ray/core_worker/task_submission/tests/BUILD.bazel b/src/ray/core_worker/task_submission/tests/BUILD.bazel index e00d9cdf4714..e41d110ff5ea 100644 --- a/src/ray/core_worker/task_submission/tests/BUILD.bazel +++ b/src/ray/core_worker/task_submission/tests/BUILD.bazel @@ -35,6 +35,7 @@ ray_cc_test( deps = [ "//:ray_mock", "//src/ray/core_worker/task_submission:actor_task_submitter", + "//src/ray/raylet_rpc_client:raylet_client_pool", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", ], @@ -54,6 +55,7 @@ ray_cc_test( "//src/ray/core_worker:reference_counter", "//src/ray/core_worker:task_manager", "//src/ray/core_worker_rpc_client:fake_core_worker_client", + "//src/ray/raylet_rpc_client:raylet_client_pool", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", ], diff --git a/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc b/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc index e1536ef89785..de5708ee59fd 100644 --- a/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc +++ b/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc @@ -22,9 +22,11 @@ #include "gtest/gtest.h" #include "mock/ray/core_worker/reference_counter.h" #include "mock/ray/core_worker/task_manager_interface.h" +#include "mock/ray/gcs_client/gcs_client.h" #include "ray/common/test_utils.h" #include "ray/core_worker/fake_actor_creator.h" #include "ray/core_worker_rpc_client/fake_core_worker_client.h" +#include "ray/raylet_rpc_client/raylet_client_pool.h" namespace ray::core { @@ -88,13 +90,20 @@ class ActorTaskSubmitterTest : public ::testing::TestWithParam { ActorTaskSubmitterTest() : client_pool_(std::make_shared( [&](const rpc::Address &addr) { return worker_client_; })), + raylet_client_pool_(std::make_shared( + [](const rpc::Address &) -> std::shared_ptr { + return nullptr; + })), worker_client_(std::make_shared()), store_(std::make_shared(io_context)), task_manager_(std::make_shared()), + mock_gcs_client_(std::make_shared()), io_work(io_context.get_executor()), reference_counter_(std::make_shared()), submitter_( *client_pool_, + *raylet_client_pool_, + mock_gcs_client_, *store_, *task_manager_, actor_creator_, @@ -110,9 +119,11 @@ class ActorTaskSubmitterTest : public ::testing::TestWithParam { int64_t last_queue_warning_ = 0; FakeActorCreator actor_creator_; std::shared_ptr client_pool_; + std::shared_ptr raylet_client_pool_; std::shared_ptr worker_client_; std::shared_ptr store_; std::shared_ptr task_manager_; + std::shared_ptr mock_gcs_client_; instrumented_io_context io_context; boost::asio::executor_work_guard io_work; std::shared_ptr reference_counter_; diff --git a/src/ray/core_worker/task_submission/tests/direct_actor_transport_test.cc b/src/ray/core_worker/task_submission/tests/direct_actor_transport_test.cc index 75e1a8034180..10c0c4512755 100644 --- a/src/ray/core_worker/task_submission/tests/direct_actor_transport_test.cc +++ b/src/ray/core_worker/task_submission/tests/direct_actor_transport_test.cc @@ -22,6 +22,7 @@ #include "mock/ray/gcs_client/gcs_client.h" #include "ray/core_worker/actor_creator.h" #include "ray/core_worker/task_submission/actor_task_submitter.h" +#include "ray/raylet_rpc_client/raylet_client_pool.h" namespace ray { namespace core { @@ -38,10 +39,16 @@ class DirectTaskTransportTest : public ::testing::Test { task_manager = std::make_shared(); client_pool = std::make_shared( [&](const rpc::Address &) { return nullptr; }); + raylet_client_pool = std::make_shared( + [](const rpc::Address &) -> std::shared_ptr { + return nullptr; + }); memory_store = DefaultCoreWorkerMemoryStoreWithThread::Create(); reference_counter = std::make_shared(); actor_task_submitter = std::make_unique( *client_pool, + *raylet_client_pool, + gcs_client, *memory_store, *task_manager, *actor_creator, @@ -81,6 +88,7 @@ class DirectTaskTransportTest : public ::testing::Test { boost::asio::executor_work_guard io_work; std::unique_ptr actor_task_submitter; std::shared_ptr client_pool; + std::shared_ptr raylet_client_pool; std::unique_ptr memory_store; std::shared_ptr task_manager; std::unique_ptr actor_creator; diff --git a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc index 62e387284dab..bd1ffe542d7a 100644 --- a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc +++ b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc @@ -25,6 +25,7 @@ #include "gtest/gtest.h" #include "mock/ray/core_worker/memory_store.h" #include "mock/ray/core_worker/task_manager_interface.h" +#include "mock/ray/gcs_client/gcs_client.h" #include "ray/common/task/task_spec.h" #include "ray/common/task/task_util.h" #include "ray/common/test_utils.h" @@ -123,12 +124,6 @@ class MockWorkerClient : public rpc::FakeCoreWorkerClient { return true; } - void CancelTask(const rpc::CancelTaskRequest &request, - const rpc::ClientCallback &callback) override { - kill_requests.push_front(request); - cancel_callbacks.push_back(callback); - } - void ReplyCancelTask(Status status = Status::OK(), bool attempt_succeeded = true, bool requested_task_running = false) { @@ -386,6 +381,24 @@ class MockRayletClient : public rpc::FakeRayletClient { return GenericPopCallbackInLock(cancel_callbacks); } + void CancelLocalTask( + const rpc::CancelLocalTaskRequest &request, + const rpc::ClientCallback &callback) override { + cancel_local_task_requests.push_front(request); + cancel_local_task_callbacks.push_back(callback); + } + + void ReplyCancelLocalTask(Status status = Status::OK(), + bool attempt_succeeded = true, + bool requested_task_running = false) { + auto &callback = cancel_local_task_callbacks.front(); + rpc::CancelLocalTaskReply reply; + reply.set_attempt_succeeded(attempt_succeeded); + reply.set_requested_task_running(requested_task_running); + callback(status, std::move(reply)); + cancel_local_task_callbacks.pop_front(); + } + ~MockRayletClient() = default; // Protects all internal fields. @@ -404,6 +417,9 @@ class MockRayletClient : public rpc::FakeRayletClient { std::list> cancel_callbacks = {}; std::list> get_task_failure_cause_callbacks = {}; + std::list cancel_local_task_requests = {}; + std::list> cancel_local_task_callbacks = + {}; }; class MockLeasePolicy : public LeasePolicyInterface { @@ -449,7 +465,8 @@ class NormalTaskSubmitterTest : public testing::Test { task_manager(std::make_unique()), actor_creator(std::make_shared()), lease_policy(std::make_unique()), - lease_policy_ptr(lease_policy.get()) { + lease_policy_ptr(lease_policy.get()), + mock_gcs_client_(std::make_shared()) { address.set_node_id(local_node_id.Binary()); lease_policy_ptr->SetNodeID(local_node_id); } @@ -484,7 +501,7 @@ class NormalTaskSubmitterTest : public testing::Test { raylet_client, client_pool, raylet_client_pool, - /*gcs_client=*/nullptr, + mock_gcs_client_, std::move(lease_policy), store, *task_manager, @@ -511,6 +528,7 @@ class NormalTaskSubmitterTest : public testing::Test { // the submitter. std::unique_ptr lease_policy; MockLeasePolicy *lease_policy_ptr = nullptr; + std::shared_ptr mock_gcs_client_; instrumented_io_context io_context; }; @@ -644,6 +662,20 @@ TEST_F(NormalTaskSubmitterTest, TestCancellationWhileHandlingTaskFailure) { // the task cancellation races between ReplyPushTask and ReplyGetWorkerFailureCause. // For an example of a python integration test, see // https://github.com/ray-project/ray/blob/2b6807f4d9c4572e6309f57bc404aa641bc4b185/python/ray/tests/test_cancel.py#L35 + + // Set up GCS node mock to return node as alive + using testing::_; + using testing::Return; + + rpc::GcsNodeInfo node_info; + node_info.set_node_id(local_node_id.Binary()); + node_info.set_node_manager_address("127.0.0.1"); + node_info.set_node_manager_port(9999); + node_info.set_state(rpc::GcsNodeInfo::ALIVE); + + EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, Get(_, false)) + .WillRepeatedly(Return(&node_info)); + auto submitter = CreateNormalTaskSubmitter(std::make_shared(1)); @@ -1442,12 +1474,14 @@ void TestSchedulingKey(const std::shared_ptr store, auto actor_creator = std::make_shared(); auto lease_policy = std::make_unique(); lease_policy->SetNodeID(local_node_id); + auto mock_gcs_client = std::make_shared(); instrumented_io_context io_context; NormalTaskSubmitter submitter( address, raylet_client, client_pool, raylet_client_pool, + mock_gcs_client, std::move(lease_policy), store, *task_manager, @@ -1711,6 +1745,16 @@ TEST_F(NormalTaskSubmitterTest, TestWorkerLeaseTimeout) { } TEST_F(NormalTaskSubmitterTest, TestKillExecutingTask) { + rpc::GcsNodeInfo node_info; + node_info.set_node_id(local_node_id.Binary()); + node_info.set_node_manager_address("127.0.0.1"); + node_info.set_node_manager_port(9999); + node_info.set_state(rpc::GcsNodeInfo::ALIVE); + + EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, Get(local_node_id, false)) + .WillOnce(testing::Return(&node_info)) + .WillOnce(testing::Return(&node_info)); + auto submitter = CreateNormalTaskSubmitter(std::make_shared(1)); TaskSpecification task = BuildEmptyTaskSpec(); @@ -1720,7 +1764,8 @@ TEST_F(NormalTaskSubmitterTest, TestKillExecutingTask) { // Try force kill, exiting the worker submitter.CancelTask(task, true, false); - ASSERT_EQ(worker_client->kill_requests.front().intended_task_id(), task.TaskIdBinary()); + ASSERT_EQ(raylet_client->cancel_local_task_requests.front().intended_task_id(), + task.TaskIdBinary()); ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("workerdying"), true)); ASSERT_TRUE(raylet_client->ReplyGetWorkerFailureCause()); ASSERT_EQ(worker_client->callbacks.size(), 0); @@ -1738,7 +1783,8 @@ TEST_F(NormalTaskSubmitterTest, TestKillExecutingTask) { // Try non-force kill, worker returns normally submitter.CancelTask(task, false, false); ASSERT_TRUE(worker_client->ReplyPushTask()); - ASSERT_EQ(worker_client->kill_requests.front().intended_task_id(), task.TaskIdBinary()); + ASSERT_EQ(raylet_client->cancel_local_task_requests.front().intended_task_id(), + task.TaskIdBinary()); ASSERT_EQ(worker_client->callbacks.size(), 0); ASSERT_EQ(raylet_client->num_workers_returned, 1); ASSERT_EQ(raylet_client->num_workers_returned_exiting, 0); @@ -1818,6 +1864,17 @@ TEST_F(NormalTaskSubmitterTest, TestQueueGeneratorForResubmit) { TEST_F(NormalTaskSubmitterTest, TestCancelBeforeAfterQueueGeneratorForResubmit) { // Cancel -> failed queue generator for resubmit -> cancel reply -> successful queue for // resubmit -> push task reply -> honor the cancel not the queued resubmit. + + rpc::GcsNodeInfo node_info; + node_info.set_node_id(local_node_id.Binary()); + node_info.set_node_manager_address("127.0.0.1"); + node_info.set_node_manager_port(9999); + node_info.set_state(rpc::GcsNodeInfo::ALIVE); + + EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, Get(local_node_id, false)) + .WillOnce(testing::Return(&node_info)) + .WillOnce(testing::Return(&node_info)); + auto submitter = CreateNormalTaskSubmitter(std::make_shared(1)); TaskSpecification task = BuildEmptyTaskSpec(); @@ -1825,7 +1882,7 @@ TEST_F(NormalTaskSubmitterTest, TestCancelBeforeAfterQueueGeneratorForResubmit) ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, local_node_id)); submitter.CancelTask(task, /*force_kill=*/false, /*recursive=*/true); ASSERT_FALSE(submitter.QueueGeneratorForResubmit(task)); - worker_client->ReplyCancelTask(); + raylet_client->ReplyCancelLocalTask(); ASSERT_TRUE(submitter.QueueGeneratorForResubmit(task)); ASSERT_TRUE(worker_client->ReplyPushTask(Status::OK(), /*exit=*/false, @@ -1843,9 +1900,9 @@ TEST_F(NormalTaskSubmitterTest, TestCancelBeforeAfterQueueGeneratorForResubmit) ASSERT_TRUE(submitter.QueueGeneratorForResubmit(task2)); submitter.CancelTask(task2, /*force_kill=*/false, /*recursive=*/true); ASSERT_TRUE(worker_client->ReplyPushTask()); - worker_client->ReplyCancelTask(Status::OK(), - /*attempt_succeeded=*/true, - /*requested_task_running=*/false); + raylet_client->ReplyCancelLocalTask(Status::OK(), + /*attempt_succeeded=*/true, + /*requested_task_running=*/false); ASSERT_EQ(task_manager->num_tasks_complete, 1); ASSERT_EQ(task_manager->num_tasks_failed, 1); ASSERT_EQ(task_manager->num_generator_failed_and_resubmitted, 0); diff --git a/src/ray/core_worker/tests/core_worker_test.cc b/src/ray/core_worker/tests/core_worker_test.cc index 69753e4bd64f..3bdbaba792ba 100644 --- a/src/ray/core_worker/tests/core_worker_test.cc +++ b/src/ray/core_worker/tests/core_worker_test.cc @@ -212,7 +212,7 @@ class CoreWorkerTest : public ::testing::Test { fake_local_raylet_rpc_client, core_worker_client_pool, raylet_client_pool, - /*gcs_client=*/nullptr, + mock_gcs_client_, std::move(lease_policy), memory_store_, *task_manager_, @@ -227,6 +227,8 @@ class CoreWorkerTest : public ::testing::Test { auto actor_task_submitter = std::make_unique( *core_worker_client_pool, + *raylet_client_pool, + mock_gcs_client_, *memory_store_, *task_manager_, *actor_creator_, diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index aa0c05ce21f5..6590620eedc6 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -530,8 +530,8 @@ service CoreWorkerService { rpc KillActor(KillActorRequest) returns (KillActorReply); // Request from local raylet to executor worker to cancel a task. - // Failure: Idempotent, does not retry for network failures. However requests - // should only be sent via CancelLocalTask from the raylet which does implement retries + // Failure: Idempotent, does not retry. However requests should only be sent via + // CancelLocalTask from the raylet which does implement retries rpc CancelTask(CancelTaskRequest) returns (CancelTaskReply); // Request from a worker to the owner worker to issue a cancellation. From 7d4ab2e2bff872d81a21324c5b6474e1bb644548 Mon Sep 17 00:00:00 2001 From: joshlee Date: Fri, 7 Nov 2025 23:34:35 +0000 Subject: [PATCH 06/18] Clean up Signed-off-by: joshlee --- .../tests/test_core_worker_fault_tolerance.py | 2 +- src/mock/ray/core_worker/core_worker.h | 6 +++--- src/mock/ray/rpc/worker/core_worker_client.h | 6 +++--- src/ray/core_worker/core_worker.cc | 12 +++++++----- src/ray/core_worker/core_worker.h | 6 +++--- src/ray/core_worker/core_worker_rpc_proxy.h | 2 +- src/ray/core_worker/grpc_service.cc | 2 +- src/ray/core_worker/grpc_service.h | 6 +++--- .../task_submission/normal_task_submitter.cc | 14 +++++++------- .../task_submission/normal_task_submitter.h | 8 ++++---- .../core_worker_client.h | 2 +- .../core_worker_client_interface.h | 6 +++--- .../fake_core_worker_client.h | 5 +++-- src/ray/protobuf/core_worker.proto | 6 +++--- src/ray/raylet/node_manager.cc | 18 ++++++++++++------ 15 files changed, 55 insertions(+), 46 deletions(-) diff --git a/python/ray/tests/test_core_worker_fault_tolerance.py b/python/ray/tests/test_core_worker_fault_tolerance.py index e35cad21b67b..9681bb8c5443 100644 --- a/python/ray/tests/test_core_worker_fault_tolerance.py +++ b/python/ray/tests/test_core_worker_fault_tolerance.py @@ -170,7 +170,7 @@ def inject_cancel_remote_task_rpc_failure(monkeypatch, request): deterministic_failure = request.param monkeypatch.setenv( "RAY_testing_rpc_failure", - "CoreWorkerService.grpc_client.CancelRemoteTask=1:" + "CoreWorkerService.grpc_client.RequestOwnerToCancelTask=1:" + ("100:0" if deterministic_failure == "request" else "0:100"), ) diff --git a/src/mock/ray/core_worker/core_worker.h b/src/mock/ray/core_worker/core_worker.h index 563d7f3d3f6c..1ba83cb01a3d 100644 --- a/src/mock/ray/core_worker/core_worker.h +++ b/src/mock/ray/core_worker/core_worker.h @@ -89,9 +89,9 @@ class MockCoreWorker : public CoreWorker { rpc::SendReplyCallback send_reply_callback), (override)); MOCK_METHOD(void, - HandleCancelRemoteTask, - (rpc::CancelRemoteTaskRequest request, - rpc::CancelRemoteTaskReply *reply, + HandleRequestOwnerToCancelTask, + (rpc::RequestOwnerToCancelTaskRequest request, + rpc::RequestOwnerToCancelTaskReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); MOCK_METHOD(void, diff --git a/src/mock/ray/rpc/worker/core_worker_client.h b/src/mock/ray/rpc/worker/core_worker_client.h index cd293cebbd93..e537a35f7255 100644 --- a/src/mock/ray/rpc/worker/core_worker_client.h +++ b/src/mock/ray/rpc/worker/core_worker_client.h @@ -87,9 +87,9 @@ class MockCoreWorkerClientInterface : public CoreWorkerClientInterface { const ClientCallback &callback), (override)); MOCK_METHOD(void, - CancelRemoteTask, - (CancelRemoteTaskRequest && request, - const ClientCallback &callback), + RequestOwnerToCancelTask, + (RequestOwnerToCancelTaskRequest && request, + const ClientCallback &callback), (override)); MOCK_METHOD(void, GetCoreWorkerStats, diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 1421a32ffed6..90fe0e886859 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2439,12 +2439,13 @@ Status CoreWorker::CancelTask(const ObjectID &object_id, } if (obj_addr.SerializeAsString() != rpc_address_.SerializeAsString()) { - // We don't have CancelRemoteTask for actor_task_submitter_ + // We don't have RequestOwnerToCancelTask for actor_task_submitter_ // because it requires the same implementation. RAY_LOG(DEBUG).WithField(object_id) << "Request to cancel a task of object to an owner " << obj_addr.SerializeAsString(); - normal_task_submitter_->CancelRemoteTask(object_id, obj_addr, force_kill, recursive); + normal_task_submitter_->RequestOwnerToCancelTask( + object_id, obj_addr, force_kill, recursive); return Status::OK(); } @@ -3878,9 +3879,10 @@ void CoreWorker::ProcessSubscribeForRefRemoved( reference_counter_->SubscribeRefRemoved(object_id, contained_in_id, owner_address); } -void CoreWorker::HandleCancelRemoteTask(rpc::CancelRemoteTaskRequest request, - rpc::CancelRemoteTaskReply *reply, - rpc::SendReplyCallback send_reply_callback) { +void CoreWorker::HandleRequestOwnerToCancelTask( + rpc::RequestOwnerToCancelTaskRequest request, + rpc::RequestOwnerToCancelTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) { auto status = CancelTask(ObjectID::FromBinary(request.remote_object_id()), request.force_kill(), request.recursive()); diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index e2463a3e7076..2065524131fd 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -1208,9 +1208,9 @@ class CoreWorker { rpc::SendReplyCallback send_reply_callback); /// Implements gRPC server handler. - void HandleCancelRemoteTask(rpc::CancelRemoteTaskRequest request, - rpc::CancelRemoteTaskReply *reply, - rpc::SendReplyCallback send_reply_callback); + void HandleRequestOwnerToCancelTask(rpc::RequestOwnerToCancelTaskRequest request, + rpc::RequestOwnerToCancelTaskReply *reply, + rpc::SendReplyCallback send_reply_callback); /// Implements gRPC server handler. void HandlePlasmaObjectReady(rpc::PlasmaObjectReadyRequest request, diff --git a/src/ray/core_worker/core_worker_rpc_proxy.h b/src/ray/core_worker/core_worker_rpc_proxy.h index 48865d81b7b9..808da2cde210 100644 --- a/src/ray/core_worker/core_worker_rpc_proxy.h +++ b/src/ray/core_worker/core_worker_rpc_proxy.h @@ -56,7 +56,7 @@ class CoreWorkerServiceHandlerProxy : public rpc::CoreWorkerServiceHandler { RAY_CORE_WORKER_RPC_PROXY(ReportGeneratorItemReturns) RAY_CORE_WORKER_RPC_PROXY(KillActor) RAY_CORE_WORKER_RPC_PROXY(CancelTask) - RAY_CORE_WORKER_RPC_PROXY(CancelRemoteTask) + RAY_CORE_WORKER_RPC_PROXY(RequestOwnerToCancelTask) RAY_CORE_WORKER_RPC_PROXY(RegisterMutableObjectReader) RAY_CORE_WORKER_RPC_PROXY(GetCoreWorkerStats) RAY_CORE_WORKER_RPC_PROXY(LocalGC) diff --git a/src/ray/core_worker/grpc_service.cc b/src/ray/core_worker/grpc_service.cc index adb5b62786d4..b0e4d9ce37d7 100644 --- a/src/ray/core_worker/grpc_service.cc +++ b/src/ray/core_worker/grpc_service.cc @@ -77,7 +77,7 @@ void CoreWorkerGrpcService::InitServerCallFactories( max_active_rpcs_per_handler_, ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, - CancelRemoteTask, + RequestOwnerToCancelTask, max_active_rpcs_per_handler_, ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, diff --git a/src/ray/core_worker/grpc_service.h b/src/ray/core_worker/grpc_service.h index d605f5176533..b90349f6f009 100644 --- a/src/ray/core_worker/grpc_service.h +++ b/src/ray/core_worker/grpc_service.h @@ -96,9 +96,9 @@ class CoreWorkerServiceHandler : public DelayedServiceHandler { CancelTaskReply *reply, SendReplyCallback send_reply_callback) = 0; - virtual void HandleCancelRemoteTask(CancelRemoteTaskRequest request, - CancelRemoteTaskReply *reply, - SendReplyCallback send_reply_callback) = 0; + virtual void HandleRequestOwnerToCancelTask(RequestOwnerToCancelTaskRequest request, + RequestOwnerToCancelTaskReply *reply, + SendReplyCallback send_reply_callback) = 0; virtual void HandleRegisterMutableObjectReader( RegisterMutableObjectReaderRequest request, diff --git a/src/ray/core_worker/task_submission/normal_task_submitter.cc b/src/ray/core_worker/task_submission/normal_task_submitter.cc index 026100f665b4..b72f111c51e4 100644 --- a/src/ray/core_worker/task_submission/normal_task_submitter.cc +++ b/src/ray/core_worker/task_submission/normal_task_submitter.cc @@ -809,18 +809,18 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, do_cancel_local_task(*node_info); } -void NormalTaskSubmitter::CancelRemoteTask(const ObjectID &object_id, - const rpc::Address &worker_addr, - bool force_kill, - bool recursive) { +void NormalTaskSubmitter::RequestOwnerToCancelTask(const ObjectID &object_id, + const rpc::Address &worker_addr, + bool force_kill, + bool recursive) { auto client = core_worker_client_pool_->GetOrConnect(worker_addr); - auto request = rpc::CancelRemoteTaskRequest(); + auto request = rpc::RequestOwnerToCancelTaskRequest(); request.set_force_kill(force_kill); request.set_recursive(recursive); request.set_remote_object_id(object_id.Binary()); - client->CancelRemoteTask( + client->RequestOwnerToCancelTask( std::move(request), - [](const Status &status, const rpc::CancelRemoteTaskReply &reply) { + [](const Status &status, const rpc::RequestOwnerToCancelTaskReply &reply) { if (!status.ok()) { RAY_LOG(ERROR) << "Failed to cancel remote task: " << status.ToString(); } diff --git a/src/ray/core_worker/task_submission/normal_task_submitter.h b/src/ray/core_worker/task_submission/normal_task_submitter.h index f03a3309f61c..c0600db332c0 100644 --- a/src/ray/core_worker/task_submission/normal_task_submitter.h +++ b/src/ray/core_worker/task_submission/normal_task_submitter.h @@ -133,10 +133,10 @@ class NormalTaskSubmitter { /// It is used when a object ID is not owned by the current process. /// We cannot cancel the task in this case because we don't have enough /// information to cancel a task. - void CancelRemoteTask(const ObjectID &object_id, - const rpc::Address &worker_addr, - bool force_kill, - bool recursive); + void RequestOwnerToCancelTask(const ObjectID &object_id, + const rpc::Address &worker_addr, + bool force_kill, + bool recursive); /// Queue the streaming generator up for resubmission. /// \return true if the task is still executing and the submitter agrees to resubmit diff --git a/src/ray/core_worker_rpc_client/core_worker_client.h b/src/ray/core_worker_rpc_client/core_worker_client.h index b9fa6b2ea71f..be6c84698283 100644 --- a/src/ray/core_worker_rpc_client/core_worker_client.h +++ b/src/ray/core_worker_rpc_client/core_worker_client.h @@ -86,7 +86,7 @@ class CoreWorkerClient : public std::enable_shared_from_this, VOID_RETRYABLE_RPC_CLIENT_METHOD(retryable_grpc_client_, CoreWorkerService, - CancelRemoteTask, + RequestOwnerToCancelTask, grpc_client_, /*method_timeout_ms*/ -1, override) diff --git a/src/ray/core_worker_rpc_client/core_worker_client_interface.h b/src/ray/core_worker_rpc_client/core_worker_client_interface.h index 80c37f709d34..e4c7f2f97cb1 100644 --- a/src/ray/core_worker_rpc_client/core_worker_client_interface.h +++ b/src/ray/core_worker_rpc_client/core_worker_client_interface.h @@ -76,9 +76,9 @@ class CoreWorkerClientInterface : public pubsub::SubscriberClientInterface { virtual void CancelTask(const CancelTaskRequest &request, const ClientCallback &callback) = 0; - virtual void CancelRemoteTask( - CancelRemoteTaskRequest &&request, - const ClientCallback &callback) = 0; + virtual void RequestOwnerToCancelTask( + RequestOwnerToCancelTaskRequest &&request, + const ClientCallback &callback) = 0; virtual void RegisterMutableObjectReader( const RegisterMutableObjectReaderRequest &request, diff --git a/src/ray/core_worker_rpc_client/fake_core_worker_client.h b/src/ray/core_worker_rpc_client/fake_core_worker_client.h index 368cda8e8628..41ac5546f44d 100644 --- a/src/ray/core_worker_rpc_client/fake_core_worker_client.h +++ b/src/ray/core_worker_rpc_client/fake_core_worker_client.h @@ -82,8 +82,9 @@ class FakeCoreWorkerClient : public CoreWorkerClientInterface { void CancelTask(const CancelTaskRequest &request, const ClientCallback &callback) override {} - void CancelRemoteTask(CancelRemoteTaskRequest &&request, - const ClientCallback &callback) override {} + void RequestOwnerToCancelTask( + RequestOwnerToCancelTaskRequest &&request, + const ClientCallback &callback) override {} void RegisterMutableObjectReader( const RegisterMutableObjectReaderRequest &request, diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index eb915aff2d05..fae2024705c2 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -293,7 +293,7 @@ message CancelTaskReply { bool attempt_succeeded = 2; } -message CancelRemoteTaskRequest { +message RequestOwnerToCancelTaskRequest { // Object ID of the remote task that should be killed. bytes remote_object_id = 1; // Whether to kill the worker. @@ -302,7 +302,7 @@ message CancelRemoteTaskRequest { bool recursive = 3; } -message CancelRemoteTaskReply {} +message RequestOwnerToCancelTaskReply {} message GetCoreWorkerStatsRequest { // The ID of the worker this message is intended for. @@ -536,7 +536,7 @@ service CoreWorkerService { // Request from a worker to the owner worker to issue a cancellation. // Failure: Retries, it's idempotent. - rpc CancelRemoteTask(CancelRemoteTaskRequest) returns (CancelRemoteTaskReply); + rpc RequestOwnerToCancelTask(RequestOwnerToCancelTaskRequest) returns (RequestOwnerToCancelTaskReply); // From raylet to get metrics from its workers. // Failure: Should not fail, always from local raylet. diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 3d07ca3a1d67..c8f0d991c6ad 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -3472,9 +3472,11 @@ void NodeManager::HandleCancelLocalTask(rpc::CancelLocalTaskRequest request, }); return; } + std::shared_ptr replied = std::make_shared(false); + auto timer = execute_after( io_service_, - [this, reply, send_reply_callback, worker_id]() { + [this, reply, send_reply_callback, worker_id, replied]() { auto current_worker = worker_pool_.GetRegisteredWorker(worker_id); if (current_worker) { // If the worker is still alive, force kill it @@ -3487,6 +3489,8 @@ void NodeManager::HandleCancelLocalTask(rpc::CancelLocalTaskRequest request, "Force-killed by ray.cancel(force=True)", /*force=*/true); } + + *replied = true; reply->set_attempt_succeeded(true); reply->set_requested_task_running(false); send_reply_callback(Status::OK(), nullptr, nullptr); @@ -3496,7 +3500,7 @@ void NodeManager::HandleCancelLocalTask(rpc::CancelLocalTaskRequest request, worker->rpc_client()->CancelTask( cancel_task_request, - [task_id = request.intended_task_id(), timer, reply, send_reply_callback]( + [task_id = request.intended_task_id(), timer, reply, send_reply_callback, replied]( const ray::Status &status, const rpc::CancelTaskReply &cancel_task_reply) { if (!status.ok()) { RAY_LOG(DEBUG) << "CancelTask RPC failed for task " << task_id << ": " @@ -3505,10 +3509,12 @@ void NodeManager::HandleCancelLocalTask(rpc::CancelLocalTaskRequest request, // timer above return; } - reply->set_attempt_succeeded(cancel_task_reply.attempt_succeeded()); - reply->set_requested_task_running(cancel_task_reply.requested_task_running()); - send_reply_callback(Status::OK(), nullptr, nullptr); - timer->cancel(); + if (!*replied) { + reply->set_attempt_succeeded(cancel_task_reply.attempt_succeeded()); + reply->set_requested_task_running(cancel_task_reply.requested_task_running()); + send_reply_callback(Status::OK(), nullptr, nullptr); + timer->cancel(); + } }); } From 9070db534b59501834b339234f71fc3562ff1600 Mon Sep 17 00:00:00 2001 From: joshlee Date: Fri, 7 Nov 2025 23:56:26 +0000 Subject: [PATCH 07/18] lint Signed-off-by: joshlee --- src/ray/protobuf/core_worker.proto | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index fae2024705c2..2654871004f5 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -536,7 +536,8 @@ service CoreWorkerService { // Request from a worker to the owner worker to issue a cancellation. // Failure: Retries, it's idempotent. - rpc RequestOwnerToCancelTask(RequestOwnerToCancelTaskRequest) returns (RequestOwnerToCancelTaskReply); + rpc RequestOwnerToCancelTask(RequestOwnerToCancelTaskRequest) + returns (RequestOwnerToCancelTaskReply); // From raylet to get metrics from its workers. // Failure: Should not fail, always from local raylet. From dcec398ee586617e83df54c5efd1a100231ebaff Mon Sep 17 00:00:00 2001 From: joshlee Date: Wed, 12 Nov 2025 21:57:14 +0000 Subject: [PATCH 08/18] Addressing comments Signed-off-by: joshlee --- .../task_submission/actor_task_submitter.cc | 97 +++++++------- .../task_submission/normal_task_submitter.cc | 122 +++++++++--------- .../tests/normal_task_submitter_test.cc | 18 ++- 3 files changed, 119 insertions(+), 118 deletions(-) diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.cc b/src/ray/core_worker/task_submission/actor_task_submitter.cc index cac3304594ed..c596f0372bc0 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.cc +++ b/src/ray/core_worker/task_submission/actor_task_submitter.cc @@ -998,7 +998,8 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) // If there's no client, it means actor is not created yet. // Retry in 1 second. - rpc::Address client_address; + NodeID node_id; + std::string executor_worker_id; { absl::MutexLock lock(&mu_); RAY_LOG(DEBUG).WithField(task_id) << "Task was sent to an actor. Send a cancel RPC."; @@ -1008,60 +1009,60 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) RetryCancelTask(task_spec, recursive, 1000); return; } - client_address = queue->second.client_address_.value(); + node_id = NodeID::FromBinary(queue->second.client_address_.value().node_id()); + executor_worker_id = queue->second.client_address_.value().worker_id(); } - const auto node_id = NodeID::FromBinary(client_address.node_id()); - const auto executor_worker_id = client_address.worker_id(); - - auto do_cancel_local_task = [this, - task_spec = std::move(task_spec), - task_id, - force_kill, - recursive, - executor_worker_id = std::move(executor_worker_id)]( - const rpc::GcsNodeInfo &node_info) mutable { - rpc::Address raylet_address; - raylet_address.set_node_id(node_info.node_id()); - raylet_address.set_ip_address(node_info.node_manager_address()); - raylet_address.set_port(node_info.node_manager_port()); - - rpc::CancelLocalTaskRequest request; - request.set_intended_task_id(task_spec.TaskIdBinary()); - request.set_force_kill(force_kill); - request.set_recursive(recursive); - request.set_caller_worker_id(task_spec.CallerWorkerIdBinary()); - request.set_executor_worker_id(executor_worker_id); - - auto raylet_client = raylet_client_pool_.GetOrConnectByAddress(raylet_address); - raylet_client->CancelLocalTask( - request, - [this, task_spec = std::move(task_spec), recursive, task_id]( - const Status &status, const rpc::CancelLocalTaskReply &reply) mutable { - RAY_LOG(DEBUG).WithField(task_spec.TaskId()) - << "CancelTask RPC response received with status " << status.ToString(); - - // Keep retrying every 2 seconds until a task is officially - // finished. - if (!task_manager_.GetTaskSpec(task_id)) { - // Task is already finished. - RAY_LOG(DEBUG).WithField(task_spec.TaskId()) - << "Task is finished. Stop a cancel request."; - return; - } + auto do_cancel_local_task = + [this, + task_spec = std::move(task_spec), + task_id, + force_kill, + recursive, + executor_worker_id](const rpc::GcsNodeAddressAndLiveness &node_info) mutable { + rpc::Address raylet_address; + raylet_address.set_node_id(node_info.node_id()); + raylet_address.set_ip_address(node_info.node_manager_address()); + raylet_address.set_port(node_info.node_manager_port()); + + rpc::CancelLocalTaskRequest request; + request.set_intended_task_id(task_spec.TaskIdBinary()); + request.set_force_kill(force_kill); + request.set_recursive(recursive); + request.set_caller_worker_id(task_spec.CallerWorkerIdBinary()); + request.set_executor_worker_id(executor_worker_id); + + auto raylet_client = raylet_client_pool_.GetOrConnectByAddress(raylet_address); + raylet_client->CancelLocalTask( + request, + [this, task_spec = std::move(task_spec), recursive, task_id]( + const Status &status, const rpc::CancelLocalTaskReply &reply) mutable { + RAY_LOG(DEBUG).WithField(task_spec.TaskId()) + << "CancelTask RPC response received with status " << status.ToString(); + + // Keep retrying every 2 seconds until a task is officially + // finished. + if (!task_manager_.GetTaskSpec(task_id)) { + // Task is already finished. + RAY_LOG(DEBUG).WithField(task_spec.TaskId()) + << "Task is finished. Stop a cancel request."; + return; + } - if (!reply.attempt_succeeded()) { - RetryCancelTask(std::move(task_spec), recursive, 2000); - } - }); - }; + if (!reply.attempt_succeeded()) { + RetryCancelTask(std::move(task_spec), recursive, 2000); + } + }); + }; // Check GCS node cache. If node info is not in the cache, query the GCS instead. - auto *node_info = gcs_client_->Nodes().Get(node_id, /*filter_dead_nodes=*/false); + auto *node_info = gcs_client_->Nodes().GetNodeAddressAndLiveness( + node_id, /*filter_dead_nodes=*/false); if (node_info == nullptr) { - gcs_client_->Nodes().AsyncGetAll( + gcs_client_->Nodes().AsyncGetAllNodeAddressAndLiveness( [do_cancel_local_task = std::move(do_cancel_local_task), node_id]( - const Status &status, std::vector &&nodes) mutable { + const Status &status, + std::vector &&nodes) mutable { if (!status.ok()) { RAY_LOG(INFO) << "Failed to get node info from GCS"; return; diff --git a/src/ray/core_worker/task_submission/normal_task_submitter.cc b/src/ray/core_worker/task_submission/normal_task_submitter.cc index b72f111c51e4..b322a08a80b3 100644 --- a/src/ray/core_worker/task_submission/normal_task_submitter.cc +++ b/src/ray/core_worker/task_submission/normal_task_submitter.cc @@ -665,7 +665,8 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, SchedulingKey scheduling_key(task_spec.GetSchedulingClass(), task_spec.GetDependencyIds(), task_spec.GetRuntimeEnvHash()); - rpc::Address executor_worker_address; + NodeID node_id; + std::string executor_worker_id; { absl::MutexLock lock(&mu_); generators_to_resubmit_.erase(task_id); @@ -721,71 +722,72 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, } return; } - executor_worker_address = rpc_client_address->second; + node_id = NodeID::FromBinary(rpc_client_address->second.node_id()); + executor_worker_id = rpc_client_address->second.worker_id(); } - const auto node_id = NodeID::FromBinary(executor_worker_address.node_id()); - const auto executor_worker_id = - WorkerID::FromBinary(executor_worker_address.worker_id()); - - auto do_cancel_local_task = [this, - task_spec = std::move(task_spec), - scheduling_key = std::move(scheduling_key), - executor_worker_address, - executor_worker_id, - force_kill, - recursive](const rpc::GcsNodeInfo &node_info) mutable { - rpc::Address raylet_address; - raylet_address.set_node_id(node_info.node_id()); - raylet_address.set_ip_address(node_info.node_manager_address()); - raylet_address.set_port(node_info.node_manager_port()); - - rpc::CancelLocalTaskRequest request; - request.set_intended_task_id(task_spec.TaskIdBinary()); - request.set_force_kill(force_kill); - request.set_recursive(recursive); - request.set_caller_worker_id(task_spec.CallerWorkerIdBinary()); - request.set_executor_worker_id(executor_worker_id.Binary()); - - auto raylet_client = raylet_client_pool_->GetOrConnectByAddress(raylet_address); - raylet_client->CancelLocalTask( - request, - [this, - task_spec = std::move(task_spec), - scheduling_key = std::move(scheduling_key), - force_kill, - recursive](const Status &status, - const rpc::CancelLocalTaskReply &reply) mutable { - absl::MutexLock lock(&mu_); - RAY_LOG(DEBUG) << "CancelTask RPC response received for " << task_spec.TaskId() - << " with status " << status.ToString(); - cancelled_tasks_.erase(task_spec.TaskId()); - - if (!reply.attempt_succeeded()) { - if (reply.requested_task_running()) { - if (cancel_retry_timer_.expiry().time_since_epoch() <= - std::chrono::high_resolution_clock::now().time_since_epoch()) { - cancel_retry_timer_.expires_after(boost::asio::chrono::milliseconds( - RayConfig::instance().cancellation_retry_ms())); + auto do_cancel_local_task = + [this, + task_spec = std::move(task_spec), + scheduling_key = std::move(scheduling_key), + executor_worker_id, + force_kill, + recursive](const rpc::GcsNodeAddressAndLiveness &node_info) mutable { + rpc::Address raylet_address; + raylet_address.set_node_id(node_info.node_id()); + raylet_address.set_ip_address(node_info.node_manager_address()); + raylet_address.set_port(node_info.node_manager_port()); + + rpc::CancelLocalTaskRequest request; + request.set_intended_task_id(task_spec.TaskIdBinary()); + request.set_force_kill(force_kill); + request.set_recursive(recursive); + request.set_caller_worker_id(task_spec.CallerWorkerIdBinary()); + request.set_executor_worker_id(executor_worker_id); + + auto raylet_client = raylet_client_pool_->GetOrConnectByAddress(raylet_address); + raylet_client->CancelLocalTask( + request, + [this, + task_spec = std::move(task_spec), + scheduling_key = std::move(scheduling_key), + force_kill, + recursive](const Status &status, + const rpc::CancelLocalTaskReply &reply) mutable { + absl::MutexLock lock(&mu_); + RAY_LOG(DEBUG) << "CancelTask RPC response received for " + << task_spec.TaskId() << " with status " + << status.ToString(); + cancelled_tasks_.erase(task_spec.TaskId()); + + if (!reply.attempt_succeeded()) { + if (reply.requested_task_running()) { + if (cancel_retry_timer_.expiry().time_since_epoch() <= + std::chrono::high_resolution_clock::now().time_since_epoch()) { + cancel_retry_timer_.expires_after(boost::asio::chrono::milliseconds( + RayConfig::instance().cancellation_retry_ms())); + } + cancel_retry_timer_.async_wait( + boost::bind(&NormalTaskSubmitter::CancelTask, + this, + std::move(task_spec), + force_kill, + recursive)); + } else { + RAY_LOG(DEBUG) << "Attempt to cancel task " << task_spec.TaskId() + << " in a worker that doesn't have this task."; + } } - cancel_retry_timer_.async_wait(boost::bind(&NormalTaskSubmitter::CancelTask, - this, - std::move(task_spec), - force_kill, - recursive)); - } else { - RAY_LOG(DEBUG) << "Attempt to cancel task " << task_spec.TaskId() - << " in a worker that doesn't have this task."; - } - } - }); - }; + }); + }; - auto *node_info = gcs_client_->Nodes().Get(node_id, /*filter_dead_nodes=*/false); + auto *node_info = gcs_client_->Nodes().GetNodeAddressAndLiveness( + node_id, /*filter_dead_nodes=*/false); if (node_info == nullptr) { - gcs_client_->Nodes().AsyncGetAll( + gcs_client_->Nodes().AsyncGetAllNodeAddressAndLiveness( [do_cancel_local_task = std::move(do_cancel_local_task), node_id]( - const Status &status, std::vector &&nodes) mutable { + const Status &status, + std::vector &&nodes) mutable { if (!status.ok()) { RAY_LOG(INFO) << "Failed to get node info from GCS"; return; diff --git a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc index e656c8ccb6e3..a68a4bae7fb4 100644 --- a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc +++ b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc @@ -385,7 +385,7 @@ class MockRayletClient : public rpc::FakeRayletClient { void CancelLocalTask( const rpc::CancelLocalTaskRequest &request, const rpc::ClientCallback &callback) override { - cancel_local_task_requests.push_front(request); + cancel_local_task_requests.push_back(request); cancel_local_task_callbacks.push_back(callback); } @@ -414,13 +414,12 @@ class MockRayletClient : public rpc::FakeRayletClient { int num_get_task_failure_causes = 0; int reported_backlog_size = 0; std::map reported_backlogs; - std::list> callbacks = {}; - std::list> cancel_callbacks = {}; + std::list> callbacks; + std::list> cancel_callbacks; std::list> - get_task_failure_cause_callbacks = {}; - std::list cancel_local_task_requests = {}; - std::list> cancel_local_task_callbacks = - {}; + get_task_failure_cause_callbacks; + std::list cancel_local_task_requests; + std::list> cancel_local_task_callbacks; }; class MockLeasePolicy : public LeasePolicyInterface { @@ -668,7 +667,6 @@ TEST_F(NormalTaskSubmitterTest, TestCancellationWhileHandlingTaskFailure) { // Set up GCS node mock to return node as alive using testing::_; - using testing::Return; rpc::GcsNodeInfo node_info; node_info.set_node_id(local_node_id.Binary()); @@ -677,7 +675,7 @@ TEST_F(NormalTaskSubmitterTest, TestCancellationWhileHandlingTaskFailure) { node_info.set_state(rpc::GcsNodeInfo::ALIVE); EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, Get(_, false)) - .WillRepeatedly(Return(&node_info)); + .WillRepeatedly(testing::Return(&node_info)); auto submitter = CreateNormalTaskSubmitter(std::make_shared(1)); @@ -1788,7 +1786,7 @@ TEST_F(NormalTaskSubmitterTest, TestKillExecutingTask) { // Try non-force kill, worker returns normally submitter.CancelTask(task, false, false); ASSERT_TRUE(worker_client->ReplyPushTask()); - ASSERT_EQ(raylet_client->cancel_local_task_requests.front().intended_task_id(), + ASSERT_EQ(raylet_client->cancel_local_task_requests.back().intended_task_id(), task.TaskIdBinary()); ASSERT_EQ(worker_client->callbacks.size(), 0); ASSERT_EQ(raylet_client->num_workers_returned, 1); From 9df37aa3a63b61a3d513ada90b910094e60e108a Mon Sep 17 00:00:00 2001 From: joshlee Date: Wed, 12 Nov 2025 22:11:26 +0000 Subject: [PATCH 09/18] Fix cpp test failures Signed-off-by: joshlee --- .../tests/normal_task_submitter_test.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc index a68a4bae7fb4..905294bac05e 100644 --- a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc +++ b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc @@ -668,13 +668,13 @@ TEST_F(NormalTaskSubmitterTest, TestCancellationWhileHandlingTaskFailure) { // Set up GCS node mock to return node as alive using testing::_; - rpc::GcsNodeInfo node_info; + rpc::GcsNodeAddressAndLiveness node_info; node_info.set_node_id(local_node_id.Binary()); node_info.set_node_manager_address("127.0.0.1"); node_info.set_node_manager_port(9999); node_info.set_state(rpc::GcsNodeInfo::ALIVE); - EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, Get(_, false)) + EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, GetNodeAddressAndLiveness(_, false)) .WillRepeatedly(testing::Return(&node_info)); auto submitter = @@ -1748,13 +1748,14 @@ TEST_F(NormalTaskSubmitterTest, TestWorkerLeaseTimeout) { } TEST_F(NormalTaskSubmitterTest, TestKillExecutingTask) { - rpc::GcsNodeInfo node_info; + rpc::GcsNodeAddressAndLiveness node_info; node_info.set_node_id(local_node_id.Binary()); node_info.set_node_manager_address("127.0.0.1"); node_info.set_node_manager_port(9999); node_info.set_state(rpc::GcsNodeInfo::ALIVE); - EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, Get(local_node_id, false)) + EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, + GetNodeAddressAndLiveness(local_node_id, false)) .WillOnce(testing::Return(&node_info)) .WillOnce(testing::Return(&node_info)); @@ -1868,13 +1869,14 @@ TEST_F(NormalTaskSubmitterTest, TestCancelBeforeAfterQueueGeneratorForResubmit) // Cancel -> failed queue generator for resubmit -> cancel reply -> successful queue for // resubmit -> push task reply -> honor the cancel not the queued resubmit. - rpc::GcsNodeInfo node_info; + rpc::GcsNodeAddressAndLiveness node_info; node_info.set_node_id(local_node_id.Binary()); node_info.set_node_manager_address("127.0.0.1"); node_info.set_node_manager_port(9999); node_info.set_state(rpc::GcsNodeInfo::ALIVE); - EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, Get(local_node_id, false)) + EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, + GetNodeAddressAndLiveness(local_node_id, false)) .WillOnce(testing::Return(&node_info)) .WillOnce(testing::Return(&node_info)); From d846b90024642a84f98259fc30ebc11bdf59845e Mon Sep 17 00:00:00 2001 From: joshlee Date: Thu, 13 Nov 2025 22:32:15 +0000 Subject: [PATCH 10/18] Addressing comments Signed-off-by: joshlee --- .../task_submission/actor_task_submitter.cc | 15 +++++++-- .../task_submission/normal_task_submitter.cc | 16 +++++++--- src/ray/raylet/node_manager.cc | 31 +++++++++++++------ 3 files changed, 45 insertions(+), 17 deletions(-) diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.cc b/src/ray/core_worker/task_submission/actor_task_submitter.cc index c596f0372bc0..d4c4aa84bb18 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.cc +++ b/src/ray/core_worker/task_submission/actor_task_submitter.cc @@ -1037,9 +1037,18 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) request, [this, task_spec = std::move(task_spec), recursive, task_id]( const Status &status, const rpc::CancelLocalTaskReply &reply) mutable { - RAY_LOG(DEBUG).WithField(task_spec.TaskId()) - << "CancelTask RPC response received with status " << status.ToString(); - + if (!status.ok()) { + RAY_LOG(DEBUG) << "CancelLocalTask RPC failed for task " + << task_spec.TaskId() << ": " << status.ToString() + << " due to node death"; + return; + } else { + RAY_LOG(DEBUG) << "CancelLocalTask RPC response received for " + << task_spec.TaskId() + << " with attempt_succeeded: " << reply.attempt_succeeded() + << " requested_task_running: " + << reply.requested_task_running(); + } // Keep retrying every 2 seconds until a task is officially // finished. if (!task_manager_.GetTaskSpec(task_id)) { diff --git a/src/ray/core_worker/task_submission/normal_task_submitter.cc b/src/ray/core_worker/task_submission/normal_task_submitter.cc index b322a08a80b3..e01adfb3b310 100644 --- a/src/ray/core_worker/task_submission/normal_task_submitter.cc +++ b/src/ray/core_worker/task_submission/normal_task_submitter.cc @@ -755,11 +755,19 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, recursive](const Status &status, const rpc::CancelLocalTaskReply &reply) mutable { absl::MutexLock lock(&mu_); - RAY_LOG(DEBUG) << "CancelTask RPC response received for " - << task_spec.TaskId() << " with status " - << status.ToString(); cancelled_tasks_.erase(task_spec.TaskId()); - + if (!status.ok()) { + RAY_LOG(DEBUG) << "CancelLocalTask RPC failed for task " + << task_spec.TaskId() << ": " << status.ToString() + << " due to node death"; + return; + } else { + RAY_LOG(DEBUG) << "CancelLocalTask RPC response received for " + << task_spec.TaskId() + << " with attempt_succeeded: " << reply.attempt_succeeded() + << " requested_task_running: " + << reply.requested_task_running(); + } if (!reply.attempt_succeeded()) { if (reply.requested_task_running()) { if (cancel_retry_timer_.expiry().time_since_epoch() <= diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index c8f0d991c6ad..ac81efc33f10 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -3400,6 +3400,9 @@ void NodeManager::HandleKillLocalActor(rpc::KillLocalActorRequest request, auto timer = execute_after( io_service_, [this, send_reply_callback, worker_id, replied]() { + if (*replied) { + return; + } auto current_worker = worker_pool_.GetRegisteredWorker(worker_id); if (current_worker) { // If the worker is still alive, force kill it @@ -3425,12 +3428,16 @@ void NodeManager::HandleKillLocalActor(rpc::KillLocalActorRequest request, timer, send_reply_callback, replied](const ray::Status &status, const rpc::KillActorReply &) { - if (!status.ok() && !*replied) { + if (*replied) { + return; + } + if (!status.ok()) { std::ostringstream stream; stream << "KillActor RPC failed for actor " << actor_id << ": " << status.ToString(); const auto &msg = stream.str(); RAY_LOG(DEBUG) << msg; + *replied = true; timer->cancel(); send_reply_callback(Status::Invalid(msg), nullptr, nullptr); } @@ -3478,6 +3485,9 @@ void NodeManager::HandleCancelLocalTask(rpc::CancelLocalTaskRequest request, io_service_, [this, reply, send_reply_callback, worker_id, replied]() { auto current_worker = worker_pool_.GetRegisteredWorker(worker_id); + if (*replied) { + return; + } if (current_worker) { // If the worker is still alive, force kill it RAY_LOG(INFO) << "Worker with PID=" << current_worker->GetProcess().GetId() @@ -3489,7 +3499,6 @@ void NodeManager::HandleCancelLocalTask(rpc::CancelLocalTaskRequest request, "Force-killed by ray.cancel(force=True)", /*force=*/true); } - *replied = true; reply->set_attempt_succeeded(true); reply->set_requested_task_running(false); @@ -3502,19 +3511,21 @@ void NodeManager::HandleCancelLocalTask(rpc::CancelLocalTaskRequest request, cancel_task_request, [task_id = request.intended_task_id(), timer, reply, send_reply_callback, replied]( const ray::Status &status, const rpc::CancelTaskReply &cancel_task_reply) { + if (*replied) { + return; + } if (!status.ok()) { - RAY_LOG(DEBUG) << "CancelTask RPC failed for task " << task_id << ": " - << status.ToString(); + RAY_LOG(DEBUG) << "CancelTask RPC failed for task " + << TaskID::FromBinary(task_id) << ": " << status.ToString(); // NOTE: We'll escalate the graceful shutdown to SIGKILL which is done by the // timer above return; } - if (!*replied) { - reply->set_attempt_succeeded(cancel_task_reply.attempt_succeeded()); - reply->set_requested_task_running(cancel_task_reply.requested_task_running()); - send_reply_callback(Status::OK(), nullptr, nullptr); - timer->cancel(); - } + *replied = true; + reply->set_attempt_succeeded(cancel_task_reply.attempt_succeeded()); + reply->set_requested_task_running(cancel_task_reply.requested_task_running()); + send_reply_callback(Status::OK(), nullptr, nullptr); + timer->cancel(); }); } From 873a17c82ce4acf14f6d27385cb71c32f2865b32 Mon Sep 17 00:00:00 2001 From: joshlee Date: Thu, 13 Nov 2025 22:33:29 +0000 Subject: [PATCH 11/18] Addressing comments Signed-off-by: joshlee --- .../core_worker/task_submission/actor_task_submitter.cc | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.cc b/src/ray/core_worker/task_submission/actor_task_submitter.cc index d4c4aa84bb18..e57448f03cb0 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.cc +++ b/src/ray/core_worker/task_submission/actor_task_submitter.cc @@ -1051,13 +1051,6 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) } // Keep retrying every 2 seconds until a task is officially // finished. - if (!task_manager_.GetTaskSpec(task_id)) { - // Task is already finished. - RAY_LOG(DEBUG).WithField(task_spec.TaskId()) - << "Task is finished. Stop a cancel request."; - return; - } - if (!reply.attempt_succeeded()) { RetryCancelTask(std::move(task_spec), recursive, 2000); } From a253c81cf403e5e747125335c8d00ae0bccc29ab Mon Sep 17 00:00:00 2001 From: joshlee Date: Fri, 14 Nov 2025 08:54:39 +0000 Subject: [PATCH 12/18] fix build error Signed-off-by: joshlee --- src/ray/core_worker/task_submission/actor_task_submitter.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.cc b/src/ray/core_worker/task_submission/actor_task_submitter.cc index e57448f03cb0..a1c243e6fdd8 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.cc +++ b/src/ray/core_worker/task_submission/actor_task_submitter.cc @@ -1035,7 +1035,7 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) auto raylet_client = raylet_client_pool_.GetOrConnectByAddress(raylet_address); raylet_client->CancelLocalTask( request, - [this, task_spec = std::move(task_spec), recursive, task_id]( + [this, task_spec = std::move(task_spec), recursive]( const Status &status, const rpc::CancelLocalTaskReply &reply) mutable { if (!status.ok()) { RAY_LOG(DEBUG) << "CancelLocalTask RPC failed for task " From 9d5cf6f83a9f63f0fcf744320dc19cc618e0023d Mon Sep 17 00:00:00 2001 From: joshlee Date: Fri, 14 Nov 2025 19:36:58 +0000 Subject: [PATCH 13/18] Addressing comments Signed-off-by: joshlee --- .../core_worker/task_submission/actor_task_submitter.cc | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.cc b/src/ray/core_worker/task_submission/actor_task_submitter.cc index a1c243e6fdd8..cd3e62a2ee03 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.cc +++ b/src/ray/core_worker/task_submission/actor_task_submitter.cc @@ -1014,12 +1014,8 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) } auto do_cancel_local_task = - [this, - task_spec = std::move(task_spec), - task_id, - force_kill, - recursive, - executor_worker_id](const rpc::GcsNodeAddressAndLiveness &node_info) mutable { + [this, task_spec = std::move(task_spec), force_kill, recursive, executor_worker_id]( + const rpc::GcsNodeAddressAndLiveness &node_info) mutable { rpc::Address raylet_address; raylet_address.set_node_id(node_info.node_id()); raylet_address.set_ip_address(node_info.node_manager_address()); From c8e0ed6aa764cd2c802e35e6f2468c5f84039067 Mon Sep 17 00:00:00 2001 From: joshlee Date: Tue, 18 Nov 2025 01:32:13 +0000 Subject: [PATCH 14/18] Addressing comments Signed-off-by: joshlee --- .../task_submission/actor_task_submitter.cc | 65 ++++++++++--------- .../task_submission/normal_task_submitter.cc | 63 ++++++++++-------- 2 files changed, 71 insertions(+), 57 deletions(-) diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.cc b/src/ray/core_worker/task_submission/actor_task_submitter.cc index cd3e62a2ee03..616385e238ee 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.cc +++ b/src/ray/core_worker/task_submission/actor_task_submitter.cc @@ -1053,35 +1053,42 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) }); }; - // Check GCS node cache. If node info is not in the cache, query the GCS instead. - auto *node_info = gcs_client_->Nodes().GetNodeAddressAndLiveness( - node_id, /*filter_dead_nodes=*/false); - if (node_info == nullptr) { - gcs_client_->Nodes().AsyncGetAllNodeAddressAndLiveness( - [do_cancel_local_task = std::move(do_cancel_local_task), node_id]( - const Status &status, - std::vector &&nodes) mutable { - if (!status.ok()) { - RAY_LOG(INFO) << "Failed to get node info from GCS"; - return; - } - if (nodes.empty() || nodes[0].state() != rpc::GcsNodeInfo::ALIVE) { - RAY_LOG(INFO).WithField(node_id) - << "Not sending CancelLocalTask because node is dead"; - return; - } - do_cancel_local_task(nodes[0]); - }, - -1, - {node_id}); - return; - } - if (node_info->state() == rpc::GcsNodeInfo::DEAD) { - RAY_LOG(INFO).WithField(node_id) - << "Not sending CancelLocalTask because node is dead"; - return; - } - do_cancel_local_task(*node_info); + // Cancel can execute on the user's python thread, but the GCS node cache is updated on + // the io service thread and is not thread-safe. Hence we need to post the entire + // cache access to the io service thread. + io_service_.post( + [this, do_cancel_local_task = std::move(do_cancel_local_task), node_id]() mutable { + // Check GCS node cache. If node info is not in the cache, query the GCS instead. + auto *node_info = gcs_client_->Nodes().GetNodeAddressAndLiveness( + node_id, /*filter_dead_nodes=*/false); + if (node_info == nullptr) { + gcs_client_->Nodes().AsyncGetAllNodeAddressAndLiveness( + [do_cancel_local_task = std::move(do_cancel_local_task), node_id]( + const Status &status, + std::vector &&nodes) mutable { + if (!status.ok()) { + RAY_LOG(INFO) << "Failed to get node info from GCS"; + return; + } + if (nodes.empty() || nodes[0].state() != rpc::GcsNodeInfo::ALIVE) { + RAY_LOG(INFO).WithField(node_id) + << "Not sending CancelLocalTask because node is dead"; + return; + } + do_cancel_local_task(nodes[0]); + }, + -1, + {node_id}); + return; + } + if (node_info->state() == rpc::GcsNodeInfo::DEAD) { + RAY_LOG(INFO).WithField(node_id) + << "Not sending CancelLocalTask because node is dead"; + return; + } + do_cancel_local_task(*node_info); + }, + "ActorTaskSubmitter.CancelTask"); } bool ActorTaskSubmitter::QueueGeneratorForResubmit(const TaskSpecification &spec) { diff --git a/src/ray/core_worker/task_submission/normal_task_submitter.cc b/src/ray/core_worker/task_submission/normal_task_submitter.cc index 4e9747781b97..c0771d914abe 100644 --- a/src/ray/core_worker/task_submission/normal_task_submitter.cc +++ b/src/ray/core_worker/task_submission/normal_task_submitter.cc @@ -789,34 +789,41 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, }); }; - auto *node_info = gcs_client_->Nodes().GetNodeAddressAndLiveness( - node_id, /*filter_dead_nodes=*/false); - if (node_info == nullptr) { - gcs_client_->Nodes().AsyncGetAllNodeAddressAndLiveness( - [do_cancel_local_task = std::move(do_cancel_local_task), node_id]( - const Status &status, - std::vector &&nodes) mutable { - if (!status.ok()) { - RAY_LOG(INFO) << "Failed to get node info from GCS"; - return; - } - if (nodes.empty() || nodes[0].state() != rpc::GcsNodeInfo::ALIVE) { - RAY_LOG(INFO).WithField(node_id) - << "Not sending CancelLocalTask because node is dead"; - return; - } - do_cancel_local_task(nodes[0]); - }, - -1, - {node_id}); - return; - } - if (node_info->state() == rpc::GcsNodeInfo::DEAD) { - RAY_LOG(INFO).WithField(node_id) - << "Not sending CancelLocalTask because node is dead"; - return; - } - do_cancel_local_task(*node_info); + // Cancel can execute on the user's python thread, but the GCS node cache is updated on + // the io service thread and is not thread-safe. Hence we need to post the entire + // cache access to the io service thread. + boost::asio::post( + cancel_retry_timer_.get_executor(), + [this, do_cancel_local_task = std::move(do_cancel_local_task), node_id]() mutable { + auto *node_info = gcs_client_->Nodes().GetNodeAddressAndLiveness( + node_id, /*filter_dead_nodes=*/false); + if (node_info == nullptr) { + gcs_client_->Nodes().AsyncGetAllNodeAddressAndLiveness( + [do_cancel_local_task = std::move(do_cancel_local_task), node_id]( + const Status &status, + std::vector &&nodes) mutable { + if (!status.ok()) { + RAY_LOG(INFO) << "Failed to get node info from GCS"; + return; + } + if (nodes.empty() || nodes[0].state() != rpc::GcsNodeInfo::ALIVE) { + RAY_LOG(INFO).WithField(node_id) + << "Not sending CancelLocalTask because node is dead"; + return; + } + do_cancel_local_task(nodes[0]); + }, + -1, + {node_id}); + return; + } + if (node_info->state() == rpc::GcsNodeInfo::DEAD) { + RAY_LOG(INFO).WithField(node_id) + << "Not sending CancelLocalTask because node is dead"; + return; + } + do_cancel_local_task(*node_info); + }); } void NormalTaskSubmitter::RequestOwnerToCancelTask(const ObjectID &object_id, From 49250fbf70477137f05d0d5397fb9eafe37b55b1 Mon Sep 17 00:00:00 2001 From: joshlee Date: Tue, 18 Nov 2025 01:35:43 +0000 Subject: [PATCH 15/18] Bad merge conflict fix Signed-off-by: joshlee --- .../tests/direct_actor_transport_test.cc | 290 +++++++++--------- 1 file changed, 142 insertions(+), 148 deletions(-) diff --git a/src/ray/core_worker/task_submission/tests/direct_actor_transport_test.cc b/src/ray/core_worker/task_submission/tests/direct_actor_transport_test.cc index 571871735823..be7ee28586a6 100644 --- a/src/ray/core_worker/task_submission/tests/direct_actor_transport_test.cc +++ b/src/ray/core_worker/task_submission/tests/direct_actor_transport_test.cc @@ -23,168 +23,162 @@ #include "ray/core_worker/reference_counter.h" #include "ray/core_worker/reference_counter_interface.h" #include "ray/core_worker/task_submission/actor_task_submitter.h" -<<<<<<< HEAD -#include "ray/raylet_rpc_client/raylet_client_pool.h" -======= #include "ray/observability/fake_metric.h" #include "ray/pubsub/fake_publisher.h" #include "ray/pubsub/fake_subscriber.h" ->>>>>>> ae94ff496a308c52100cd99b1857836b739498e0 +#include "ray/raylet_rpc_client/raylet_client_pool.h" - namespace ray { - namespace core { - using ::testing::_; +namespace ray { +namespace core { +using ::testing::_; - class DirectTaskTransportTest : public ::testing::Test { - public: - DirectTaskTransportTest() : io_work(io_context.get_executor()) {} +class DirectTaskTransportTest : public ::testing::Test { + public: + DirectTaskTransportTest() : io_work(io_context.get_executor()) {} - void SetUp() override { - gcs_client = std::make_shared(); - actor_creator = std::make_unique(gcs_client->Actors()); + void SetUp() override { + gcs_client = std::make_shared(); + actor_creator = std::make_unique(gcs_client->Actors()); - task_manager = std::make_shared(); - client_pool = std::make_shared( - [&](const rpc::Address &) { return nullptr; }); - raylet_client_pool = std::make_shared( - [](const rpc::Address &) -> std::shared_ptr { - return nullptr; - }); - memory_store = DefaultCoreWorkerMemoryStoreWithThread::Create(); - publisher = std::make_unique(); - subscriber = std::make_unique(); - reference_counter = std::make_shared( - rpc::Address(), - publisher.get(), - subscriber.get(), - /*is_node_dead=*/[](const NodeID &) { return false; }, - fake_owned_object_count_gauge, - fake_owned_object_size_gauge, - /*lineage_pinning_enabled=*/false); - actor_task_submitter = std::make_unique( - *client_pool, - *raylet_client_pool, - gcs_client, - *memory_store, - *task_manager, - *actor_creator, - [](const ObjectID &object_id) { return rpc::TensorTransport::OBJECT_STORE; }, - nullptr, - io_context, - reference_counter); - } + task_manager = std::make_shared(); + client_pool = std::make_shared( + [&](const rpc::Address &) { return nullptr; }); + raylet_client_pool = std::make_shared( + [](const rpc::Address &) -> std::shared_ptr { + return nullptr; + }); + memory_store = DefaultCoreWorkerMemoryStoreWithThread::Create(); + publisher = std::make_unique(); + subscriber = std::make_unique(); + reference_counter = std::make_shared( + rpc::Address(), + publisher.get(), + subscriber.get(), + /*is_node_dead=*/[](const NodeID &) { return false; }, + fake_owned_object_count_gauge, + fake_owned_object_size_gauge, + /*lineage_pinning_enabled=*/false); + actor_task_submitter = std::make_unique( + *client_pool, + *raylet_client_pool, + gcs_client, + *memory_store, + *task_manager, + *actor_creator, + [](const ObjectID &object_id) { return rpc::TensorTransport::OBJECT_STORE; }, + nullptr, + io_context, + reference_counter); + } - TaskSpecification GetActorTaskSpec(const ActorID &actor_id) { - rpc::TaskSpec task_spec; - task_spec.set_type(rpc::TaskType::ACTOR_TASK); - task_spec.mutable_actor_task_spec()->set_actor_id(actor_id.Binary()); - task_spec.set_task_id( - TaskID::ForActorTask(JobID::FromInt(10), TaskID::Nil(), 0, actor_id).Binary()); - return TaskSpecification(task_spec); - } + TaskSpecification GetActorTaskSpec(const ActorID &actor_id) { + rpc::TaskSpec task_spec; + task_spec.set_type(rpc::TaskType::ACTOR_TASK); + task_spec.mutable_actor_task_spec()->set_actor_id(actor_id.Binary()); + task_spec.set_task_id( + TaskID::ForActorTask(JobID::FromInt(10), TaskID::Nil(), 0, actor_id).Binary()); + return TaskSpecification(task_spec); + } - TaskSpecification GetActorCreationTaskSpec(const ActorID &actor_id) { - rpc::TaskSpec task_spec; - task_spec.set_task_id(TaskID::ForActorCreationTask(actor_id).Binary()); - task_spec.set_type(rpc::TaskType::ACTOR_CREATION_TASK); - rpc::ActorCreationTaskSpec actor_creation_task_spec; - actor_creation_task_spec.set_actor_id(actor_id.Binary()); - task_spec.mutable_actor_creation_task_spec()->CopyFrom(actor_creation_task_spec); - return TaskSpecification(task_spec); - } + TaskSpecification GetActorCreationTaskSpec(const ActorID &actor_id) { + rpc::TaskSpec task_spec; + task_spec.set_task_id(TaskID::ForActorCreationTask(actor_id).Binary()); + task_spec.set_type(rpc::TaskType::ACTOR_CREATION_TASK); + rpc::ActorCreationTaskSpec actor_creation_task_spec; + actor_creation_task_spec.set_actor_id(actor_id.Binary()); + task_spec.mutable_actor_creation_task_spec()->CopyFrom(actor_creation_task_spec); + return TaskSpecification(task_spec); + } - protected: - bool CheckSubmitTask(TaskSpecification task) { - actor_task_submitter->SubmitTask(task); - return 1 == io_context.poll_one(); - } + protected: + bool CheckSubmitTask(TaskSpecification task) { + actor_task_submitter->SubmitTask(task); + return 1 == io_context.poll_one(); + } - protected: - instrumented_io_context io_context; - boost::asio::executor_work_guard io_work; - std::unique_ptr actor_task_submitter; - std::shared_ptr client_pool; - std::shared_ptr raylet_client_pool; - std::unique_ptr memory_store; - std::shared_ptr task_manager; - std::unique_ptr actor_creator; - std::shared_ptr gcs_client; - std::unique_ptr publisher; - std::unique_ptr subscriber; - ray::observability::FakeGauge fake_owned_object_count_gauge; - ray::observability::FakeGauge fake_owned_object_size_gauge; - std::shared_ptr reference_counter; - }; + protected: + instrumented_io_context io_context; + boost::asio::executor_work_guard io_work; + std::unique_ptr actor_task_submitter; + std::shared_ptr client_pool; + std::shared_ptr raylet_client_pool; + std::unique_ptr memory_store; + std::shared_ptr task_manager; + std::unique_ptr actor_creator; + std::shared_ptr gcs_client; + std::unique_ptr publisher; + std::unique_ptr subscriber; + ray::observability::FakeGauge fake_owned_object_count_gauge; + ray::observability::FakeGauge fake_owned_object_size_gauge; + std::shared_ptr reference_counter; +}; - TEST_F(DirectTaskTransportTest, ActorCreationOk) { - auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000"); - auto creation_task_spec = GetActorCreationTaskSpec(actor_id); - EXPECT_CALL(*task_manager, CompletePendingTask(creation_task_spec.TaskId(), _, _, _)); - actor_task_submitter->SubmitActorCreationTask(creation_task_spec); - gcs_client->mock_actor_accessor->async_create_actor_callback_( - Status::OK(), rpc::CreateActorReply()); - } +TEST_F(DirectTaskTransportTest, ActorCreationOk) { + auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000"); + auto creation_task_spec = GetActorCreationTaskSpec(actor_id); + EXPECT_CALL(*task_manager, CompletePendingTask(creation_task_spec.TaskId(), _, _, _)); + actor_task_submitter->SubmitActorCreationTask(creation_task_spec); + gcs_client->mock_actor_accessor->async_create_actor_callback_(Status::OK(), + rpc::CreateActorReply()); +} - TEST_F(DirectTaskTransportTest, ActorCreationFail) { - auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000"); - auto creation_task_spec = GetActorCreationTaskSpec(actor_id); - EXPECT_CALL(*task_manager, CompletePendingTask(_, _, _, _)).Times(0); - EXPECT_CALL( - *task_manager, - FailPendingTask( - creation_task_spec.TaskId(), rpc::ErrorType::ACTOR_CREATION_FAILED, _, _)); - actor_task_submitter->SubmitActorCreationTask(creation_task_spec); - gcs_client->mock_actor_accessor->async_create_actor_callback_( - Status::IOError(""), rpc::CreateActorReply()); - } +TEST_F(DirectTaskTransportTest, ActorCreationFail) { + auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000"); + auto creation_task_spec = GetActorCreationTaskSpec(actor_id); + EXPECT_CALL(*task_manager, CompletePendingTask(_, _, _, _)).Times(0); + EXPECT_CALL( + *task_manager, + FailPendingTask( + creation_task_spec.TaskId(), rpc::ErrorType::ACTOR_CREATION_FAILED, _, _)); + actor_task_submitter->SubmitActorCreationTask(creation_task_spec); + gcs_client->mock_actor_accessor->async_create_actor_callback_(Status::IOError(""), + rpc::CreateActorReply()); +} - TEST_F(DirectTaskTransportTest, ActorRegisterFailure) { - auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000"); - ASSERT_TRUE(ObjectID::IsActorID(ObjectID::ForActorHandle(actor_id))); - ASSERT_EQ(actor_id, ObjectID::ToActorID(ObjectID::ForActorHandle(actor_id))); - auto creation_task_spec = GetActorCreationTaskSpec(actor_id); - auto task_spec = GetActorTaskSpec(actor_id); - auto task_arg = task_spec.GetMutableMessage().add_args(); - auto inline_obj_ref = task_arg->add_nested_inlined_refs(); - inline_obj_ref->set_object_id(ObjectID::ForActorHandle(actor_id).Binary()); - actor_creator->AsyncRegisterActor(creation_task_spec, nullptr); - ASSERT_TRUE(actor_creator->IsActorInRegistering(actor_id)); - actor_task_submitter->AddActorQueueIfNotExists(actor_id, - -1, - /*allow_out_of_order_execution*/ false, - /*fail_if_actor_unreachable*/ true, - /*owned*/ false); - ASSERT_TRUE(CheckSubmitTask(task_spec)); - EXPECT_CALL(*task_manager, - FailOrRetryPendingTask(task_spec.TaskId(), - rpc::ErrorType::DEPENDENCY_RESOLUTION_FAILED, - _, - _, - _, - _)); - gcs_client->mock_actor_accessor->async_register_actor_callback_(Status::IOError("")); - } +TEST_F(DirectTaskTransportTest, ActorRegisterFailure) { + auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000"); + ASSERT_TRUE(ObjectID::IsActorID(ObjectID::ForActorHandle(actor_id))); + ASSERT_EQ(actor_id, ObjectID::ToActorID(ObjectID::ForActorHandle(actor_id))); + auto creation_task_spec = GetActorCreationTaskSpec(actor_id); + auto task_spec = GetActorTaskSpec(actor_id); + auto task_arg = task_spec.GetMutableMessage().add_args(); + auto inline_obj_ref = task_arg->add_nested_inlined_refs(); + inline_obj_ref->set_object_id(ObjectID::ForActorHandle(actor_id).Binary()); + actor_creator->AsyncRegisterActor(creation_task_spec, nullptr); + ASSERT_TRUE(actor_creator->IsActorInRegistering(actor_id)); + actor_task_submitter->AddActorQueueIfNotExists(actor_id, + -1, + /*allow_out_of_order_execution*/ false, + /*fail_if_actor_unreachable*/ true, + /*owned*/ false); + ASSERT_TRUE(CheckSubmitTask(task_spec)); + EXPECT_CALL( + *task_manager, + FailOrRetryPendingTask( + task_spec.TaskId(), rpc::ErrorType::DEPENDENCY_RESOLUTION_FAILED, _, _, _, _)); + gcs_client->mock_actor_accessor->async_register_actor_callback_(Status::IOError("")); +} - TEST_F(DirectTaskTransportTest, ActorRegisterOk) { - auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000"); - ASSERT_TRUE(ObjectID::IsActorID(ObjectID::ForActorHandle(actor_id))); - ASSERT_EQ(actor_id, ObjectID::ToActorID(ObjectID::ForActorHandle(actor_id))); - auto creation_task_spec = GetActorCreationTaskSpec(actor_id); - auto task_spec = GetActorTaskSpec(actor_id); - auto task_arg = task_spec.GetMutableMessage().add_args(); - auto inline_obj_ref = task_arg->add_nested_inlined_refs(); - inline_obj_ref->set_object_id(ObjectID::ForActorHandle(actor_id).Binary()); - actor_creator->AsyncRegisterActor(creation_task_spec, nullptr); - ASSERT_TRUE(actor_creator->IsActorInRegistering(actor_id)); - actor_task_submitter->AddActorQueueIfNotExists(actor_id, - -1, - /*allow_out_of_order_execution*/ false, - /*fail_if_actor_unreachable*/ true, - /*owned*/ false); - ASSERT_TRUE(CheckSubmitTask(task_spec)); - EXPECT_CALL(*task_manager, FailOrRetryPendingTask(_, _, _, _, _, _)).Times(0); - gcs_client->mock_actor_accessor->async_register_actor_callback_(Status::OK()); - } +TEST_F(DirectTaskTransportTest, ActorRegisterOk) { + auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000"); + ASSERT_TRUE(ObjectID::IsActorID(ObjectID::ForActorHandle(actor_id))); + ASSERT_EQ(actor_id, ObjectID::ToActorID(ObjectID::ForActorHandle(actor_id))); + auto creation_task_spec = GetActorCreationTaskSpec(actor_id); + auto task_spec = GetActorTaskSpec(actor_id); + auto task_arg = task_spec.GetMutableMessage().add_args(); + auto inline_obj_ref = task_arg->add_nested_inlined_refs(); + inline_obj_ref->set_object_id(ObjectID::ForActorHandle(actor_id).Binary()); + actor_creator->AsyncRegisterActor(creation_task_spec, nullptr); + ASSERT_TRUE(actor_creator->IsActorInRegistering(actor_id)); + actor_task_submitter->AddActorQueueIfNotExists(actor_id, + -1, + /*allow_out_of_order_execution*/ false, + /*fail_if_actor_unreachable*/ true, + /*owned*/ false); + ASSERT_TRUE(CheckSubmitTask(task_spec)); + EXPECT_CALL(*task_manager, FailOrRetryPendingTask(_, _, _, _, _, _)).Times(0); + gcs_client->mock_actor_accessor->async_register_actor_callback_(Status::OK()); +} - } // namespace core +} // namespace core } // namespace ray From 3dbcc22bead5c1d28e622297828b16adedc5684a Mon Sep 17 00:00:00 2001 From: joshlee Date: Tue, 18 Nov 2025 05:46:20 +0000 Subject: [PATCH 16/18] Addressing comments Signed-off-by: joshlee --- .../task_submission/normal_task_submitter.cc | 101 +++++++++--------- 1 file changed, 49 insertions(+), 52 deletions(-) diff --git a/src/ray/core_worker/task_submission/normal_task_submitter.cc b/src/ray/core_worker/task_submission/normal_task_submitter.cc index c0771d914abe..94a956e7ff71 100644 --- a/src/ray/core_worker/task_submission/normal_task_submitter.cc +++ b/src/ray/core_worker/task_submission/normal_task_submitter.cc @@ -665,66 +665,63 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, SchedulingKey scheduling_key(task_spec.GetSchedulingClass(), task_spec.GetDependencyIds(), task_spec.GetRuntimeEnvHash()); - NodeID node_id; - std::string executor_worker_id; - { - absl::MutexLock lock(&mu_); - generators_to_resubmit_.erase(task_id); - // For idempotency. - if (cancelled_tasks_.contains(task_id)) { - // The task cancel is already in progress. We don't need to do anything. - return; - } + absl::MutexLock lock(&mu_); + generators_to_resubmit_.erase(task_id); - task_manager_.MarkTaskCanceled(task_id); - if (!task_manager_.IsTaskPending(task_id)) { - // The task is finished or failed so marking the task as cancelled is sufficient. - return; - } + // For idempotency. + if (cancelled_tasks_.contains(task_id)) { + // The task cancel is already in progress. We don't need to do anything. + return; + } - auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key]; - auto &scheduling_tasks = scheduling_key_entry.task_queue; - // This cancels tasks that have completed dependencies and are awaiting - // a worker lease. - if (!scheduling_tasks.empty()) { - for (auto spec = scheduling_tasks.begin(); spec != scheduling_tasks.end(); spec++) { - if (spec->TaskId() == task_id) { - scheduling_tasks.erase(spec); - CancelWorkerLeaseIfNeeded(scheduling_key); - task_manager_.FailPendingTask(task_id, rpc::ErrorType::TASK_CANCELLED); - return; - } - } - } + task_manager_.MarkTaskCanceled(task_id); + if (!task_manager_.IsTaskPending(task_id)) { + // The task is finished or failed so marking the task as cancelled is sufficient. + return; + } - // This will get removed either when the RPC call to cancel is returned, when all - // dependencies are resolved, or when dependency resolution is successfully cancelled. - RAY_CHECK(cancelled_tasks_.emplace(task_id).second); - auto rpc_client_address = executing_tasks_.find(task_id); - if (rpc_client_address == executing_tasks_.end()) { - if (failed_tasks_pending_failure_cause_.contains(task_id)) { - // We are waiting for the task failure cause. Do not fail it here; instead, - // wait for the cause to come in and then handle it appropriately. - } else { - // This case is reached for tasks that have unresolved dependencies. - if (resolver_.CancelDependencyResolution(task_id)) { - // ResolveDependencies callback will never be called if dependency resolution - // was successfully cancelled, so need to remove from the set here. - cancelled_tasks_.erase(task_id); - } + auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key]; + auto &scheduling_tasks = scheduling_key_entry.task_queue; + // This cancels tasks that have completed dependencies and are awaiting + // a worker lease. + if (!scheduling_tasks.empty()) { + for (auto spec = scheduling_tasks.begin(); spec != scheduling_tasks.end(); spec++) { + if (spec->TaskId() == task_id) { + scheduling_tasks.erase(spec); + CancelWorkerLeaseIfNeeded(scheduling_key); task_manager_.FailPendingTask(task_id, rpc::ErrorType::TASK_CANCELLED); + return; } - if (scheduling_key_entry.CanDelete()) { - // We can safely remove the entry keyed by scheduling_key from the - // scheduling_key_entries_ hashmap. - scheduling_key_entries_.erase(scheduling_key); + } + } + + // This will get removed either when the RPC call to cancel is returned, when all + // dependencies are resolved, or when dependency resolution is successfully cancelled. + RAY_CHECK(cancelled_tasks_.emplace(task_id).second); + auto rpc_client_address = executing_tasks_.find(task_id); + if (rpc_client_address == executing_tasks_.end()) { + if (failed_tasks_pending_failure_cause_.contains(task_id)) { + // We are waiting for the task failure cause. Do not fail it here; instead, + // wait for the cause to come in and then handle it appropriately. + } else { + // This case is reached for tasks that have unresolved dependencies. + if (resolver_.CancelDependencyResolution(task_id)) { + // ResolveDependencies callback will never be called if dependency resolution + // was successfully cancelled, so need to remove from the set here. + cancelled_tasks_.erase(task_id); } - return; + task_manager_.FailPendingTask(task_id, rpc::ErrorType::TASK_CANCELLED); } - node_id = NodeID::FromBinary(rpc_client_address->second.node_id()); - executor_worker_id = rpc_client_address->second.worker_id(); + if (scheduling_key_entry.CanDelete()) { + // We can safely remove the entry keyed by scheduling_key from the + // scheduling_key_entries_ hashmap. + scheduling_key_entries_.erase(scheduling_key); + } + return; } + auto node_id = NodeID::FromBinary(rpc_client_address->second.node_id()); + auto executor_worker_id = rpc_client_address->second.worker_id(); auto do_cancel_local_task = [this, @@ -754,7 +751,7 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, force_kill, recursive](const Status &status, const rpc::CancelLocalTaskReply &reply) mutable { - absl::MutexLock lock(&mu_); + absl::MutexLock callback_lock(&mu_); cancelled_tasks_.erase(task_spec.TaskId()); if (!status.ok()) { RAY_LOG(DEBUG) << "CancelLocalTask RPC failed for task " From 73445ab7ecd423e9960b6631b65381ca3e6bc455 Mon Sep 17 00:00:00 2001 From: joshlee Date: Tue, 18 Nov 2025 21:12:17 +0000 Subject: [PATCH 17/18] Fix cpp test Signed-off-by: joshlee --- .../tests/normal_task_submitter_test.cc | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc index e8a50de3271b..0f418ea33156 100644 --- a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc +++ b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc @@ -466,7 +466,8 @@ class NormalTaskSubmitterTest : public testing::Test { actor_creator(std::make_shared()), lease_policy(std::make_unique()), lease_policy_ptr(lease_policy.get()), - mock_gcs_client_(std::make_shared()) { + mock_gcs_client_(std::make_shared()), + io_work_(boost::asio::make_work_guard(io_context)) { address.set_node_id(local_node_id.Binary()); lease_policy_ptr->SetNodeID(local_node_id); } @@ -531,6 +532,7 @@ class NormalTaskSubmitterTest : public testing::Test { MockLeasePolicy *lease_policy_ptr = nullptr; std::shared_ptr mock_gcs_client_; instrumented_io_context io_context; + boost::asio::executor_work_guard io_work_; ray::observability::FakeHistogram fake_scheduler_placement_time_ms_histogram_; }; @@ -688,6 +690,9 @@ TEST_F(NormalTaskSubmitterTest, TestCancellationWhileHandlingTaskFailure) { ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("oops"))); // Cancel the task while GetWorkerFailureCause has not been completed. submitter.CancelTask(task, true, false); + // ReplyPushTask removes the task from the executing_tasks_ map hence + // we don't need to trigger CancelLocalTask RPC. + RAY_CHECK(!io_context.poll_one()); // Completing the GetWorkerFailureCause call. Check that the reply runs without error // and FailPendingTask is not called. ASSERT_TRUE(raylet_client->ReplyGetWorkerFailureCause()); @@ -1770,6 +1775,7 @@ TEST_F(NormalTaskSubmitterTest, TestKillExecutingTask) { // Try force kill, exiting the worker submitter.CancelTask(task, true, false); + RAY_CHECK(io_context.poll_one()); ASSERT_EQ(raylet_client->cancel_local_task_requests.front().intended_task_id(), task.TaskIdBinary()); ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("workerdying"), true)); @@ -1788,6 +1794,7 @@ TEST_F(NormalTaskSubmitterTest, TestKillExecutingTask) { // Try non-force kill, worker returns normally submitter.CancelTask(task, false, false); + RAY_CHECK(io_context.poll_one()); ASSERT_TRUE(worker_client->ReplyPushTask()); ASSERT_EQ(raylet_client->cancel_local_task_requests.back().intended_task_id(), task.TaskIdBinary()); @@ -1810,6 +1817,9 @@ TEST_F(NormalTaskSubmitterTest, TestKillPendingTask) { submitter.SubmitTask(task); submitter.CancelTask(task, true, false); + // We haven't been granted a worker lease yet, so the task is not executing. + // So we don't need to trigger CancelLocalTask RPC. + RAY_CHECK(!io_context.poll_one()); ASSERT_EQ(worker_client->kill_requests.size(), 0); ASSERT_EQ(worker_client->callbacks.size(), 0); ASSERT_EQ(raylet_client->num_workers_returned, 0); @@ -1838,6 +1848,9 @@ TEST_F(NormalTaskSubmitterTest, TestKillResolvingTask) { submitter.SubmitTask(task); ASSERT_EQ(task_manager->num_inlined_dependencies, 0); submitter.CancelTask(task, true, false); + // We haven't been granted a worker lease yet, so the task is not executing. + // So we don't need to trigger CancelLocalTask RPC. + RAY_CHECK(!io_context.poll_one()); auto data = GenerateRandomObject(); store->Put(*data, obj1, /*has_reference=*/true); WaitForObjectIdInMemoryStore(*store, obj1); @@ -1888,6 +1901,7 @@ TEST_F(NormalTaskSubmitterTest, TestCancelBeforeAfterQueueGeneratorForResubmit) submitter.SubmitTask(task); ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, local_node_id)); submitter.CancelTask(task, /*force_kill=*/false, /*recursive=*/true); + RAY_CHECK(io_context.poll_one()); ASSERT_FALSE(submitter.QueueGeneratorForResubmit(task)); raylet_client->ReplyCancelLocalTask(); ASSERT_TRUE(submitter.QueueGeneratorForResubmit(task)); @@ -1906,6 +1920,7 @@ TEST_F(NormalTaskSubmitterTest, TestCancelBeforeAfterQueueGeneratorForResubmit) ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, local_node_id)); ASSERT_TRUE(submitter.QueueGeneratorForResubmit(task2)); submitter.CancelTask(task2, /*force_kill=*/false, /*recursive=*/true); + RAY_CHECK(io_context.poll_one()); ASSERT_TRUE(worker_client->ReplyPushTask()); raylet_client->ReplyCancelLocalTask(Status::OK(), /*attempt_succeeded=*/true, From f737eefafd183b458fa0b15981056f5e8f5800a4 Mon Sep 17 00:00:00 2001 From: joshlee Date: Wed, 19 Nov 2025 23:53:12 +0000 Subject: [PATCH 18/18] Addressing comments Signed-off-by: joshlee --- .../ray/tests/test_raylet_fault_tolerance.py | 2 +- src/ray/core_worker/core_worker_process.cc | 2 +- .../core_worker/task_submission/BUILD.bazel | 13 +++ .../task_submission/actor_task_submitter.cc | 76 +++++----------- .../task_submission/actor_task_submitter.h | 4 +- .../task_submission/normal_task_submitter.cc | 72 ++++------------ .../task_submission/normal_task_submitter.h | 6 +- .../task_submission/task_submission_util.h | 86 +++++++++++++++++++ .../tests/normal_task_submitter_test.cc | 30 ++++--- src/ray/core_worker/tests/core_worker_test.cc | 2 +- .../raylet_rpc_client/fake_raylet_client.h | 14 +-- 11 files changed, 173 insertions(+), 134 deletions(-) create mode 100644 src/ray/core_worker/task_submission/task_submission_util.h diff --git a/python/ray/tests/test_raylet_fault_tolerance.py b/python/ray/tests/test_raylet_fault_tolerance.py index b9d1089b9724..8c4597da8d78 100644 --- a/python/ray/tests/test_raylet_fault_tolerance.py +++ b/python/ray/tests/test_raylet_fault_tolerance.py @@ -269,7 +269,7 @@ def test_cancel_local_task_rpc_retry_and_idempotency( Verify that the RPC is idempotent when network failures occur. When force_kill=True, verify the worker process is actually killed using psutil. """ - ray.init(num_cpus=2) + ray.init(num_cpus=1) signaler = SignalActor.remote() @ray.remote(num_cpus=1) diff --git a/src/ray/core_worker/core_worker_process.cc b/src/ray/core_worker/core_worker_process.cc index 97d16d988e44..35c6fbd4e7b0 100644 --- a/src/ray/core_worker/core_worker_process.cc +++ b/src/ray/core_worker/core_worker_process.cc @@ -556,7 +556,7 @@ std::shared_ptr CoreWorkerProcessImpl::CreateCoreWorker( // OBJECT_STORE. return rpc::TensorTransport::OBJECT_STORE; }, - boost::asio::steady_timer(io_service_), + io_service_, *scheduler_placement_time_ms_histogram_); auto report_locality_data_callback = [this]( diff --git a/src/ray/core_worker/task_submission/BUILD.bazel b/src/ray/core_worker/task_submission/BUILD.bazel index 9873169b4cf8..6856592a792b 100644 --- a/src/ray/core_worker/task_submission/BUILD.bazel +++ b/src/ray/core_worker/task_submission/BUILD.bazel @@ -53,6 +53,17 @@ ray_cc_library( ], ) +ray_cc_library( + name = "task_submission_util", + hdrs = ["task_submission_util.h"], + visibility = [":__subpackages__"], + deps = [ + "//src/ray/common:asio", + "//src/ray/common:id", + "//src/ray/gcs_rpc_client:gcs_client", + ], +) + ray_cc_library( name = "actor_task_submitter", srcs = ["actor_task_submitter.cc"], @@ -66,6 +77,7 @@ ray_cc_library( ":dependency_resolver", ":out_of_order_actor_submit_queue", ":sequential_actor_submit_queue", + ":task_submission_util", "//src/ray/common:asio", "//src/ray/common:id", "//src/ray/common:protobuf_utils", @@ -93,6 +105,7 @@ ray_cc_library( ], deps = [ ":dependency_resolver", + ":task_submission_util", "//src/ray/common:id", "//src/ray/common:lease", "//src/ray/common:protobuf_utils", diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.cc b/src/ray/core_worker/task_submission/actor_task_submitter.cc index 616385e238ee..8cf769caa03f 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.cc +++ b/src/ray/core_worker/task_submission/actor_task_submitter.cc @@ -21,6 +21,7 @@ #include #include "ray/common/protobuf_utils.h" +#include "ray/core_worker/task_submission/task_submission_util.h" #include "ray/util/time.h" namespace ray { @@ -912,17 +913,16 @@ std::string ActorTaskSubmitter::DebugString(const ActorID &actor_id) const { return stream.str(); } -void ActorTaskSubmitter::RetryCancelTask(TaskSpecification task_spec, - bool recursive, - int64_t milliseconds) { +void ActorTaskSubmitter::RetryCancelTask(TaskSpecification task_spec, bool recursive) { + auto delay_ms = RayConfig::instance().cancellation_retry_ms(); RAY_LOG(DEBUG).WithField(task_spec.TaskId()) - << "Task cancelation will be retried in " << milliseconds << " ms"; + << "Task cancelation will be retried in " << delay_ms << " ms"; execute_after( io_service_, [this, task_spec = std::move(task_spec), recursive] { CancelTask(task_spec, recursive); }, - std::chrono::milliseconds(milliseconds)); + std::chrono::milliseconds(delay_ms)); } void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) { @@ -997,7 +997,7 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) // an executor tells us to stop retrying. // If there's no client, it means actor is not created yet. - // Retry in 1 second. + // Retry after the configured delay. NodeID node_id; std::string executor_worker_id; { @@ -1006,7 +1006,7 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) auto queue = client_queues_.find(actor_id); RAY_CHECK(queue != client_queues_.end()); if (!queue->second.client_address_.has_value()) { - RetryCancelTask(task_spec, recursive, 1000); + RetryCancelTask(task_spec, recursive); return; } node_id = NodeID::FromBinary(queue->second.client_address_.value().node_id()); @@ -1034,61 +1034,29 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive) [this, task_spec = std::move(task_spec), recursive]( const Status &status, const rpc::CancelLocalTaskReply &reply) mutable { if (!status.ok()) { - RAY_LOG(DEBUG) << "CancelLocalTask RPC failed for task " - << task_spec.TaskId() << ": " << status.ToString() - << " due to node death"; + RAY_LOG(INFO) << "CancelLocalTask RPC failed for task " + << task_spec.TaskId() << ": " << status.ToString() + << " due to node death"; return; } else { - RAY_LOG(DEBUG) << "CancelLocalTask RPC response received for " - << task_spec.TaskId() - << " with attempt_succeeded: " << reply.attempt_succeeded() - << " requested_task_running: " - << reply.requested_task_running(); + RAY_LOG(INFO) << "CancelLocalTask RPC response received for " + << task_spec.TaskId() + << " with attempt_succeeded: " << reply.attempt_succeeded() + << " requested_task_running: " + << reply.requested_task_running(); } - // Keep retrying every 2 seconds until a task is officially - // finished. + // Keep retrying until a task is officially finished. if (!reply.attempt_succeeded()) { - RetryCancelTask(std::move(task_spec), recursive, 2000); + RetryCancelTask(std::move(task_spec), recursive); } }); }; - // Cancel can execute on the user's python thread, but the GCS node cache is updated on - // the io service thread and is not thread-safe. Hence we need to post the entire - // cache access to the io service thread. - io_service_.post( - [this, do_cancel_local_task = std::move(do_cancel_local_task), node_id]() mutable { - // Check GCS node cache. If node info is not in the cache, query the GCS instead. - auto *node_info = gcs_client_->Nodes().GetNodeAddressAndLiveness( - node_id, /*filter_dead_nodes=*/false); - if (node_info == nullptr) { - gcs_client_->Nodes().AsyncGetAllNodeAddressAndLiveness( - [do_cancel_local_task = std::move(do_cancel_local_task), node_id]( - const Status &status, - std::vector &&nodes) mutable { - if (!status.ok()) { - RAY_LOG(INFO) << "Failed to get node info from GCS"; - return; - } - if (nodes.empty() || nodes[0].state() != rpc::GcsNodeInfo::ALIVE) { - RAY_LOG(INFO).WithField(node_id) - << "Not sending CancelLocalTask because node is dead"; - return; - } - do_cancel_local_task(nodes[0]); - }, - -1, - {node_id}); - return; - } - if (node_info->state() == rpc::GcsNodeInfo::DEAD) { - RAY_LOG(INFO).WithField(node_id) - << "Not sending CancelLocalTask because node is dead"; - return; - } - do_cancel_local_task(*node_info); - }, - "ActorTaskSubmitter.CancelTask"); + PostCancelLocalTask(gcs_client_, + io_service_, + node_id, + std::move(do_cancel_local_task), + "ActorTaskSubmitter.CancelTask"); } bool ActorTaskSubmitter::QueueGeneratorForResubmit(const TaskSpecification &spec) { diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.h b/src/ray/core_worker/task_submission/actor_task_submitter.h index 1c9d4d755d0d..16769e27f7d3 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.h +++ b/src/ray/core_worker/task_submission/actor_task_submitter.h @@ -236,8 +236,8 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface { /// \param recursive If true, it will cancel all child tasks. void CancelTask(TaskSpecification task_spec, bool recursive); - /// Retry the CancelTask in milliseconds. - void RetryCancelTask(TaskSpecification task_spec, bool recursive, int64_t milliseconds); + /// Retry the CancelTask after a configured delay. + void RetryCancelTask(TaskSpecification task_spec, bool recursive); /// Queue the streaming generator up for resubmission. /// \return true if the task is still executing and the submitter agrees to resubmit diff --git a/src/ray/core_worker/task_submission/normal_task_submitter.cc b/src/ray/core_worker/task_submission/normal_task_submitter.cc index 94a956e7ff71..e6bda55bb640 100644 --- a/src/ray/core_worker/task_submission/normal_task_submitter.cc +++ b/src/ray/core_worker/task_submission/normal_task_submitter.cc @@ -22,8 +22,10 @@ #include #include "absl/strings/str_format.h" +#include "ray/common/asio/asio_util.h" #include "ray/common/lease/lease_spec.h" #include "ray/common/protobuf_utils.h" +#include "ray/core_worker/task_submission/task_submission_util.h" #include "ray/util/time.h" namespace ray { @@ -754,30 +756,26 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, absl::MutexLock callback_lock(&mu_); cancelled_tasks_.erase(task_spec.TaskId()); if (!status.ok()) { - RAY_LOG(DEBUG) << "CancelLocalTask RPC failed for task " - << task_spec.TaskId() << ": " << status.ToString() - << " due to node death"; + RAY_LOG(INFO) << "CancelLocalTask RPC failed for task " + << task_spec.TaskId() << ": " << status.ToString() + << " due to node death"; return; } else { - RAY_LOG(DEBUG) << "CancelLocalTask RPC response received for " - << task_spec.TaskId() - << " with attempt_succeeded: " << reply.attempt_succeeded() - << " requested_task_running: " - << reply.requested_task_running(); + RAY_LOG(INFO) << "CancelLocalTask RPC response received for " + << task_spec.TaskId() + << " with attempt_succeeded: " << reply.attempt_succeeded() + << " requested_task_running: " + << reply.requested_task_running(); } if (!reply.attempt_succeeded()) { if (reply.requested_task_running()) { - if (cancel_retry_timer_.expiry().time_since_epoch() <= - std::chrono::high_resolution_clock::now().time_since_epoch()) { - cancel_retry_timer_.expires_after(boost::asio::chrono::milliseconds( - RayConfig::instance().cancellation_retry_ms())); - } - cancel_retry_timer_.async_wait( - boost::bind(&NormalTaskSubmitter::CancelTask, - this, - std::move(task_spec), - force_kill, - recursive)); + execute_after( + io_service_, + [this, task_spec = std::move(task_spec), force_kill, recursive] { + CancelTask(task_spec, force_kill, recursive); + }, + std::chrono::milliseconds( + RayConfig::instance().cancellation_retry_ms())); } else { RAY_LOG(DEBUG) << "Attempt to cancel task " << task_spec.TaskId() << " in a worker that doesn't have this task."; @@ -786,41 +784,7 @@ void NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, }); }; - // Cancel can execute on the user's python thread, but the GCS node cache is updated on - // the io service thread and is not thread-safe. Hence we need to post the entire - // cache access to the io service thread. - boost::asio::post( - cancel_retry_timer_.get_executor(), - [this, do_cancel_local_task = std::move(do_cancel_local_task), node_id]() mutable { - auto *node_info = gcs_client_->Nodes().GetNodeAddressAndLiveness( - node_id, /*filter_dead_nodes=*/false); - if (node_info == nullptr) { - gcs_client_->Nodes().AsyncGetAllNodeAddressAndLiveness( - [do_cancel_local_task = std::move(do_cancel_local_task), node_id]( - const Status &status, - std::vector &&nodes) mutable { - if (!status.ok()) { - RAY_LOG(INFO) << "Failed to get node info from GCS"; - return; - } - if (nodes.empty() || nodes[0].state() != rpc::GcsNodeInfo::ALIVE) { - RAY_LOG(INFO).WithField(node_id) - << "Not sending CancelLocalTask because node is dead"; - return; - } - do_cancel_local_task(nodes[0]); - }, - -1, - {node_id}); - return; - } - if (node_info->state() == rpc::GcsNodeInfo::DEAD) { - RAY_LOG(INFO).WithField(node_id) - << "Not sending CancelLocalTask because node is dead"; - return; - } - do_cancel_local_task(*node_info); - }); + PostCancelLocalTask(gcs_client_, io_service_, node_id, std::move(do_cancel_local_task)); } void NormalTaskSubmitter::RequestOwnerToCancelTask(const ObjectID &object_id, diff --git a/src/ray/core_worker/task_submission/normal_task_submitter.h b/src/ray/core_worker/task_submission/normal_task_submitter.h index c0600db332c0..d8090de144bc 100644 --- a/src/ray/core_worker/task_submission/normal_task_submitter.h +++ b/src/ray/core_worker/task_submission/normal_task_submitter.h @@ -101,7 +101,7 @@ class NormalTaskSubmitter { const JobID &job_id, std::shared_ptr lease_request_rate_limiter, const TensorTransportGetter &tensor_transport_getter, - boost::asio::steady_timer cancel_timer, + instrumented_io_context &io_service, ray::observability::MetricInterface &scheduler_placement_time_ms_histogram) : rpc_address_(std::move(rpc_address)), local_raylet_client_(std::move(local_raylet_client)), @@ -117,7 +117,7 @@ class NormalTaskSubmitter { core_worker_client_pool_(std::move(core_worker_client_pool)), job_id_(job_id), lease_request_rate_limiter_(std::move(lease_request_rate_limiter)), - cancel_retry_timer_(std::move(cancel_timer)), + io_service_(io_service), scheduler_placement_time_ms_histogram_(scheduler_placement_time_ms_histogram) {} /// Schedule a task for direct submission to a worker. @@ -373,7 +373,7 @@ class NormalTaskSubmitter { std::shared_ptr lease_request_rate_limiter_; // Retries cancelation requests if they were not successful. - boost::asio::steady_timer cancel_retry_timer_ ABSL_GUARDED_BY(mu_); + instrumented_io_context &io_service_; ray::observability::MetricInterface &scheduler_placement_time_ms_histogram_; }; diff --git a/src/ray/core_worker/task_submission/task_submission_util.h b/src/ray/core_worker/task_submission/task_submission_util.h new file mode 100644 index 000000000000..617ace7959b1 --- /dev/null +++ b/src/ray/core_worker/task_submission/task_submission_util.h @@ -0,0 +1,86 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ray/common/asio/instrumented_io_context.h" +#include "ray/common/id.h" +#include "ray/gcs_rpc_client/gcs_client.h" + +namespace ray { +namespace core { + +/// Post a CancelLocalTask operation after checking GCS node cache for node liveness. +/// The reason we query the GCS is that we don't store the address of the raylet in the +/// task submission path. Since it's only needed in cancellation, we query the GCS if it's +/// needed rather than pollute the hot path. +/// +/// \param gcs_client GCS client to query node information. +/// \param io_service IO service to post the cancel operation to. +/// \param node_id The local node ID of where the task is executing on +/// \param cancel_callback Callback containing CancelLocalTask RPC to invoke if the node +/// is alive. +inline void PostCancelLocalTask( + std::shared_ptr gcs_client, + instrumented_io_context &io_service, + const NodeID &node_id, + std::function cancel_callback, + const std::string &operation_name = "") { + // Cancel can execute on the user's python thread, but the GCS node cache is updated on + // the io service thread and is not thread-safe. Hence we need to post the entire + // cache access to the io service thread. + io_service.post( + [gcs_client, cancel_callback = std::move(cancel_callback), node_id]() mutable { + // Check GCS node cache. If node info is not in the cache, query the GCS instead. + auto *node_info = + gcs_client->Nodes().GetNodeAddressAndLiveness(node_id, + /*filter_dead_nodes=*/false); + if (node_info == nullptr) { + gcs_client->Nodes().AsyncGetAllNodeAddressAndLiveness( + [cancel_callback = std::move(cancel_callback), node_id]( + const Status &status, + std::vector &&nodes) mutable { + if (!status.ok()) { + RAY_LOG(INFO) << "Failed to get node info from GCS"; + return; + } + if (nodes.empty() || nodes[0].state() != rpc::GcsNodeInfo::ALIVE) { + RAY_LOG(INFO).WithField(node_id) + << "Not sending CancelLocalTask because node is dead"; + return; + } + cancel_callback(nodes[0]); + }, + -1, + {node_id}); + return; + } + if (node_info->state() == rpc::GcsNodeInfo::DEAD) { + RAY_LOG(INFO).WithField(node_id) + << "Not sending CancelLocalTask because node is dead"; + return; + } + cancel_callback(*node_info); + }, + operation_name); +} + +} // namespace core +} // namespace ray diff --git a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc index 0f418ea33156..e7ee13445f0c 100644 --- a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc +++ b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc @@ -513,7 +513,7 @@ class NormalTaskSubmitterTest : public testing::Test { JobID::Nil(), rate_limiter, [](const ObjectID &object_id) { return rpc::TensorTransport::OBJECT_STORE; }, - boost::asio::steady_timer(io_context), + io_context, fake_scheduler_placement_time_ms_histogram_); } @@ -691,8 +691,10 @@ TEST_F(NormalTaskSubmitterTest, TestCancellationWhileHandlingTaskFailure) { // Cancel the task while GetWorkerFailureCause has not been completed. submitter.CancelTask(task, true, false); // ReplyPushTask removes the task from the executing_tasks_ map hence - // we don't need to trigger CancelLocalTask RPC. - RAY_CHECK(!io_context.poll_one()); + // we shouldn't have triggered CancelLocalTask RPC. + while (io_context.poll_one()) { + } + ASSERT_EQ(raylet_client->num_cancel_local_task_requested, 0); // Completing the GetWorkerFailureCause call. Check that the reply runs without error // and FailPendingTask is not called. ASSERT_TRUE(raylet_client->ReplyGetWorkerFailureCause()); @@ -1499,7 +1501,7 @@ void TestSchedulingKey(const std::shared_ptr store, JobID::Nil(), std::make_shared(1), [](const ObjectID &object_id) { return rpc::TensorTransport::OBJECT_STORE; }, - boost::asio::steady_timer(io_context), + io_context, fake_scheduler_placement_time_ms_histogram_); submitter.SubmitTask(same1); @@ -1775,7 +1777,7 @@ TEST_F(NormalTaskSubmitterTest, TestKillExecutingTask) { // Try force kill, exiting the worker submitter.CancelTask(task, true, false); - RAY_CHECK(io_context.poll_one()); + ASSERT_TRUE(io_context.poll_one()); ASSERT_EQ(raylet_client->cancel_local_task_requests.front().intended_task_id(), task.TaskIdBinary()); ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("workerdying"), true)); @@ -1794,7 +1796,7 @@ TEST_F(NormalTaskSubmitterTest, TestKillExecutingTask) { // Try non-force kill, worker returns normally submitter.CancelTask(task, false, false); - RAY_CHECK(io_context.poll_one()); + ASSERT_TRUE(io_context.poll_one()); ASSERT_TRUE(worker_client->ReplyPushTask()); ASSERT_EQ(raylet_client->cancel_local_task_requests.back().intended_task_id(), task.TaskIdBinary()); @@ -1818,8 +1820,10 @@ TEST_F(NormalTaskSubmitterTest, TestKillPendingTask) { submitter.SubmitTask(task); submitter.CancelTask(task, true, false); // We haven't been granted a worker lease yet, so the task is not executing. - // So we don't need to trigger CancelLocalTask RPC. - RAY_CHECK(!io_context.poll_one()); + // So we shouldn't have triggered CancelLocalTask RPC. + while (io_context.poll_one()) { + } + ASSERT_EQ(raylet_client->num_cancel_local_task_requested, 0); ASSERT_EQ(worker_client->kill_requests.size(), 0); ASSERT_EQ(worker_client->callbacks.size(), 0); ASSERT_EQ(raylet_client->num_workers_returned, 0); @@ -1849,8 +1853,10 @@ TEST_F(NormalTaskSubmitterTest, TestKillResolvingTask) { ASSERT_EQ(task_manager->num_inlined_dependencies, 0); submitter.CancelTask(task, true, false); // We haven't been granted a worker lease yet, so the task is not executing. - // So we don't need to trigger CancelLocalTask RPC. - RAY_CHECK(!io_context.poll_one()); + // So we shouldn't have triggered CancelLocalTask RPC. + while (io_context.poll_one()) { + } + ASSERT_EQ(raylet_client->num_cancel_local_task_requested, 0); auto data = GenerateRandomObject(); store->Put(*data, obj1, /*has_reference=*/true); WaitForObjectIdInMemoryStore(*store, obj1); @@ -1901,7 +1907,7 @@ TEST_F(NormalTaskSubmitterTest, TestCancelBeforeAfterQueueGeneratorForResubmit) submitter.SubmitTask(task); ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, local_node_id)); submitter.CancelTask(task, /*force_kill=*/false, /*recursive=*/true); - RAY_CHECK(io_context.poll_one()); + ASSERT_TRUE(io_context.poll_one()); ASSERT_FALSE(submitter.QueueGeneratorForResubmit(task)); raylet_client->ReplyCancelLocalTask(); ASSERT_TRUE(submitter.QueueGeneratorForResubmit(task)); @@ -1920,7 +1926,7 @@ TEST_F(NormalTaskSubmitterTest, TestCancelBeforeAfterQueueGeneratorForResubmit) ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, local_node_id)); ASSERT_TRUE(submitter.QueueGeneratorForResubmit(task2)); submitter.CancelTask(task2, /*force_kill=*/false, /*recursive=*/true); - RAY_CHECK(io_context.poll_one()); + ASSERT_TRUE(io_context.poll_one()); ASSERT_TRUE(worker_client->ReplyPushTask()); raylet_client->ReplyCancelLocalTask(Status::OK(), /*attempt_succeeded=*/true, diff --git a/src/ray/core_worker/tests/core_worker_test.cc b/src/ray/core_worker/tests/core_worker_test.cc index b67b8286b743..31b5c2cc6f17 100644 --- a/src/ray/core_worker/tests/core_worker_test.cc +++ b/src/ray/core_worker/tests/core_worker_test.cc @@ -227,7 +227,7 @@ class CoreWorkerTest : public ::testing::Test { JobID::Nil(), lease_request_rate_limiter, [](const ObjectID &object_id) { return rpc::TensorTransport::OBJECT_STORE; }, - boost::asio::steady_timer(io_service_), + io_service_, fake_scheduler_placement_time_ms_histogram_); auto actor_task_submitter = std::make_unique( diff --git a/src/ray/raylet_rpc_client/fake_raylet_client.h b/src/ray/raylet_rpc_client/fake_raylet_client.h index b08ed3ce90d8..37645bd7aa62 100644 --- a/src/ray/raylet_rpc_client/fake_raylet_client.h +++ b/src/ray/raylet_rpc_client/fake_raylet_client.h @@ -291,7 +291,9 @@ class FakeRayletClient : public RayletClientInterface { int64_t GetPinsInFlight() const override { return 0; } void CancelLocalTask(const CancelLocalTaskRequest &request, - const ClientCallback &callback) override {} + const ClientCallback &callback) override { + num_cancel_local_task_requested += 1; + } int num_workers_requested = 0; int num_workers_returned = 0; @@ -299,17 +301,17 @@ class FakeRayletClient : public RayletClientInterface { int num_leases_canceled = 0; int num_release_unused_workers = 0; int num_get_task_failure_causes = 0; + int num_lease_requested = 0; + int num_return_requested = 0; + int num_commit_requested = 0; + int num_cancel_local_task_requested = 0; + int num_release_unused_bundles_requested = 0; NodeID node_id_ = NodeID::FromRandom(); std::vector killed_actors; std::list> drain_raylet_callbacks = {}; std::list> callbacks = {}; std::list> cancel_callbacks = {}; std::list> release_callbacks = {}; - int num_lease_requested = 0; - int num_return_requested = 0; - int num_commit_requested = 0; - - int num_release_unused_bundles_requested = 0; std::list> lease_callbacks = {}; std::list> commit_callbacks = {}; std::list> return_callbacks = {};