From ac27fcf58acd6c8bc1b3b88591d11276f2b8bf8e Mon Sep 17 00:00:00 2001 From: steve-chavez Date: Fri, 15 Aug 2025 15:20:57 -0500 Subject: [PATCH 1/2] refactor: remove unnecessary palloc in events.c --- src/event.c | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/event.c b/src/event.c index cf7dcf0..2663d15 100644 --- a/src/event.c +++ b/src/event.c @@ -73,12 +73,12 @@ int multi_socket_cb(__attribute__ ((unused)) CURL *easy, curl_socket_t sockfd, i int epoll_op; if(!socketp){ epoll_op = EPOLL_CTL_ADD; - bool *socket_exists = palloc(sizeof(bool)); - curl_multi_assign(wstate->curl_mhandle, sockfd, socket_exists); + bool socket_exists = true; + curl_multi_assign(wstate->curl_mhandle, sockfd, &socket_exists); } else if (what == CURL_POLL_REMOVE){ epoll_op = EPOLL_CTL_DEL; - pfree(socketp); - curl_multi_assign(wstate->curl_mhandle, sockfd, NULL); + bool socket_exists = false; + curl_multi_assign(wstate->curl_mhandle, sockfd, &socket_exists); } else { epoll_op = EPOLL_CTL_MOD; } From 8cbe05a87f064ed630632d60b283437ae21748fe Mon Sep 17 00:00:00 2001 From: steve-chavez Date: Tue, 12 Aug 2025 18:54:49 -0500 Subject: [PATCH 2/2] refactor: remove unnecessary memory context switches Currently memory contexts are used to work around the short memory context of SPI blocks (SPI_connect and SPI_finish), which are used inside some functions. For example: https://github.com/supabase/pg_net/blob/314d00e34e22378ca2b733838f96b4c179d727a9/src/core.c#L139-L170 https://github.com/supabase/pg_net/blob/314d00e34e22378ca2b733838f96b4c179d727a9/src/core.c#L172-L227 But it's unnecessary to do that, instead we can use a single SPI block in the worker loop and share the same memory context for the curl operations. This makes the code simpler and makes memory allocations more visible. --- src/core.c | 202 +++++++++++++++++++-------------------------------- src/core.h | 30 +++++++- src/worker.c | 51 ++++++++++--- 3 files changed, 143 insertions(+), 140 deletions(-) diff --git a/src/core.c b/src/core.c index cae7dc1..78f269a 100644 --- a/src/core.c +++ b/src/core.c @@ -9,13 +9,6 @@ #include "event.h" #include "errors.h" -typedef struct { - int64 id; - StringInfo body; - struct curl_slist* request_headers; - int32 timeout_milliseconds; -} CurlData; - static SPIPlanPtr del_response_plan = NULL; static SPIPlanPtr del_return_queue_plan = NULL; static SPIPlanPtr ins_response_plan = NULL; @@ -52,18 +45,15 @@ static struct curl_slist *pg_text_array_to_slist(ArrayType *array, return headers; } -// We need a different memory context here, as the parent function will have an SPI memory context, which has a shorter lifetime. -static void init_curl_handle(CURLM *curl_mhandle, MemoryContext curl_memctx, int64 id, Datum urlBin, NullableDatum bodyBin, NullableDatum headersBin, Datum methodBin, int32 timeout_milliseconds){ - MemoryContext old_ctx = MemoryContextSwitchTo(curl_memctx); - - CurlData *cdata = palloc(sizeof(CurlData)); - cdata->id = id; +void init_curl_handle(CurlData *cdata, RequestQueueRow row){ + cdata->id = row.id; cdata->body = makeStringInfo(); + cdata->ez_handle = curl_easy_init(); - cdata->timeout_milliseconds = timeout_milliseconds; + cdata->timeout_milliseconds = row.timeout_milliseconds; - if (!headersBin.isnull) { - ArrayType *pgHeaders = DatumGetArrayTypeP(headersBin.value); + if (!row.headersBin.isnull) { + ArrayType *pgHeaders = DatumGetArrayTypeP(row.headersBin.value); struct curl_slist *request_headers = NULL; request_headers = pg_text_array_to_slist(pgHeaders, request_headers); @@ -73,64 +63,55 @@ static void init_curl_handle(CURLM *curl_mhandle, MemoryContext curl_memctx, int cdata->request_headers = request_headers; } - char *url = TextDatumGetCString(urlBin); + cdata->url = TextDatumGetCString(row.url); - char *reqBody = !bodyBin.isnull ? TextDatumGetCString(bodyBin.value) : NULL; + cdata->req_body = !row.bodyBin.isnull ? TextDatumGetCString(row.bodyBin.value) : NULL; - char *method = TextDatumGetCString(methodBin); - if (strcasecmp(method, "GET") != 0 && strcasecmp(method, "POST") != 0 && strcasecmp(method, "DELETE") != 0) { - ereport(ERROR, errmsg("Unsupported request method %s", method)); - } + cdata->method = TextDatumGetCString(row.method); - CURL *curl_ez_handle = curl_easy_init(); - if(!curl_ez_handle) - ereport(ERROR, errmsg("curl_easy_init()")); + if (strcasecmp(cdata->method, "GET") != 0 && strcasecmp(cdata->method, "POST") != 0 && strcasecmp(cdata->method, "DELETE") != 0) { + ereport(ERROR, errmsg("Unsupported request method %s", cdata->method)); + } - if (strcasecmp(method, "GET") == 0) { - if (reqBody) { - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_POSTFIELDS, reqBody); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_CUSTOMREQUEST, "GET"); + if (strcasecmp(cdata->method, "GET") == 0) { + if (cdata->req_body) { + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_POSTFIELDS, cdata->req_body); + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_CUSTOMREQUEST, "GET"); } } - if (strcasecmp(method, "POST") == 0) { - if (reqBody) { - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_POSTFIELDS, reqBody); + if (strcasecmp(cdata->method, "POST") == 0) { + if (cdata->req_body) { + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_POSTFIELDS, cdata->req_body); } else { - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_POST, 1L); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_POSTFIELDSIZE, 0L); + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_POST, 1L); + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_POSTFIELDSIZE, 0L); } } - if (strcasecmp(method, "DELETE") == 0) { - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_CUSTOMREQUEST, "DELETE"); - if (reqBody) { - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_POSTFIELDS, reqBody); + if (strcasecmp(cdata->method, "DELETE") == 0) { + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_CUSTOMREQUEST, "DELETE"); + if (cdata->req_body) { + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_POSTFIELDS, cdata->req_body); } } - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_WRITEFUNCTION, body_cb); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_WRITEDATA, cdata); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_HEADER, 0L); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_URL, url); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_HTTPHEADER, cdata->request_headers); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_TIMEOUT_MS, (long) cdata->timeout_milliseconds); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_PRIVATE, cdata); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_FOLLOWLOCATION, (long) true); + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_WRITEFUNCTION, body_cb); + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_WRITEDATA, cdata); + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_HEADER, 0L); + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_URL, cdata->url); + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_HTTPHEADER, cdata->request_headers); + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_TIMEOUT_MS, (long) cdata->timeout_milliseconds); + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_PRIVATE, cdata); + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_FOLLOWLOCATION, (long) true); if (log_min_messages <= DEBUG2) - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_VERBOSE, 1L); + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_VERBOSE, 1L); #if LIBCURL_VERSION_NUM >= 0x075500 /* libcurl 7.85.0 */ - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_PROTOCOLS_STR, "http,https"); + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_PROTOCOLS_STR, "http,https"); #else - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_PROTOCOLS, CURLPROTO_HTTP | CURLPROTO_HTTPS); + EREPORT_CURL_SETOPT(cdata->ez_handle, CURLOPT_PROTOCOLS, CURLPROTO_HTTP | CURLPROTO_HTTPS); #endif - - EREPORT_MULTI( - curl_multi_add_handle(curl_mhandle, curl_ez_handle) - ); - - MemoryContextSwitchTo(old_ctx); } void set_curl_mhandle(WorkerState *wstate){ @@ -141,8 +122,6 @@ void set_curl_mhandle(WorkerState *wstate){ } uint64 delete_expired_responses(char *ttl, int batch_size){ - SPI_connect(); - if (del_response_plan == NULL) { SPIPlanPtr tmp = SPI_prepare("\ WITH\ @@ -178,14 +157,10 @@ uint64 delete_expired_responses(char *ttl, int batch_size){ ereport(ERROR, errmsg("Error expiring response table rows: %s", SPI_result_code_string(ret_code))); } - SPI_finish(); - return affected_rows; } -uint64 consume_request_queue(CURLM *curl_mhandle, int batch_size, MemoryContext curl_memctx){ - SPI_connect(); - +uint64 consume_request_queue(const int batch_size){ if (del_return_queue_plan == NULL) { SPIPlanPtr tmp = SPI_prepare("\ WITH\ @@ -214,47 +189,40 @@ uint64 consume_request_queue(CURLM *curl_mhandle, int batch_size, MemoryContext if (ret_code != SPI_OK_DELETE_RETURNING) ereport(ERROR, errmsg("Error getting http request queue: %s", SPI_result_code_string(ret_code))); - uint64 affected_rows = SPI_processed; - - for (size_t j = 0; j < affected_rows; j++) { - bool tupIsNull = false; - - int64 id = DatumGetInt64(SPI_getbinval(SPI_tuptable->vals[j], SPI_tuptable->tupdesc, 1, &tupIsNull)); - EREPORT_NULL_ATTR(tupIsNull, id); + return SPI_processed; +} - int32 timeout_milliseconds = DatumGetInt32(SPI_getbinval(SPI_tuptable->vals[j], SPI_tuptable->tupdesc, 4, &tupIsNull)); - EREPORT_NULL_ATTR(tupIsNull, timeout_milliseconds); +// This has an implicit dependency on the execution of delete_return_request_queue, +// unfortunately we're not able to make this dependency explicit +// due to the design of SPI (which uses global variables) +RequestQueueRow get_request_queue_row(HeapTuple spi_tupval, TupleDesc spi_tupdesc){ + bool tupIsNull = false; - Datum method = SPI_getbinval(SPI_tuptable->vals[j], SPI_tuptable->tupdesc, 2, &tupIsNull); - EREPORT_NULL_ATTR(tupIsNull, method); + int64 id = DatumGetInt64(SPI_getbinval(spi_tupval, spi_tupdesc, 1, &tupIsNull)); + EREPORT_NULL_ATTR(tupIsNull, id); - Datum url = SPI_getbinval(SPI_tuptable->vals[j], SPI_tuptable->tupdesc, 3, &tupIsNull); - EREPORT_NULL_ATTR(tupIsNull, url); + Datum method = SPI_getbinval(spi_tupval, spi_tupdesc, 2, &tupIsNull); + EREPORT_NULL_ATTR(tupIsNull, method); - NullableDatum headersBin = { - .value = SPI_getbinval(SPI_tuptable->vals[j], SPI_tuptable->tupdesc, 5, &tupIsNull), - .isnull = tupIsNull - }; + Datum url = SPI_getbinval(spi_tupval, spi_tupdesc, 3, &tupIsNull); + EREPORT_NULL_ATTR(tupIsNull, url); - NullableDatum bodyBin = { - .value = SPI_getbinval(SPI_tuptable->vals[j], SPI_tuptable->tupdesc, 6, &tupIsNull), - .isnull = tupIsNull - }; + int32 timeout_milliseconds = DatumGetInt32(SPI_getbinval(spi_tupval, spi_tupdesc, 4, &tupIsNull)); + EREPORT_NULL_ATTR(tupIsNull, timeout_milliseconds); - init_curl_handle(curl_mhandle, curl_memctx, id, url, bodyBin, headersBin, method, timeout_milliseconds); - } + NullableDatum headersBin = { + .value = SPI_getbinval(spi_tupval, spi_tupdesc, 5, &tupIsNull), + .isnull = tupIsNull + }; - SPI_finish(); + NullableDatum bodyBin = { + .value = SPI_getbinval(spi_tupval, spi_tupdesc, 6, &tupIsNull), + .isnull = tupIsNull + }; - return affected_rows; -} - -static void pfree_curl_data(CurlData *cdata){ - if(cdata->body){ - destroyStringInfo(cdata->body); - } - if(cdata->request_headers) //curl_slist_free_all already handles the NULL case, but be explicit about it - curl_slist_free_all(cdata->request_headers); + return (RequestQueueRow){ + id, method, url, timeout_milliseconds, headersBin, bodyBin + }; } static Jsonb *jsonb_headers_from_curl_handle(CURL *ez_handle){ @@ -276,11 +244,14 @@ static Jsonb *jsonb_headers_from_curl_handle(CURL *ez_handle){ return jsonb_headers; } -static void insert_response(CURL *ez_handle, CurlData *cdata, CURLcode curl_return_code){ +void insert_response(CURL *ez_handle, CURLcode curl_return_code){ enum { nparams = 7 }; // using an enum because const size_t nparams doesn't compile Datum vals[nparams]; char nulls[nparams]; MemSet(nulls, 'n', nparams); + CurlData *cdata = NULL; + EREPORT_CURL_GETINFO(ez_handle, CURLINFO_PRIVATE, &cdata); + vals[0] = Int64GetDatum(cdata->id); nulls[0] = ' '; @@ -352,36 +323,15 @@ static void insert_response(CURL *ez_handle, CurlData *cdata, CURLcode curl_retu } } -// Switch back to the curl memory context, which has the curl handles stored -void insert_curl_responses(WorkerState *wstate, MemoryContext curl_memctx){ - MemoryContext old_ctx = MemoryContextSwitchTo(curl_memctx); - int msgs_left=0; - CURLMsg *msg = NULL; - CURLM *curl_mhandle = wstate->curl_mhandle; - - while ((msg = curl_multi_info_read(curl_mhandle, &msgs_left))) { - if (msg->msg == CURLMSG_DONE) { - CURLcode return_code = msg->data.result; - CURL *ez_handle= msg->easy_handle; - CurlData *cdata = NULL; - EREPORT_CURL_GETINFO(ez_handle, CURLINFO_PRIVATE, &cdata); - - SPI_connect(); - insert_response(ez_handle, cdata, return_code); - SPI_finish(); - - pfree_curl_data(cdata); +void pfree_curl_data(CurlData *cdata){ + pfree(cdata->url); + pfree(cdata->method); + if(cdata->req_body) + pfree(cdata->req_body); - int res = curl_multi_remove_handle(curl_mhandle, ez_handle); - if(res != CURLM_OK) - ereport(ERROR, errmsg("curl_multi_remove_handle: %s", curl_multi_strerror(res))); - - curl_easy_cleanup(ez_handle); - } else { - ereport(ERROR, errmsg("curl_multi_info_read(), CURLMsg=%d\n", msg->msg)); - } - } + if(cdata->body) + destroyStringInfo(cdata->body); - MemoryContextSwitchTo(old_ctx); + if(cdata->request_headers) //curl_slist_free_all already handles the NULL case, but be explicit about it + curl_slist_free_all(cdata->request_headers); } - diff --git a/src/core.h b/src/core.h index 7c3b85b..74697b1 100644 --- a/src/core.h +++ b/src/core.h @@ -17,12 +17,38 @@ typedef struct { CURLM *curl_mhandle; } WorkerState; +typedef struct { + int64 id; + StringInfo body; + struct curl_slist* request_headers; + int32 timeout_milliseconds; + char *url; + char *req_body; + char *method; + CURL *ez_handle; +} CurlData; + +typedef struct { + int64 id; + Datum method; + Datum url; + int32 timeout_milliseconds; + NullableDatum headersBin; + NullableDatum bodyBin; +} RequestQueueRow; + uint64 delete_expired_responses(char *ttl, int batch_size); -uint64 consume_request_queue(CURLM *curl_mhandle, int batch_size, MemoryContext curl_memctx); +uint64 consume_request_queue(const int batch_size); -void insert_curl_responses(WorkerState *wstate, MemoryContext curl_memctx); +RequestQueueRow get_request_queue_row(HeapTuple spi_tupval, TupleDesc spi_tupdesc); void set_curl_mhandle(WorkerState *wstate); +void insert_response(CURL *ez_handle, CURLcode curl_return_code); + +void init_curl_handle(CurlData *cdata, RequestQueueRow row); + +void pfree_curl_data(CurlData *cdata); + #endif diff --git a/src/worker.c b/src/worker.c index d4b92ed..bbbf14a 100644 --- a/src/worker.c +++ b/src/worker.c @@ -36,7 +36,6 @@ static char* guc_ttl; static int guc_batch_size; static char* guc_database_name; static char* guc_username; -static MemoryContext CurlMemContext = NULL; #if PG15_GTE static shmem_request_hook_type prev_shmem_request_hook = NULL; @@ -290,15 +289,29 @@ void pg_net_worker(__attribute__ ((unused)) Datum main_arg) { break; } + SPI_connect(); + expired_responses = delete_expired_responses(guc_ttl, guc_batch_size); elog(DEBUG1, "Deleted "UINT64_FORMAT" expired rows", expired_responses); - requests_consumed = consume_request_queue(worker_state->curl_mhandle, guc_batch_size, CurlMemContext); + requests_consumed = consume_request_queue(guc_batch_size); elog(DEBUG1, "Consumed "UINT64_FORMAT" request rows", requests_consumed); if(requests_consumed > 0){ + CurlData *cdatas = palloc(mul_size(sizeof(CurlData), requests_consumed)); + + // initialize curl handles + for (size_t j = 0; j < requests_consumed; j++) { + init_curl_handle(&cdatas[j], get_request_queue_row(SPI_tuptable->vals[j], SPI_tuptable->tupdesc)); + + EREPORT_MULTI( + curl_multi_add_handle(worker_state->curl_mhandle, cdatas[j].ez_handle) + ); + } + + // start curl event loop int running_handles = 0; int maxevents = guc_batch_size + 1; // 1 extra for the timer event events[maxevents]; @@ -334,22 +347,42 @@ void pg_net_worker(__attribute__ ((unused)) Datum main_arg) { &running_handles) ); } - } - insert_curl_responses(worker_state, CurlMemContext); + // insert finished responses + CURLMsg *msg = NULL; int msgs_left=0; + while ((msg = curl_multi_info_read(worker_state->curl_mhandle, &msgs_left))) { + if (msg->msg == CURLMSG_DONE) { + insert_response(msg->easy_handle, msg->data.result); + } else { + ereport(ERROR, errmsg("curl_multi_info_read(), CURLMsg=%d\n", msg->msg)); + } + } elog(DEBUG1, "Pending curl running_handles: %d", running_handles); } while (running_handles > 0); // run while there are curl handles, some won't finish in a single iteration since they could be slow and waiting for a timeout + + // cleanup + for(uint64 i = 0; i < requests_consumed; i++){ + EREPORT_MULTI( + curl_multi_remove_handle(worker_state->curl_mhandle, cdatas[i].ez_handle) + ); + + curl_easy_cleanup(cdatas[i].ez_handle); + + pfree_curl_data(&cdatas[i]); + } + + pfree(cdatas); } + SPI_finish(); + unlock_extension(ext_table_oids); PopActiveSnapshot(); CommitTransactionCommand(); - MemoryContextReset(CurlMemContext); - // slow down queue processing to avoid using too much CPU wait_while_processing_interrupts(WORKER_WAIT_ONE_SECOND, &worker_should_restart); @@ -430,12 +463,6 @@ void _PG_init(void) { prev_shmem_startup_hook = shmem_startup_hook; shmem_startup_hook = net_shmem_startup; - CurlMemContext = AllocSetContextCreate(TopMemoryContext, - "pg_net curl context", - ALLOCSET_DEFAULT_MINSIZE, - ALLOCSET_DEFAULT_INITSIZE, - ALLOCSET_DEFAULT_MAXSIZE); - DefineCustomStringVariable("pg_net.ttl", "time to live for request/response rows", "should be a valid interval type",