Validator: OpGroupNonUniformBallotBitCount validation (#1486)
diff --git a/source/val/validate_non_uniform.cpp b/source/val/validate_non_uniform.cpp
index 89e82c6..94c7c2c 100644
--- a/source/val/validate_non_uniform.cpp
+++ b/source/val/validate_non_uniform.cpp
@@ -63,6 +63,27 @@
   return SPV_SUCCESS;
 }
 
+spv_result_t ValidateGroupNonUniformBallotBitCount(ValidationState_t& _,
+                                                   const Instruction* inst) {
+  // Scope is already checked by ValidateExecutionScope() above.
+
+  const uint32_t result_type = inst->type_id();
+  if (!_.IsUnsignedIntScalarType(result_type)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected Result Type to be an unsigned integer type scalar.";
+  }
+
+  const auto value = inst->GetOperandAs<uint32_t>(4);
+  const auto value_type = _.FindDef(value)->type_id();
+  if (!_.IsUnsignedIntVectorType(value_type) ||
+      _.GetDimension(value_type) != 4) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Value to be a "
+                                                   "vector of four components "
+                                                   "of integer type scalar";
+  }
+  return SPV_SUCCESS;
+}
+
 }  // namespace
 
 // Validates correctness of non-uniform group instructions.
@@ -76,6 +97,13 @@
     }
   }
 
+  switch (opcode) {
+    case SpvOpGroupNonUniformBallotBitCount:
+      return ValidateGroupNonUniformBallotBitCount(_, inst);
+    default:
+      break;
+  }
+
   return SPV_SUCCESS;
 }
 
diff --git a/test/val/val_non_uniform_test.cpp b/test/val/val_non_uniform_test.cpp
index 6ff5c12..0621d9b 100644
--- a/test/val/val_non_uniform_test.cpp
+++ b/test/val/val_non_uniform_test.cpp
@@ -55,8 +55,10 @@
 %func = OpTypeFunction %void
 %bool = OpTypeBool
 %u32 = OpTypeInt 32 0
+%int = OpTypeInt 32 1
 %float = OpTypeFloat 32
 %u32vec4 = OpTypeVector %u32 4
+%u32vec3 = OpTypeVector %u32 3
 
 %true = OpConstantTrue %bool
 %false = OpConstantFalse %bool
@@ -66,6 +68,7 @@
 %float_0 = OpConstant %float 0
 
 %u32vec4_null = OpConstantComposite %u32vec4 %u32_0 %u32_0 %u32_0 %u32_0
+%u32vec3_null = OpConstantComposite %u32vec3 %u32_0 %u32_0 %u32_0
 
 %cross_device = OpConstant %u32 0
 %device = OpConstant %u32 1
@@ -94,8 +97,8 @@
 SpvScope scopes[] = {SpvScopeCrossDevice, SpvScopeDevice, SpvScopeWorkgroup,
                      SpvScopeSubgroup, SpvScopeInvocation};
 
-using GroupNonUniformScope = spvtest::ValidateBase<
-    std::tuple<std::string, std::string, SpvScope, std::string>>;
+using GroupNonUniform = spvtest::ValidateBase<
+    std::tuple<std::string, std::string, SpvScope, std::string, std::string>>;
 
 std::string ConvertScope(SpvScope scope) {
   switch (scope) {
@@ -114,11 +117,12 @@
   }
 }
 
