Validate that selections are structured (#2962)

* Validate that selections are structured

WIP

* new checks that switch and conditional branch are proceeded by a
selection merge where necessary

* Don't consider unreachable blocks

* Add some tests

* Changed how labels are marked as seen

* Moved check to more appropriate place
* Labels are now marked as seen when there are encountered in a
terminator instead of when the block is checked
* more tests

* more tests

* Method comment

* new test for a bad case
diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp
index 8d8839e..4801fc5 100644
--- a/source/val/validate_cfg.cpp
+++ b/source/val/validate_cfg.cpp
@@ -590,9 +590,68 @@
   return SPV_SUCCESS;
 }
 
+// Validates that all CFG divergences (i.e. conditional branch or switch) are
+// structured correctly. Either divergence is preceded by a merge instruction
+// or the divergence introduces at most one unseen label.
+spv_result_t ValidateStructuredSelections(
+    ValidationState_t& _, const std::vector<const BasicBlock*>& postorder) {
+  std::unordered_set<uint32_t> seen;
+  for (auto iter = postorder.rbegin(); iter != postorder.rend(); ++iter) {
+    const auto* block = *iter;
+    const auto* terminator = block->terminator();
+    if (!terminator) continue;
+    const auto index = terminator - &_.ordered_instructions()[0];
+    auto* merge = &_.ordered_instructions()[index - 1];
+    // Marks merges and continues as seen.
+    if (merge->opcode() == SpvOpSelectionMerge) {
+      seen.insert(merge->GetOperandAs<uint32_t>(0));
+    } else if (merge->opcode() == SpvOpLoopMerge) {
+      seen.insert(merge->GetOperandAs<uint32_t>(0));
+      seen.insert(merge->GetOperandAs<uint32_t>(1));
+    } else {
+      // Only track the pointer if it is a merge instruction.
+      merge = nullptr;
+    }
+
+    // Skip unreachable blocks.
+    if (!block->reachable()) continue;
+
+    if (terminator->opcode() == SpvOpBranchConditional) {
+      const auto true_label = terminator->GetOperandAs<uint32_t>(1);
+      const auto false_label = terminator->GetOperandAs<uint32_t>(2);
+      // Mark the upcoming blocks as seen now, but only error out if this block
+      // was missing a merge instruction and both labels hadn't been seen
+      // previously.
+      const bool both_unseen =
+          seen.insert(true_label).second && seen.insert(false_label).second;
+      if (!merge && both_unseen) {
+        return _.diag(SPV_ERROR_INVALID_CFG, terminator)
+               << "Selection must be structured";
+      }
+    } else if (terminator->opcode() == SpvOpSwitch) {
+      uint32_t count = 0;
+      // Mark the targets as seen now, but only error out if this block was
+      // missing a merge instruction and there were multiple unseen labels.
+      for (uint32_t i = 1; i < terminator->operands().size(); i += 2) {
+        const auto target = terminator->GetOperandAs<uint32_t>(i);
+        if (seen.insert(target).second) {
+          count++;
+        }
+      }
+      if (!merge && count > 1) {
+        return _.diag(SPV_ERROR_INVALID_CFG, terminator)
+               << "Selection must be structured";
+      }
+    }
+  }
+
+  return SPV_SUCCESS;
+}
+
 spv_result_t StructuredControlFlowChecks(
     ValidationState_t& _, Function* function,
-    const std::vector<std::pair<uint32_t, uint32_t>>& back_edges) {
+    const std::vector<std::pair<uint32_t, uint32_t>>& back_edges,
+    const std::vector<const BasicBlock*>& postorder) {
   /// Check all backedges target only loop headers and have exactly one
   /// back-edge branching to it
 
@@ -709,6 +768,10 @@
     }
   }
 
+  if (auto error = ValidateStructuredSelections(_, postorder)) {
+    return error;
+  }
+
   return SPV_SUCCESS;
 }
 
