Skip to content

Commit

Permalink
use composite cache key
Browse files Browse the repository at this point in the history
  • Loading branch information
mertalev committed Dec 30, 2024
1 parent 5940861 commit d9f8c1a
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 63 deletions.
149 changes: 149 additions & 0 deletions machine-learning/0001-fix-rocm-conv-thread-safety.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
From 60e2603045f220d560ed92e03c8804a806f5a325 Mon Sep 17 00:00:00 2001
From: mertalev <[email protected]>
Date: Fri, 20 Dec 2024 00:59:21 -0500
Subject: [PATCH] fix: avoid race condition for rocm conv algo caching

---
onnxruntime/core/providers/rocm/nn/conv.cc | 8 ++++----
onnxruntime/core/providers/rocm/nn/conv.h | 13 +++++++++++--
.../core/providers/rocm/nn/conv_transpose.cc | 8 ++++----
3 files changed, 19 insertions(+), 10 deletions(-)

diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc
index d7f47d07a8..98b6b69212 100644
--- a/onnxruntime/core/providers/rocm/nn/conv.cc
+++ b/onnxruntime/core/providers/rocm/nn/conv.cc
@@ -127,7 +127,6 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)

if (w_dims_changed) {
s_.last_w_dims = gsl::make_span(w_dims);
- s_.cached_benchmark_fwd_results.clear();
}

ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, channels_last));
@@ -278,7 +277,8 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
HIP_CALL_THROW(hipMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context)));
}

- if (!s_.cached_benchmark_fwd_results.contains(x_dims_miopen)) {
+ const std::size_t algo_key = HashConvAlgoKey(x_dims_miopen, w_dims);
+ if (!s_.cached_benchmark_fwd_results.contains(algo_key)) {
miopenConvAlgoPerf_t perf;
int algo_count = 1;
const ROCMExecutionProvider* rocm_ep = static_cast<const ROCMExecutionProvider*>(this->Info().GetExecutionProvider());
@@ -301,9 +301,9 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
algo_search_workspace.get(),
max_ws_size,
false)); // Do not do exhaustive algo search.
- s_.cached_benchmark_fwd_results.insert(x_dims_miopen, {perf.fwd_algo, perf.memory});
+ s_.cached_benchmark_fwd_results.insert(algo_key, {perf.fwd_algo, perf.memory});
}
- const auto& perf = s_.cached_benchmark_fwd_results.at(x_dims_miopen);
+ const auto& perf = s_.cached_benchmark_fwd_results.at(algo_key);
s_.fwd_algo = perf.fwd_algo;
s_.workspace_bytes = perf.memory;
} else {
diff --git a/onnxruntime/core/providers/rocm/nn/conv.h b/onnxruntime/core/providers/rocm/nn/conv.h
index bc9846203e..0b07ee3f08 100644
--- a/onnxruntime/core/providers/rocm/nn/conv.h
+++ b/onnxruntime/core/providers/rocm/nn/conv.h
@@ -43,6 +43,10 @@ struct vector_hash {
}
};

