Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/ray/tests/test_core_worker_fault_tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def inject_cancel_remote_task_rpc_failure(monkeypatch, request):
deterministic_failure = request.param
monkeypatch.setenv(
"RAY_testing_rpc_failure",
"CoreWorkerService.grpc_client.CancelRemoteTask=1:"
"CoreWorkerService.grpc_client.RequestOwnerToCancelTask=1:"
+ ("100:0" if deterministic_failure == "request" else "0:100"),
)

Expand Down
56 changes: 55 additions & 1 deletion python/ray/tests/test_raylet_fault_tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import pytest

import ray
from ray._private.test_utils import wait_for_condition
from ray._common.test_utils import SignalActor, wait_for_condition
from ray.core.generated import autoscaler_pb2
from ray.exceptions import GetTimeoutError, TaskCancelledError
from ray.util.placement_group import placement_group, remove_placement_group
from ray.util.scheduling_strategies import (
NodeAffinitySchedulingStrategy,
Expand Down Expand Up @@ -180,5 +181,58 @@ def verify_process_killed():
wait_for_condition(verify_process_killed, timeout=30)


@pytest.fixture
def inject_cancel_local_task_rpc_failure(monkeypatch, request):
deterministic_failure = request.param
monkeypatch.setenv(
"RAY_testing_rpc_failure",
"NodeManagerService.grpc_client.CancelLocalTask=1:"
+ ("100:0" if deterministic_failure == "request" else "0:100"),
)


@pytest.mark.parametrize(
"inject_cancel_local_task_rpc_failure", ["request", "response"], indirect=True
)
@pytest.mark.parametrize("force_kill", [True, False])
def test_cancel_local_task_rpc_retry_and_idempotency(
inject_cancel_local_task_rpc_failure, force_kill, shutdown_only
):
"""Test that CancelLocalTask RPC retries work correctly.

Verify that the RPC is idempotent when network failures occur.
When force_kill=True, verify the worker process is actually killed using psutil.
"""
ray.init(num_cpus=2)
signaler = SignalActor.remote()

@ray.remote(num_cpus=1)
def get_pid():
return os.getpid()

@ray.remote(num_cpus=1)
def blocking_task():
return ray.get(signaler.wait.remote())

worker_pid = ray.get(get_pid.remote())

blocking_ref = blocking_task.remote()

with pytest.raises(GetTimeoutError):
ray.get(blocking_ref, timeout=1)

ray.cancel(blocking_ref, force=force_kill)

with pytest.raises(TaskCancelledError):
ray.get(blocking_ref, timeout=10)

if force_kill:

def verify_process_killed():
return not psutil.pid_exists(worker_pid)

wait_for_condition(verify_process_killed, timeout=30)


if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))
6 changes: 3 additions & 3 deletions src/mock/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ class MockCoreWorker : public CoreWorker {
rpc::SendReplyCallback send_reply_callback),
(override));
MOCK_METHOD(void,
HandleCancelRemoteTask,
(rpc::CancelRemoteTaskRequest request,
rpc::CancelRemoteTaskReply *reply,
HandleRequestOwnerToCancelTask,
(rpc::RequestOwnerToCancelTaskRequest request,
rpc::RequestOwnerToCancelTaskReply *reply,
rpc::SendReplyCallback send_reply_callback),
(override));
MOCK_METHOD(void,
Expand Down
5 changes: 5 additions & 0 deletions src/mock/ray/raylet_client/raylet_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ class MockRayletClientInterface : public RayletClientInterface {
(const rpc::ClientCallback<rpc::GlobalGCReply> &callback),
(override));
MOCK_METHOD(int64_t, GetPinsInFlight, (), (const, override));
MOCK_METHOD(void,
CancelLocalTask,
(const rpc::CancelLocalTaskRequest &request,
const rpc::ClientCallback<rpc::CancelLocalTaskReply> &callback),
(override));
};

} // namespace ray
6 changes: 3 additions & 3 deletions src/mock/ray/rpc/worker/core_worker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ class MockCoreWorkerClientInterface : public CoreWorkerClientInterface {
const ClientCallback<CancelTaskReply> &callback),
(override));
MOCK_METHOD(void,
CancelRemoteTask,
(CancelRemoteTaskRequest && request,
const ClientCallback<CancelRemoteTaskReply> &callback),
RequestOwnerToCancelTask,
(RequestOwnerToCancelTaskRequest && request,
const ClientCallback<RequestOwnerToCancelTaskReply> &callback),
(override));
MOCK_METHOD(void,
GetCoreWorkerStats,
Expand Down
12 changes: 7 additions & 5 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2439,12 +2439,13 @@ Status CoreWorker::CancelTask(const ObjectID &object_id,
}

if (obj_addr.SerializeAsString() != rpc_address_.SerializeAsString()) {
// We don't have CancelRemoteTask for actor_task_submitter_
// We don't have RequestOwnerToCancelTask for actor_task_submitter_
// because it requires the same implementation.
RAY_LOG(DEBUG).WithField(object_id)
<< "Request to cancel a task of object to an owner "
<< obj_addr.SerializeAsString();
normal_task_submitter_->CancelRemoteTask(object_id, obj_addr, force_kill, recursive);
normal_task_submitter_->RequestOwnerToCancelTask(
object_id, obj_addr, force_kill, recursive);
return Status::OK();
}

Expand Down Expand Up @@ -3878,9 +3879,10 @@ void CoreWorker::ProcessSubscribeForRefRemoved(
reference_counter_->SubscribeRefRemoved(object_id, contained_in_id, owner_address);
}

void CoreWorker::HandleCancelRemoteTask(rpc::CancelRemoteTaskRequest request,
rpc::CancelRemoteTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
void CoreWorker::HandleRequestOwnerToCancelTask(
rpc::RequestOwnerToCancelTaskRequest request,
rpc::RequestOwnerToCancelTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
auto status = CancelTask(ObjectID::FromBinary(request.remote_object_id()),
request.force_kill(),
request.recursive());
Expand Down
6 changes: 3 additions & 3 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -1208,9 +1208,9 @@ class CoreWorker {
rpc::SendReplyCallback send_reply_callback);

/// Implements gRPC server handler.
void HandleCancelRemoteTask(rpc::CancelRemoteTaskRequest request,
rpc::CancelRemoteTaskReply *reply,
rpc::SendReplyCallback send_reply_callback);
void HandleRequestOwnerToCancelTask(rpc::RequestOwnerToCancelTaskRequest request,
rpc::RequestOwnerToCancelTaskReply *reply,
rpc::SendReplyCallback send_reply_callback);

/// Implements gRPC server handler.
void HandlePlasmaObjectReady(rpc::PlasmaObjectReadyRequest request,
Expand Down
3 changes: 3 additions & 0 deletions src/ray/core_worker/core_worker_process.cc
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,8 @@ std::shared_ptr<CoreWorker> CoreWorkerProcessImpl::CreateCoreWorker(

auto actor_task_submitter = std::make_unique<ActorTaskSubmitter>(
*core_worker_client_pool,
*raylet_client_pool,
gcs_client,
*memory_store,
*task_manager,
*actor_creator,
Expand Down Expand Up @@ -535,6 +537,7 @@ std::shared_ptr<CoreWorker> CoreWorkerProcessImpl::CreateCoreWorker(
local_raylet_rpc_client,
core_worker_client_pool,
raylet_client_pool,
gcs_client,
std::move(lease_policy),
memory_store,
*task_manager,
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/core_worker_rpc_proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class CoreWorkerServiceHandlerProxy : public rpc::CoreWorkerServiceHandler {
RAY_CORE_WORKER_RPC_PROXY(ReportGeneratorItemReturns)
RAY_CORE_WORKER_RPC_PROXY(KillActor)
RAY_CORE_WORKER_RPC_PROXY(CancelTask)
RAY_CORE_WORKER_RPC_PROXY(CancelRemoteTask)
RAY_CORE_WORKER_RPC_PROXY(RequestOwnerToCancelTask)
RAY_CORE_WORKER_RPC_PROXY(RegisterMutableObjectReader)
RAY_CORE_WORKER_RPC_PROXY(GetCoreWorkerStats)
RAY_CORE_WORKER_RPC_PROXY(LocalGC)
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/grpc_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void CoreWorkerGrpcService::InitServerCallFactories(
max_active_rpcs_per_handler_,
ClusterIdAuthType::NO_AUTH);
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
CancelRemoteTask,
RequestOwnerToCancelTask,
max_active_rpcs_per_handler_,
ClusterIdAuthType::NO_AUTH);
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
Expand Down
6 changes: 3 additions & 3 deletions src/ray/core_worker/grpc_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ class CoreWorkerServiceHandler : public DelayedServiceHandler {
CancelTaskReply *reply,
SendReplyCallback send_reply_callback) = 0;

virtual void HandleCancelRemoteTask(CancelRemoteTaskRequest request,
CancelRemoteTaskReply *reply,
SendReplyCallback send_reply_callback) = 0;
virtual void HandleRequestOwnerToCancelTask(RequestOwnerToCancelTaskRequest request,
RequestOwnerToCancelTaskReply *reply,
SendReplyCallback send_reply_callback) = 0;

virtual void HandleRegisterMutableObjectReader(
RegisterMutableObjectReaderRequest request,
Expand Down
5 changes: 5 additions & 0 deletions src/ray/core_worker/task_submission/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ ray_cc_library(
"//src/ray/common:protobuf_utils",
"//src/ray/core_worker:actor_creator",
"//src/ray/core_worker_rpc_client:core_worker_client_pool",
"//src/ray/gcs_rpc_client:gcs_client",
"//src/ray/raylet_rpc_client:raylet_client_interface",
"//src/ray/raylet_rpc_client:raylet_client_pool",
"//src/ray/rpc:rpc_callback_types",
"//src/ray/util:time",
"@com_google_absl//absl/base:core_headers",
Expand All @@ -96,7 +99,9 @@ ray_cc_library(
"//src/ray/core_worker:memory_store",
"//src/ray/core_worker:task_manager_interface",
"//src/ray/core_worker_rpc_client:core_worker_client_pool",
"//src/ray/gcs_rpc_client:gcs_client",
"//src/ray/raylet_rpc_client:raylet_client_interface",
"//src/ray/raylet_rpc_client:raylet_client_pool",
"//src/ray/util:time",
"@com_google_absl//absl/base:core_headers",
],
Expand Down
92 changes: 70 additions & 22 deletions src/ray/core_worker/task_submission/actor_task_submitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,7 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive)

// If there's no client, it means actor is not created yet.
// Retry in 1 second.
rpc::Address client_address;
{
absl::MutexLock lock(&mu_);
RAY_LOG(DEBUG).WithField(task_id) << "Task was sent to an actor. Send a cancel RPC.";
Expand All @@ -1007,34 +1008,81 @@ void ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursive)
RetryCancelTask(task_spec, recursive, 1000);
return;
}
client_address = queue->second.client_address_.value();
}

rpc::CancelTaskRequest request;
const auto node_id = NodeID::FromBinary(client_address.node_id());
const auto executor_worker_id = client_address.worker_id();

auto do_cancel_local_task = [this,
task_spec = std::move(task_spec),
task_id,
force_kill,
recursive,
executor_worker_id = std::move(executor_worker_id)](
const rpc::GcsNodeInfo &node_info) mutable {
rpc::Address raylet_address;
raylet_address.set_node_id(node_info.node_id());
raylet_address.set_ip_address(node_info.node_manager_address());
raylet_address.set_port(node_info.node_manager_port());

rpc::CancelLocalTaskRequest request;
request.set_intended_task_id(task_spec.TaskIdBinary());
request.set_force_kill(force_kill);
request.set_recursive(recursive);
request.set_caller_worker_id(task_spec.CallerWorkerIdBinary());
auto client = core_worker_client_pool_.GetOrConnect(*queue->second.client_address_);
client->CancelTask(request,
[this, task_spec = std::move(task_spec), recursive, task_id](
const Status &status, const rpc::CancelTaskReply &reply) {
RAY_LOG(DEBUG).WithField(task_spec.TaskId())
<< "CancelTask RPC response received with status "
<< status.ToString();

// Keep retrying every 2 seconds until a task is officially
// finished.
if (!task_manager_.GetTaskSpec(task_id)) {
// Task is already finished.
RAY_LOG(DEBUG).WithField(task_spec.TaskId())
<< "Task is finished. Stop a cancel request.";
return;
}

if (!reply.attempt_succeeded()) {
RetryCancelTask(task_spec, recursive, 2000);
}
});
request.set_executor_worker_id(executor_worker_id);

auto raylet_client = raylet_client_pool_.GetOrConnectByAddress(raylet_address);
raylet_client->CancelLocalTask(
request,
[this, task_spec = std::move(task_spec), recursive, task_id](
const Status &status, const rpc::CancelLocalTaskReply &reply) mutable {
RAY_LOG(DEBUG).WithField(task_spec.TaskId())
<< "CancelTask RPC response received with status " << status.ToString();

// Keep retrying every 2 seconds until a task is officially
// finished.
if (!task_manager_.GetTaskSpec(task_id)) {
// Task is already finished.
RAY_LOG(DEBUG).WithField(task_spec.TaskId())
<< "Task is finished. Stop a cancel request.";
return;
}

if (!reply.attempt_succeeded()) {
RetryCancelTask(std::move(task_spec), recursive, 2000);
}
});
};

// Check GCS node cache. If node info is not in the cache, query the GCS instead.
auto *node_info = gcs_client_->Nodes().Get(node_id, /*filter_dead_nodes=*/false);
if (node_info == nullptr) {
gcs_client_->Nodes().AsyncGetAll(
[do_cancel_local_task = std::move(do_cancel_local_task), node_id](
const Status &status, std::vector<rpc::GcsNodeInfo> &&nodes) mutable {
if (!status.ok()) {
RAY_LOG(INFO) << "Failed to get node info from GCS";
return;
}
if (nodes.empty() || nodes[0].state() != rpc::GcsNodeInfo::ALIVE) {
RAY_LOG(INFO).WithField(node_id)
<< "Not sending CancelLocalTask because node is dead";
return;
}
do_cancel_local_task(nodes[0]);
},
-1,
{node_id});
return;
}
if (node_info->state() == rpc::GcsNodeInfo::DEAD) {
RAY_LOG(INFO).WithField(node_id)
<< "Not sending CancelLocalTask because node is dead";
return;
}
do_cancel_local_task(*node_info);
}

bool ActorTaskSubmitter::QueueGeneratorForResubmit(const TaskSpecification &spec) {
Expand Down
11 changes: 10 additions & 1 deletion src/ray/core_worker/task_submission/actor_task_submitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class ActorTaskSubmitterInterface {
class ActorTaskSubmitter : public ActorTaskSubmitterInterface {
public:
ActorTaskSubmitter(rpc::CoreWorkerClientPool &core_worker_client_pool,
rpc::RayletClientPool &raylet_client_pool,
std::shared_ptr<gcs::GcsClient> gcs_client,
CoreWorkerMemoryStore &store,
TaskManagerInterface &task_manager,
ActorCreatorInterface &actor_creator,
Expand All @@ -76,6 +78,8 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface {
instrumented_io_context &io_service,
std::shared_ptr<ReferenceCounterInterface> reference_counter)
: core_worker_client_pool_(core_worker_client_pool),
raylet_client_pool_(raylet_client_pool),
gcs_client_(std::move(gcs_client)),
actor_creator_(actor_creator),
resolver_(store, task_manager, actor_creator, tensor_transport_getter),
task_manager_(task_manager),
Expand Down Expand Up @@ -300,7 +304,7 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface {
int64_t num_restarts_due_to_lineage_reconstructions_ = 0;
/// Whether this actor exits by spot preemption.
bool preempted_ = false;
/// The RPC client address.
/// The RPC client address of the actor.
std::optional<rpc::Address> client_address_;
/// The intended worker ID of the actor.
std::string worker_id_;
Expand Down Expand Up @@ -411,6 +415,11 @@ class ActorTaskSubmitter : public ActorTaskSubmitterInterface {
/// Pool for producing new core worker clients.
rpc::CoreWorkerClientPool &core_worker_client_pool_;

/// Pool for producing new raylet clients.
rpc::RayletClientPool &raylet_client_pool_;

std::shared_ptr<gcs::GcsClient> gcs_client_;

ActorCreatorInterface &actor_creator_;

/// Mutex to protect the various maps below.
Expand Down
Loading