Skip to content

Commit

Permalink
Fix circular import in aitemplate (#951)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #951

Fixes import cycle: P852820857

Reviewed By: cxxxs, chenyang78

Differential Revision: D50283535

fbshipit-source-id: e9e7fbc054b5312d1c9a073a0554d0c3a3788cc0
  • Loading branch information
Kronuz authored and facebook-github-bot committed Oct 20, 2023
1 parent b383d29 commit 3c4ba48
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion python/aitemplate/compiler/transform/fuse_expand_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from typing import List

from aitemplate.compiler.base import Operator, Tensor
from aitemplate.compiler.ops.tensor.expand import ExpandDimensionType
from aitemplate.compiler.tensor_accessor import TensorAccessor
from aitemplate.compiler.transform.toposort import toposort
from aitemplate.compiler.transform.transform_utils import (
Expand All @@ -39,6 +38,10 @@ def _can_fuse(expand_op: Operator, bmm_op: Operator) -> bool:
"""
determine if expand_op and bmm_op can be fused
"""
from aitemplate.compiler.ops.tensor.expand import ( # inner import to break circular import
ExpandDimensionType,
)

expand_output = expand_op._attrs["outputs"][0]
if expand_output._attrs["is_output"]:
return False
Expand Down

0 comments on commit 3c4ba48

Please sign in to comment.