Protect against out-of-bounds references when folding OpCompositeExtract (#2774)
This fixes #2608.
The original test case had an out-of-bounds reference that ended up
folding into OpCompositeExtract that was indexing right outside the
constant composite.
The returned constant would then cause a segfault during constant
propagation.
diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp
index 10fcde4..23a7998 100644
--- a/source/opt/const_folding_rules.cpp
+++ b/source/opt/const_folding_rules.cpp
@@ -55,6 +55,9 @@
auto cc = c->AsCompositeConstant();
assert(cc != nullptr);
auto components = cc->GetComponents();
+ // Protect against invalid IR. Refuse to fold if the index is out
+ // of bounds.
+ if (element_index >= components.size()) return nullptr;
c = components[element_index];
}
return c;
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index 3ea3204..a9f3089 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -11,6 +11,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "source/opt/fold.h"
+
#include <limits>
#include <memory>
#include <string>
@@ -22,7 +24,6 @@
#include "gtest/gtest.h"
#include "source/opt/build_module.h"
#include "source/opt/def_use_manager.h"
-#include "source/opt/fold.h"
#include "source/opt/ir_context.h"
#include "source/opt/module.h"
#include "spirv-tools/libspirv.hpp"
@@ -2980,7 +2981,17 @@
"%4 = OpCompositeExtract %int %3 0\n" +
"OpReturn\n" +
"OpFunctionEnd",
- 4, INT_7_ID)
+ 4, INT_7_ID),
+ // Test case 13: https://github.com/KhronosGroup/SPIRV-Tools/issues/2608
+ // Out of bounds access. Do not fold.
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1\n" +
+ "%3 = OpCompositeExtract %float %2 4\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 3, 0)
));
INSTANTIATE_TEST_SUITE_P(CompositeConstructFoldingTest, GeneralInstructionFoldingTest,