Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reduce_min and reduce_max #942

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions fx2ait/fx2ait/acc_tracer/acc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,22 @@ def max_dim_reduce(*, input, dim=None, keepdim=False):
return torch.max(input=input, dim=dim, keepdim=keepdim)


@register_acc_op_properties(AccOpProperty.unary)
@register_acc_op_mapping(op_and_target=("call_function", torch.amax))
@register_acc_op_mapping(op_and_target=("call_method", "amax"))
@register_acc_op
def amax(*, input, dim, keepdim=False):
return torch.amax(input=input, dim=dim, keepdim=keepdim)


@register_acc_op_properties(AccOpProperty.unary)
@register_acc_op_mapping(op_and_target=("call_function", torch.amin))
@register_acc_op_mapping(op_and_target=("call_method", "amin"))
@register_acc_op
def amin(*, input, dim, keepdim=False):
return torch.amin(input=input, dim=dim, keepdim=keepdim)


@register_acc_op_properties(AccOpProperty.pointwise)
@register_acc_op_mapping(op_and_target=("call_function", torch.maximum))
@register_acc_op_mapping(op_and_target=("call_method", "maximum"))
Expand Down
22 changes: 22 additions & 0 deletions fx2ait/fx2ait/converters/ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@
ndhwc3to8,
pad_last_dim,
permute,
reduce_max,
reduce_mean,
reduce_min,
reduce_sum,
reshape,
size,
Expand Down Expand Up @@ -241,6 +243,26 @@ def acc_ops_mean(
return create_reduce_op(reduce_mean, args, kwargs, name)


@ait_converter(acc_ops.amax)
def acc_ops_amax(
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> ConverterOutput:
return create_reduce_op(reduce_max, args, kwargs, name)


@ait_converter(acc_ops.amin)
def acc_ops_amin(
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> ConverterOutput:
return create_reduce_op(reduce_min, args, kwargs, name)


@ait_converter(acc_ops.linear)
def acc_ops_linear(
target: Target,
Expand Down
8 changes: 5 additions & 3 deletions fx2ait/fx2ait/converters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,19 @@ def create_reduce_op(
dims = kwargs.get("dim", None)
if dims is None:
dims = list(range(len(input_val.shape())))
if isinstance(dims, int):
dims = [dims]
if len(dims) < 1:
raise ValueError("No dims to reduce on")
dim = dims[0]
keepdim = False if "keepdim" not in kwargs else kwargs["keepdim"]
sum_val = op_type(dim=dim, keepdim=keepdim)(input_val)
reduced_val = op_type(dim=dim, keepdim=keepdim)(input_val)

if len(dims) > 1:
new_kwargs = {"input": sum_val, "dims": dims[1:]}
new_kwargs = {"input": reduced_val, "dims": dims[1:]}
return create_reduce_op(op_type, args, new_kwargs, name)

return sum_val
return reduced_val


def create_binary_op(
Expand Down
66 changes: 66 additions & 0 deletions fx2ait/fx2ait/test/converters/test_ait_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,69 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
[torch.randn(2, 3, 4).half().cuda()],
expected_ops={},
)


class TestAminConverter(AITTestCase):
@parameterized.expand(
[
["default", (1), False],
["keepdim", (1), True],
["negative_dim", (-1), False],
]
)
def test_amin(self, test_name, dim, keepdim):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.amin(x, dim=dim, keepdim=keepdim)

model = TestModule().cuda()
inputs = [torch.randn(1, 2, 3).half().cuda()]
self.run_test(model, inputs, expected_ops={acc_ops.amin})

@parameterized.expand(
[
["default", None, False],
["specified_dims", (0, 1, 2), False],
]
)
def test_amin_multi_dims(self, test_name, dim, keepdim):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return y + torch.amin(x, dim=dim, keepdim=keepdim)

model = TestModule().cuda()
inputs = [torch.randn(2, 3, 5).half().cuda()] * 2
self.run_test(model, inputs, expected_ops={acc_ops.add, acc_ops.amin})


class TestAmaxConverter(AITTestCase):
@parameterized.expand(
[
["default", (1), False],
["keepdim", (1), True],
["negative_dim", (-1), False],
]
)
def test_amax(self, test_name, dim, keepdim):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.amax(x, dim=dim, keepdim=keepdim)

model = TestModule().cuda()
inputs = [torch.randn(1, 2, 3).half().cuda()]
self.run_test(model, inputs, expected_ops={acc_ops.amax})

@parameterized.expand(
[
["default", None, False],
["specified_dims", (0, 1, 2), False],
]
)
def test_amax_multi_dims(self, test_name, dim, keepdim):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return y + torch.amax(x, dim=dim, keepdim=keepdim)

model = TestModule().cuda()
inputs = [torch.randn(2, 3, 5).half().cuda()] * 2
self.run_test(model, inputs, expected_ops={acc_ops.add, acc_ops.amax})
4 changes: 4 additions & 0 deletions python/aitemplate/backend/cuda/reduce/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from aitemplate.backend.cuda.reduce import (
reduce_3d,
reduce_common,
reduce_max,
reduce_mean,
reduce_min,
reduce_sum,
var,
vector_norm,
Expand All @@ -27,7 +29,9 @@
__all__ = [
"reduce_3d",
"reduce_common",
"reduce_max",
"reduce_mean",
"reduce_min",
"reduce_sum",
"var",
"vector_norm",
Expand Down
17 changes: 15 additions & 2 deletions python/aitemplate/backend/cuda/reduce/reduce_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,17 @@
ReduceScalarOp reduce_s_op;

FragmentCompute frag_compute;

{% if reduction_identity == 'ElementCompute()' %}
// initialize the frag_compute with default values
frag_compute.clear();
{% else %}
// need to initialize the frag_compute with the specific
// reduction_identity values, as those are likely non-default
for (int i = 0; i < kAlignment; ++i) {
frag_compute[i] = {{reduction_identity}};
}
{% endif %}

if (idx_m < args.extent.row()) {

Expand Down Expand Up @@ -414,7 +424,7 @@
// Tree reduction
ElementCompute *smem_ptr = shared_storage.exchange.data() + threadIdx.y * Shape::kColumn;

ElementCompute result = ElementCompute();
ElementCompute result = {{reduction_identity}};

CUTLASS_PRAGMA_UNROLL
for (
Expand Down Expand Up @@ -465,7 +475,7 @@

// Certain shape combinations require an additional reduction step
if (kLgResidual) {
result = ElementCompute();
result = {{reduction_identity}};

int const kResidualVector = (1 << kLgResidual);
cutlass::Array<ElementCompute, kResidualVector> fetch;
Expand Down Expand Up @@ -800,6 +810,7 @@ def gen_function(
epilogue_scalar_template=DEFAULT_EPILOGUE_SCALAR_TEMPLATE,
extra_code_str="",
accumulation_type=None,
reduction_identity="ElementCompute()",
) -> str:
"""a common function for generating a reduce-family kernel

Expand Down Expand Up @@ -927,6 +938,7 @@ def gen_function(
acc_type,
output_accessors,
output_alignment,
reduction_identity,
)
exec_paths = EXEC_COND_TEMPLATE.render(
indent=" ",
Expand All @@ -950,6 +962,7 @@ def gen_function(
kernel_src = KERNEL_SRC_TEMPLATE.render(
extra_code=extra_code_str,
reduce_op=reduce_op,
reduction_identity=reduction_identity,
reduce_kernel_instance=reduce_instance,
alignments=alignments,
prologue_code=prologue_code,
Expand Down
9 changes: 7 additions & 2 deletions python/aitemplate/backend/cuda/reduce/reduce_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
src_dims[0], src_dims[1], src_dims[2], src_dims[3]
);
Layout src_layout(Layout::packed(src_extent));
ElementCompute reduction_identity = ElementCompute();
ElementCompute reduction_identity = {{reduction_identity}};
TensorReduction reduction(src_extent, reduction_axis);
ReductionOp reduction_op = ReductionOp();
assert(dst_ptr);
Expand Down Expand Up @@ -189,7 +189,11 @@ def gen_function_decl(func_attrs):
)


def gen_function(func_attrs, reduction_op):
def gen_function(
func_attrs,
reduction_op,
reduction_identity="ElementCompute()",
):
backend_spec = CUDASpec()
elem_input_type = backend_spec.dtype_to_lib_type(
func_attrs["inputs"][0]._attrs["dtype"]
Expand Down Expand Up @@ -221,6 +225,7 @@ def gen_function(func_attrs, reduction_op):
return SRC_TEMPLATE.render(
func_name=func_attrs["name"],
reduction_op=reduction_op,
reduction_identity=reduction_identity,
exec_paths=exec_paths,
workspace_ptr=workspace_ptr,
accumulation_type=accumulation_type,
Expand Down
113 changes: 113 additions & 0 deletions python/aitemplate/backend/cuda/reduce/reduce_max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
reduce_max kernel:
(1) it invokes reduce_3d kernel for reduction_dim = -1 cases; and
(2) invokes reduce_common for all other cases

We do this because there is huge perf difference between reduce_3d and
reduce_common for different reduction dims. We should consider to unify
our implementation later. Ideally, we should fix the perf issue in
reduce_3d for non-neg-dim cases, because reduce_3d can take prologue and
epilogue so it is more general than reduce_common.
"""

from aitemplate.backend import registry
from aitemplate.backend.cuda.reduce import reduce_3d, reduce_common


def _is_last_reduction_dim(func_attrs):
"""return true if the reduction dim is the last dim (i.e. inner most dim)"""
axes = func_attrs["reduction_axes"]
if not len(axes) == 1:
raise NotImplementedError("Multiple reduction axes are not supported yet")
reduction_dim = axes[0]
# make sure our frontend handle negative dims
assert reduction_dim >= 0, "cannot have negative dim here: {}".format(reduction_dim)
x = func_attrs["inputs"][0]
rank = x._rank()
assert rank >= 1, "rank must >= 1, got: {}".format(rank)
return reduction_dim == rank - 1


@registry.reg("cuda.reduce_max.func_decl")
def reduce_max_gen_function_decl(func_attrs):
"""the registered function for generating reduce_max function declaration

Parameters
----------
func_attrs : Dict[str, Any]
holds attributes of this reduce_max op

Returns
-------
[type] : str
returns the rendered function declaration with appropriate replacements
"""
if _is_last_reduction_dim(func_attrs):
return reduce_3d.gen_function_decl(func_attrs)
else:
return reduce_common.gen_function_decl(func_attrs)


@registry.reg("cuda.reduce_max.gen_function")
def reduce_max_gen_function(func_attrs):
"""the registered function for generating reduce_max kernel and all of
its auxiliary functions

Parameters
----------
func_attrs : Dict[str, Any]
holds attributes of this reduce_max op

Returns
-------
str
returns the rendered code for the complete implementation of this reduce max op
"""
if _is_last_reduction_dim(func_attrs):
return reduce_3d.gen_function(
func_attrs,
reduce_op="cutlass::maximum",
reduction_identity="std::numeric_limits<ElementCompute>::lowest()",
)
else:
return reduce_common.gen_function(
func_attrs,
reduction_op="cutlass::maximum",
reduction_identity="std::numeric_limits<ElementCompute>::lowest()",
)


@registry.reg("cuda.reduce_max.func_call")
def reduce_max_gen_function_call(func_attrs, indent=" "):
"""the registered function for generating a function call to reduce_mean

Parameters
----------
func_attrs : Dict[str, Any]
holds attributes of this reduce_mean op
indent : str, optional
indentation for each line of the rendered code (default " ")

Returns
-------
str
returns rendered code for invoking the reduce op
"""
if _is_last_reduction_dim(func_attrs):
return reduce_3d.gen_function_call(func_attrs, indent)
else:
return reduce_common.gen_function_call(func_attrs, indent)
Loading
Loading