Skip to content

Commit

Permalink
Add reduce_max and reduce_min (facebookincubator#942)
Browse files Browse the repository at this point in the history
Summary:

My understanding is that the max / min reduction are usually not used in the trained PT models due to their non-differentiability. Probably, that's the reason why we don't have them reflected as AIT ops. However, they may be useful in inference for a manually composed AIT model (or PT model translated by fx2ait) which was solicited in the issue [facebookincubator#941](facebookincubator#941).

In this diff, we're adding the new `reduce_max` and `reduce_min` ops to AIT. The new ops are using the existing common back-end implementation used by `reduce_sum` with a small extension: the default (initial) accumulator value of 0 is not suitable for the min / max reductions, it is now made configurable in the codegen; `lowest()` and `max()` numeric limits for the ElementCompute are set by the `reduce_max` / `reduce_min` ops' codegen, respectively.

fx2ait wiring of the new ops to the `torch.amax` and `torch.amin` has also been added.

Differential Revision: D49978036
  • Loading branch information
aakhundov authored and facebook-github-bot committed Oct 6, 2023
1 parent 2bf9919 commit 793bc04
Show file tree
Hide file tree
Showing 16 changed files with 817 additions and 35 deletions.
16 changes: 16 additions & 0 deletions fx2ait/fx2ait/acc_tracer/acc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,22 @@ def max_dim_reduce(*, input, dim=None, keepdim=False):
return torch.max(input=input, dim=dim, keepdim=keepdim)


@register_acc_op_properties(AccOpProperty.unary)
@register_acc_op_mapping(op_and_target=("call_function", torch.amax))
@register_acc_op_mapping(op_and_target=("call_method", "amax"))
@register_acc_op
def amax(*, input, dim, keepdim=False):
return torch.amax(input=input, dim=dim, keepdim=keepdim)


@register_acc_op_properties(AccOpProperty.unary)
@register_acc_op_mapping(op_and_target=("call_function", torch.amin))
@register_acc_op_mapping(op_and_target=("call_method", "amin"))
@register_acc_op
def amin(*, input, dim, keepdim=False):
return torch.amin(input=input, dim=dim, keepdim=keepdim)


@register_acc_op_properties(AccOpProperty.pointwise)
@register_acc_op_mapping(op_and_target=("call_function", torch.maximum))
@register_acc_op_mapping(op_and_target=("call_method", "maximum"))
Expand Down
22 changes: 22 additions & 0 deletions fx2ait/fx2ait/converters/ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@
ndhwc3to8,
pad_last_dim,
permute,
reduce_max,
reduce_mean,
reduce_min,
reduce_sum,
reshape,
size,
Expand Down Expand Up @@ -241,6 +243,26 @@ def acc_ops_mean(
return create_reduce_op(reduce_mean, args, kwargs, name)


@ait_converter(acc_ops.amax)
def acc_ops_amax(
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> ConverterOutput:
return create_reduce_op(reduce_max, args, kwargs, name)


@ait_converter(acc_ops.amin)
def acc_ops_amin(
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> ConverterOutput:
return create_reduce_op(reduce_min, args, kwargs, name)


@ait_converter(acc_ops.linear)
def acc_ops_linear(
target: Target,
Expand Down
8 changes: 5 additions & 3 deletions fx2ait/fx2ait/converters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,19 @@ def create_reduce_op(
dims = kwargs.get("dim", None)
if dims is None:
dims = list(range(len(input_val.shape())))
if isinstance(dims, int):
dims = [dims]
if len(dims) < 1:
raise ValueError("No dims to reduce on")
dim = dims[0]
keepdim = False if "keepdim" not in kwargs else kwargs["keepdim"]
sum_val = op_type(dim=dim, keepdim=keepdim)(input_val)
reduced_val = op_type(dim=dim, keepdim=keepdim)(input_val)

if len(dims) > 1:
new_kwargs = {"input": sum_val, "dims": dims[1:]}
new_kwargs = {"input": reduced_val, "dims": dims[1:]}
return create_reduce_op(op_type, args, new_kwargs, name)

return sum_val
return reduced_val


def create_binary_op(
Expand Down
66 changes: 66 additions & 0 deletions fx2ait/fx2ait/test/converters/test_ait_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,69 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
[torch.randn(2, 3, 4).half().cuda()],
expected_ops={},
)


class TestAminConverter(AITTestCase):
@parameterized.expand(
[
["default", (1), False],
["keepdim", (1), True],
["negative_dim", (-1), False],
]
)
def test_amin(self, test_name, dim, keepdim):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.amin(x, dim=dim, keepdim=keepdim)

model = TestModule().cuda()
inputs = [torch.randn(1, 2, 3).half().cuda()]
self.run_test(model, inputs, expected_ops={acc_ops.amin})

@parameterized.expand(
[
["default", None, False],
["specified_dims", (0, 1, 2), False],
]
)
def test_amin_multi_dims(self, test_name, dim, keepdim):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return y + torch.amin(x, dim=dim, keepdim=keepdim)

model = TestModule().cuda()
inputs = [torch.randn(2, 3, 5).half().cuda()] * 2
self.run_test(model, inputs, expected_ops={acc_ops.add, acc_ops.amin})


class TestAmaxConverter(AITTestCase):
@parameterized.expand(
[
["default", (1), False],
["keepdim", (1), True],
["negative_dim", (-1), False],
]
)
def test_amax(self, test_name, dim, keepdim):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.amax(x, dim=dim, keepdim=keepdim)

model = TestModule().cuda()
inputs = [torch.randn(1, 2, 3).half().cuda()]
self.run_test(model, inputs, expected_ops={acc_ops.amax})

@parameterized.expand(
[
["default", None, False],
["specified_dims", (0, 1, 2), False],
]
)
def test_amax_multi_dims(self, test_name, dim, keepdim):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return y + torch.amax(x, dim=dim, keepdim=keepdim)

model = TestModule().cuda()
inputs = [torch.randn(2, 3, 5).half().cuda()] * 2
self.run_test(model, inputs, expected_ops={acc_ops.add, acc_ops.amax})
4 changes: 4 additions & 0 deletions python/aitemplate/backend/cuda/reduce/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from aitemplate.backend.cuda.reduce import (
reduce_3d,
reduce_common,
reduce_max,
reduce_mean,
reduce_min,
reduce_sum,
var,
vector_norm,
Expand All @@ -27,7 +29,9 @@
__all__ = [
"reduce_3d",
"reduce_common",
"reduce_max",
"reduce_mean",
"reduce_min",
"reduce_sum",
"var",
"vector_norm",
Expand Down
17 changes: 15 additions & 2 deletions python/aitemplate/backend/cuda/reduce/reduce_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,17 @@
ReduceScalarOp reduce_s_op;
FragmentCompute frag_compute;
{% if reduction_identity == 'ElementCompute()' %}
// initialize the frag_compute with default values
frag_compute.clear();
{% else %}
// need to initialize the frag_compute with the specific
// reduction_identity values, as those are likely non-default
for (int i = 0; i < kAlignment; ++i) {
frag_compute[i] = {{reduction_identity}};
}
{% endif %}
if (idx_m < args.extent.row()) {
Expand Down Expand Up @@ -414,7 +424,7 @@
// Tree reduction
ElementCompute *smem_ptr = shared_storage.exchange.data() + threadIdx.y * Shape::kColumn;
ElementCompute result = ElementCompute();
ElementCompute result = {{reduction_identity}};
CUTLASS_PRAGMA_UNROLL
for (
Expand Down Expand Up @@ -465,7 +475,7 @@
// Certain shape combinations require an additional reduction step
if (kLgResidual) {
result = ElementCompute();
result = {{reduction_identity}};
int const kResidualVector = (1 << kLgResidual);
cutlass::Array<ElementCompute, kResidualVector> fetch;
Expand Down Expand Up @@ -800,6 +810,7 @@ def gen_function(
epilogue_scalar_template=DEFAULT_EPILOGUE_SCALAR_TEMPLATE,
extra_code_str="",
accumulation_type=None,
reduction_identity="ElementCompute()",
) -> str:
"""a common function for generating a reduce-family kernel
Expand Down Expand Up @@ -927,6 +938,7 @@ def gen_function(
acc_type,
output_accessors,
output_alignment,
reduction_identity,
)
exec_paths = EXEC_COND_TEMPLATE.render(
indent=" ",
Expand All @@ -950,6 +962,7 @@ def gen_function(
kernel_src = KERNEL_SRC_TEMPLATE.render(
extra_code=extra_code_str,
reduce_op=reduce_op,
reduction_identity=reduction_identity,
reduce_kernel_instance=reduce_instance,
alignments=alignments,
prologue_code=prologue_code,
Expand Down
9 changes: 7 additions & 2 deletions python/aitemplate/backend/cuda/reduce/reduce_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
src_dims[0], src_dims[1], src_dims[2], src_dims[3]
);
Layout src_layout(Layout::packed(src_extent));
ElementCompute reduction_identity = ElementCompute();
ElementCompute reduction_identity = {{reduction_identity}};
TensorReduction reduction(src_extent, reduction_axis);
ReductionOp reduction_op = ReductionOp();
assert(dst_ptr);
Expand Down Expand Up @@ -189,7 +189,11 @@ def gen_function_decl(func_attrs):
)


def gen_function(func_attrs, reduction_op):
def gen_function(
func_attrs,
reduction_op,
reduction_identity="ElementCompute()",
):
backend_spec = CUDASpec()
elem_input_type = backend_spec.dtype_to_lib_type(
func_attrs["inputs"][0]._attrs["dtype"]
Expand Down Expand Up @@ -221,6 +225,7 @@ def gen_function(func_attrs, reduction_op):
return SRC_TEMPLATE.render(
func_name=func_attrs["name"],
reduction_op=reduction_op,
reduction_identity=reduction_identity,
exec_paths=exec_paths,
workspace_ptr=workspace_ptr,
accumulation_type=accumulation_type,
Expand Down
113 changes: 113 additions & 0 deletions python/aitemplate/backend/cuda/reduce/reduce_max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
reduce_max kernel:
(1) it invokes reduce_3d kernel for reduction_dim = -1 cases; and
(2) invokes reduce_common for all other cases
We do this because there is huge perf difference between reduce_3d and
reduce_common for different reduction dims. We should consider to unify
our implementation later. Ideally, we should fix the perf issue in
reduce_3d for non-neg-dim cases, because reduce_3d can take prologue and
epilogue so it is more general than reduce_common.
"""

from aitemplate.backend import registry
from aitemplate.backend.cuda.reduce import reduce_3d, reduce_common


def _is_last_reduction_dim(func_attrs):
"""return true if the reduction dim is the last dim (i.e. inner most dim)"""
axes = func_attrs["reduction_axes"]
if not len(axes) == 1:
raise NotImplementedError("Multiple reduction axes are not supported yet")
reduction_dim = axes[0]
# make sure our frontend handle negative dims
assert reduction_dim >= 0, "cannot have negative dim here: {}".format(reduction_dim)
x = func_attrs["inputs"][0]
rank = x._rank()
assert rank >= 1, "rank must >= 1, got: {}".format(rank)
return reduction_dim == rank - 1


@registry.reg("cuda.reduce_max.func_decl")
def reduce_max_gen_function_decl(func_attrs):
"""the registered function for generating reduce_max function declaration
Parameters
----------
func_attrs : Dict[str, Any]
holds attributes of this reduce_max op
Returns
-------
[type] : str
returns the rendered function declaration with appropriate replacements
"""
if _is_last_reduction_dim(func_attrs):
return reduce_3d.gen_function_decl(func_attrs)
else:
return reduce_common.gen_function_decl(func_attrs)


@registry.reg("cuda.reduce_max.gen_function")
def reduce_max_gen_function(func_attrs):
"""the registered function for generating reduce_max kernel and all of
its auxiliary functions
Parameters
----------
func_attrs : Dict[str, Any]
holds attributes of this reduce_max op
Returns
-------
str
returns the rendered code for the complete implementation of this reduce max op
"""
if _is_last_reduction_dim(func_attrs):
return reduce_3d.gen_function(
func_attrs,
reduce_op="cutlass::maximum",
reduction_identity="std::numeric_limits<ElementCompute>::lowest()",
)
else:
return reduce_common.gen_function(
func_attrs,
reduction_op="cutlass::maximum",
reduction_identity="std::numeric_limits<ElementCompute>::lowest()",
)


@registry.reg("cuda.reduce_max.func_call")
def reduce_max_gen_function_call(func_attrs, indent=" "):
"""the registered function for generating a function call to reduce_mean
Parameters
----------
func_attrs : Dict[str, Any]
holds attributes of this reduce_mean op
indent : str, optional
indentation for each line of the rendered code (default " ")
Returns
-------
str
returns rendered code for invoking the reduce op
"""
if _is_last_reduction_dim(func_attrs):
return reduce_3d.gen_function_call(func_attrs, indent)
else:
return reduce_common.gen_function_call(func_attrs, indent)
Loading

0 comments on commit 793bc04

Please sign in to comment.