spirv-val: Add Vulkan 32-bit bit op Base (#4758)

diff --git a/source/val/validate_bitwise.cpp b/source/val/validate_bitwise.cpp
index d46b3fc..e6e97c4 100644
--- a/source/val/validate_bitwise.cpp
+++ b/source/val/validate_bitwise.cpp
@@ -14,16 +14,48 @@
 
 // Validates correctness of bitwise instructions.
 
-#include "source/val/validate.h"
-
 #include "source/diagnostic.h"
 #include "source/opcode.h"
+#include "source/spirv_target_env.h"
 #include "source/val/instruction.h"
+#include "source/val/validate.h"
 #include "source/val/validation_state.h"
 
 namespace spvtools {
 namespace val {
 
+// Validates when base and result need to be the same type
+spv_result_t ValidateBaseType(ValidationState_t& _, const Instruction* inst,
+                              const uint32_t base_type) {
+  const SpvOp opcode = inst->opcode();
+
+  if (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << _.VkErrorID(4781)
+           << "Expected int scalar or vector type for Base operand: "
+           << spvOpcodeString(opcode);
+  }
+
+  // Vulkan has a restriction to 32 bit for base
+  if (spvIsVulkanEnv(_.context()->target_env)) {
+    if (_.GetBitWidth(base_type) != 32) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << _.VkErrorID(4781)
+             << "Expected 32-bit int type for Base operand: "
+             << spvOpcodeString(opcode);
+    }
+  }
+
+  // OpBitCount just needs same number of components
+  if (base_type != inst->type_id() && opcode != SpvOpBitCount) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected Base Type to be equal to Result Type: "
+           << spvOpcodeString(opcode);
+  }
+
+  return SPV_SUCCESS;
+}
+
 // Validates correctness of bitwise instructions.
 spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst) {
   const SpvOp opcode = inst->opcode();
@@ -109,20 +141,14 @@
     }
 
     case SpvOpBitFieldInsert: {
-      if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Expected int scalar or vector type as Result Type: "
-               << spvOpcodeString(opcode);
-
       const uint32_t base_type = _.GetOperandTypeId(inst, 2);
       const uint32_t insert_type = _.GetOperandTypeId(inst, 3);
       const uint32_t offset_type = _.GetOperandTypeId(inst, 4);
       const uint32_t count_type = _.GetOperandTypeId(inst, 5);
 
-      if (base_type != result_type)
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Expected Base Type to be equal to Result Type: "
-               << spvOpcodeString(opcode);
+      if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
+        return error;
+      }
 
       if (insert_type != result_type)
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -143,19 +169,13 @@
 
     case SpvOpBitFieldSExtract:
     case SpvOpBitFieldUExtract: {
-      if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Expected int scalar or vector type as Result Type: "
-               << spvOpcodeString(opcode);
-
       const uint32_t base_type = _.GetOperandTypeId(inst, 2);
       const uint32_t offset_type = _.GetOperandTypeId(inst, 3);
       const uint32_t count_type = _.GetOperandTypeId(inst, 4);
 
-      if (base_type != result_type)
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Expected Base Type to be equal to Result Type: "
-               << spvOpcodeString(opcode);
+      if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
+        return error;
+      }
 
       if (!offset_type || !_.IsIntScalarType(offset_type))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -170,17 +190,12 @@
     }
 
     case SpvOpBitReverse: {
-      if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Expected int scalar or vector type as Result Type: "
-               << spvOpcodeString(opcode);
-
       const uint32_t base_type = _.GetOperandTypeId(inst, 2);
 
-      if (base_type != result_type)
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Expected Base Type to be equal to Result Type: "
-               << spvOpcodeString(opcode);
+      if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
+        return error;
+      }
+
       break;
     }
 
@@ -191,15 +206,13 @@
                << spvOpcodeString(opcode);
 
       const uint32_t base_type = _.GetOperandTypeId(inst, 2);
