Skip to content

Commit

Permalink
add fix to torch and torchrec for aot export in tzrec
Browse files Browse the repository at this point in the history
  • Loading branch information
杨熙 committed Dec 31, 2024
1 parent 421568f commit 265b845
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
1 change: 1 addition & 0 deletions tzrec/acc/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
from torch_tensorrt.dynamo.utils import to_torch_device


from ._decomposition_groups import (
ENABLED_TORCH_DECOMPOSITIONS,
TORCH_TRT_DECOMPOSITIONS,
Expand Down
44 changes: 44 additions & 0 deletions tzrec/acc/aot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,61 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import os
from typing import Dict

import torch
import torch._prims_common as prims_utils
import torch.nn.functional as F
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
BoundsCheckMode,
)
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
IntNBitTableBatchedEmbeddingBagsCodegen,
)
from torch import nn
from torch._decomp import decomposition_table, register_decomposition
from torch._prims_common.wrappers import out_wrapper
from torch.export import Dim

from tzrec.utils.fx_util import symbolic_trace
from tzrec.utils.logging_util import logger

# skip default bound check which is not allow by aot
if "ENABLE_AOT" in os.environ:
IntNBitTableBatchedEmbeddingBagsCodegen.__init__ = functools.partialmethod(
IntNBitTableBatchedEmbeddingBagsCodegen.__init__,
bounds_check_mode=BoundsCheckMode.NONE,
)


# add new aten._softmax decomposition which is supported by dynamo
aten = torch._ops.ops.aten
if aten._softmax.default in decomposition_table:
del decomposition_table[aten._softmax.default]
del decomposition_table[aten._softmax.out]


@register_decomposition(aten._softmax)
@out_wrapper()
def _softmax(x: torch.Tensor, dim: int, half_to_float: bool):
# eager softmax returns a contiguous tensor. Ensure that decomp also returns
# a contiguous tensor.
x = x.contiguous()
if half_to_float:
assert x.dtype == torch.half
computation_dtype, result_dtype = prims_utils.elementwise_dtypes(
x, type_promotion_kind=prims_utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
x = x.to(computation_dtype)
x_max = torch.max(x, dim, keepdim=True).values
unnormalized = torch.exp(x - x_max)
result = unnormalized / torch.sum(unnormalized, dim, keepdim=True)
if not half_to_float:
result = result.to(result_dtype)
return result


def export_model_aot(
model: nn.Module, data: Dict[str, torch.Tensor], save_dir: str
Expand Down

0 comments on commit 265b845

Please sign in to comment.