@@ -931,7 +994,8 @@
 
     /// Structured control flow checks are only required for shader capabilities
     if (_.HasCapability(SpvCapabilityShader)) {
-      if (auto error = StructuredControlFlowChecks(_, &function, back_edges))
+      if (auto error =
+              StructuredControlFlowChecks(_, &function, back_edges, postorder))
         return error;
     }
   }
diff --git a/test/opt/if_conversion_test.cpp b/test/opt/if_conversion_test.cpp
index 03932a9..aa5adea 100644
--- a/test/opt/if_conversion_test.cpp
+++ b/test/opt/if_conversion_test.cpp
@@ -375,14 +375,12 @@
 OpSelectionMerge %12 None
 OpBranchConditional %true %13 %12
 %13 = OpLabel
-OpBranchConditional %true %14 %15
+OpBranchConditional %true %14 %12
 %14 = OpLabel
 OpBranch %12
-%15 = OpLabel
-OpBranch %12
 %12 = OpLabel
-%16 = OpPhi %uint %uint_0 %11 %uint_0 %14 %uint_1 %15
-OpStore %2 %16
+%15 = OpPhi %uint %uint_0 %11 %uint_0 %13 %uint_1 %14
+OpStore %2 %15
 OpReturn
 OpFunctionEnd
 )";
diff --git a/test/val/val_cfg_test.cpp b/test/val/val_cfg_test.cpp
index 547eb57..f06f36c 100644
--- a/test/val/val_cfg_test.cpp
+++ b/test/val/val_cfg_test.cpp
@@ -3026,6 +3026,7 @@
 %undef = OpUndef %bool
 %func = OpFunction %void None %void_fn
 %entry = OpLabel
+OpSelectionMerge %block None
 OpBranchConditional %undef %block %unreachable
 %block = OpLabel
 OpReturn
@@ -3049,6 +3050,7 @@
 %undef = OpUndef %int
 %func = OpFunction %void None %void_fn
 %entry = OpLabel
+OpSelectionMerge %block1 None
 OpSwitch %undef %block1 0 %unreachable 1 %block2
 %block1 = OpLabel
 OpReturn
@@ -3748,7 +3750,355 @@
   EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
