Fix up bit shifts by 32. (#2292)
In C++, a bit shift of the same size as the type is undefined, but it is
defined in spir-v. When folding those cases, we have to be careful. We
cannot simply do the shift in C++.
Fixes https://crbug.com/917697.
diff --git a/source/opt/fold.cpp b/source/opt/fold.cpp
index d6b583f..0604da2 100644
--- a/source/opt/fold.cpp
+++ b/source/opt/fold.cpp
@@ -115,8 +115,10 @@
// Shifting
case SpvOp::SpvOpShiftRightLogical:
- if (b > 32) {
- // This is undefined behaviour. Choose 0 for consistency.
+ if (b >= 32) {
+ // This is undefined behaviour when |b| > 32. Choose 0 for consistency.
+ // When |b| == 32, doing the shift in C++ in undefined, but the result
+ // will be 0, so just return that value.
return 0;
}
return a >> b;
@@ -125,10 +127,21 @@
// This is undefined behaviour. Choose 0 for consistency.
return 0;
}
+ if (b == 32) {
+ // Doing the shift in C++ is undefined, but the result is defined in the
+ // spir-v spec. Find that value another way.
+ if (static_cast<int32_t>(a) >= 0) {
+ return 0;
+ } else {
+ return static_cast<uint32_t>(-1);
+ }
+ }
return (static_cast<int32_t>(a)) >> b;
case SpvOp::SpvOpShiftLeftLogical:
- if (b > 32) {
- // This is undefined behaviour. Choose 0 for consistency.
+ if (b >= 32) {
+ // This is undefined behaviour when |b| > 32. Choose 0 for consistency.
+ // When |b| == 32, doing the shift in C++ in undefined, but the result
+ // will be 0, so just return that value.
return 0;
}
return a << b;
@@ -307,7 +320,8 @@
if (constants[1] != nullptr) {
// When shifting by a value larger than the size of the result, the
// result is undefined. We are setting the undefined behaviour to a
- // result of 0.
+ // result of 0. If the shift amount is the same as the size of the
+ // result, then the result is defined, and it 0.
uint32_t shift_amount = constants[1]->GetU32BitValue();
if (shift_amount >= 32) {
*result = 0;
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index 1a54421..88248ba 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -168,6 +168,7 @@
%int_2 = OpConstant %int 2
%int_3 = OpConstant %int 3
%int_4 = OpConstant %int 4
+%int_n24 = OpConstant %int -24
%int_min = OpConstant %int -2147483648
%int_max = OpConstant %int 2147483647
%long_0 = OpConstant %long 0
@@ -486,7 +487,7 @@
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
- // Test case 21: fold signed n >> 42 (undefined, so set to zero).
+ // Test case 23: fold signed n >> 42 (undefined, so set to zero).
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@@ -496,7 +497,7 @@
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
- // Test case 22: fold n << 42 (undefined, so set to zero).
+ // Test case 24: fold n << 42 (undefined, so set to zero).
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@@ -505,6 +506,38 @@
"%2 = OpShiftLeftLogical %int %load %uint_42\n" +
"OpReturn\n" +
"OpFunctionEnd",
+ 2, 0),
+ // Test case 25: fold -24 >> 32 (defined as -1)
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpShiftRightArithmetic %int %int_n24 %uint_32\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, -1),
+ // Test case 26: fold 2 >> 32 (signed)
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpShiftRightArithmetic %int %int_2 %uint_32\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 27: fold 2 >> 32 (unsigned)
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpShiftRightLogical %int %int_2 %uint_32\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, 0),
+ // Test case 28: fold 2 << 32
+ InstructionFoldingCase<uint32_t>(
+ Header() + "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpShiftLeftLogical %int %int_2 %uint_32\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
2, 0)
));
// clang-format on