Skip to content

Commit

Permalink
Update GGML_HIP_UMA (#473)
Browse files Browse the repository at this point in the history
Add UMA config for higher speed like in (ggerganov/llama.cpp#7414)
but made 2 changes:

- Remove UMA build option
- Use it in all case if hipalloc failed with 'not have enough memory'

Another change is look for 'hipcc' on linux and not 'amdclang++'
  • Loading branch information
Djip007 committed Jun 20, 2024
1 parent c38feb4 commit a28250b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 19 deletions.
31 changes: 22 additions & 9 deletions llama.cpp/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,7 @@
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
#define cudaHostUnregister hipHostUnregister
#define cudaLaunchHostFunc hipLaunchHostFunc
#ifdef GGML_HIP_UMA
#define cudaMalloc hipMallocManaged
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
#else
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#endif
#define cudaMemcpy hipMemcpy
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
Expand Down Expand Up @@ -10866,6 +10860,25 @@ int ggml_cuda_get_device() {
return id;
}

static inline cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
#if defined(GGML_USE_HIPBLAS)
auto res = hipMalloc(ptr, size);
// if Not enough space on VRAM => try with UMA
if (res == hipErrorOutOfMemory) {
GGML_CUDA_LOG_INFO(" Device %d: can not alloc %d MB on VRAM try alloc on HMM\n", device, (uint32_t)(size / 1024 / 1024));
res = hipMallocManaged(ptr, size);
if (res == hipSuccess) {
// Config the memory for best speed (It's not supposed to fail)
CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
GGML_CUDA_LOG_INFO(" => success\n");
}
}
return res;
#else
return cudaMalloc(ptr, size);
#endif
}

static ggml_cuda_device_info ggml_cuda_init() {
#ifdef __HIP_PLATFORM_AMD__
// Workaround for a rocBLAS bug when using multiple graphics cards:
Expand Down Expand Up @@ -11020,7 +11033,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
size_t look_ahead_size = (size_t) (1.05 * size);
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
ggml_cuda_set_device(device);
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
*actual_size = look_ahead_size;
pool_size += look_ahead_size;
#ifdef DEBUG_CUDA_MALLOC
Expand Down Expand Up @@ -11286,7 +11299,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe
size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0

void * dev_ptr;
cudaError_t err = cudaMalloc(&dev_ptr, size);
cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
if (err != cudaSuccess) {
// clear the error
cudaGetLastError();
Expand Down Expand Up @@ -11547,7 +11560,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu
// currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
ggml_cuda_set_device(id);
char * buf;
CUDA_CHECK(cudaMalloc(&buf, size));
CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));

// set padding to 0 to avoid possible NaN values
if (size > original_size) {
Expand Down
26 changes: 16 additions & 10 deletions llamafile/cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -778,9 +778,9 @@ static bool import_cuda_impl(void) {

char dso[PATH_MAX];
char bindir[PATH_MAX];
const char *compiler_path;
const char *compiler_path = NULL;
char compiler_path_buf[PATH_MAX];
const char *library_path;
const char *library_path = NULL;
char library_path_buf[PATH_MAX];

// Attempt to load AMD GPU support.
Expand All @@ -791,15 +791,21 @@ static bool import_cuda_impl(void) {

// Get some essential paths.
// ROCm SDK puts BLAS DLLs in same folder as clang++
if (get_rocm_bin_path(compiler_path_buf, "amdclang++") ||
get_rocm_bin_path(compiler_path_buf, "clang++")) {
strcpy(library_path_buf, compiler_path_buf);
dirname(library_path_buf);
compiler_path = compiler_path_buf;
library_path = library_path_buf;
if (!IsWindows()) {
if (get_rocm_bin_path(compiler_path_buf, "hipcc")) {
strcpy(library_path_buf, compiler_path_buf);
dirname(library_path_buf);
compiler_path = compiler_path_buf;
library_path = library_path_buf;
}
} else {
compiler_path = 0;
library_path = 0;
if (get_rocm_bin_path(compiler_path_buf, "amdclang++") ||
get_rocm_bin_path(compiler_path_buf, "clang++")) {
strcpy(library_path_buf, compiler_path_buf);
dirname(library_path_buf);
compiler_path = compiler_path_buf;
library_path = library_path_buf;
}
}

// Get path of GGML DSO for AMD.
Expand Down

0 comments on commit a28250b

Please sign in to comment.