Skip to content

Commit 679d43c

Browse files
authored
feat: respect context and add request cancellation (#7187)
* feat: respect context Signed-off-by: Ettore Di Giacinto <[email protected]> * workaround fasthttp Signed-off-by: Ettore Di Giacinto <[email protected]> * feat(ui): allow to abort call Signed-off-by: Ettore Di Giacinto <[email protected]> * Refactor Signed-off-by: Ettore Di Giacinto <[email protected]> * chore: improving error Signed-off-by: Ettore Di Giacinto <[email protected]> * Respect context also with MCP Signed-off-by: Ettore Di Giacinto <[email protected]> * Tie to both contexts Signed-off-by: Ettore Di Giacinto <[email protected]> * Make detection more robust Signed-off-by: Ettore Di Giacinto <[email protected]> --------- Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 4730b52 commit 679d43c

File tree

8 files changed

+240
-42
lines changed

8 files changed

+240
-42
lines changed

backend/cpp/llama-cpp/grpc-server.cpp

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,12 @@ class BackendServiceImpl final : public backend::Backend::Service {
822822
}
823823

824824
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
825+
// Check if context is cancelled before processing result
826+
if (context->IsCancelled()) {
827+
ctx_server.cancel_tasks(task_ids);
828+
return false;
829+
}
830+
825831
json res_json = result->to_json();
826832
if (res_json.is_array()) {
827833
for (const auto & res : res_json) {
@@ -875,13 +881,18 @@ class BackendServiceImpl final : public backend::Backend::Service {
875881
reply.set_message(error_data.value("content", ""));
876882
writer->Write(reply);
877883
return true;
878-
}, [&]() {
879-
// NOTE: we should try to check when the writer is closed here
880-
return false;
884+
}, [&context]() {
885+
// Check if the gRPC context is cancelled
886+
return context->IsCancelled();
881887
});
882888

883889
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
884890

891+
// Check if context was cancelled during processing
892+
if (context->IsCancelled()) {
893+
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
894+
}
895+
885896
return grpc::Status::OK;
886897
}
887898

@@ -1145,6 +1156,14 @@ class BackendServiceImpl final : public backend::Backend::Service {
11451156

11461157

11471158
std::cout << "[DEBUG] Waiting for results..." << std::endl;
1159+
1160+
// Check cancellation before waiting for results
1161+
if (context->IsCancelled()) {
1162+
ctx_server.cancel_tasks(task_ids);
1163+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
1164+
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
1165+
}
1166+
11481167
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
11491168
std::cout << "[DEBUG] Received " << results.size() << " results" << std::endl;
11501169
if (results.size() == 1) {
@@ -1176,13 +1195,20 @@ class BackendServiceImpl final : public backend::Backend::Service {
11761195
}, [&](const json & error_data) {
11771196
std::cout << "[DEBUG] Error in results: " << error_data.value("content", "") << std::endl;
11781197
reply->set_message(error_data.value("content", ""));
1179-
}, [&]() {
1180-
return false;
1198+
}, [&context]() {
1199+
// Check if the gRPC context is cancelled
1200+
// This is checked every HTTP_POLLING_SECONDS (1 second) during receive_multi_results
1201+
return context->IsCancelled();
11811202
});
11821203

11831204
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
11841205
std::cout << "[DEBUG] Predict request completed successfully" << std::endl;
11851206

1207+
// Check if context was cancelled during processing
1208+
if (context->IsCancelled()) {
1209+
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
1210+
}
1211+
11861212
return grpc::Status::OK;
11871213
}
11881214

@@ -1234,6 +1260,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
12341260
ctx_server.queue_tasks.post(std::move(tasks));
12351261
}
12361262

1263+
// Check cancellation before waiting for results
1264+
if (context->IsCancelled()) {
1265+
ctx_server.cancel_tasks(task_ids);
1266+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
1267+
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
1268+
}
1269+
12371270
// get the result
12381271
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
12391272
for (auto & res : results) {
@@ -1242,12 +1275,18 @@ class BackendServiceImpl final : public backend::Backend::Service {
12421275
}
12431276
}, [&](const json & error_data) {
12441277
error = true;
1245-
}, [&]() {
1246-
return false;
1278+
}, [&context]() {
1279+
// Check if the gRPC context is cancelled
1280+
return context->IsCancelled();
12471281
});
12481282