-      if (!base_type ||
-          (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type)))
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Expected Base Type to be int scalar or vector: "
-               << spvOpcodeString(opcode);
-
       const uint32_t base_dimension = _.GetDimension(base_type);
       const uint32_t result_dimension = _.GetDimension(result_type);
 
+      if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
+        return error;
+      }
+
       if (base_dimension != result_dimension)
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected Base dimension to be equal to Result Type "
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 65c1dd6..0be47b9 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -1892,6 +1892,8 @@
       return VUID_WRAP(VUID-StandaloneSpirv-OpImage-04777);
     case 4780:
       return VUID_WRAP(VUID-StandaloneSpirv-Result-04780);
+    case 4781:
+      return VUID_WRAP(VUID-StandaloneSpirv-Base-04781);
     case 4915:
       return VUID_WRAP(VUID-StandaloneSpirv-Location-04915);
     case 4916:
diff --git a/test/val/val_bitwise_test.cpp b/test/val/val_bitwise_test.cpp
index 1001def..bebaa84 100644
--- a/test/val/val_bitwise_test.cpp
+++ b/test/val/val_bitwise_test.cpp
@@ -340,6 +340,16 @@
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
+TEST_F(ValidateBitwise, OpBitFieldInsertVulkanSuccess) {
+  const std::string body = R"(
+%val1 = OpBitFieldInsert %u32 %u32_1 %u32_2 %s32_1 %s32_2
+%val2 = OpBitFieldInsert %s32vec2 %s32vec2_12 %s32vec2_12 %s32_1 %u32_2
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+}
+
 TEST_F(ValidateBitwise, OpBitFieldInsertWrongResultType) {
   const std::string body = R"(
 %val1 = OpBitFieldInsert %bool %u64_1 %u64_2 %s32_1 %s32_2
@@ -350,7 +360,7 @@
   EXPECT_THAT(
       getDiagnosticString(),
       HasSubstr(
-          "Expected int scalar or vector type as Result Type: BitFieldInsert"));
+          "Expected Base Type to be equal to Result Type: BitFieldInsert"));
 }
 
 TEST_F(ValidateBitwise, OpBitFieldInsertWrongBaseType) {
@@ -403,6 +413,20 @@
       HasSubstr("Expected Count Type to be int scalar: BitFieldInsert"));
 }
 
+TEST_F(ValidateBitwise, OpBitFieldInsertNot32Vulkan) {
+  const std::string body = R"(
+%val1 = OpBitFieldInsert %u64 %u64_1 %u64_2 %s32_1 %s32_2
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+  EXPECT_THAT(getDiagnosticString(),
+              AnyVUID("VUID-StandaloneSpirv-Base-04781"));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Expected 32-bit int type for Base operand: BitFieldInsert"));
+}
+
 TEST_F(ValidateBitwise, OpBitFieldSExtractSuccess) {
   const std::string body = R"(
 %val1 = OpBitFieldSExtract %u64 %u64_1 %s32_1 %s32_2
@@ -413,6 +437,16 @@
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
+TEST_F(ValidateBitwise, OpBitFieldSExtractVulkanSuccess) {
+  const std::string body = R"(
+%val1 = OpBitFieldSExtract %u32 %u32_1 %s32_1 %s32_2
+%val2 = OpBitFieldSExtract %s32vec2 %s32vec2_12 %s32_1 %u32_2
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+}
+
 TEST_F(ValidateBitwise, OpBitFieldSExtractWrongResultType) {
   const std::string body = R"(
 %val1 = OpBitFieldSExtract %bool %u64_1 %s32_1 %s32_2
@@ -420,9 +454,10 @@
 
   CompileSuccessfully(GenerateShaderCode(body).c_str());
   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
-  EXPECT_THAT(getDiagnosticString(),
-              HasSubstr("Expected int scalar or vector type as Result Type: "
-                        "BitFieldSExtract"));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr(
+          "Expected Base Type to be equal to Result Type: BitFieldSExtract"));
 }
 
 TEST_F(ValidateBitwise, OpBitFieldSExtractWrongBaseType) {
@@ -462,6 +497,20 @@
       HasSubstr("Expected Count Type to be int scalar: BitFieldSExtract"));
 }
 
