forked from facebookincubator/AITemplate
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add reduce_max and reduce_min (facebookincubator#942)
Summary: My understanding is that the max / min reduction are usually not used in the trained PT models due to their non-differentiability. Probably, that's the reason why we don't have them reflected as AIT ops. However, they may be useful in inference for a manually composed AIT model (or PT model translated by fx2ait) which was solicited in the issue [facebookincubator#941](facebookincubator#941). In this diff, we're adding the new `reduce_max` and `reduce_min` ops to AIT. The new ops are using the existing common back-end implementation used by `reduce_sum` with a small extension: the default (initial) accumulator value of 0 is not suitable for the min / max reductions, it is now made configurable in the codegen; `lowest()` and `max()` numeric limits for the ElementCompute are set by the `reduce_max` / `reduce_min` ops' codegen, respectively. fx2ait wiring of the new ops to the `torch.amax` and `torch.amin` has also been added. Differential Revision: D49978036
- Loading branch information
1 parent
2bf9919
commit 793bc04
Showing
16 changed files
with
817 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
""" | ||
reduce_max kernel: | ||
(1) it invokes reduce_3d kernel for reduction_dim = -1 cases; and | ||
(2) invokes reduce_common for all other cases | ||
We do this because there is huge perf difference between reduce_3d and | ||
reduce_common for different reduction dims. We should consider to unify | ||
our implementation later. Ideally, we should fix the perf issue in | ||
reduce_3d for non-neg-dim cases, because reduce_3d can take prologue and | ||
epilogue so it is more general than reduce_common. | ||
""" | ||
|
||
from aitemplate.backend import registry | ||
from aitemplate.backend.cuda.reduce import reduce_3d, reduce_common | ||
|
||
|
||
def _is_last_reduction_dim(func_attrs): | ||
"""return true if the reduction dim is the last dim (i.e. inner most dim)""" | ||
axes = func_attrs["reduction_axes"] | ||
if not len(axes) == 1: | ||
raise NotImplementedError("Multiple reduction axes are not supported yet") | ||
reduction_dim = axes[0] | ||
# make sure our frontend handle negative dims | ||
assert reduction_dim >= 0, "cannot have negative dim here: {}".format(reduction_dim) | ||
x = func_attrs["inputs"][0] | ||
rank = x._rank() | ||
assert rank >= 1, "rank must >= 1, got: {}".format(rank) | ||
return reduction_dim == rank - 1 | ||
|
||
|
||
@registry.reg("cuda.reduce_max.func_decl") | ||
def reduce_max_gen_function_decl(func_attrs): | ||
"""the registered function for generating reduce_max function declaration | ||
Parameters | ||
---------- | ||
func_attrs : Dict[str, Any] | ||
holds attributes of this reduce_max op | ||
Returns | ||
------- | ||
[type] : str | ||
returns the rendered function declaration with appropriate replacements | ||
""" | ||
if _is_last_reduction_dim(func_attrs): | ||
return reduce_3d.gen_function_decl(func_attrs) | ||
else: | ||
return reduce_common.gen_function_decl(func_attrs) | ||
|
||
|
||
@registry.reg("cuda.reduce_max.gen_function") | ||
def reduce_max_gen_function(func_attrs): | ||
"""the registered function for generating reduce_max kernel and all of | ||
its auxiliary functions | ||
Parameters | ||
---------- | ||
func_attrs : Dict[str, Any] | ||
holds attributes of this reduce_max op | ||
Returns | ||
------- | ||
str | ||
returns the rendered code for the complete implementation of this reduce max op | ||
""" | ||
if _is_last_reduction_dim(func_attrs): | ||
return reduce_3d.gen_function( | ||
func_attrs, | ||
reduce_op="cutlass::maximum", | ||
reduction_identity="std::numeric_limits<ElementCompute>::lowest()", | ||
) | ||
else: | ||
return reduce_common.gen_function( | ||
func_attrs, | ||
reduction_op="cutlass::maximum", | ||
reduction_identity="std::numeric_limits<ElementCompute>::lowest()", | ||
) | ||
|
||
|
||
@registry.reg("cuda.reduce_max.func_call") | ||
def reduce_max_gen_function_call(func_attrs, indent=" "): | ||
"""the registered function for generating a function call to reduce_mean | ||
Parameters | ||
---------- | ||
func_attrs : Dict[str, Any] | ||
holds attributes of this reduce_mean op | ||
indent : str, optional | ||
indentation for each line of the rendered code (default " ") | ||
Returns | ||
------- | ||
str | ||
returns rendered code for invoking the reduce op | ||
""" | ||
if _is_last_reduction_dim(func_attrs): | ||
return reduce_3d.gen_function_call(func_attrs, indent) | ||
else: | ||
return reduce_common.gen_function_call(func_attrs, indent) |
Oops, something went wrong.