Skip to content

Commit

Permalink
Compute reduction workspace size with accumulation type (#944)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #944

The reduction ops relying on CUTLASS reduction in the back-end pre-compute the workspace size in the front-end. Previously, the workspace size was computed in terms of the input type. This, however is not consistent with how CUTLASS computes the workspace size. See the `ElementCompute` [here](https://github.com/NVIDIA/cutlass/blob/ff02da266713bd3365aed65c552412e126c040cb/include/cutlass/reduction/device/tensor_reduce_affine_strided.h#L223), which is actually accumulation type. As a result, when `float32` accumulation was used for `float16` or `bfloat16` input type, the workspace size pre-computed was twice smaller than required.

In this diff, the workspace size pre-computation is modified to be done in terms of the accumulation type. As the `use_fp16_acc` flag is set in the backend `Target`, if the target is not set by the time the workspace size should be pre-computed, the `float32` accumulation type is used conservatively.

Reviewed By: ipiszy, chenyang78

Differential Revision: D50060329

fbshipit-source-id: 554d29a37cc9e15a72b2d1c36b781b12ec6313e3
  • Loading branch information
aakhundov authored and facebook-github-bot committed Oct 8, 2023
1 parent 74d813f commit 7e0da3f
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
32 changes: 28 additions & 4 deletions python/aitemplate/compiler/ops/reduce/reduce_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ def _infer_shapes(self, x: Tensor) -> List[IntVar]:
return output_dims

def _compute_ws_size_strided(
self, extent, reduction_axis, vector_length, dtype
self,
extent: List[int],
reduction_axis: int,
vector_length: int,
accumulation_type: str,
) -> int:
"""
Compute workspace size for contiguous reduction kernels.
Expand Down Expand Up @@ -137,19 +141,36 @@ def _reshape_pow2(ext, count):
cta_count_z = (inner_count + threadblock_shape_z - 1) // threadblock_shape_z
if int(cta_count_z) == 1:
return 0
vector_size_bytes = vector_length * get_dtype_size(dtype)
vector_size_bytes = vector_length * get_dtype_size(accumulation_type)
workspace_stride = extent[k_rank - 1] * vector_size_bytes
return workspace_stride * outer_count * cta_count_z

def _compute_workspace_size(
self, shape: List[IntVar], reduction_axis, dtype
self,
shape: List[IntVar],
reduction_axis: int,
input_type: str,
) -> int:
"""
Compute workspace size for the given shape using the same algorithm as
cutlass's TensorReduction kernel. The only difference is that we use
the maximum dim value for dynamic dimension, whereas TensorReduction
uses the real dim value at runtime.
"""
# workspace size must be computed in terms of the accumulation dtype;
# see the CUTLASS code in TensorReductionAffineStrided for the reference:
# https://github.com/NVIDIA/cutlass/blob/ff02da266713bd3365aed65c552412e126c040cb/include/cutlass/reduction/device/tensor_reduce_affine_strided.h#L223
accumulation_type = "float32"
try:
if (
backend.target.Target.current()._kwargs.get("use_fp16_acc", False)
and input_type == "float16"
):
accumulation_type = input_type
except RuntimeError:
# Target is not set: conservatively
# assume float32 accumulation type
pass
# Make sure the last dim is static to pre-compute vector_length.
# Note that this is a temporary constraint. Once we replace TensorReduction
# with our own col-reduction kernel, we will remove this entire workaround.
Expand Down Expand Up @@ -212,7 +233,10 @@ def _compute_workspace_size(
max_ws = max(
max_ws,
self._compute_ws_size_strided(
extent_affine, reduction_axis, vector_length, dtype
extent_affine,
reduction_axis,
vector_length,
accumulation_type,
),
)
return max_ws
Expand Down
10 changes: 10 additions & 0 deletions tests/unittest/ops/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,16 @@ def test_reduce_sum(self):
input_type="float16",
output_type=None,
)
# make sure that the workspace size is computed correctly
# for the fp32 accumulator (use_fp16_acc=False)
self._run_reduce_sum(
dim=1,
input_shape=[1024, 2, 1855],
keepdim=False,
input_type="float16",
output_type=None,
use_fp16_acc=False,
)

def _run_reduce_mean(
self,
Expand Down

0 comments on commit 7e0da3f

Please sign in to comment.