From 67c0e39e0d6b2b393ddcbbbf8e194a15db1c5d7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E7=86=99?= Date: Thu, 2 Jan 2025 10:43:59 +0800 Subject: [PATCH] fix code style bug --- .pre-commit-config.yaml | 2 +- docs/source/conf.py | 2 +- scripts/pyre_check.py | 2 +- setup.py | 2 +- tzrec/acc/aot_utils.py | 6 ++++-- tzrec/models/model.py | 1 + 6 files changed, 9 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2926b4b..7056902 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ repos: hooks: - id: insert-license files: \.py$ - args: ["--license-filepath", "data/.license_header.txt", "--use-current-year"] + args: ["--license-filepath", "data/.license_header.txt", "--allow-past-years"] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.7.1 hooks: diff --git a/docs/source/conf.py b/docs/source/conf.py index 79c4a65..e4ceec4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2024-2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/scripts/pyre_check.py b/scripts/pyre_check.py index 741ecc3..631ec2e 100644 --- a/scripts/pyre_check.py +++ b/scripts/pyre_check.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2024-2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/setup.py b/setup.py index d749cff..2d2d1df 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2024-2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tzrec/acc/aot_utils.py b/tzrec/acc/aot_utils.py index bd3aad4..5ebe4a8 100644 --- a/tzrec/acc/aot_utils.py +++ b/tzrec/acc/aot_utils.py @@ -32,6 +32,7 @@ # skip default bound check which is not allow by aot if "ENABLE_AOT" in os.environ: + # pyre-ignore [8] IntNBitTableBatchedEmbeddingBagsCodegen.__init__ = functools.partialmethod( IntNBitTableBatchedEmbeddingBagsCodegen.__init__, bounds_check_mode=BoundsCheckMode.NONE, @@ -45,9 +46,10 @@ del decomposition_table[aten._softmax.out] +# pyre-ignore [56] @register_decomposition(aten._softmax) @out_wrapper() -def _softmax(x: torch.Tensor, dim: int, half_to_float: bool): +def _softmax(x: torch.Tensor, dim: int, half_to_float: bool) -> torch.Tensor: # eager softmax returns a contiguous tensor. Ensure that decomp also returns # a contiguous tensor. x = x.contiguous() @@ -67,7 +69,7 @@ def _softmax(x: torch.Tensor, dim: int, half_to_float: bool): def export_model_aot( model: nn.Module, data: Dict[str, torch.Tensor], save_dir: str -) -> None: +) -> torch.export.ExportedProgram: """Export aot model. Args: diff --git a/tzrec/models/model.py b/tzrec/models/model.py index 562c728..f7292a5 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -241,6 +241,7 @@ def forward( class ScriptWrapperAOT(ScriptWrapper): """Model inference wrapper for aot export.""" + # pyre-ignore [14] def forward( self, data: Dict[str, torch.Tensor],