+inline std::size_t HashConvAlgoKey(const TensorShapeVector& x_dims, const TensorShapeVector& w_dims) {
+ return std::hash<vector_hash>()(values.x_dims) ^ std::hash<vector_hash>()(values.w_dims);
+}
+
template <typename Key, typename T,
typename Hash = std::hash<Key>,
typename KeyEqual = std::equal_to<Key>,
@@ -52,6 +56,7 @@ class lru_unordered_map {
lru_unordered_map(size_t max_size) : max_size_(max_size) {}

void insert(const Key& key, const T& value) {
+ std::lock_guard<std::mutex> guard(mutex_);
auto it = items_.find(key);
if (it != items_.end()) {
it->second.value = value;
@@ -69,6 +74,7 @@ class lru_unordered_map {
}

T& at(const Key& key) {
+ std::lock_guard<std::mutex> guard(mutex_);
auto it = items_.find(key);
if (it == items_.end()) {
throw std::out_of_range("There is no such key in cache");
@@ -78,6 +84,7 @@ class lru_unordered_map {
}

bool contains(const Key& key) const {
+ std::lock_guard<std::mutex> guard(mutex_);
return items_.find(key) != items_.end();
}

@@ -86,6 +93,7 @@ class lru_unordered_map {
}

void clear() {
+ std::lock_guard<std::mutex> guard(mutex_);
items_.clear();
lru_list_.clear();
}
@@ -106,6 +114,7 @@ class lru_unordered_map {
size_t max_size_;
std::unordered_map<Key, value_type, Hash, KeyEqual, MapAllocator> items_;
list_type lru_list_;
+ mutable std::mutex mutex_;
};

// cached miopen descriptors
@@ -148,8 +157,8 @@ struct MiopenConvState {
decltype(AlgoPerfType().memory) memory;
};

- lru_unordered_map<TensorShapeVector, PerfFwdResultParams, vector_hash> cached_benchmark_fwd_results{MAX_CACHED_ALGO_PERF_RESULTS};
- lru_unordered_map<TensorShapeVector, PerfBwdResultParams, vector_hash> cached_benchmark_bwd_results{MAX_CACHED_ALGO_PERF_RESULTS};
+ lru_unordered_map<std::size_t, PerfFwdResultParams> cached_benchmark_fwd_results{MAX_CACHED_ALGO_PERF_RESULTS};
+ lru_unordered_map<std::size_t, PerfBwdResultParams> cached_benchmark_bwd_results{MAX_CACHED_ALGO_PERF_RESULTS};

// Some properties needed to support asymmetric padded Conv nodes
bool post_slicing_required;
diff --git a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
index 7447113fdf..495aafa200 100644
--- a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
+++ b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
@@ -76,7 +76,6 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy

if (w_dims_changed) {
s_.last_w_dims = gsl::make_span(w_dims);
- s_.cached_benchmark_bwd_results.clear();
}

ConvTransposeAttributes::Prepare p;
@@ -127,7 +126,8 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy

y_data = reinterpret_cast<HipT*>(p.Y->MutableData<T>());

- if (!s_.cached_benchmark_bwd_results.contains(x_dims)) {
+ const std::size_t algo_key = HashConvAlgoKey(x_dims_miopen, w_dims);
+ if (!s_.cached_benchmark_bwd_results.contains(algo_key)) {
IAllocatorUniquePtr<void> algo_search_workspace = GetScratchBuffer<void>(AlgoSearchWorkspaceSize, context->GetComputeStream());

miopenConvAlgoPerf_t perf;
@@ -147,10 +147,10 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
algo_search_workspace.get(),
AlgoSearchWorkspaceSize,
false));
- s_.cached_benchmark_bwd_results.insert(x_dims, {perf.bwd_data_algo, perf.memory});
+ s_.cached_benchmark_bwd_results.insert(algo_key, {perf.bwd_data_algo, perf.memory});
}

- const auto& perf = s_.cached_benchmark_bwd_results.at(x_dims);
+ const auto& perf = s_.cached_benchmark_bwd_results.at(algo_key);
s_.bwd_data_algo = perf.bwd_data_algo;
s_.workspace_bytes = perf.memory;
}
--
2.43.0

58 changes: 0 additions & 58 deletions machine-learning/0001-guard-algo-benchmark-results.patch

This file was deleted.

10 changes: 5 additions & 5 deletions machine-learning/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ RUN wget -nv https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.
ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH}

# Prepare onnxruntime repository & build onnxruntime
RUN git clone --single-branch --branch v1.19.2 --recursive "https://github.com/Microsoft/onnxruntime" onnxruntime
RUN git clone --single-branch --branch v1.20.1 --recursive "https://github.com/Microsoft/onnxruntime" onnxruntime
WORKDIR /code/onnxruntime
# Fix for multi-threading based on comments in https://github.com/microsoft/onnxruntime/pull/19567
COPY ./0001-guard-algo-benchmark-results.patch /tmp/
RUN git apply /tmp/0001-guard-algo-benchmark-results.patch
COPY ./0001-fix-rocm-conv-thread-safety.patch /tmp/
RUN git apply /tmp/0001-fix-rocm-conv-thread-safety.patch

RUN /bin/sh ./dockerfiles/scripts/install_common_deps.sh
# Note: the `parallel` setting uses a substantial amount of RAM
RUN ./build.sh --allow_running_as_root --config Release --build_wheel --update --build --parallel 13 --cmake_extra_defines\
ONNXRUNTIME_VERSION=1.19.2 --use_rocm --rocm_home=/opt/rocm
ONNXRUNTIME_VERSION=1.20.1 --use_rocm --rocm_home=/opt/rocm
RUN mv /code/onnxruntime/build/Linux/Release/dist/*.whl /opt/

FROM builder-${DEVICE} AS builder
Expand Down Expand Up @@ -112,7 +112,7 @@ COPY --from=builder-armnn \
/opt/ann/build.sh \
/opt/armnn/

FROM rocm/dev-ubuntu-24.04:6.2.4-complete AS prod-rocm
FROM rocm/dev-ubuntu-22.04:6.1.2-complete AS prod-rocm


FROM prod-${DEVICE} AS prod
Expand Down

0 comments on commit d9f8c1a

Please sign in to comment.