Skip to content

Commit

Permalink
add cuda_graph for mts_gpu_benchmark (#1012)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1012

Add cuda_graph enablement for AIT. Also leave the hook for AOTI

GPU trace after enabling cudagraph: https://fburl.com/perfdoctor/yybf0z60

log information for verification
I0710 17:58:44.425707 2013974 AITModelImpl.cpp:148] AITModelImpl: loading .so lib /tmp/benchmark_529602.1720659401/_run_on_acc_0/_run_on_acc_0-ait_engine.so
I0710 17:58:44.425731 2013974 AITModelImpl.cpp:149] AITModelImpl: num_runtimes: 1,use_cuda_graph: 1

Reviewed By: chenyang78, guowentian, sijiac

Differential Revision: D59617284
  • Loading branch information
frank-wei authored and facebook-github-bot committed Jul 12, 2024
1 parent a1769f2 commit bf6c8e3
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 2 deletions.
1 change: 1 addition & 0 deletions fx2ait/fx2ait/ait_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def _lower_model_to_backend(
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interpreter_result,
)
Expand Down
3 changes: 2 additions & 1 deletion fx2ait/fx2ait/csrc/AITModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ static auto registerAITModel =
std::vector<std::string>,
std::optional<at::ScalarType>,
std::optional<at::ScalarType>,
int64_t>())
int64_t,
bool>())
.def("forward", &AITModel::forward)
.def("profile", &AITModel::profile)
.def("get_library_path", &AITModel::libraryPath)
Expand Down
4 changes: 3 additions & 1 deletion fx2ait/fx2ait/csrc/AITModelImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ AITModelImpl::AITModelImpl(
floating_point_input_dtype_(input_dtype),
floating_point_output_dtype_(output_dtype),
use_cuda_graph_(use_cuda_graph) {
LOG(INFO) << "Loading .so lib " << model_path;
LOG(INFO) << "AITModelImpl: loading .so lib " << model_path;
LOG(INFO) << "AITModelImpl: num_runtimes: " << num_runtimes
<< ",use_cuda_graph: " << use_cuda_graph;
TORCH_CHECK(handle_, "could not dlopen ", model_path, ": ", dlerror());
TORCH_CHECK(num_runtimes > 0, "num_runtimes must be positive");

Expand Down
1 change: 1 addition & 0 deletions fx2ait/fx2ait/lower/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def lower_pass(
_precision_to_torch_type(lower_settings.precision),
_precision_to_torch_type(lower_settings.output_precision),
1, # num_runtimes
False,
),
interp_res,
lower_settings.trace_ait_module,
Expand Down
2 changes: 2 additions & 0 deletions fx2ait/fx2ait/test/test_fx2ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def _test_fx2ait_impl(self, test_serialization=False, test_cuda_graph=False):
torch.float16,
torch.float16,
1, # num_runtimes
False,
)
)
ait_mod.engine.use_cuda_graph = test_cuda_graph
Expand Down Expand Up @@ -140,6 +141,7 @@ def forward(self, a, b, c, d):
torch.float16,
torch.float16,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down
1 change: 1 addition & 0 deletions fx2ait/fx2ait/tools/ait_minimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def lower_mod_default(
torch.float16,
torch.float16,
1, # num_runtimes
False,
),
interpreter_result,
)
Expand Down
3 changes: 3 additions & 0 deletions fx2ait/fx2ait/tools/common_aten2ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def run_test(
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down Expand Up @@ -256,6 +257,7 @@ def run_test_with_dynamic_shape(
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down Expand Up @@ -375,6 +377,7 @@ def benchmark(f, args):
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down
5 changes: 5 additions & 0 deletions fx2ait/fx2ait/tools/common_fx2ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def run_test(
torch_dtype,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand All @@ -199,6 +200,7 @@ def run_test(
torch_dtype,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down Expand Up @@ -317,6 +319,7 @@ def run_test_with_dynamic_shape(
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand All @@ -329,6 +332,7 @@ def run_test_with_dynamic_shape(
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down Expand Up @@ -467,6 +471,7 @@ def benchmark(f, args):
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down

0 comments on commit bf6c8e3

Please sign in to comment.