diff --git a/fx2ait/fx2ait/acc_tracer/acc_ops.py b/fx2ait/fx2ait/acc_tracer/acc_ops.py index b202777de..041ebf9f9 100644 --- a/fx2ait/fx2ait/acc_tracer/acc_ops.py +++ b/fx2ait/fx2ait/acc_tracer/acc_ops.py @@ -1556,6 +1556,22 @@ def max_dim_reduce(*, input, dim=None, keepdim=False): return torch.max(input=input, dim=dim, keepdim=keepdim) +@register_acc_op_properties(AccOpProperty.unary) +@register_acc_op_mapping(op_and_target=("call_function", torch.amax)) +@register_acc_op_mapping(op_and_target=("call_method", "amax")) +@register_acc_op +def amax(*, input, dim, keepdim=False): + return torch.amax(input=input, dim=dim, keepdim=keepdim) + + +@register_acc_op_properties(AccOpProperty.unary) +@register_acc_op_mapping(op_and_target=("call_function", torch.amin)) +@register_acc_op_mapping(op_and_target=("call_method", "amin")) +@register_acc_op +def amin(*, input, dim, keepdim=False): + return torch.amin(input=input, dim=dim, keepdim=keepdim) + + @register_acc_op_properties(AccOpProperty.pointwise) @register_acc_op_mapping(op_and_target=("call_function", torch.maximum)) @register_acc_op_mapping(op_and_target=("call_method", "maximum")) diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index 3965341b0..1839adcbf 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -51,7 +51,9 @@ ndhwc3to8, pad_last_dim, permute, + reduce_max, reduce_mean, + reduce_min, reduce_sum, reshape, size, @@ -241,6 +243,26 @@ def acc_ops_mean( return create_reduce_op(reduce_mean, args, kwargs, name) +@ait_converter(acc_ops.amax) +def acc_ops_amax( + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> ConverterOutput: + return create_reduce_op(reduce_max, args, kwargs, name) + + +@ait_converter(acc_ops.amin) +def acc_ops_amin( + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> ConverterOutput: + return create_reduce_op(reduce_min, args, kwargs, name) + + @ait_converter(acc_ops.linear) def acc_ops_linear( target: Target, diff --git a/fx2ait/fx2ait/converters/utils.py b/fx2ait/fx2ait/converters/utils.py index c94f62102..d64cb084a 100644 --- a/fx2ait/fx2ait/converters/utils.py +++ b/fx2ait/fx2ait/converters/utils.py @@ -49,17 +49,19 @@ def create_reduce_op( dims = kwargs.get("dim", None) if dims is None: dims = list(range(len(input_val.shape()))) + if isinstance(dims, int): + dims = [dims] if len(dims) < 1: raise ValueError("No dims to reduce on") dim = dims[0] keepdim = False if "keepdim" not in kwargs else kwargs["keepdim"] - sum_val = op_type(dim=dim, keepdim=keepdim)(input_val) + reduced_val = op_type(dim=dim, keepdim=keepdim)(input_val) if len(dims) > 1: - new_kwargs = {"input": sum_val, "dims": dims[1:]} + new_kwargs = {"input": reduced_val, "dims": dims[1:]} return create_reduce_op(op_type, args, new_kwargs, name) - return sum_val + return reduced_val def create_binary_op( diff --git a/fx2ait/fx2ait/test/converters/test_ait_reduce.py b/fx2ait/fx2ait/test/converters/test_ait_reduce.py index 009b7cfa4..d5ca1c328 100644 --- a/fx2ait/fx2ait/test/converters/test_ait_reduce.py +++ b/fx2ait/fx2ait/test/converters/test_ait_reduce.py @@ -113,3 +113,69 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: [torch.randn(2, 3, 4).half().cuda()], expected_ops={}, ) + + +class TestAminConverter(AITTestCase): + @parameterized.expand( + [ + ["default", (1), False], + ["keepdim", (1), True], + ["negative_dim", (-1), False], + ] + ) + def test_amin(self, test_name, dim, keepdim): + class TestModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.amin(x, dim=dim, keepdim=keepdim) + + model = TestModule().cuda() + inputs = [torch.randn(1, 2, 3).half().cuda()] + self.run_test(model, inputs, expected_ops={acc_ops.amin}) + + @parameterized.expand( + [ + ["default", None, False], + ["specified_dims", (0, 1, 2), False], + ] + ) + def test_amin_multi_dims(self, test_name, dim, keepdim): + class TestModule(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return y + torch.amin(x, dim=dim, keepdim=keepdim) + + model = TestModule().cuda() + inputs = [torch.randn(2, 3, 5).half().cuda()] * 2 + self.run_test(model, inputs, expected_ops={acc_ops.add, acc_ops.amin}) + + +class TestAmaxConverter(AITTestCase): + @parameterized.expand( + [ + ["default", (1), False], + ["keepdim", (1), True], + ["negative_dim", (-1), False], + ] + ) + def test_amax(self, test_name, dim, keepdim): + class TestModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.amax(x, dim=dim, keepdim=keepdim) + + model = TestModule().cuda() + inputs = [torch.randn(1, 2, 3).half().cuda()] + self.run_test(model, inputs, expected_ops={acc_ops.amax}) + + @parameterized.expand( + [ + ["default", None, False], + ["specified_dims", (0, 1, 2), False], + ] + ) + def test_amax_multi_dims(self, test_name, dim, keepdim): + class TestModule(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return y + torch.amax(x, dim=dim, keepdim=keepdim) + + model = TestModule().cuda() + inputs = [torch.randn(2, 3, 5).half().cuda()] * 2 + self.run_test(model, inputs, expected_ops={acc_ops.add, acc_ops.amax}) diff --git a/python/aitemplate/backend/cuda/reduce/__init__.py b/python/aitemplate/backend/cuda/reduce/__init__.py index feb5cde4c..9aa2a5bf2 100644 --- a/python/aitemplate/backend/cuda/reduce/__init__.py +++ b/python/aitemplate/backend/cuda/reduce/__init__.py @@ -18,7 +18,9 @@ from aitemplate.backend.cuda.reduce import ( reduce_3d, reduce_common, + reduce_max, reduce_mean, + reduce_min, reduce_sum, var, vector_norm, @@ -27,7 +29,9 @@ __all__ = [ "reduce_3d", "reduce_common", + "reduce_max", "reduce_mean", + "reduce_min", "reduce_sum", "var", "vector_norm", diff --git a/python/aitemplate/backend/cuda/reduce/reduce_3d.py b/python/aitemplate/backend/cuda/reduce/reduce_3d.py index a259d3974..04a81416a 100644 --- a/python/aitemplate/backend/cuda/reduce/reduce_3d.py +++ b/python/aitemplate/backend/cuda/reduce/reduce_3d.py @@ -323,7 +323,17 @@ ReduceScalarOp reduce_s_op; FragmentCompute frag_compute; + +{% if reduction_identity == 'ElementCompute()' %} + // initialize the frag_compute with default values frag_compute.clear(); +{% else %} + // need to initialize the frag_compute with the specific + // reduction_identity values, as those are likely non-default + for (int i = 0; i < kAlignment; ++i) { + frag_compute[i] = {{reduction_identity}}; + } +{% endif %} if (idx_m < args.extent.row()) { @@ -414,7 +424,7 @@ // Tree reduction ElementCompute *smem_ptr = shared_storage.exchange.data() + threadIdx.y * Shape::kColumn; - ElementCompute result = ElementCompute(); + ElementCompute result = {{reduction_identity}}; CUTLASS_PRAGMA_UNROLL for ( @@ -465,7 +475,7 @@ // Certain shape combinations require an additional reduction step if (kLgResidual) { - result = ElementCompute(); + result = {{reduction_identity}}; int const kResidualVector = (1 << kLgResidual); cutlass::Array fetch; @@ -800,6 +810,7 @@ def gen_function( epilogue_scalar_template=DEFAULT_EPILOGUE_SCALAR_TEMPLATE, extra_code_str="", accumulation_type=None, + reduction_identity="ElementCompute()", ) -> str: """a common function for generating a reduce-family kernel @@ -927,6 +938,7 @@ def gen_function( acc_type, output_accessors, output_alignment, + reduction_identity, ) exec_paths = EXEC_COND_TEMPLATE.render( indent=" ", @@ -950,6 +962,7 @@ def gen_function( kernel_src = KERNEL_SRC_TEMPLATE.render( extra_code=extra_code_str, reduce_op=reduce_op, + reduction_identity=reduction_identity, reduce_kernel_instance=reduce_instance, alignments=alignments, prologue_code=prologue_code, diff --git a/python/aitemplate/backend/cuda/reduce/reduce_common.py b/python/aitemplate/backend/cuda/reduce/reduce_common.py index 43e40c7e7..161f93427 100644 --- a/python/aitemplate/backend/cuda/reduce/reduce_common.py +++ b/python/aitemplate/backend/cuda/reduce/reduce_common.py @@ -127,7 +127,7 @@ src_dims[0], src_dims[1], src_dims[2], src_dims[3] ); Layout src_layout(Layout::packed(src_extent)); - ElementCompute reduction_identity = ElementCompute(); + ElementCompute reduction_identity = {{reduction_identity}}; TensorReduction reduction(src_extent, reduction_axis); ReductionOp reduction_op = ReductionOp(); assert(dst_ptr); @@ -189,7 +189,11 @@ def gen_function_decl(func_attrs): ) -def gen_function(func_attrs, reduction_op): +def gen_function( + func_attrs, + reduction_op, + reduction_identity="ElementCompute()", +): backend_spec = CUDASpec() elem_input_type = backend_spec.dtype_to_lib_type( func_attrs["inputs"][0]._attrs["dtype"] @@ -221,6 +225,7 @@ def gen_function(func_attrs, reduction_op): return SRC_TEMPLATE.render( func_name=func_attrs["name"], reduction_op=reduction_op, + reduction_identity=reduction_identity, exec_paths=exec_paths, workspace_ptr=workspace_ptr, accumulation_type=accumulation_type, diff --git a/python/aitemplate/backend/cuda/reduce/reduce_max.py b/python/aitemplate/backend/cuda/reduce/reduce_max.py new file mode 100644 index 000000000..1b7961f5a --- /dev/null +++ b/python/aitemplate/backend/cuda/reduce/reduce_max.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +reduce_max kernel: + (1) it invokes reduce_3d kernel for reduction_dim = -1 cases; and + (2) invokes reduce_common for all other cases + +We do this because there is huge perf difference between reduce_3d and +reduce_common for different reduction dims. We should consider to unify +our implementation later. Ideally, we should fix the perf issue in +reduce_3d for non-neg-dim cases, because reduce_3d can take prologue and +epilogue so it is more general than reduce_common. +""" + +from aitemplate.backend import registry +from aitemplate.backend.cuda.reduce import reduce_3d, reduce_common + + +def _is_last_reduction_dim(func_attrs): + """return true if the reduction dim is the last dim (i.e. inner most dim)""" + axes = func_attrs["reduction_axes"] + if not len(axes) == 1: + raise NotImplementedError("Multiple reduction axes are not supported yet") + reduction_dim = axes[0] + # make sure our frontend handle negative dims + assert reduction_dim >= 0, "cannot have negative dim here: {}".format(reduction_dim) + x = func_attrs["inputs"][0] + rank = x._rank() + assert rank >= 1, "rank must >= 1, got: {}".format(rank) + return reduction_dim == rank - 1 + + +@registry.reg("cuda.reduce_max.func_decl") +def reduce_max_gen_function_decl(func_attrs): + """the registered function for generating reduce_max function declaration + + Parameters + ---------- + func_attrs : Dict[str, Any] + holds attributes of this reduce_max op + + Returns + ------- + [type] : str + returns the rendered function declaration with appropriate replacements + """ + if _is_last_reduction_dim(func_attrs): + return reduce_3d.gen_function_decl(func_attrs) + else: + return reduce_common.gen_function_decl(func_attrs) + + +@registry.reg("cuda.reduce_max.gen_function") +def reduce_max_gen_function(func_attrs): + """the registered function for generating reduce_max kernel and all of + its auxiliary functions + + Parameters + ---------- + func_attrs : Dict[str, Any] + holds attributes of this reduce_max op + + Returns + ------- + str + returns the rendered code for the complete implementation of this reduce max op + """ + if _is_last_reduction_dim(func_attrs): + return reduce_3d.gen_function( + func_attrs, + reduce_op="cutlass::maximum", + reduction_identity="std::numeric_limits::lowest()", + ) + else: + return reduce_common.gen_function( + func_attrs, + reduction_op="cutlass::maximum", + reduction_identity="std::numeric_limits::lowest()", + ) + + +@registry.reg("cuda.reduce_max.func_call") +def reduce_max_gen_function_call(func_attrs, indent=" "): + """the registered function for generating a function call to reduce_mean + + Parameters + ---------- + func_attrs : Dict[str, Any] + holds attributes of this reduce_mean op + indent : str, optional + indentation for each line of the rendered code (default " ") + + Returns + ------- + str + returns rendered code for invoking the reduce op + """ + if _is_last_reduction_dim(func_attrs): + return reduce_3d.gen_function_call(func_attrs, indent) + else: + return reduce_common.gen_function_call(func_attrs, indent) diff --git a/python/aitemplate/backend/cuda/reduce/reduce_min.py b/python/aitemplate/backend/cuda/reduce/reduce_min.py new file mode 100644 index 000000000..933cf0e5e --- /dev/null +++ b/python/aitemplate/backend/cuda/reduce/reduce_min.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +reduce_min kernel: + (1) it invokes reduce_3d kernel for reduction_dim = -1 cases; and + (2) invokes reduce_common for all other cases + +We do this because there is huge perf difference between reduce_3d and +reduce_common for different reduction dims. We should consider to unify +our implementation later. Ideally, we should fix the perf issue in +reduce_3d for non-neg-dim cases, because reduce_3d can take prologue and +epilogue so it is more general than reduce_common. +""" + +from aitemplate.backend import registry +from aitemplate.backend.cuda.reduce import reduce_3d, reduce_common + + +def _is_last_reduction_dim(func_attrs): + """return true if the reduction dim is the last dim (i.e. inner most dim)""" + axes = func_attrs["reduction_axes"] + if not len(axes) == 1: + raise NotImplementedError("Multiple reduction axes are not supported yet") + reduction_dim = axes[0] + # make sure our frontend handle negative dims + assert reduction_dim >= 0, "cannot have negative dim here: {}".format(reduction_dim) + x = func_attrs["inputs"][0] + rank = x._rank() + assert rank >= 1, "rank must >= 1, got: {}".format(rank) + return reduction_dim == rank - 1 + + +@registry.reg("cuda.reduce_min.func_decl") +def reduce_min_gen_function_decl(func_attrs): + """the registered function for generating reduce_min function declaration + + Parameters + ---------- + func_attrs : Dict[str, Any] + holds attributes of this reduce_min op + + Returns + ------- + [type] : str + returns the rendered function declaration with appropriate replacements + """ + if _is_last_reduction_dim(func_attrs): + return reduce_3d.gen_function_decl(func_attrs) + else: + return reduce_common.gen_function_decl(func_attrs) + + +@registry.reg("cuda.reduce_min.gen_function") +def reduce_min_gen_function(func_attrs): + """the registered function for generating reduce_min kernel and all of + its auxiliary functions + + Parameters + ---------- + func_attrs : Dict[str, Any] + holds attributes of this reduce_min op + + Returns + ------- + str + returns the rendered code for the complete implementation of this reduce min op + """ + if _is_last_reduction_dim(func_attrs): + return reduce_3d.gen_function( + func_attrs, + reduce_op="cutlass::minimum", + reduction_identity="std::numeric_limits::max()", + ) + else: + return reduce_common.gen_function( + func_attrs, + reduction_op="cutlass::minimum", + reduction_identity="std::numeric_limits::max()", + ) + + +@registry.reg("cuda.reduce_min.func_call") +def reduce_min_gen_function_call(func_attrs, indent=" "): + """the registered function for generating a function call to reduce_mean + + Parameters + ---------- + func_attrs : Dict[str, Any] + holds attributes of this reduce_mean op + indent : str, optional + indentation for each line of the rendered code (default " ") + + Returns + ------- + str + returns rendered code for invoking the reduce op + """ + if _is_last_reduction_dim(func_attrs): + return reduce_3d.gen_function_call(func_attrs, indent) + else: + return reduce_common.gen_function_call(func_attrs, indent) diff --git a/python/aitemplate/backend/cuda/reduce/reduce_small_axis.py b/python/aitemplate/backend/cuda/reduce/reduce_small_axis.py index 2db0f5524..e21a91557 100644 --- a/python/aitemplate/backend/cuda/reduce/reduce_small_axis.py +++ b/python/aitemplate/backend/cuda/reduce/reduce_small_axis.py @@ -80,7 +80,8 @@ const ElementInput *input, int64_t num_rows, int64_t batch_stride_input, - int64_t batch_stride_output) { + int64_t batch_stride_output, + ElementCompute reduction_identity) { int block_batch = blockIdx.y; // index within the batch const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -133,7 +134,7 @@ CUTLASS_PRAGMA_UNROLL for (int64_t i = 0; i < num_elems_per_thread / num_cols; i++) { static_assert(num_elems_per_thread % num_rows_per_thread == 0); - FragmentCompute frag_compute = FragmentCompute(0); + FragmentCompute frag_compute = FragmentCompute(reduction_identity); CUTLASS_PRAGMA_UNROLL for (int64_t j = 0; j < num_cols; j++) { int64_t read_idx = i * num_cols + j; @@ -166,13 +167,13 @@ {% endif %} } -template void reduce_mean_launcher_small_axis( - ElemOutputType *output, - ElemInputType *input, + ElementOutput *output, + ElementInput *input, int64_t num_batches, int64_t num_rows, int64_t batch_stride_input, @@ -180,16 +181,16 @@ cudaStream_t stream ) { constexpr int64_t num_read_v = - sizeof({{read_vec_type}}) / sizeof(ElemInputType); + sizeof({{read_vec_type}}) / sizeof(ElementInput); constexpr int64_t row_gcd = std::gcd(num_cols, num_read_v); constexpr int64_t num_rows_per_thread = num_read_v / row_gcd; {% if output_accessor.is_contiguous %} constexpr int64_t num_write_bytes_v = - num_rows_per_thread * sizeof(ElemOutputType); + num_rows_per_thread * sizeof(ElementOutput); {% else %} constexpr int64_t num_write_bytes_v = std::min(num_rows_per_thread, static_cast({{output_access_alignment}})) * - sizeof(ElemOutputType); + sizeof(ElementOutput); {% endif %} assert(num_rows % num_rows_per_thread == 0); @@ -201,9 +202,9 @@ #define HANDLE_ONE_WRITE_VEC(write_bytes, write_vec_type) \\ if (write_bytes == num_write_bytes_v) { \\ - reduce_small_in_v_out_v) { + if constexpr (std::is_same_v) { HANDLE_ONE_WRITE_VEC(2, cutlass::half_t) } - else if constexpr (std::is_same_v) { + else if constexpr (std::is_same_v) { HANDLE_ONE_WRITE_VEC(2, cutlass::bfloat16_t) } throw std::runtime_error("unsupported vector size for write"); @@ -232,10 +234,10 @@ } } -template +template void reduce_mean_launcher_small_axis_column_major( - ElemOutputType *output, - ElemInputType *input, + ElementOutput *output, + ElementInput *input, int64_t num_batches, int64_t num_rows, int64_t num_columns, @@ -379,6 +381,7 @@ def get_exec_cond_and_kernel( acc_type, output_accessors, output_alignment, + reduction_identity, ) -> str: """return a pair that contains the execution condition for this special reduction kernel and the source code of this reduction kernel @@ -450,6 +453,7 @@ def get_exec_cond_and_kernel( ) kernel_src = KERNEL_SRC_TEMPLATE.render( reduce_op=reduce_op, + reduction_identity=reduction_identity, prologue_code=prologue_code, epilogue_scalar_code=epilogue_scalar_code, read_vec_type=read_vec_type, diff --git a/python/aitemplate/compiler/ops/reduce/__init__.py b/python/aitemplate/compiler/ops/reduce/__init__.py index 1fdff06c8..335c329c2 100644 --- a/python/aitemplate/compiler/ops/reduce/__init__.py +++ b/python/aitemplate/compiler/ops/reduce/__init__.py @@ -15,10 +15,19 @@ """ Reduce module init. """ +from aitemplate.compiler.ops.reduce.reduce_max import reduce_max from aitemplate.compiler.ops.reduce.reduce_mean import reduce_mean +from aitemplate.compiler.ops.reduce.reduce_min import reduce_min from aitemplate.compiler.ops.reduce.reduce_sum import reduce_sum from aitemplate.compiler.ops.reduce.var import var from aitemplate.compiler.ops.reduce.vector_norm import vector_norm -__all__ = ["reduce_mean", "reduce_sum", "var", "vector_norm"] +__all__ = [ + "reduce_max", + "reduce_mean", + "reduce_min", + "reduce_sum", + "var", + "vector_norm", +] diff --git a/python/aitemplate/compiler/ops/reduce/reduce_common.py b/python/aitemplate/compiler/ops/reduce/reduce_common.py index 1dafa717f..c7054b2f5 100644 --- a/python/aitemplate/compiler/ops/reduce/reduce_common.py +++ b/python/aitemplate/compiler/ops/reduce/reduce_common.py @@ -253,7 +253,7 @@ def __call__(self, x: Tensor) -> Tensor: # Note that this is a temprary solution only for col-reduction reduce_sum # kernels that invoke cutlass's TensorReduction kernel. Once we have our # own implementation, we will remove the workaround. - if self._attrs["op"] == "reduce_sum" and ( + if self._attrs["op"] in ("reduce_sum", "reduce_min", "reduce_max") and ( self._attrs["reduction_axes"][0] != input_rank - 1 ): ws_size = self._compute_workspace_size( diff --git a/python/aitemplate/compiler/ops/reduce/reduce_max.py b/python/aitemplate/compiler/ops/reduce/reduce_max.py new file mode 100644 index 000000000..3a58ae3e0 --- /dev/null +++ b/python/aitemplate/compiler/ops/reduce/reduce_max.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +reduce_max op +""" +from aitemplate.compiler.ops.reduce.reduce_common import reduce_base + +# pylint: disable=C0103 + + +class reduce_max(reduce_base): + """ + Implements the reduce_max op. + + * .attr.:`dim` : int or tuple of python:ints + the dimension or dimensions to reduce + + * .attr.:`keepdim` : bool, optional + keep the reduced dimensions if True, default is False + + * .attr.:`dtype` : str, optional + the type of the return tensor. If it is not None, + the input tensor is cast to dtype before reduction. + + Args: + input (Tensor): the input tensor. + + Return: + Tensor that contains the max of all elements in the input tensor. + """ + + def __init__(self, dim, keepdim=False, dtype=None) -> None: + super().__init__(dim, keepdim, dtype) + self._attrs["op"] = "reduce_max" diff --git a/python/aitemplate/compiler/ops/reduce/reduce_min.py b/python/aitemplate/compiler/ops/reduce/reduce_min.py new file mode 100644 index 000000000..f1ef6c2e6 --- /dev/null +++ b/python/aitemplate/compiler/ops/reduce/reduce_min.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +reduce_min op +""" +from aitemplate.compiler.ops.reduce.reduce_common import reduce_base + +# pylint: disable=C0103 + + +class reduce_min(reduce_base): + """ + Implements the reduce_min op. + + * .attr.:`dim` : int or tuple of python:ints + the dimension or dimensions to reduce + + * .attr.:`keepdim` : bool, optional + keep the reduced dimensions if True, default is False + + * .attr.:`dtype` : str, optional + the type of the return tensor. If it is not None, + the input tensor is cast to dtype before reduction. + + Args: + input (Tensor): the input tensor. + + Return: + Tensor that contains the min of all elements in the input tensor. + """ + + def __init__(self, dim, keepdim=False, dtype=None) -> None: + super().__init__(dim, keepdim, dtype) + self._attrs["op"] = "reduce_min" diff --git a/python/aitemplate/compiler/public/__init__.py b/python/aitemplate/compiler/public/__init__.py index f8d4e4e9e..9d9f9bc33 100644 --- a/python/aitemplate/compiler/public/__init__.py +++ b/python/aitemplate/compiler/public/__init__.py @@ -44,7 +44,9 @@ from aitemplate.compiler.ops.gemm_universal.gemm_rrr import gemm_rrr """Reduce""" +from aitemplate.compiler.ops.reduce.reduce_max import reduce_max from aitemplate.compiler.ops.reduce.reduce_mean import reduce_mean +from aitemplate.compiler.ops.reduce.reduce_min import reduce_min from aitemplate.compiler.ops.reduce.reduce_sum import reduce_sum from aitemplate.compiler.ops.reduce.var import var from aitemplate.compiler.ops.reduce.vector_norm import vector_norm diff --git a/tests/unittest/ops/test_reduce.py b/tests/unittest/ops/test_reduce.py index 38fe1522e..9e1f226e4 100644 --- a/tests/unittest/ops/test_reduce.py +++ b/tests/unittest/ops/test_reduce.py @@ -74,11 +74,13 @@ def _run_reduce( dll_name = f"test_{self.test_count}.so" module = compile_model(Y, target, "./tmp", test_name, dll_name=dll_name) X_pt = get_random_torch_tensor(input_shape, input_type) - dtype_pt = string_to_torch_dtype(output_type) - if keepdim is None: - Y_pt = torch_reduce_op(X_pt, dim, dtype=dtype_pt) - else: - Y_pt = torch_reduce_op(X_pt, dim, keepdim=keepdim, dtype=dtype_pt) + pt_args = [X_pt, dim] + pt_kwargs = {} + if keepdim is not None: + pt_kwargs["keepdim"] = keepdim + if torch_reduce_op not in (torch.amin, torch.amax): + pt_kwargs["dtype"] = string_to_torch_dtype(output_type) + Y_pt = torch_reduce_op(*pt_args, **pt_kwargs) y = torch.empty_like(Y_pt) module.run_with_tensors([X_pt], [y]) @@ -116,7 +118,11 @@ def _run_reduce_sum( def test_reduce_sum(self): self._run_reduce_sum( - dim=0, input_shape=[1], keepdim=True, input_type="float16", output_type=None + dim=0, + input_shape=[1], + keepdim=True, + input_type="float16", + output_type=None, ) self._run_reduce_sum( dim=1, @@ -160,7 +166,6 @@ def test_reduce_sum(self): input_type="float16", output_type="float16", ) - self._run_reduce_sum( dim=0, input_shape=[4], @@ -408,6 +413,268 @@ def test_reduce_mean(self): use_fp16_acc=True, ) + def _run_reduce_max( + self, + *, + dim, + input_shape, + keepdim, + input_type="float16", + output_type=None, + use_fp16_acc=False, + ): + self._run_reduce( + test_name=f"reduce_max_{input_type}_{output_type}", + reduce_op=ops.reduce_max, + torch_reduce_op=torch.amax, + dim=dim, + input_shape=input_shape, + keepdim=keepdim, + input_type=input_type, + output_type=output_type, + use_fp16_acc=use_fp16_acc, + ) + + def test_reduce_max(self): + self._run_reduce_max( + dim=0, + input_shape=[1], + keepdim=True, + input_type="float16", + output_type=None, + ) + self._run_reduce_max( + dim=1, + input_shape=[1, 4], + keepdim=True, + input_type="float16", + output_type=None, + ) + self._run_reduce_max( + dim=0, + input_shape=[1, 4], + keepdim=True, + input_type="float16", + output_type=None, + ) + self._run_reduce_max( + dim=0, + input_shape=[2, 4], + keepdim=True, + input_type="float16", + output_type=None, + ) + self._run_reduce_max( + dim=0, + input_shape=[1, 2, 1], + keepdim=True, + input_type="float16", + output_type=None, + ) + self._run_reduce_max( + dim=1, + input_shape=[1, 2, 1], + keepdim=True, + input_type="float16", + output_type=None, + ) + self._run_reduce_max( + dim=2, + input_shape=[5, 4, 3], + keepdim=True, + input_type="float16", + output_type="float16", + ) + self._run_reduce_max( + dim=0, + input_shape=[4], + keepdim=False, + input_type="float16", + output_type=None, + ) + self._run_reduce_max( + dim=0, + input_shape=[1, 4], + keepdim=False, + input_type="float16", + output_type=None, + ) + self._run_reduce_max( + dim=0, + input_shape=[5, 4, 3], + keepdim=False, + input_type="float16", + output_type="float16", + ) + self._run_reduce_max( + dim=1, + input_shape=[5, 4, 3], + keepdim=False, + input_type="float16", + output_type="float16", + ) + self._run_reduce_max( + dim=1, + input_shape=[5, 4, 3], + keepdim=False, + input_type="float16", + output_type="float16", + use_fp16_acc=True, + ) + self._run_reduce_max( + dim=2, + input_shape=[5, 4, 3], + keepdim=False, + input_type="float16", + output_type="float16", + ) + self._run_reduce_max( + dim=-1, + input_shape=[1, 1000000, 6], + keepdim=True, + input_type="float16", + output_type=None, + ) + # allocate workspace for the strided tensor_reduce kernel + self._run_reduce_max( + dim=2, + input_shape=[1, 1, 8, 128], + keepdim=True, + input_type="float16", + output_type=None, + ) + + def _run_reduce_min( + self, + *, + dim, + input_shape, + keepdim, + input_type="float16", + output_type=None, + use_fp16_acc=False, + ): + self._run_reduce( + test_name=f"reduce_min_{input_type}_{output_type}", + reduce_op=ops.reduce_min, + torch_reduce_op=torch.amin, + dim=dim, + input_shape=input_shape, + keepdim=keepdim, + input_type=input_type, + output_type=output_type, + use_fp16_acc=use_fp16_acc, + ) + + def test_reduce_min(self): + self._run_reduce_min( + dim=0, + input_shape=[1], + keepdim=True, + input_type="float16", + output_type=None, + ) + self._run_reduce_min( + dim=1, + input_shape=[1, 4], + keepdim=True, + input_type="float16", + output_type=None, + ) + self._run_reduce_min( + dim=0, + input_shape=[1, 4], + keepdim=True, + input_type="float16", + output_type=None, + ) + self._run_reduce_min( + dim=0, + input_shape=[2, 4], + keepdim=True, + input_type="float16", + output_type=None, + ) + self._run_reduce_min( + dim=0, + input_shape=[1, 2, 1], + keepdim=True, + input_type="float16", + output_type=None, + ) + self._run_reduce_min( + dim=1, + input_shape=[1, 2, 1], + keepdim=True, + input_type="float16", + output_type=None, + ) + self._run_reduce_min( + dim=2, + input_shape=[5, 4, 3], + keepdim=True, + input_type="float16", + output_type="float16", + ) + self._run_reduce_min( + dim=0, + input_shape=[4], + keepdim=False, + input_type="float16", + output_type=None, + ) + self._run_reduce_min( + dim=0, + input_shape=[1, 4], + keepdim=False, + input_type="float16", + output_type=None, + ) + self._run_reduce_min( + dim=0, + input_shape=[5, 4, 3], + keepdim=False, + input_type="float16", + output_type="float16", + ) + self._run_reduce_min( + dim=1, + input_shape=[5, 4, 3], + keepdim=False, + input_type="float16", + output_type="float16", + ) + self._run_reduce_min( + dim=1, + input_shape=[5, 4, 3], + keepdim=False, + input_type="float16", + output_type="float16", + use_fp16_acc=True, + ) + self._run_reduce_min( + dim=2, + input_shape=[5, 4, 3], + keepdim=False, + input_type="float16", + output_type="float16", + ) + self._run_reduce_min( + dim=-1, + input_shape=[1, 1000000, 6], + keepdim=True, + input_type="float16", + output_type=None, + ) + # allocate workspace for the strided tensor_reduce kernel + self._run_reduce_min( + dim=2, + input_shape=[1, 1, 8, 128], + keepdim=True, + input_type="float16", + output_type=None, + ) + def _run_batched_reduce( self, *, @@ -566,6 +833,33 @@ def test_reduce_sum_float32(self): atol=1e-5, ) + @unittest.skipIf(detect_target().name() == "rocm", "fp32 not supported in ROCm") + def test_reduce_max_float32(self): + # reduce_smallaxis + self._run_reduce_max( + dim=1, + input_shape=[1, 4], + keepdim=True, + input_type="float32", + output_type=None, + ) + # reduce_3d + self._run_reduce_max( + dim=-2, + input_shape=[3, 2048, 4], + keepdim=False, + input_type="float32", + output_type=None, + ) + # reduce (common) 2d + self._run_reduce_max( + dim=-1, + input_shape=[1270, 1223], + keepdim=False, + input_type="float32", + output_type=None, + ) + @unittest.skipIf(detect_target().name() == "rocm", "fp32 not supported in ROCm") def test_reduce_sum_bfloat16(self): # reduce_smallaxis @@ -599,6 +893,33 @@ def test_reduce_sum_bfloat16(self): atol=1e-0, ) + @unittest.skipIf(detect_target().name() == "rocm", "fp32 not supported in ROCm") + def test_reduce_max_bfloat16(self): + # reduce_smallaxis + self._run_reduce_max( + dim=1, + input_shape=[1, 4], + keepdim=True, + input_type="bfloat16", + output_type=None, + ) + # reduce_3d + self._run_reduce_max( + dim=-2, + input_shape=[3, 2048, 4], + keepdim=False, + input_type="bfloat16", + output_type=None, + ) + # reduce (common) 2d + self._run_reduce_max( + dim=-1, + input_shape=[1270, 1223], + keepdim=False, + input_type="bfloat16", + output_type=None, + ) + if __name__ == "__main__": unittest.main()