Skip to content

Commit

Permalink
Merge pull request #373 from PietroGhg/pietro/reqd_work_group_size
Browse files Browse the repository at this point in the history
Support reqd_work_group_size with less than 3 operands
  • Loading branch information
PietroGhg authored Feb 22, 2024
2 parents 9b15dd8 + f916970 commit f728550
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
31 changes: 31 additions & 0 deletions modules/compiler/test/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,34 @@ TEST_F(CompilerUtilsTest, ReplaceFunctionInMetadata) {
ASSERT_EQ(Kernels.size(), 1);
ASSERT_EQ(Kernels[0].Name, Bar->getName());
}

TEST_F(CompilerUtilsTest, ParseReqWGSize) {
auto M = parseModule(R"(
define void @foo() !reqd_work_group_size !1 {
entry:
ret void
}
define void @bar() !reqd_work_group_size !2 {
entry:
ret void
}
!1 = !{i32 30}
!2 = !{i32 30, i32 20}
)");

auto check = [&M](llvm::StringRef Name, const std::array<uint64_t, 3> &Res) {
auto *const F = M->getFunction(Name);
if (auto Parsed = parseRequiredWGSMetadata(*F)) {
for (auto [El, Exp] : llvm::zip(*Parsed, Res)) {
EXPECT_EQ(El, Exp);
}
} else {
GTEST_FAIL();
}
};

check("foo", {30, 1, 1});
check("bar", {30, 20, 1});
}
14 changes: 6 additions & 8 deletions modules/compiler/utils/source/metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,14 +293,12 @@ std::optional<unsigned> isSchedulingParameter(const Function &f, unsigned idx) {
std::optional<std::array<uint64_t, 3>> parseRequiredWGSMetadata(
const Function &f) {
if (auto mdnode = f.getMetadata(ReqdWGSizeMD)) {
auto *const op0 = mdconst::extract<ConstantInt>(mdnode->getOperand(0));
auto *const op1 = mdconst::extract<ConstantInt>(mdnode->getOperand(1));
auto *const op2 = mdconst::extract<ConstantInt>(mdnode->getOperand(2));

// KLOCWORK "UNINIT.STACK.ARRAY.MUST" possible false positive
// This is normal std::array initialization
std::array<uint64_t, 3> wgs = {
{op0->getZExtValue(), op1->getZExtValue(), op2->getZExtValue()}};
std::array<uint64_t, 3> wgs = {0, 1, 1};
assert(mdnode->getNumOperands() >= 1 && mdnode->getNumOperands() <= 3 &&
"Unsupported number of operands in reqd_work_group_size");
for (const auto &[idx, op] : enumerate(mdnode->operands())) {
wgs[idx] = mdconst::extract<ConstantInt>(op)->getZExtValue();
}
return wgs;
}
return std::nullopt;
Expand Down

0 comments on commit f728550

Please sign in to comment.