Skip to content

metal : reuse graphs #14570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: gg/llama-reuse-graphs
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1464,6 +1464,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.swa_full = true;
}
).set_env("LLAMA_ARG_SWA_FULL"));
add_opt(common_arg(
{"--graph-reuse", "-gr"},
string_format("reuse previous compute graphs when possible (default: %s)"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/14482)", params.graph_reuse ? "true" : "false"),
[](common_params & params) {
params.graph_reuse = true;
}
).set_env("LLAMA_ARG_GRAPH_REUSE"));
add_opt(common_arg(
{"--no-context-shift"},
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
Expand Down
1 change: 1 addition & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
cparams.swa_full = params.swa_full;
cparams.graph_reuse = params.graph_reuse;

cparams.type_k = params.cache_type_k;
cparams.type_v = params.cache_type_v;
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ struct common_params {
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
bool graph_reuse = false; // reuse previous compute graphs when possible

bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool use_mmap = true; // use mmap for faster loads
Expand Down
258 changes: 205 additions & 53 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -821,13 +821,23 @@ static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {

// the callback given to the thread pool
void (^encode_async)(size_t ith);
void (^encode_next)(void);

// n_cb command buffers + 1 used by the main thread
struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
struct ggml_metal_command_buffer cmd_bufs_next[2];

// abort ggml_metal_graph_compute if callback returns true
ggml_abort_callback abort_callback;
void * abort_callback_data;

// reuse info
int i_next;

int n_nodes_max;
int n_nodes_prev;

struct ggml_tensor * cg_nodes;
};

// MSL code
Expand Down Expand Up @@ -1084,13 +1094,21 @@ @implementation GGMLMetalClass

ctx->gf = nil;
ctx->encode_async = nil;
ctx->encode_next = nil;
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
ctx->cmd_bufs[i].obj = nil;

ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
ctx->cmd_bufs[i].mem_pool->device = device;
}

for (int i = 0; i < 2; ++i) {
ctx->cmd_bufs_next[i].obj = nil;

ctx->cmd_bufs_next[i].mem_pool = ggml_metal_mem_pool_init();
ctx->cmd_bufs_next[i].mem_pool->device = device;
}

#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
if (@available(macOS 10.12, iOS 16.0, *)) {
GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6);
Expand Down Expand Up @@ -1521,6 +1539,13 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
}

ctx->i_next = 0;

ctx->n_nodes_max = 16384;
ctx->n_nodes_prev = -1;

ctx->cg_nodes = ggml_aligned_malloc(ctx->n_nodes_max * sizeof(struct ggml_tensor));

return ctx;
}

Expand All @@ -1532,6 +1557,7 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
}

Block_release(ctx->encode_async);
Block_release(ctx->encode_next);

[ctx->queue release];

Expand All @@ -1541,8 +1567,13 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
}

ggml_metal_mem_pool_free(ctx->cmd_bufs_next[0].mem_pool);
ggml_metal_mem_pool_free(ctx->cmd_bufs_next[1].mem_pool);

dispatch_release(ctx->d_queue);

ggml_aligned_free(ctx->cg_nodes, ctx->n_nodes_max * sizeof(struct ggml_tensor));

free(ctx);
}

Expand Down Expand Up @@ -5448,6 +5479,39 @@ static enum ggml_status ggml_metal_graph_compute(
struct ggml_backend_metal_context * ctx = backend->context;
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;

//const int64_t t_start = ggml_time_us();

/////////////////////////////////////////////////////
// hacky way to determine that the graph is the same as the previous one
//
bool can_reuse = true;

if (gf->n_nodes > ctx->n_nodes_max) {
can_reuse = false;
}

if (gf->n_nodes != ctx->n_nodes_prev) {
can_reuse = false;
}

if (can_reuse) {
for (int i = 0; i < gf->n_nodes; ++i) {
if (memcmp(gf->nodes[i], ctx->cg_nodes + i, sizeof(struct ggml_tensor)) != 0) {
can_reuse = false;
break;
}
}
}

if (!can_reuse) {
ctx->n_nodes_prev = gf->n_nodes;

for (int i = 0; i < gf->n_nodes; ++i) {
memcpy(ctx->cg_nodes + i, gf->nodes[i], sizeof(struct ggml_tensor));
}
}
//////////////////////////////////////////////////////

// number of nodes encoded by the main thread (empirically determined)
const int n_main = 128;

Expand Down Expand Up @@ -5492,78 +5556,126 @@ static enum ggml_status ggml_metal_graph_compute(
}
}

// the main thread commits the first few commands immediately
// cmd_buf[n_cb]
{
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
ctx->cmd_bufs[n_cb].obj = cmd_buf;

[cmd_buf enqueue];
ctx->encode_async(n_cb);
}

// prepare the rest of the command buffers asynchronously
// cmd_buf[0.. n_cb)
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
if (!can_reuse) {
// the main thread commits the first few commands immediately
// cmd_buf[n_cb]
{
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
ctx->cmd_bufs[n_cb].obj = cmd_buf;

// always enqueue the first two command buffers
// enqueue all of the command buffers if we don't need to abort
if (cb_idx < 2 || ctx->abort_callback == NULL) {
[cmd_buf enqueue];
ctx->encode_async(n_cb);
}
}

dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);