-/// TODO(umar): Nested CFG constructs
+TEST_F(ValidateCFG, MissingMergeConditionalBranchBad) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%bool = OpTypeBool
+%undef = OpUndef %bool
+%func = OpFunction %void None %void_fn
+%entry = OpLabel
+OpBranchConditional %undef %then %else
+%then = OpLabel
+OpReturn
+%else = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(), HasSubstr("Selection must be structured"));
+}
+
+TEST_F(ValidateCFG, MissingMergeSwitchBad) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%int = OpTypeInt 32 0
+%undef = OpUndef %int
+%func = OpFunction %void None %void_fn
+%entry = OpLabel
+OpSwitch %undef %then 0 %else
+%then = OpLabel
+OpReturn
+%else = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(), HasSubstr("Selection must be structured"));
+}
+
+TEST_F(ValidateCFG, MissingMergeSwitchBad2) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%int = OpTypeInt 32 0
+%undef = OpUndef %int
+%func = OpFunction %void None %void_fn
+%entry = OpLabel
+OpSwitch %undef %then 0 %then 1 %then 2 %else
+%then = OpLabel
+OpReturn
+%else = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(), HasSubstr("Selection must be structured"));
+}
+
+TEST_F(ValidateCFG, MissingMergeOneBranchToMergeGood) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%bool = OpTypeBool
+%undef = OpUndef %bool
+%func = OpFunction %void None %void_fn
+%entry = OpLabel
+OpSelectionMerge %b3 None
+OpBranchConditional %undef %b1 %b2
+%b1 = OpLabel
+OpBranchConditional %undef %b2 %b3
+%b2 = OpLabel
+OpBranch %b3
+%b3 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, MissingMergeSameTargetConditionalBranchGood) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%bool = OpTypeBool
+%undef = OpUndef %bool
+%func = OpFunction %void None %void_fn
+%entry = OpLabel
+OpBranchConditional %undef %then %then
+%then = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, MissingMergeOneTargetSwitchGood) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%int = OpTypeInt 32 0
+%undef = OpUndef %int
+%func = OpFunction %void None %void_fn
+%entry = OpLabel
+OpSwitch %undef %then 0 %then 1 %then
+%then = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, MissingMergeOneUnseenTargetSwitchGood) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%int = OpTypeInt 32 0
+%undef_int = OpUndef %int
+%bool = OpTypeBool
+%undef_bool = OpUndef %bool
+%func = OpFunction %void None %void_fn
+%entry = OpLabel
+OpSelectionMerge %merge None
+OpBranchConditional %undef_bool %merge %b1
+%b1 = OpLabel
+OpSwitch %undef_int %b2 0 %b2 1 %merge 2 %b2
+%b2 = OpLabel
+OpBranch %merge
+%merge = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, MissingMergeLoopBreakGood) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%bool = OpTypeBool
+%undef = OpUndef %bool
+%func = OpFunction %void None %void_fn
+%entry = OpLabel
+OpBranch %loop
+%loop = OpLabel
+OpLoopMerge %exit %continue None
+OpBranch %body
+%body = OpLabel
+OpBranchConditional %undef %body2 %exit
+%body2 = OpLabel
+OpBranch %continue
+%continue = OpLabel
+OpBranch %loop
+%exit = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, MissingMergeLoopContinueGood) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%bool = OpTypeBool
+%undef = OpUndef %bool
+%func = OpFunction %void None %void_fn
+%entry = OpLabel
+OpBranch %loop
+%loop = OpLabel
+OpLoopMerge %exit %continue None
+OpBranch %body
+%body = OpLabel
+OpBranchConditional %undef %body2 %continue
+%body2 = OpLabel
+OpBranch %continue
+%continue = OpLabel
+OpBranch %loop
+%exit = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, MissingMergeSwitchBreakGood) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%bool = OpTypeBool
+%undef = OpUndef %bool
+%int = OpTypeInt 32 0
+%int_0 = OpConstant %int 0
+%func = OpFunction %void None %void_fn
+%entry = OpLabel
+OpSelectionMerge %merge None
+OpSwitch %int_0 %merge 1 %b1
+%b1 = OpLabel
+OpBranchConditional %undef %merge %b2
+%b2 = OpLabel
+OpBranch %merge
+%merge = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, MissingMergeSwitchFallThroughGood) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%bool = OpTypeBool
+%undef = OpUndef %bool
+%int = OpTypeInt 32 0
+%int_0 = OpConstant %int 0
+%func = OpFunction %void None %void_fn
+%entry = OpLabel
+OpSelectionMerge %merge None
+OpSwitch %int_0 %b1 1 %b2
+%b1 = OpLabel
+OpBranchConditional %undef %b3 %b2
+%b2 = OpLabel
+OpBranch %merge
+%b3 = OpLabel
+OpBranch %merge
+%merge = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, MissingMergeInALoopBad) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%bool = OpTypeBool
+%undef = OpUndef %bool
+%func = OpFunction %void None %void_fn
+%entry = OpLabel
+OpBranch %loop
+%loop = OpLabel
+OpLoopMerge %exit %continue None
+OpBranch %body
+%body = OpLabel
+OpBranchConditional %undef %b1 %b2
+%b1 = OpLabel
+OpBranch %exit
+%b2 = OpLabel
+OpBranch %continue
+%continue = OpLabel
+OpBranch %loop
+%exit = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(), HasSubstr("Selection must be structured"));
+}
+
+TEST_F(ValidateCFG, MissingMergeCrissCrossBad) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%bool = OpTypeBool
+%undef = OpUndef %bool
+%func = OpFunction %void None %void_fn
+%entry = OpLabel
+OpSelectionMerge %merge None
+OpBranchConditional %undef %b1 %b2
+%b1 = OpLabel
+OpBranchConditional %undef %b3 %b4
+%b2 = OpLabel
+OpBranchConditional %undef %b3 %b4
+%b3 = OpLabel
+OpBranch %merge
+%b4 = OpLabel
+OpBranch %merge
+%merge = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(), HasSubstr("Selection must be structured"));
+}
 
 }  // namespace
 }  // namespace val