spirv-val: Add Vulkan 32-bit bit op Base (#4758)
diff --git a/source/val/validate_bitwise.cpp b/source/val/validate_bitwise.cpp index d46b3fc..e6e97c4 100644 --- a/source/val/validate_bitwise.cpp +++ b/source/val/validate_bitwise.cpp
@@ -14,16 +14,48 @@ // Validates correctness of bitwise instructions. -#include "source/val/validate.h" - #include "source/diagnostic.h" #include "source/opcode.h" +#include "source/spirv_target_env.h" #include "source/val/instruction.h" +#include "source/val/validate.h" #include "source/val/validation_state.h" namespace spvtools { namespace val { +// Validates when base and result need to be the same type +spv_result_t ValidateBaseType(ValidationState_t& _, const Instruction* inst, + const uint32_t base_type) { + const SpvOp opcode = inst->opcode(); + + if (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << _.VkErrorID(4781) + << "Expected int scalar or vector type for Base operand: " + << spvOpcodeString(opcode); + } + + // Vulkan has a restriction to 32 bit for base + if (spvIsVulkanEnv(_.context()->target_env)) { + if (_.GetBitWidth(base_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << _.VkErrorID(4781) + << "Expected 32-bit int type for Base operand: " + << spvOpcodeString(opcode); + } + } + + // OpBitCount just needs same number of components + if (base_type != inst->type_id() && opcode != SpvOpBitCount) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Base Type to be equal to Result Type: " + << spvOpcodeString(opcode); + } + + return SPV_SUCCESS; +} + // Validates correctness of bitwise instructions. spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst) { const SpvOp opcode = inst->opcode(); @@ -109,20 +141,14 @@ } case SpvOpBitFieldInsert: { - if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "Expected int scalar or vector type as Result Type: " - << spvOpcodeString(opcode); - const uint32_t base_type = _.GetOperandTypeId(inst, 2); const uint32_t insert_type = _.GetOperandTypeId(inst, 3); const uint32_t offset_type = _.GetOperandTypeId(inst, 4); const uint32_t count_type = _.GetOperandTypeId(inst, 5); - if (base_type != result_type) - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "Expected Base Type to be equal to Result Type: " - << spvOpcodeString(opcode); + if (spv_result_t error = ValidateBaseType(_, inst, base_type)) { + return error; + } if (insert_type != result_type) return _.diag(SPV_ERROR_INVALID_DATA, inst) @@ -143,19 +169,13 @@ case SpvOpBitFieldSExtract: case SpvOpBitFieldUExtract: { - if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "Expected int scalar or vector type as Result Type: " - << spvOpcodeString(opcode); - const uint32_t base_type = _.GetOperandTypeId(inst, 2); const uint32_t offset_type = _.GetOperandTypeId(inst, 3); const uint32_t count_type = _.GetOperandTypeId(inst, 4); - if (base_type != result_type) - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "Expected Base Type to be equal to Result Type: " - << spvOpcodeString(opcode); + if (spv_result_t error = ValidateBaseType(_, inst, base_type)) { + return error; + } if (!offset_type || !_.IsIntScalarType(offset_type)) return _.diag(SPV_ERROR_INVALID_DATA, inst) @@ -170,17 +190,12 @@ } case SpvOpBitReverse: { - if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "Expected int scalar or vector type as Result Type: " - << spvOpcodeString(opcode); - const uint32_t base_type = _.GetOperandTypeId(inst, 2); - if (base_type != result_type) - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "Expected Base Type to be equal to Result Type: " - << spvOpcodeString(opcode); + if (spv_result_t error = ValidateBaseType(_, inst, base_type)) { + return error; + } + break; } @@ -191,15 +206,13 @@ << spvOpcodeString(opcode); const uint32_t base_type = _.GetOperandTypeId(inst, 2); - if (!base_type || - (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type))) - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "Expected Base Type to be int scalar or vector: " - << spvOpcodeString(opcode); - const uint32_t base_dimension = _.GetDimension(base_type); const uint32_t result_dimension = _.GetDimension(result_type); + if (spv_result_t error = ValidateBaseType(_, inst, base_type)) { + return error; + } + if (base_dimension != result_dimension) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Base dimension to be equal to Result Type "
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp index 65c1dd6..0be47b9 100644 --- a/source/val/validation_state.cpp +++ b/source/val/validation_state.cpp
@@ -1892,6 +1892,8 @@ return VUID_WRAP(VUID-StandaloneSpirv-OpImage-04777); case 4780: return VUID_WRAP(VUID-StandaloneSpirv-Result-04780); + case 4781: + return VUID_WRAP(VUID-StandaloneSpirv-Base-04781); case 4915: return VUID_WRAP(VUID-StandaloneSpirv-Location-04915); case 4916:
diff --git a/test/val/val_bitwise_test.cpp b/test/val/val_bitwise_test.cpp index 1001def..bebaa84 100644 --- a/test/val/val_bitwise_test.cpp +++ b/test/val/val_bitwise_test.cpp
@@ -340,6 +340,16 @@ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } +TEST_F(ValidateBitwise, OpBitFieldInsertVulkanSuccess) { + const std::string body = R"( +%val1 = OpBitFieldInsert %u32 %u32_1 %u32_2 %s32_1 %s32_2 +%val2 = OpBitFieldInsert %s32vec2 %s32vec2_12 %s32vec2_12 %s32_1 %u32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); +} + TEST_F(ValidateBitwise, OpBitFieldInsertWrongResultType) { const std::string body = R"( %val1 = OpBitFieldInsert %bool %u64_1 %u64_2 %s32_1 %s32_2 @@ -350,7 +360,7 @@ EXPECT_THAT( getDiagnosticString(), HasSubstr( - "Expected int scalar or vector type as Result Type: BitFieldInsert")); + "Expected Base Type to be equal to Result Type: BitFieldInsert")); } TEST_F(ValidateBitwise, OpBitFieldInsertWrongBaseType) { @@ -403,6 +413,20 @@ HasSubstr("Expected Count Type to be int scalar: BitFieldInsert")); } +TEST_F(ValidateBitwise, OpBitFieldInsertNot32Vulkan) { + const std::string body = R"( +%val1 = OpBitFieldInsert %u64 %u64_1 %u64_2 %s32_1 %s32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + AnyVUID("VUID-StandaloneSpirv-Base-04781")); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected 32-bit int type for Base operand: BitFieldInsert")); +} + TEST_F(ValidateBitwise, OpBitFieldSExtractSuccess) { const std::string body = R"( %val1 = OpBitFieldSExtract %u64 %u64_1 %s32_1 %s32_2 @@ -413,6 +437,16 @@ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } +TEST_F(ValidateBitwise, OpBitFieldSExtractVulkanSuccess) { + const std::string body = R"( +%val1 = OpBitFieldSExtract %u32 %u32_1 %s32_1 %s32_2 +%val2 = OpBitFieldSExtract %s32vec2 %s32vec2_12 %s32_1 %u32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); +} + TEST_F(ValidateBitwise, OpBitFieldSExtractWrongResultType) { const std::string body = R"( %val1 = OpBitFieldSExtract %bool %u64_1 %s32_1 %s32_2 @@ -420,9 +454,10 @@ CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - HasSubstr("Expected int scalar or vector type as Result Type: " - "BitFieldSExtract")); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected Base Type to be equal to Result Type: BitFieldSExtract")); } TEST_F(ValidateBitwise, OpBitFieldSExtractWrongBaseType) { @@ -462,6 +497,20 @@ HasSubstr("Expected Count Type to be int scalar: BitFieldSExtract")); } +TEST_F(ValidateBitwise, OpBitFieldSExtractNot32Vulkan) { + const std::string body = R"( +%val1 = OpBitFieldSExtract %u64 %u64_1 %s32_1 %s32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + AnyVUID("VUID-StandaloneSpirv-Base-04781")); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected 32-bit int type for Base operand: BitFieldSExtract")); +} + TEST_F(ValidateBitwise, OpBitReverseSuccess) { const std::string body = R"( %val1 = OpBitReverse %u64 %u64_1 @@ -472,6 +521,16 @@ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } +TEST_F(ValidateBitwise, OpBitReverseVulkanSuccess) { + const std::string body = R"( +%val1 = OpBitReverse %u32 %u32_1 +%val2 = OpBitReverse %s32vec2 %s32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); +} + TEST_F(ValidateBitwise, OpBitReverseWrongResultType) { const std::string body = R"( %val1 = OpBitReverse %bool %u64_1 @@ -481,8 +540,7 @@ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr( - "Expected int scalar or vector type as Result Type: BitReverse")); + HasSubstr("Expected Base Type to be equal to Result Type: BitReverse")); } TEST_F(ValidateBitwise, OpBitReverseWrongBaseType) { @@ -497,16 +555,41 @@ HasSubstr("Expected Base Type to be equal to Result Type: BitReverse")); } +TEST_F(ValidateBitwise, OpBitReverseNot32Vulkan) { + const std::string body = R"( +%val1 = OpBitReverse %u64 %u64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + AnyVUID("VUID-StandaloneSpirv-Base-04781")); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected 32-bit int type for Base operand: BitReverse")); +} + TEST_F(ValidateBitwise, OpBitCountSuccess) { const std::string body = R"( %val1 = OpBitCount %s32 %u64_1 %val2 = OpBitCount %u32vec2 %s32vec2_12 +%val3 = OpBitCount %s64 %s64_1 )"; CompileSuccessfully(GenerateShaderCode(body).c_str()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } +TEST_F(ValidateBitwise, OpBitCountVulkanSuccess) { + const std::string body = R"( +%val1 = OpBitCount %s32 %u32_1 +%val2 = OpBitCount %u32vec2 %s32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); +} + TEST_F(ValidateBitwise, OpBitCountWrongResultType) { const std::string body = R"( %val1 = OpBitCount %bool %u64_1 @@ -524,11 +607,14 @@ %val1 = OpBitCount %u32 %f64_1 )"; - CompileSuccessfully(GenerateShaderCode(body).c_str()); - ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + AnyVUID("VUID-StandaloneSpirv-Base-04781")); EXPECT_THAT( getDiagnosticString(), - HasSubstr("Expected Base Type to be int scalar or vector: BitCount")); + HasSubstr( + "Expected int scalar or vector type for Base operand: BitCount")); } TEST_F(ValidateBitwise, OpBitCountBaseWrongDimension) { @@ -544,6 +630,19 @@ "BitCount")); } +TEST_F(ValidateBitwise, OpBitCountNot32Vulkan) { + const std::string body = R"( +%val1 = OpBitCount %s64 %s64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + AnyVUID("VUID-StandaloneSpirv-Base-04781")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected 32-bit int type for Base operand: BitCount")); +} + } // namespace } // namespace val } // namespace spvtools