// wait for completion and check status of each command buffer
// needed to detect if the device ran out-of-memory for example (#1881)
{
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
[cmd_buf waitUntilCompleted];
// prepare the rest of the command buffers asynchronously
// cmd_buf[0.. n_cb)
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
ctx->cmd_bufs[cb_idx].obj = cmd_buf;

MTLCommandBufferStatus status = [cmd_buf status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
if (status == MTLCommandBufferStatusError) {
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
// always enqueue the first two command buffers
// enqueue all of the command buffers if we don't need to abort
if (cb_idx < 2 || ctx->abort_callback == NULL) {
[cmd_buf enqueue];
}

return GGML_STATUS_FAILED;
}
}

for (int i = 0; i < n_cb; ++i) {
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
[cmd_buf waitUntilCompleted];
dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);

MTLCommandBufferStatus status = [cmd_buf status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
if (status == MTLCommandBufferStatusError) {
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
// encode the command buffer for the next iter while the GPU has already started
{
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
[cmd_buf retain];
if (ctx->cmd_bufs_next[ctx->i_next].obj != nil) {
[ctx->cmd_bufs_next[ctx->i_next].obj release];
}
ctx->cmd_bufs_next[ctx->i_next].obj = cmd_buf;

return GGML_STATUS_FAILED;
ctx->encode_next();
}

id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
if (!next_buffer) {
continue;
// wait for completion and check status of each command buffer
// needed to detect if the device ran out-of-memory for example (#1881)
{
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
[cmd_buf waitUntilCompleted];

MTLCommandBufferStatus status = [cmd_buf status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
if (status == MTLCommandBufferStatusError) {
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
}

return GGML_STATUS_FAILED;
}
}

const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
if (next_queued) {
continue;
for (int i = 0; i < n_cb; ++i) {
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
[cmd_buf waitUntilCompleted];

MTLCommandBufferStatus status = [cmd_buf status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
if (status == MTLCommandBufferStatusError) {
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
}

return GGML_STATUS_FAILED;
}

id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
if (!next_buffer) {
continue;
}

const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
if (next_queued) {
continue;
}

if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
return GGML_STATUS_ABORTED;
}

[next_buffer commit];
}
} else {
struct ggml_metal_command_buffer cmd_buf_cur = ctx->cmd_bufs_next[(ctx->i_next + 1)%2];

// directly submit the command buffer that we have prepared in the previous iteration
[ctx->cmd_bufs_next[(ctx->i_next + 1)%2].obj commit];

if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
return GGML_STATUS_ABORTED;
// encode the command buffer for the next iter
{
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
[cmd_buf retain];
if (ctx->cmd_bufs_next[ctx->i_next].obj != nil) {
[ctx->cmd_bufs_next[ctx->i_next].obj release];
}
ctx->cmd_bufs_next[ctx->i_next].obj = cmd_buf;

ctx->encode_next();
}

[next_buffer commit];
// wait for completion and check status of each command buffer
// needed to detect if the device ran out-of-memory for example (#1881)
{
id<MTLCommandBuffer> cmd_buf = cmd_buf_cur.obj;
[cmd_buf waitUntilCompleted];

MTLCommandBufferStatus status = [cmd_buf status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, ctx->i_next, status);
if (status == MTLCommandBufferStatusError) {
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
}

return GGML_STATUS_FAILED;
}
}
}

if (!should_capture && ctx->capture_started) {
Expand All @@ -5572,6 +5684,8 @@ static enum ggml_status ggml_metal_graph_compute(
}
}

//printf(" time = %.3f ms\n", (float)(ggml_time_us() - t_start)/1000.0f);

return GGML_STATUS_SUCCESS;
}

Expand Down Expand Up @@ -5919,6 +6033,10 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
Block_release(ctx->encode_async);
}

if (ctx->encode_next) {
Block_release(ctx->encode_next);
}

ctx->encode_async = Block_copy(^(size_t iter) {
const int cb_idx = iter;
const int n_cb_l = ctx->n_cb;
Expand Down Expand Up @@ -5967,6 +6085,40 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
[cmd_buf commit];
}
});

ctx->encode_next = Block_copy(^(void) {
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs_next[ctx->i_next].obj;

id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];

int node_start = 0;
int node_end = ctx->gf->n_nodes;

const bool should_capture = ctx->capture_next_compute;

struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs_next[ctx->i_next].mem_pool;
ggml_metal_mem_pool_reset(mem_pool);

for (int idx = node_start; idx < node_end; ++idx) {
if (should_capture) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
}

const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);

if (should_capture) {
[encoder popDebugGroup];
}

if (!res) {
break;
}
}

[encoder endEncoding];

ctx->i_next = (ctx->i_next + 1) % 2;
});
}

static struct ggml_backend_i ggml_backend_metal_i = {
Expand Down
Loading
Loading