-TEST_P(GroupNonUniformScope, Vulkan1p1) {
+TEST_P(GroupNonUniform, Vulkan1p1) {
   std::string opcode = std::get<0>(GetParam());
   std::string type = std::get<1>(GetParam());
   SpvScope execution_scope = std::get<2>(GetParam());
   std::string args = std::get<3>(GetParam());
+  std::string error = std::get<4>(GetParam());
 
   std::ostringstream sstr;
   sstr << "%result = " << opcode << " ";
@@ -128,22 +132,28 @@
 
   CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_VULKAN_1_1);
   spv_result_t result = ValidateInstructions(SPV_ENV_VULKAN_1_1);
-  if (execution_scope == SpvScopeSubgroup) {
-    EXPECT_EQ(SPV_SUCCESS, result);
+  if (error == "") {
+    if (execution_scope == SpvScopeSubgroup) {
+      EXPECT_EQ(SPV_SUCCESS, result);
+    } else {
+      EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
+      EXPECT_THAT(
+          getDiagnosticString(),
+          HasSubstr(
+              "in Vulkan environment Execution scope is limited to Subgroup"));
+    }
   } else {
     EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
-    EXPECT_THAT(
-        getDiagnosticString(),
-        HasSubstr(
-            "in Vulkan environment Execution scope is limited to Subgroup"));
+    EXPECT_THAT(getDiagnosticString(), HasSubstr(error));
   }
 }
 
-TEST_P(GroupNonUniformScope, Spirv1p3) {
+TEST_P(GroupNonUniform, Spirv1p3) {
   std::string opcode = std::get<0>(GetParam());
   std::string type = std::get<1>(GetParam());
   SpvScope execution_scope = std::get<2>(GetParam());
   std::string args = std::get<3>(GetParam());
+  std::string error = std::get<4>(GetParam());
 
   std::ostringstream sstr;
   sstr << "%result = " << opcode << " ";
@@ -153,99 +163,127 @@
 
   CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_UNIVERSAL_1_3);
   spv_result_t result = ValidateInstructions(SPV_ENV_UNIVERSAL_1_3);
-  if (execution_scope == SpvScopeSubgroup ||
-      execution_scope == SpvScopeWorkgroup) {
-    EXPECT_EQ(SPV_SUCCESS, result);
+  if (error == "") {
+    if (execution_scope == SpvScopeSubgroup ||
+        execution_scope == SpvScopeWorkgroup) {
+      EXPECT_EQ(SPV_SUCCESS, result);
+    } else {
+      EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
+      EXPECT_THAT(
+          getDiagnosticString(),
+          HasSubstr("Execution scope is limited to Subgroup or Workgroup"));
+    }
   } else {
     EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
-    EXPECT_THAT(
-        getDiagnosticString(),
-        HasSubstr("Execution scope is limited to Subgroup or Workgroup"));
+    EXPECT_THAT(getDiagnosticString(), HasSubstr(error));
   }
 }
 
-INSTANTIATE_TEST_CASE_P(GroupNonUniformElect, GroupNonUniformScope,
+INSTANTIATE_TEST_CASE_P(GroupNonUniformElect, GroupNonUniform,
                         Combine(Values("OpGroupNonUniformElect"),
-                                Values("%bool"), ValuesIn(scopes), Values("")));
+                                Values("%bool"), ValuesIn(scopes), Values(""),
+                                Values("")));
 
-INSTANTIATE_TEST_CASE_P(GroupNonUniformVote, GroupNonUniformScope,
+INSTANTIATE_TEST_CASE_P(GroupNonUniformVote, GroupNonUniform,
                         Combine(Values("OpGroupNonUniformAll",
                                        "OpGroupNonUniformAny",
                                        "OpGroupNonUniformAllEqual"),
                                 Values("%bool"), ValuesIn(scopes),
-                                Values("%true")));
+                                Values("%true"), Values("")));
 
-INSTANTIATE_TEST_CASE_P(GroupNonUniformBroadcast, GroupNonUniformScope,
+INSTANTIATE_TEST_CASE_P(GroupNonUniformBroadcast, GroupNonUniform,
                         Combine(Values("OpGroupNonUniformBroadcast"),
                                 Values("%bool"), ValuesIn(scopes),
-                                Values("%true %u32_0")));
+                                Values("%true %u32_0"), Values("")));
 
-INSTANTIATE_TEST_CASE_P(GroupNonUniformBroadcastFirst, GroupNonUniformScope,
+INSTANTIATE_TEST_CASE_P(GroupNonUniformBroadcastFirst, GroupNonUniform,
                         Combine(Values("OpGroupNonUniformBroadcastFirst"),
                                 Values("%bool"), ValuesIn(scopes),
-                                Values("%true")));
+                                Values("%true"), Values("")));
 
-INSTANTIATE_TEST_CASE_P(GroupNonUniformBallot, GroupNonUniformScope,
+INSTANTIATE_TEST_CASE_P(GroupNonUniformBallot, GroupNonUniform,
                         Combine(Values("OpGroupNonUniformBallot"),
                                 Values("%u32vec4"), ValuesIn(scopes),
-                                Values("%true")));
+                                Values("%true"), Values("")));
 
-INSTANTIATE_TEST_CASE_P(GroupNonUniformInverseBallot, GroupNonUniformScope,
+INSTANTIATE_TEST_CASE_P(GroupNonUniformInverseBallot, GroupNonUniform,
                         Combine(Values("OpGroupNonUniformInverseBallot"),
                                 Values("%bool"), ValuesIn(scopes),
-                                Values("%u32vec4_null")));
+                                Values("%u32vec4_null"), Values("")));
 
-INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotBitExtract, GroupNonUniformScope,
+INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotBitExtract, GroupNonUniform,
                         Combine(Values("OpGroupNonUniformBallotBitExtract"),
                                 Values("%bool"), ValuesIn(scopes),
-                                Values("%u32vec4_null %u32_0")));
+                                Values("%u32vec4_null %u32_0"), Values("")));
 
-INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotBitCount, GroupNonUniformScope,
+INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotBitCount, GroupNonUniform,
                         Combine(Values("OpGroupNonUniformBallotBitCount"),
                                 Values("%u32"), ValuesIn(scopes),
-                                Values("Reduce %u32vec4_null")));
+                                Values("Reduce %u32vec4_null"), Values("")));
 
-INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotFind, GroupNonUniformScope,
+INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotFind, GroupNonUniform,
                         Combine(Values("OpGroupNonUniformBallotFindLSB",
                                        "OpGroupNonUniformBallotFindMSB"),
                                 Values("%u32"), ValuesIn(scopes),
-                                Values("%u32vec4_null")));
+                                Values("%u32vec4_null"), Values("")));
 
-INSTANTIATE_TEST_CASE_P(GroupNonUniformShuffle, GroupNonUniformScope,
+INSTANTIATE_TEST_CASE_P(GroupNonUniformShuffle, GroupNonUniform,
                         Combine(Values("OpGroupNonUniformShuffle",
                                        "OpGroupNonUniformShuffleXor",
                                        "OpGroupNonUniformShuffleUp",
                                        "OpGroupNonUniformShuffleDown"),
                                 Values("%u32"), ValuesIn(scopes),
-                                Values("%u32_0 %u32_0")));
+                                Values("%u32_0 %u32_0"), Values("")));
 
 INSTANTIATE_TEST_CASE_P(
-    GroupNonUniformIntegerArithmetic, GroupNonUniformScope,
+    GroupNonUniformIntegerArithmetic, GroupNonUniform,
     Combine(Values("OpGroupNonUniformIAdd", "OpGroupNonUniformIMul",
                    "OpGroupNonUniformSMin", "OpGroupNonUniformUMin",
                    "OpGroupNonUniformSMax", "OpGroupNonUniformUMax",
                    "OpGroupNonUniformBitwiseAnd", "OpGroupNonUniformBitwiseOr",
                    "OpGroupNonUniformBitwiseXor"),
-            Values("%u32"), ValuesIn(scopes), Values("Reduce %u32_0")));
+            Values("%u32"), ValuesIn(scopes), Values("Reduce %u32_0"),
+            Values("")));
 
 INSTANTIATE_TEST_CASE_P(
-    GroupNonUniformFloatArithmetic, GroupNonUniformScope,
+    GroupNonUniformFloatArithmetic, GroupNonUniform,
     Combine(Values("OpGroupNonUniformFAdd", "OpGroupNonUniformFMul",
                    "OpGroupNonUniformFMin", "OpGroupNonUniformFMax"),
-            Values("%float"), ValuesIn(scopes), Values("Reduce %float_0")));
+            Values("%float"), ValuesIn(scopes), Values("Reduce %float_0"),
+            Values("")));
 
-INSTANTIATE_TEST_CASE_P(GroupNonUniformLogicalArithmetic, GroupNonUniformScope,
+INSTANTIATE_TEST_CASE_P(GroupNonUniformLogicalArithmetic, GroupNonUniform,
                         Combine(Values("OpGroupNonUniformLogicalAnd",
                                        "OpGroupNonUniformLogicalOr",
                                        "OpGroupNonUniformLogicalXor"),
                                 Values("%bool"), ValuesIn(scopes),
-                                Values("Reduce %true")));
+                                Values("Reduce %true"), Values("")));
 
-INSTANTIATE_TEST_CASE_P(GroupNonUniformQuad, GroupNonUniformScope,
+INSTANTIATE_TEST_CASE_P(GroupNonUniformQuad, GroupNonUniform,
                         Combine(Values("OpGroupNonUniformQuadBroadcast",
                                        "OpGroupNonUniformQuadSwap"),
                                 Values("%u32"), ValuesIn(scopes),
-                                Values("%u32_0 %u32_0")));
+                                Values("%u32_0 %u32_0"), Values("")));
+
+INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotBitCountScope, GroupNonUniform,
+                        Combine(Values("OpGroupNonUniformBallotBitCount"),
+                                Values("%u32"), ValuesIn(scopes),
+                                Values("Reduce %u32vec4_null"), Values("")));
+
+INSTANTIATE_TEST_CASE_P(
+    GroupNonUniformBallotBitCountBadResultType, GroupNonUniform,
+    Combine(
+        Values("OpGroupNonUniformBallotBitCount"), Values("%float", "%int"),
+        Values(SpvScopeSubgroup), Values("Reduce %u32vec4_null"),
+        Values("Expected Result Type to be an unsigned integer type scalar.")));
+
+INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotBitCountBadValue, GroupNonUniform,
+                        Combine(Values("OpGroupNonUniformBallotBitCount"),
+                                Values("%u32"), Values(SpvScopeSubgroup),
+                                Values("Reduce %u32vec3_null", "Reduce %u32_0",
+                                       "Reduce %float_0"),
+                                Values("Expected Value to be a vector of four "
+                                       "components of integer type scalar")));
 
 }  // namespace
 }  // namespace val