12491283
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
12501284

1285+
// Check if context was cancelled during processing
1286+
if (context->IsCancelled()) {
1287+
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
1288+
}
1289+
12511290
if (error) {
12521291
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
12531292
}
@@ -1325,6 +1364,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
13251364
ctx_server.queue_tasks.post(std::move(tasks));
13261365
}
13271366

1367+
// Check cancellation before waiting for results
1368+
if (context->IsCancelled()) {
1369+
ctx_server.cancel_tasks(task_ids);
1370+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
1371+
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
1372+
}
1373+
13281374
// Get the results
13291375
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
13301376
for (auto & res : results) {
@@ -1333,12 +1379,18 @@ class BackendServiceImpl final : public backend::Backend::Service {
13331379
}
13341380
}, [&](const json & error_data) {
13351381
error = true;
1336-
}, [&]() {
1337-
return false;
1382+
}, [&context]() {
1383+
// Check if the gRPC context is cancelled
1384+
return context->IsCancelled();
13381385
});
13391386

13401387
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
13411388

1389+
// Check if context was cancelled during processing
1390+
if (context->IsCancelled()) {
1391+
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
1392+
}
1393+
13421394
if (error) {
13431395
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
13441396
}

core/config/model_config.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,19 +93,18 @@ type AgentConfig struct {
9393
EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator" json:"enable_plan_re_evaluator"`
9494
}
9595

96-
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers]) {
96+
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) {
9797
var remote MCPGenericConfig[MCPRemoteServers]
9898
var stdio MCPGenericConfig[MCPSTDIOServers]
9999

100100
if err := yaml.Unmarshal([]byte(c.Servers), &remote); err != nil {
101-
return remote, stdio
101+
return remote, stdio, err
102102
}
103103

104104
if err := yaml.Unmarshal([]byte(c.Stdio), &stdio); err != nil {
105-
return remote, stdio
105+
return remote, stdio, err
106106
}
107-
108-
return remote, stdio
107+
return remote, stdio, nil
109108
}
110109

