Skip to content

Commit

Permalink
Allow Passing in Size of Dynamic Dimensions to Inference Function (#1025
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: #1025

Add a new param `dynamic_size` to lower settings, that allows explicitly setting the number to look for that corresponds to the dynamic dimension in inputs.

Reviewed By: frank-wei

Differential Revision: D62448015

fbshipit-source-id: bc3be0d891631c41f74a3386b826d4630b4bda69
  • Loading branch information
oniononion36 authored and facebook-github-bot committed Sep 13, 2024
1 parent bfb1dc2 commit 437b48a
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 43 deletions.
88 changes: 46 additions & 42 deletions fx2ait/fx2ait/find_batch_size_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,55 +21,59 @@ def find_batch_size_dim(
inputs: Any,
can_non_first_dim_be_dynamic: bool = True,
can_dim_value_one_be_dynamic: bool = True,
dynamic_size: int = -1,
# pyre-fixme Invalid type [31]
) -> []:
if isinstance(inputs, torch.Tensor) or len(inputs) <= 1:
return [0]
shapes = [i.shape for i in inputs]
frequency_map = {}
position_scores = {}
first_dims = set()
for shape in shapes:
if len(shape) < 2:
# By pass for rank-1 tensors. MRS model has rank-1 tensor carry no batch_size info
continue
# Dedup shape value for single tensor
first_dims.add(shape[0])
seen_dims = set()
valid_len = len(shape) if can_non_first_dim_be_dynamic else 1
for i in range(valid_len):
dim = shape[i]
if dim not in seen_dims:
frequency_map[dim] = frequency_map.get(dim, 0) + 1
position_scores[dim] = position_scores.get(dim, 0) + i
seen_dims.add(dim)

if len(first_dims) == 1:
# first dim is the same in every input: we use it as batch_size
batch_size = first_dims.pop()
elif frequency_map:
# first dims are different: we use the most frequent dim as batch_size
# if there is more than 1 most frequent dim, we choose the one with the
# lowest position score (i.e., the leftmost of the most frequent ones)
sorted_frequency = sorted(
frequency_map.items(),
key=lambda x: (-x[1], position_scores[x[0]]),
)
if len(sorted_frequency) > 1:
if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1:
# It's often that dim value one indicates a non-dynamic dimension.
# If the user says so, we pick the second most frequent value.
batch_size = sorted_frequency[1][0]
if dynamic_size > 0:
batch_size = dynamic_size
else:
shapes = [i.shape for i in inputs]
frequency_map = {}
position_scores = {}
first_dims = set()
for shape in shapes:
if len(shape) < 2:
# By pass for rank-1 tensors. MRS model has rank-1 tensor carry no batch_size info
continue
# Dedup shape value for single tensor
first_dims.add(shape[0])
seen_dims = set()
valid_len = len(shape) if can_non_first_dim_be_dynamic else 1
for i in range(valid_len):
dim = shape[i]
if dim not in seen_dims:
frequency_map[dim] = frequency_map.get(dim, 0) + 1
position_scores[dim] = position_scores.get(dim, 0) + i
seen_dims.add(dim)
if len(first_dims) == 1:
# first dim is the same in every input: we use it as batch_size
batch_size = first_dims.pop()
elif frequency_map:
# first dims are different: we use the most frequent dim as batch_size
# if there is more than 1 most frequent dim, we choose the one with the
# lowest position score (i.e., the leftmost of the most frequent ones)
sorted_frequency = sorted(
frequency_map.items(),
key=lambda x: (-x[1], position_scores[x[0]]),
)
if len(sorted_frequency) > 1:
if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1:
# It's often that dim value one indicates a non-dynamic dimension.
# If the user says so, we pick the second most frequent value.
batch_size = sorted_frequency[1][0]
else:
batch_size = sorted_frequency[0][0]
else:
batch_size = sorted_frequency[0][0]
if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1:
batch_size = -1
else:
batch_size = sorted_frequency[0][0]
else:
if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1:
batch_size = -1
else:
batch_size = sorted_frequency[0][0]
else:
# no dims to sort: no batch_size
batch_size = -1
# no dims to sort: no batch_size
batch_size = -1

bs_dim = []
for i in inputs:
Expand Down
1 change: 1 addition & 0 deletions fx2ait/fx2ait/lower/lower_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class LowerSettings:
name: str = ""
dll_name: str = "ait_engine.so"
dynamic_profile_strategy: DynamicProfileStrategy = DynamicProfileStrategy.MAX
dynamic_size: int = -1
profile_devs: Any = None
# If None, infer the dtypes from the sample inputs.
precision: Optional[LowerPrecision] = LowerPrecision.FP16
Expand Down
6 changes: 5 additions & 1 deletion fx2ait/fx2ait/tensor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,10 +477,14 @@ def find_batch_size_dim(
inputs: Any,
can_non_first_dim_be_dynamic: bool = True,
can_dim_value_one_be_dynamic: bool = True,
dynamic_size: int = -1,
# pyre-fixme Invalid type [31]
) -> []:
return find_batch_size_dim_impl(
inputs, can_non_first_dim_be_dynamic, can_dim_value_one_be_dynamic
inputs,
can_non_first_dim_be_dynamic,
can_dim_value_one_be_dynamic,
dynamic_size,
)

@classmethod
Expand Down

0 comments on commit 437b48a

Please sign in to comment.