summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlan Baker <alanbaker@google.com>2018-07-20 09:27:52 -0400
committerDavid Neto <dneto@google.com>2018-07-20 11:32:43 -0400
commitb49f76fd62f5840d848f8891c599ef91e6fa57bb (patch)
tree9421301a7050930c0795a077a1d6fdd8162ed132
parenteffafedcee7310c50bf5e93c9d4bebc3b1bee49a (diff)
Handle undef literal value in vector shuffle
Fixes #1731 * Updated folding rules related to vector shuffle to account for the undef literal value: * FoldVectorShuffleFeedingShuffle * FoldVectorShuffleFeedingExtract * FoldVectorShuffleWithConstants * These rules would commit memory violations due to treating the undef literal value as an accessible composite component
-rw-r--r--source/opt/const_folding_rules.cpp6
-rw-r--r--source/opt/folding_rules.cpp13
-rw-r--r--test/opt/fold_test.cpp63
3 files changed, 76 insertions, 6 deletions
diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp
index d1b902e2..6b94986e 100644
--- a/source/opt/const_folding_rules.cpp
+++ b/source/opt/const_folding_rules.cpp
@@ -95,9 +95,13 @@ ConstantFoldingRule FoldVectorShuffleWithConstants() {
}
std::vector<uint32_t> ids;
+ const uint32_t undef_literal_value = 0xffffffff;
for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
uint32_t index = inst->GetSingleWordInOperand(i);
- if (index < c1_components.size()) {
+ if (index == undef_literal_value) {
+ // Don't fold shuffle with undef literal value.
+ return nullptr;
+ } else if (index < c1_components.size()) {
Instruction* member_inst =
const_mgr->GetDefiningInstruction(c1_components[index]);
ids.push_back(member_inst->result_id());
diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp
index 5f683a8a..c3daa6fc 100644
--- a/source/opt/folding_rules.cpp
+++ b/source/opt/folding_rules.cpp
@@ -1502,6 +1502,14 @@ FoldingRule VectorShuffleFeedingExtract() {
uint32_t new_index =
cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
+ // Extracting an undefined value so fold this extract into an undef.
+ const uint32_t undef_literal_value = 0xffffffff;
+ if (new_index == undef_literal_value) {
+ inst->SetOpcode(SpvOpUndef);
+ inst->SetInOperands({});
+ return true;
+ }
+
// Get the id of the of the vector the elemtent comes from, and update the
// index if needed.
uint32_t new_vector = 0;
@@ -2035,10 +2043,13 @@ FoldingRule VectorShuffleFeedingShuffle() {
std::vector<Operand> new_operands;
new_operands.resize(
2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands.
+ const uint32_t undef_literal = 0xffffffff;
for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
uint32_t component_index = inst->GetSingleWordInOperand(op);
- if (feeder_is_op0 == (component_index < op0_length)) {
+ // Do not interpret the undefined value literal as coming from operand 1.
+ if (component_index != undef_literal &&
+ feeder_is_op0 == (component_index < op0_length)) {
// This component comes from the feeding_shuffle_inst. Update
// |component_index| to be the index into the operand of the feeder.
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index aa4d61c5..9a870a6e 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -452,11 +452,12 @@ TEST_P(IntVectorInstructionFoldingTest, Case) {
// Fold the instruction to test.
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
+ SpvOp original_opcode = inst->opcode();
bool succeeded = context->get_instruction_folder().FoldInstruction(inst);
// Make sure the instruction folded as expected.
- EXPECT_TRUE(succeeded);
- if (inst != nullptr) {
+ EXPECT_EQ(succeeded, inst == nullptr || inst->opcode() != original_opcode);
+ if (succeeded && inst != nullptr) {
EXPECT_EQ(inst->opcode(), SpvOpCopyObject);
inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
std::vector<SpvOp> opcodes = {SpvOpConstantComposite};
@@ -496,7 +497,25 @@ INSTANTIATE_TEST_CASE_P(TestCase, IntVectorInstructionFoldingTest,
"%2 = OpVectorShuffle %v2int %v2int_null %v2int_2_3 0 3\n" +
"OpReturn\n" +
"OpFunctionEnd",
- 2, {0,3})
+ 2, {0,3}),
+ InstructionFoldingCase<std::vector<uint32_t>>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_int Function\n" +
+ "%load = OpLoad %int %n\n" +
+ "%2 = OpVectorShuffle %v2int %v2int_null %v2int_2_3 4294967295 3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, {0,0}),
+ InstructionFoldingCase<std::vector<uint32_t>>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_int Function\n" +
+ "%load = OpLoad %int %n\n" +
+ "%2 = OpVectorShuffle %v2int %v2int_null %v2int_2_3 0 4294967295 \n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, {0,0})
));
// clang-format on
@@ -5512,7 +5531,22 @@ INSTANTIATE_TEST_CASE_P(CompositeExtractMatchingTest, MatchingInstructionFolding
"%5 = OpCompositeExtract %double %4 3\n" +
"OpReturn\n" +
"OpFunctionEnd",
- 5, false)
+ 5, false),
+ // Test case 7: Extracting the undefined literal value from a vector
+ // shuffle.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: %4 = OpUndef [[int]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_v4int Function\n" +
+ "%2 = OpLoad %v4int %n\n" +
+ "%3 = OpVectorShuffle %v2int %2 %2 2 4294967295\n" +
+ "%4 = OpCompositeExtract %int %3 1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true)
));
INSTANTIATE_TEST_CASE_P(DotProductMatchingTest, MatchingInstructionFoldingTest,
@@ -5898,6 +5932,27 @@ INSTANTIATE_TEST_CASE_P(VectorShuffleMatchingTest, MatchingInstructionWithNoResu
"%9 = OpVectorShuffle %v4double %7 %8 2 0 1 3\n" +
"OpReturn\n" +
"OpFunctionEnd",
+ 9, true),
+ // Test case 13: Shuffle with undef literal.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" +
+ "; CHECK: [[v4double:%\\w+]] = OpTypeVector [[double]] 2\n" +
+ "; CHECK: OpVectorShuffle\n" +
+ "; CHECK: OpVectorShuffle {{%\\w+}} %7 {{%\\w+}} 2 0 1 4294967295\n" +
+ "; CHECK: OpReturn\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpVariable %_ptr_v4double Function\n" +
+ "%3 = OpVariable %_ptr_v4double Function\n" +
+ "%4 = OpVariable %_ptr_v4double Function\n" +
+ "%5 = OpLoad %v4double %2\n" +
+ "%6 = OpLoad %v4double %3\n" +
+ "%7 = OpLoad %v4double %4\n" +
+ "%8 = OpVectorShuffle %v2double %5 %5 0 1\n" +
+ "%9 = OpVectorShuffle %v4double %7 %8 2 0 1 4294967295\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
9, true)
));
#endif