111110
type MCPGenericConfig[T any] struct {

core/http/endpoints/openai/chat.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ package openai
33
import (
44
"bufio"
55
"bytes"
6+
"context"
67
"encoding/json"
78
"fmt"
9+
"net"
810
"time"
911

1012
"github.com/gofiber/fiber/v2"
@@ -22,6 +24,59 @@ import (
2224
"github.com/valyala/fasthttp"
2325
)
2426

27+
// NOTE: this is a bad WORKAROUND! We should find a better way to handle this.
28+
// Fasthttp doesn't support context cancellation from the caller
29+
// for non-streaming requests, so we need to monitor the connection directly.
30+
// Monitor connection for client disconnection during non-streaming requests
31+
// We access the connection directly via c.Context().Conn() to monitor it
32+
// during ComputeChoices execution, not after the response is sent
33+
// see: https://github.com/mudler/LocalAI/pull/7187#issuecomment-3506720906
34+
func handleConnectionCancellation(c *fiber.Ctx, cancelFunc func(), requestCtx context.Context) {
35+
var conn net.Conn = c.Context().Conn()
36+
if conn == nil {
37+
return
38+
}
39+
40+
go func() {
41+
defer func() {
42+
// Clear read deadline when goroutine exits
43+
conn.SetReadDeadline(time.Time{})
44+
}()
45+
46+
buf := make([]byte, 1)
47+
// Use a short read deadline to periodically check if connection is closed
48+
// Without a deadline, Read() would block indefinitely waiting for data
49+
// that will never come (client is waiting for response, not sending more data)
50+
ticker := time.NewTicker(100 * time.Millisecond)
51+
defer ticker.Stop()
52+
53+
for {
54+
select {
55+
case <-requestCtx.Done():
56+
// Request completed or was cancelled - exit goroutine
57+
return
58+
case <-ticker.C:
59+
// Set a short deadline - if connection is closed, read will fail immediately
60+
// If connection is open but no data, it will timeout and we check again
61+
conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
62+
_, err := conn.Read(buf)
63+
if err != nil {
64+
// Check if it's a timeout (connection still open, just no data)
65+
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
66+
// Timeout is expected - connection is still open, just no data to read
67+
// Continue the loop to check again
68+
continue
69+
}
70+
// Connection closed or other error - cancel the context to stop gRPC call
71+
log.Debug().Msgf("Calling cancellation function")
72+
cancelFunc()
73+
return
74+
}
75+
}
76+
}
77+
}()
78+
}
79+
2580
// ChatEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/chat/create
2681
// @Summary Generate a chat completions for a given prompt and model.
2782
// @Param request body schema.OpenAIRequest true "query params"
@@ -358,6 +413,11 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
358413
LOOP:
359414
for {
360415
select {
416+
case <-input.Context.Done():
417+
// Context was cancelled (client disconnected or request cancelled)
418+
log.Debug().Msgf("Request context cancelled, stopping stream")
419+
input.Cancel()
420+
break LOOP
361421
case ev := <-responses:
362422
if len(ev.Choices) == 0 {
363423
log.Debug().Msgf("No choices in the response, skipping")
@@ -511,6 +571,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
511571

512572
}
513573

574+
// NOTE: this is a workaround as fasthttp
575+
// context cancellation does not fire in non-streaming requests
576+
handleConnectionCancellation(c, input.Cancel, input.Context)
577+
514578
result, tokenUsage, err := ComputeChoices(
515579
input,
516580
predInput,

core/http/endpoints/openai/mcp.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package openai
22

33
import (
4+
"context"
45
"encoding/json"
56
"errors"
67
"fmt"
@@ -50,12 +51,15 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
5051
}
5152

5253
// Get MCP config from model config
53-
remote, stdio := config.MCP.MCPConfigFromYAML()
54+
remote, stdio, err := config.MCP.MCPConfigFromYAML()
55+
if err != nil {
56+
return fmt.Errorf("failed to get MCP config: %w", err)
57+
}
5458

5559
// Check if we have tools in cache, or we have to have an initial connection
5660
sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio)
5761
if err != nil {
58-
return err
62+
return fmt.Errorf("failed to get MCP sessions: %w", err)
5963
}
6064

6165
if len(sessions) == 0 {
@@ -73,6 +77,10 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
7377
if appConfig.ApiKeys != nil {
7478
apiKey = appConfig.ApiKeys[0]
7579
}
80+
81+
ctxWithCancellation, cancel := context.WithCancel(ctx)
82+
defer cancel()
83+
handleConnectionCancellation(c, cancel, ctxWithCancellation)
7684
// TODO: instead of connecting to the API, we should just wire this internally
7785
// and act like completion.go.
7886
// We can do this as cogito expects an interface and we can create one that
@@ -83,7 +91,7 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
8391
cogito.WithStatusCallback(func(s string) {
8492
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
8593
}),
86-
cogito.WithContext(ctx),
94+
cogito.WithContext(ctxWithCancellation),
8795
cogito.WithMCPs(sessions...),
8896
cogito.WithIterations(3), // default to 3 iterations
8997
cogito.WithMaxAttempts(3), // default to 3 attempts

core/http/middleware/request.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,17 @@ func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
161161
correlationID := ctx.Get("X-Correlation-ID", uuid.New().String())
162162
ctx.Set("X-Correlation-ID", correlationID)
163163

164+
//c1, cancel := context.WithCancel(re.applicationConfig.Context)
165+
// Use the application context as parent to ensure cancellation on app shutdown
166+
// We'll monitor the Fiber context separately and cancel our context when the request is canceled
164167
c1, cancel := context.WithCancel(re.applicationConfig.Context)
168+
// Monitor the Fiber context and cancel our context when it's canceled
169+
// This ensures we respect request cancellation without causing panics
170+
go func() {
171+
<-ctx.Context().Done()
172+
// Fiber context was canceled (request completed or client disconnected)
173+
cancel()
174+
}()
165175
// Add the correlation ID to the new context
166176
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
167177

0 commit comments

Comments
 (0)