+TEST_F(ValidateBitwise, OpBitFieldSExtractNot32Vulkan) {
+  const std::string body = R"(
+%val1 = OpBitFieldSExtract %u64 %u64_1 %s32_1 %s32_2
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+  EXPECT_THAT(getDiagnosticString(),
+              AnyVUID("VUID-StandaloneSpirv-Base-04781"));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Expected 32-bit int type for Base operand: BitFieldSExtract"));
+}
+
 TEST_F(ValidateBitwise, OpBitReverseSuccess) {
   const std::string body = R"(
 %val1 = OpBitReverse %u64 %u64_1
@@ -472,6 +521,16 @@
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
+TEST_F(ValidateBitwise, OpBitReverseVulkanSuccess) {
+  const std::string body = R"(
+%val1 = OpBitReverse %u32 %u32_1
+%val2 = OpBitReverse %s32vec2 %s32vec2_12
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+}
+
 TEST_F(ValidateBitwise, OpBitReverseWrongResultType) {
   const std::string body = R"(
 %val1 = OpBitReverse %bool %u64_1
@@ -481,8 +540,7 @@
   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
   EXPECT_THAT(
       getDiagnosticString(),
-      HasSubstr(
-          "Expected int scalar or vector type as Result Type: BitReverse"));
+      HasSubstr("Expected Base Type to be equal to Result Type: BitReverse"));
 }
 
 TEST_F(ValidateBitwise, OpBitReverseWrongBaseType) {
@@ -497,16 +555,41 @@
       HasSubstr("Expected Base Type to be equal to Result Type: BitReverse"));
 }
 
+TEST_F(ValidateBitwise, OpBitReverseNot32Vulkan) {
+  const std::string body = R"(
+%val1 = OpBitReverse %u64 %u64_1
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+  EXPECT_THAT(getDiagnosticString(),
+              AnyVUID("VUID-StandaloneSpirv-Base-04781"));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Expected 32-bit int type for Base operand: BitReverse"));
+}
+
 TEST_F(ValidateBitwise, OpBitCountSuccess) {
   const std::string body = R"(
 %val1 = OpBitCount %s32 %u64_1
 %val2 = OpBitCount %u32vec2 %s32vec2_12
+%val3 = OpBitCount %s64 %s64_1
 )";
 
   CompileSuccessfully(GenerateShaderCode(body).c_str());
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
+TEST_F(ValidateBitwise, OpBitCountVulkanSuccess) {
+  const std::string body = R"(
+%val1 = OpBitCount %s32 %u32_1
+%val2 = OpBitCount %u32vec2 %s32vec2_12
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+}
+
 TEST_F(ValidateBitwise, OpBitCountWrongResultType) {
   const std::string body = R"(
 %val1 = OpBitCount %bool %u64_1
@@ -524,11 +607,14 @@
 %val1 = OpBitCount %u32 %f64_1
 )";
 
-  CompileSuccessfully(GenerateShaderCode(body).c_str());
-  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+  EXPECT_THAT(getDiagnosticString(),
+              AnyVUID("VUID-StandaloneSpirv-Base-04781"));
   EXPECT_THAT(
       getDiagnosticString(),
-      HasSubstr("Expected Base Type to be int scalar or vector: BitCount"));
+      HasSubstr(
+          "Expected int scalar or vector type for Base operand: BitCount"));
 }
 
 TEST_F(ValidateBitwise, OpBitCountBaseWrongDimension) {
@@ -544,6 +630,19 @@
                 "BitCount"));
 }
 
+TEST_F(ValidateBitwise, OpBitCountNot32Vulkan) {
+  const std::string body = R"(
+%val1 = OpBitCount %s64 %s64_1
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str(), SPV_ENV_VULKAN_1_0);
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+  EXPECT_THAT(getDiagnosticString(),
+              AnyVUID("VUID-StandaloneSpirv-Base-04781"));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Expected 32-bit int type for Base operand: BitCount"));
+}
+
 }  // namespace
 }  // namespace val
 }  // namespace spvtools