@@ -822,6 +822,12 @@ class BackendServiceImpl final : public backend::Backend::Service {
822822 }
823823
824824 ctx_server.receive_cmpl_results_stream (task_ids, [&](server_task_result_ptr & result) -> bool {
825+ // Check if context is cancelled before processing result
826+ if (context->IsCancelled ()) {
827+ ctx_server.cancel_tasks (task_ids);
828+ return false ;
829+ }
830+
825831 json res_json = result->to_json ();
826832 if (res_json.is_array ()) {
827833 for (const auto & res : res_json) {
@@ -875,13 +881,18 @@ class BackendServiceImpl final : public backend::Backend::Service {
875881 reply.set_message (error_data.value (" content" , " " ));
876882 writer->Write (reply);
877883 return true ;
878- }, [&]() {
879- // NOTE: we should try to check when the writer is closed here
880- return false ;
884+ }, [&context ]() {
885+ // Check if the gRPC context is cancelled
886+ return context-> IsCancelled () ;
881887 });
882888
883889 ctx_server.queue_results .remove_waiting_task_ids (task_ids);
884890
891+ // Check if context was cancelled during processing
892+ if (context->IsCancelled ()) {
893+ return grpc::Status (grpc::StatusCode::CANCELLED, " Request cancelled by client" );
894+ }
895+
885896 return grpc::Status::OK;
886897 }
887898
@@ -1145,6 +1156,14 @@ class BackendServiceImpl final : public backend::Backend::Service {
11451156
11461157
11471158 std::cout << " [DEBUG] Waiting for results..." << std::endl;
1159+
1160+ // Check cancellation before waiting for results
1161+ if (context->IsCancelled ()) {
1162+ ctx_server.cancel_tasks (task_ids);
1163+ ctx_server.queue_results .remove_waiting_task_ids (task_ids);
1164+ return grpc::Status (grpc::StatusCode::CANCELLED, " Request cancelled by client" );
1165+ }
1166+
11481167 ctx_server.receive_multi_results (task_ids, [&](std::vector<server_task_result_ptr> & results) {
11491168 std::cout << " [DEBUG] Received " << results.size () << " results" << std::endl;
11501169 if (results.size () == 1 ) {
@@ -1176,13 +1195,20 @@ class BackendServiceImpl final : public backend::Backend::Service {
11761195 }, [&](const json & error_data) {
11771196 std::cout << " [DEBUG] Error in results: " << error_data.value (" content" , " " ) << std::endl;
11781197 reply->set_message (error_data.value (" content" , " " ));
1179- }, [&]() {
1180- return false ;
1198+ }, [&context]() {
1199+ // Check if the gRPC context is cancelled
1200+ // This is checked every HTTP_POLLING_SECONDS (1 second) during receive_multi_results
1201+ return context->IsCancelled ();
11811202 });
11821203
11831204 ctx_server.queue_results .remove_waiting_task_ids (task_ids);
11841205 std::cout << " [DEBUG] Predict request completed successfully" << std::endl;
11851206
1207+ // Check if context was cancelled during processing
1208+ if (context->IsCancelled ()) {
1209+ return grpc::Status (grpc::StatusCode::CANCELLED, " Request cancelled by client" );
1210+ }
1211+
11861212 return grpc::Status::OK;
11871213 }
11881214
@@ -1234,6 +1260,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
12341260 ctx_server.queue_tasks .post (std::move (tasks));
12351261 }
12361262
1263+ // Check cancellation before waiting for results
1264+ if (context->IsCancelled ()) {
1265+ ctx_server.cancel_tasks (task_ids);
1266+ ctx_server.queue_results .remove_waiting_task_ids (task_ids);
1267+ return grpc::Status (grpc::StatusCode::CANCELLED, " Request cancelled by client" );
1268+ }
1269+
12371270 // get the result
12381271 ctx_server.receive_multi_results (task_ids, [&](std::vector<server_task_result_ptr> & results) {
12391272 for (auto & res : results) {
@@ -1242,12 +1275,18 @@ class BackendServiceImpl final : public backend::Backend::Service {
12421275 }
12431276 }, [&](const json & error_data) {
12441277 error = true ;
1245- }, [&]() {
1246- return false ;
1278+ }, [&context]() {
1279+ // Check if the gRPC context is cancelled
1280+ return context->IsCancelled ();
12471281 });
12481282
12491283 ctx_server.queue_results .remove_waiting_task_ids (task_ids);
12501284
1285+ // Check if context was cancelled during processing
1286+ if (context->IsCancelled ()) {
1287+ return grpc::Status (grpc::StatusCode::CANCELLED, " Request cancelled by client" );
1288+ }
1289+
12511290 if (error) {
12521291 return grpc::Status (grpc::StatusCode::INTERNAL, " Error in receiving results" );
12531292 }
@@ -1325,6 +1364,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
13251364 ctx_server.queue_tasks .post (std::move (tasks));
13261365 }
13271366
1367+ // Check cancellation before waiting for results
1368+ if (context->IsCancelled ()) {
1369+ ctx_server.cancel_tasks (task_ids);
1370+ ctx_server.queue_results .remove_waiting_task_ids (task_ids);
1371+ return grpc::Status (grpc::StatusCode::CANCELLED, " Request cancelled by client" );
1372+ }
1373+
13281374 // Get the results
13291375 ctx_server.receive_multi_results (task_ids, [&](std::vector<server_task_result_ptr> & results) {
13301376 for (auto & res : results) {
@@ -1333,12 +1379,18 @@ class BackendServiceImpl final : public backend::Backend::Service {
13331379 }
13341380 }, [&](const json & error_data) {
13351381 error = true ;
1336- }, [&]() {
1337- return false ;
1382+ }, [&context]() {
1383+ // Check if the gRPC context is cancelled
1384+ return context->IsCancelled ();
13381385 });
13391386
13401387 ctx_server.queue_results .remove_waiting_task_ids (task_ids);
13411388
1389+ // Check if context was cancelled during processing
1390+ if (context->IsCancelled ()) {
1391+ return grpc::Status (grpc::StatusCode::CANCELLED, " Request cancelled by client" );
1392+ }
1393+
13421394 if (error) {
13431395 return grpc::Status (grpc::StatusCode::INTERNAL, " Error in receiving results" );
13441396 }
0 commit comments