Skip to content

Commit

Permalink
Make GGML asynchronously cancelable
Browse files Browse the repository at this point in the history
It's now possible to instantly kill jobs. Even if a thread is stuck in a
very long matmul operation, an asynchronous signal will be sent, to make
pthread_exit() run immediately. Doing this does not leak memory. When it
happens, a 503 Service Unavailable response is sent to the client, so it
knows to try again with exponential backoff. It'll help with reliability
in the event of DDOS. It'll also help with prioritization of tasks. With
this capability, we could have an HTTP header where clients volunteer to
be preemptable, which could be useful for jobs running in the background
  • Loading branch information
jart committed Jul 6, 2024
1 parent d7c8e33 commit b3930aa
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 187 deletions.
62 changes: 50 additions & 12 deletions llama.cpp/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1636,7 +1636,7 @@ struct ggml_compute_state_shared {
typedef pthread_t ggml_thread_t;

struct ggml_compute_state {
ggml_thread_t thrd;
_Atomic(ggml_thread_t) thrd;
int ith;
struct ggml_compute_state_shared* shared;
enum ggml_status ec;
Expand Down Expand Up @@ -19127,6 +19127,26 @@ static void print_graph(FILE *f, const struct ggml_cgraph *g, int n_threads) {
}
}


struct ggml_compute_cleanup {
int n_threads;
struct ggml_compute_state * workers;
};

static void ggml_compute_canceled(void *arg) {
struct ggml_compute_cleanup *cleanup = arg;
clear_numa_thread_affinity();
for (int j = 1; j < cleanup->n_threads; j++) {
pthread_t t;
if ((t = atomic_exchange_explicit(&cleanup->workers[j].thrd, 0,
memory_order_relaxed))) {
pthread_cancel(t);
const int rc = ggml_thread_join(t, NULL);
GGML_ASSERT(rc == 0);
}
}
}

enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
{
GGML_ASSERT(cplan);
Expand Down Expand Up @@ -19170,43 +19190,61 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
if (n_threads > 1) {
for (int j = 1; j < n_threads; ++j) {
workers[j] = (struct ggml_compute_state) {
.thrd = 0,
.thrd = ATOMIC_VAR_INIT(0),
.ith = j,
.shared = &state_shared,
.ec = GGML_STATUS_SUCCESS,
.is_main_thread = false, // [jart]
};

const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
const int rc = ggml_thread_create((pthread_t *)&workers[j].thrd, NULL,
ggml_graph_compute_thread, &workers[j]);
GGML_ASSERT(rc == 0);
UNUSED(rc);
}
}

int64_t perf_start_cycles;
int64_t perf_start_time_us;
enum ggml_status compute_status;
struct ggml_compute_cleanup cleanup = {n_threads, workers};
pthread_cleanup_push(ggml_compute_canceled, &cleanup);

workers[0].ith = 0;
workers[0].shared = &state_shared;
workers[0].ec = GGML_STATUS_SUCCESS;
workers[0].is_main_thread = true; // [jart]

const int64_t perf_start_cycles = ggml_perf_cycles();
const int64_t perf_start_time_us = ggml_perf_time_us();
perf_start_cycles = ggml_perf_cycles();
perf_start_time_us = ggml_perf_time_us();

// this is a work thread too
ggml_graph_compute_thread(&workers[0]);
enum ggml_status compute_status = workers[0].ec;

// don't leave affinity set on the main thread
clear_numa_thread_affinity();
compute_status = workers[0].ec;

// join or kill thread pool
if (n_threads > 1) {
for (int j = 1; j < n_threads; j++) {
const int rc = ggml_thread_join(workers[j].thrd, NULL);
int cs;
pthread_setcancelstate(PTHREAD_CANCEL_MASKED, &cs);
for (int j = 1; j < n_threads; j++) {
pthread_t t;
if ((t = atomic_exchange_explicit(&workers[j].thrd, 0,
memory_order_relaxed))) {
const int rc = ggml_thread_join(t, NULL);
if (rc == ECANCELED) {
workers[j].thrd = t;
pthread_exit(PTHREAD_CANCELED);
}
GGML_ASSERT(rc == 0);
if (workers[j].ec != GGML_STATUS_SUCCESS)
compute_status = workers[j].ec;
}
}
pthread_setcancelstate(cs, 0);

// don't leave affinity set on the main thread
clear_numa_thread_affinity();

pthread_cleanup_pop(false);

#ifdef LLAMAFILE_SYNC_REPORT
opstart[cgraph->n_nodes] = rdtsc();
Expand Down
1 change: 0 additions & 1 deletion llamafile/BUILD.mk
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ o/$(MODE)/llamafile/tinyblas_cpu_mixmul_arm82.o: private TARGET_ARCH += -Xaarch6

o/$(MODE)/llamafile/thread_test: \
o/$(MODE)/llamafile/thread_test.o \
o/$(MODE)/llamafile/thread.o \
o/$(MODE)/llamafile/crash.o \
o/$(MODE)/llamafile/dll3.o \

Expand Down
34 changes: 32 additions & 2 deletions llamafile/server/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "client.h"

#include <errno.h>
#include <fcntl.h>
#include <limits.h>
#include <string.h>
#include <sys/uio.h>
Expand Down Expand Up @@ -215,10 +216,18 @@ Client::transport()
return dispatch();
}

void
Client::begin_response()
{
cleanup();
pthread_testcancel();
should_send_error_if_canceled = false;
}

bool
Client::send_error(int code, const char* reason)
{
cleanup();
begin_response();
if (!reason)
reason = GetHttpReason(code);
LOG("error %d %s", code, reason);
Expand Down Expand Up @@ -254,7 +263,7 @@ Client::start_response(char* p, int code, const char* reason)
bool
Client::send_response(char* p0, char* p, string_view content)
{
cleanup();
begin_response();

// append date header
tm tm;
Expand Down Expand Up @@ -362,8 +371,29 @@ Client::read_payload()
return true;
}

static void
cancel_http_request(void* arg)
{
Client* client = (Client*)arg;
if (client->should_send_error_if_canceled) {
fcntl(client->fd, F_SETFL, fcntl(client->fd, F_GETFL) | O_NONBLOCK);
client->send_error(503);
}
}

bool
Client::dispatch()
{
bool res;
should_send_error_if_canceled = true;
pthread_cleanup_push(cancel_http_request, this);
res = dispatcher();
pthread_cleanup_pop(false);
return res;
}

bool
Client::dispatcher()
{
if (path() == "/tokenize")
return tokenize();
Expand Down
5 changes: 4 additions & 1 deletion llamafile/server/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct Client
{
int fd = -1;
bool close_connection = false;
bool should_send_error_if_canceled;
size_t unread = 0;
timespec message_started;
HttpMessage msg;
Expand All @@ -72,10 +73,11 @@ struct Client
bool read_request() __wur;
bool read_content() __wur;
bool send_continue() __wur;
void begin_response();
bool send(const ctl::string_view) __wur;
void defer_cleanup(void (*)(void*), void*);
bool send_error(int, const char* = nullptr);
char* start_response(char*, int, const char* = nullptr);
bool send_error(int, const char* = nullptr) __wur;
bool send_response(char*, char*, const ctl::string_view) __wur;
bool send2(const ctl::string_view, const ctl::string_view) __wur;
char* append_header(const ctl::string_view, const ctl::string_view);
Expand All @@ -87,5 +89,6 @@ struct Client
bool dispatch() __wur;
bool tokenize() __wur;
bool embedding() __wur;
bool dispatcher() __wur;
bool get_embedding_params(EmbeddingParams*);
};
23 changes: 15 additions & 8 deletions llamafile/server/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,9 @@ add_token_to_batch(struct llama_batch& batch,
}

void
cleanup_llama_batch(void* arg)
cleanup_float_vector(void* arg)
{
llama_batch* batch = (llama_batch*)arg;
llama_batch_free(*batch);
delete batch;
delete (ctl::vector<float>*)arg;
}

void
Expand All @@ -81,6 +79,14 @@ cleanup_token_vector(void* arg)
delete (ctl::vector<llama_token>*)arg;
}

void
cleanup_llama_batch(void* arg)
{
llama_batch* batch = (llama_batch*)arg;
llama_batch_free(*batch);
delete batch;
}

void
cleanup_llama_context(void* arg)
{
Expand Down Expand Up @@ -211,7 +217,8 @@ Client::embedding()
LOG("llama_decode failed");
return send_error(500);
}
ctl::vector<float> embeddings(n_embd, 0);
auto embeddings = new ctl::vector<float>(n_embd, 0);
defer_cleanup(cleanup_float_vector, embeddings);
for (int i = 0; i < batch->n_tokens; i++) {
if (!batch->logits[i])
continue;
Expand All @@ -221,7 +228,7 @@ Client::embedding()
return send_error(500);
}
normalize_embeddings(
embd, &embeddings[0] + batch->seq_id[i][0] * n_embd, n_embd);
embd, embeddings->data() + batch->seq_id[i][0] * n_embd, n_embd);
}

// serialize tokens to json
Expand All @@ -240,12 +247,12 @@ Client::embedding()
p = encode_json(p, count);
p = stpcpy(p, ",\n");
p = stpcpy(p, " \"embedding\": [");
for (size_t i = 0; i < embeddings.size(); ++i) {
for (size_t i = 0; i < embeddings->size(); ++i) {
if (i) {
*p++ = ',';
*p++ = ' ';
}
p = encode_json(p, embeddings[i]);
p = encode_json(p, (*embeddings)[i]);
}
p = stpcpy(p, "]\r\n");
p = stpcpy(p, "}\r\n");
Expand Down
10 changes: 6 additions & 4 deletions llamafile/server/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@
#include "signals.h"
#include "time.h"

extern "C" void
_pthread_decimate(void);

Server* g_server;
llama_model* g_model;

int
main(int argc, char* argv[])
{
ShowCrashReports();
llamafile_check_cpu();
if (llamafile_has(argv, "--version")) {
puts("llamafile-server v" LLAMAFILE_VERSION_STRING);
Expand All @@ -47,6 +45,10 @@ main(int argc, char* argv[])
llamafile_get_flags(argc, argv);
time_init();

// we must disable the llama.cpp logger
// otherwise pthread_cancel() will cause deadlocks
FLAG_log_disable = true;

// load model
llama_model_params mparams = {
.n_gpu_layers = FLAG_n_gpu_layers,
Expand Down Expand Up @@ -92,6 +94,6 @@ main(int argc, char* argv[])

// quality assurance
while (!pthread_orphan_np())
_pthread_decimate();
pthread_decimate_np();
CheckForMemoryLeaks();
}
Loading

0 comments on commit b3930aa

Please sign in to comment.