spirv-val: Add Vulkan 32-bit bit op Base (#4758)
diff --git a/source/val/validate_bitwise.cpp b/source/val/validate_bitwise.cpp
index d46b3fc..e6e97c4 100644
--- a/source/val/validate_bitwise.cpp
+++ b/source/val/validate_bitwise.cpp
@@ -14,16 +14,48 @@
// Validates correctness of bitwise instructions.
-#include "source/val/validate.h"
-
#include "source/diagnostic.h"
#include "source/opcode.h"
+#include "source/spirv_target_env.h"
#include "source/val/instruction.h"
+#include "source/val/validate.h"
#include "source/val/validation_state.h"
namespace spvtools {
namespace val {
+// Validates when base and result need to be the same type
+spv_result_t ValidateBaseType(ValidationState_t& _, const Instruction* inst,
+ const uint32_t base_type) {
+ const SpvOp opcode = inst->opcode();
+
+ if (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << _.VkErrorID(4781)
+ << "Expected int scalar or vector type for Base operand: "
+ << spvOpcodeString(opcode);
+ }
+
+ // Vulkan has a restriction to 32 bit for base
+ if (spvIsVulkanEnv(_.context()->target_env)) {
+ if (_.GetBitWidth(base_type) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << _.VkErrorID(4781)
+ << "Expected 32-bit int type for Base operand: "
+ << spvOpcodeString(opcode);
+ }
+ }
+
+ // OpBitCount just needs same number of components
+ if (base_type != inst->type_id() && opcode != SpvOpBitCount) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected Base Type to be equal to Result Type: "
+ << spvOpcodeString(opcode);
+ }
+
+ return SPV_SUCCESS;
+}
+
// Validates correctness of bitwise instructions.
spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst) {
const SpvOp opcode = inst->opcode();
@@ -109,20 +141,14 @@
}
case SpvOpBitFieldInsert: {
- if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Expected int scalar or vector type as Result Type: "
- << spvOpcodeString(opcode);
-
const uint32_t base_type = _.GetOperandTypeId(inst, 2);
const uint32_t insert_type = _.GetOperandTypeId(inst, 3);
const uint32_t offset_type = _.GetOperandTypeId(inst, 4);
const uint32_t count_type = _.GetOperandTypeId(inst, 5);
- if (base_type != result_type)
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Expected Base Type to be equal to Result Type: "
- << spvOpcodeString(opcode);
+ if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
+ return error;
+ }
if (insert_type != result_type)
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -143,19 +169,13 @@
case SpvOpBitFieldSExtract:
case SpvOpBitFieldUExtract: {
- if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Expected int scalar or vector type as Result Type: "
- << spvOpcodeString(opcode);
-
const uint32_t base_type = _.GetOperandTypeId(inst, 2);
const uint32_t offset_type = _.GetOperandTypeId(inst, 3);
const uint32_t count_type = _.GetOperandTypeId(inst, 4);
- if (base_type != result_type)
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Expected Base Type to be equal to Result Type: "
- << spvOpcodeString(opcode);
+ if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
+ return error;
+ }
if (!offset_type || !_.IsIntScalarType(offset_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -170,17 +190,12 @@
}
case SpvOpBitReverse: {
- if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Expected int scalar or vector type as Result Type: "
- << spvOpcodeString(opcode);
-
const uint32_t base_type = _.GetOperandTypeId(inst, 2);
- if (base_type != result_type)
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Expected Base Type to be equal to Result Type: "
- << spvOpcodeString(opcode);
+ if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
+ return error;
+ }
+
break;
}
@@ -191,15 +206,13 @@
<< spvOpcodeString(opcode);
const uint32_t base_type = _.GetOperandTypeId(inst, 2);
- if (!base_type ||
- (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type)))
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Expected Base Type to be int scalar or vector: "
- << spvOpcodeString(opcode);
-
const uint32_t base_dimension = _.GetDimension(base_type);
const uint32_t result_dimension = _.GetDimension(result_type);
+ if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
+ return error;
+ }
+
if (base_dimension != result_dimension)
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected Base dimension to be equal to Result Type "
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 65c1dd6..0be47b9 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -1892,6 +1892,8 @@
return VUID_WRAP(VUID-StandaloneSpirv-OpImage-04777);
case 4780:
return VUID_WRAP(VUID-StandaloneSpirv-Result-04780);
+ case 4781:
+ return VUID_WRAP(VUID-StandaloneSpirv-Base-04781);
case 4915:
return VUID_WRAP(VUID-StandaloneSpirv-Location-04915);
case 4916:
diff --git a/test/val/val_bitwise_test.cpp b/test/val/val_bitwise_test.cpp
index 1001def..bebaa84 100644
--- a/test/val/val_bitwise_test.cpp
+++ b/test/val/val_bitwise_test.cpp
@@ -340,6 +340,16 @@
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
+TEST_F(ValidateBitwise, OpBitFieldInsertVulkanSuccess) {
+ const std::string body = R"(
+%val1 = OpBitFieldInsert %u32 %u32_1 %u32_2 %s32_1 %s32_2
+%val2 = OpBitFieldInsert %s32vec2 %s32vec2_12 %s32vec2_12 %s32_1 %u32_2
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+}
+
TEST_F(ValidateBitwise, OpBitFieldInsertWrongResultType) {
const std::string body = R"(
%val1 = OpBitFieldInsert %bool %u64_1 %u64_2 %s32_1 %s32_2
@@ -350,7 +360,7 @@
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
- "Expected int scalar or vector type as Result Type: BitFieldInsert"));
+ "Expected Base Type to be equal to Result Type: BitFieldInsert"));
}
TEST_F(ValidateBitwise, OpBitFieldInsertWrongBaseType) {
@@ -403,6 +413,20 @@
HasSubstr("Expected Count Type to be int scalar: BitFieldInsert"));
}
+TEST_F(ValidateBitwise, OpBitFieldInsertNot32Vulkan) {
+ const std::string body = R"(
+%val1 = OpBitFieldInsert %u64 %u64_1 %u64_2 %s32_1 %s32_2
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+ EXPECT_THAT(getDiagnosticString(),
+ AnyVUID("VUID-StandaloneSpirv-Base-04781"));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Expected 32-bit int type for Base operand: BitFieldInsert"));
+}
+
TEST_F(ValidateBitwise, OpBitFieldSExtractSuccess) {
const std::string body = R"(
%val1 = OpBitFieldSExtract %u64 %u64_1 %s32_1 %s32_2
@@ -413,6 +437,16 @@
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
+TEST_F(ValidateBitwise, OpBitFieldSExtractVulkanSuccess) {
+ const std::string body = R"(
+%val1 = OpBitFieldSExtract %u32 %u32_1 %s32_1 %s32_2
+%val2 = OpBitFieldSExtract %s32vec2 %s32vec2_12 %s32_1 %u32_2
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+}
+
TEST_F(ValidateBitwise, OpBitFieldSExtractWrongResultType) {
const std::string body = R"(
%val1 = OpBitFieldSExtract %bool %u64_1 %s32_1 %s32_2
@@ -420,9 +454,10 @@
CompileSuccessfully(GenerateShaderCode(body).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
- EXPECT_THAT(getDiagnosticString(),
- HasSubstr("Expected int scalar or vector type as Result Type: "
- "BitFieldSExtract"));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr(
+ "Expected Base Type to be equal to Result Type: BitFieldSExtract"));
}
TEST_F(ValidateBitwise, OpBitFieldSExtractWrongBaseType) {
@@ -462,6 +497,20 @@
HasSubstr("Expected Count Type to be int scalar: BitFieldSExtract"));
}
+TEST_F(ValidateBitwise, OpBitFieldSExtractNot32Vulkan) {
+ const std::string body = R"(
+%val1 = OpBitFieldSExtract %u64 %u64_1 %s32_1 %s32_2
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+ EXPECT_THAT(getDiagnosticString(),
+ AnyVUID("VUID-StandaloneSpirv-Base-04781"));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Expected 32-bit int type for Base operand: BitFieldSExtract"));
+}
+
TEST_F(ValidateBitwise, OpBitReverseSuccess) {
const std::string body = R"(
%val1 = OpBitReverse %u64 %u64_1
@@ -472,6 +521,16 @@
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
+TEST_F(ValidateBitwise, OpBitReverseVulkanSuccess) {
+ const std::string body = R"(
+%val1 = OpBitReverse %u32 %u32_1
+%val2 = OpBitReverse %s32vec2 %s32vec2_12
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+}
+
TEST_F(ValidateBitwise, OpBitReverseWrongResultType) {
const std::string body = R"(
%val1 = OpBitReverse %bool %u64_1
@@ -481,8 +540,7 @@
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
- HasSubstr(
- "Expected int scalar or vector type as Result Type: BitReverse"));
+ HasSubstr("Expected Base Type to be equal to Result Type: BitReverse"));
}
TEST_F(ValidateBitwise, OpBitReverseWrongBaseType) {
@@ -497,16 +555,41 @@
HasSubstr("Expected Base Type to be equal to Result Type: BitReverse"));
}
+TEST_F(ValidateBitwise, OpBitReverseNot32Vulkan) {
+ const std::string body = R"(
+%val1 = OpBitReverse %u64 %u64_1
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+ EXPECT_THAT(getDiagnosticString(),
+ AnyVUID("VUID-StandaloneSpirv-Base-04781"));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Expected 32-bit int type for Base operand: BitReverse"));
+}
+
TEST_F(ValidateBitwise, OpBitCountSuccess) {
const std::string body = R"(
%val1 = OpBitCount %s32 %u64_1
%val2 = OpBitCount %u32vec2 %s32vec2_12
+%val3 = OpBitCount %s64 %s64_1
)";
CompileSuccessfully(GenerateShaderCode(body).c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
+TEST_F(ValidateBitwise, OpBitCountVulkanSuccess) {
+ const std::string body = R"(
+%val1 = OpBitCount %s32 %u32_1
+%val2 = OpBitCount %u32vec2 %s32vec2_12
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+}
+
TEST_F(ValidateBitwise, OpBitCountWrongResultType) {
const std::string body = R"(
%val1 = OpBitCount %bool %u64_1
@@ -524,11 +607,14 @@
%val1 = OpBitCount %u32 %f64_1
)";
- CompileSuccessfully(GenerateShaderCode(body).c_str());
- ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+ EXPECT_THAT(getDiagnosticString(),
+ AnyVUID("VUID-StandaloneSpirv-Base-04781"));
EXPECT_THAT(
getDiagnosticString(),
- HasSubstr("Expected Base Type to be int scalar or vector: BitCount"));
+ HasSubstr(
+ "Expected int scalar or vector type for Base operand: BitCount"));
}
TEST_F(ValidateBitwise, OpBitCountBaseWrongDimension) {
@@ -544,6 +630,19 @@
"BitCount"));
}
+TEST_F(ValidateBitwise, OpBitCountNot32Vulkan) {
+ const std::string body = R"(
+%val1 = OpBitCount %s64 %s64_1
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+ EXPECT_THAT(getDiagnosticString(),
+ AnyVUID("VUID-StandaloneSpirv-Base-04781"));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Expected 32-bit int type for Base operand: BitCount"));
+}
+
} // namespace
} // namespace val
} // namespace spvtools