Validate construct exits (#2459)

Validate structured exits from constructs

* Add checks that exits from a construct are valid
* Add Construct::IsStructuredExit()
 * uses specific rules for each type of construct
* Added a test and check for #2213
* Adding tests for bad loop and continue exits
* Fix identification of continue block that prevented some selections
from having any blocks
diff --git a/source/val/construct.cpp b/source/val/construct.cpp
index e0053ee..7e106e6 100644
--- a/source/val/construct.cpp
+++ b/source/val/construct.cpp
@@ -19,6 +19,7 @@
 #include <unordered_set>
 
 #include "source/val/function.h"
+#include "source/val/validation_state.h"
 
 namespace spvtools {
 namespace val {
@@ -105,7 +106,8 @@
     // A selection construct nested directly within the loop construct is also
     // at the same depth. It is valid, however, to branch directly to the
     // continue target from within the selection construct.
-    if (block_depth == header_depth && type() == ConstructType::kSelection &&
+    if (block != header && block_depth == header_depth &&
+        type() == ConstructType::kSelection &&
         block->is_type(kBlockTypeContinue)) {
       // Continued to outer construct.
       continue;
@@ -126,5 +128,72 @@
   return construct_blocks;
 }
 
+bool Construct::IsStructuredExit(ValidationState_t& _, BasicBlock* dest) const {
+  // Structured Exits:
+  // - Selection:
+  //  - branch to its merge
+  //  - branch to nearest enclosing loop merge or continue
+  // - Loop:
+  //  - branch to its merge
+  //  - branch to its continue
+  // - Continue:
+  //  - branch to loop header
+  //  - branch to loop merge
+  //
+  // Note: we will never see a case construct here.
+  assert(type() != ConstructType::kCase);
+  if (type() == ConstructType::kLoop) {
+    auto header = entry_block();
+    auto terminator = header->terminator();
+    auto index = terminator - &_.ordered_instructions()[0];
+    auto merge_inst = &_.ordered_instructions()[index - 1];
+    auto merge_block_id = merge_inst->GetOperandAs<uint32_t>(0u);
+    auto continue_block_id = merge_inst->GetOperandAs<uint32_t>(1u);
+    if (dest->id() == merge_block_id || dest->id() == continue_block_id) {
+      return true;
+    }
+  } else if (type() == ConstructType::kContinue) {
+    auto loop_construct = corresponding_constructs()[0];
+    auto header = loop_construct->entry_block();
+    auto terminator = header->terminator();
+    auto index = terminator - &_.ordered_instructions()[0];
+    auto merge_inst = &_.ordered_instructions()[index - 1];
+    auto merge_block_id = merge_inst->GetOperandAs<uint32_t>(0u);
+    if (dest == header || dest->id() == merge_block_id) {
+      return true;
+    }
+  } else {
+    assert(type() == ConstructType::kSelection);
+    if (dest == exit_block()) {
+      return true;
+    }
+
+    auto header = entry_block();
+    auto block = header;
+    while (block) {
+      auto terminator = block->terminator();
+      auto index = terminator - &_.ordered_instructions()[0];
+      auto merge_inst = &_.ordered_instructions()[index - 1];
+      if (merge_inst->opcode() == SpvOpLoopMerge) {
+        auto merge_target = merge_inst->GetOperandAs<uint32_t>(0u);
+        auto merge_block = merge_inst->function()->GetBlock(merge_target).first;
+        if (merge_block->dominates(*header)) {
+          block = block->immediate_dominator();
+          continue;
+        }
+
+        auto continue_target = merge_inst->GetOperandAs<uint32_t>(1u);
+        if (dest->id() == merge_target || dest->id() == continue_target) {
+          return true;
+        }
+      }
+
+      block = block->immediate_dominator();
+    }
+  }
+
+  return false;
+}
+
 }  // namespace val
 }  // namespace spvtools
diff --git a/source/val/construct.h b/source/val/construct.h
index c7e7a78..172976d 100644
--- a/source/val/construct.h
+++ b/source/val/construct.h
@@ -23,6 +23,7 @@
 
 namespace spvtools {
 namespace val {
+class ValidationState_t;
 
 /// Functor for ordering BasicBlocks. BasicBlock pointers must not be null.
 struct less_than_id {
@@ -109,6 +110,22 @@
   // calculated.
   ConstructBlockSet blocks(Function* function) const;
 
+  // Returns true if |dest| is structured exit from the construct. Structured
+  // exits depend on the construct type.
+  // Selection:
+  //  * branch to the associated merge
+  //  * branch to the merge or continue of the innermost loop containing the
+  //  selection
+  // Loop:
+  //  * branch to the associated merge or continue
+  // Continue:
+  //  * back-edge to the associated loop header
+  //  * branch to the associated loop merge
+  //
+  // Note: the validator does not generate case constructs. Switches are
+  // checked separately from other constructs.
+  bool IsStructuredExit(ValidationState_t& _, BasicBlock* dest) const;
+
  private:
   /// The type of the construct
   ConstructType type_;
diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp
index 17f144c..b974d26 100644
--- a/source/val/validate_cfg.cpp
+++ b/source/val/validate_cfg.cpp
@@ -652,16 +652,27 @@
       }
     }
 
-    // Check that for all non-header blocks, all predecessors are within this
-    // construct.
     Construct::ConstructBlockSet construct_blocks = construct.blocks(function);
     for (auto block : construct_blocks) {
+      std::string construct_name, header_name, exit_name;
+      std::tie(construct_name, header_name, exit_name) =
+          ConstructNames(construct.type());
+      // Check that all exits from the construct are via structured exits.
+      for (auto succ : *block->successors()) {
+        if (block->reachable() && !construct_blocks.count(succ) &&
+            !construct.IsStructuredExit(_, succ)) {
+          return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+                 << "block <ID> " << _.getIdName(block->id()) << " exits the "
+                 << construct_name << " headed by <ID> "
+                 << _.getIdName(header->id())
+                 << ", but not via a structured exit";
+        }
+      }
       if (block == header) continue;
+      // Check that for all non-header blocks, all predecessors are within this
+      // construct.
       for (auto pred : *block->predecessors()) {
         if (pred->reachable() && !construct_blocks.count(pred)) {
-          std::string construct_name, header_name, exit_name;
-          std::tie(construct_name, header_name, exit_name) =
-              ConstructNames(construct.type());
           return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(pred->id()))
                  << "block <ID> " << pred->id() << " branches to the "
                  << construct_name << " construct, but not to the "
@@ -680,6 +691,7 @@
       }
     }
   }
+
   return SPV_SUCCESS;
 }
 
@@ -836,7 +848,8 @@
       auto edges = CFA<BasicBlock>::CalculateDominators(
           postorder, function.AugmentedCFGPredecessorsFunction());
       for (auto edge : edges) {
-        edge.first->SetImmediateDominator(edge.second);
+        if (edge.first != edge.second)
+          edge.first->SetImmediateDominator(edge.second);
       }
 
       /// calculate post dominators
diff --git a/test/val/val_cfg_test.cpp b/test/val/val_cfg_test.cpp
index 00a24ef..0bb8fec 100644
--- a/test/val/val_cfg_test.cpp
+++ b/test/val/val_cfg_test.cpp
@@ -3262,6 +3262,147 @@
           "IterationMultiple loop control operand must be greater than zero"));
 }
 
+TEST_F(ValidateCFG, InvalidSelectionExit) {
+  const std::string text = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+%2 = OpTypeVoid
+%3 = OpTypeBool
+%4 = OpConstantTrue %3
+%5 = OpTypeFunction %2
+%1 = OpFunction %2 None %5
+%6 = OpLabel
+OpSelectionMerge %7 None
+OpBranchConditional %4 %7 %8
+%8 = OpLabel
+OpSelectionMerge %9 None
+OpBranchConditional %4 %10 %9
+%10 = OpLabel
+OpBranch %7
+%9 = OpLabel
+OpBranch %7
+%7 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("block <ID> 10[%10] exits the selection headed by <ID> "
+                        "8[%8], but not via a structured exit"));
+}
+
+TEST_F(ValidateCFG, InvalidLoopExit) {
+  const std::string text = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+%2 = OpTypeVoid
+%3 = OpTypeBool
+%4 = OpConstantTrue %3
+%5 = OpTypeFunction %2
+%1 = OpFunction %2 None %5
+%6 = OpLabel
+OpSelectionMerge %7 None
+OpBranchConditional %4 %7 %8
+%8 = OpLabel
+OpLoopMerge %9 %10 None
+OpBranchConditional %4 %9 %11
+%11 = OpLabel
+OpBranchConditional %4 %7 %10
+%10 = OpLabel
+OpBranch %8
+%9 = OpLabel
+OpBranch %7
+%7 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("block <ID> 11[%11] exits the loop headed by <ID> "
+                        "8[%8], but not via a structured exit"));
+}
+
+TEST_F(ValidateCFG, InvalidContinueExit) {
+  const std::string text = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+%2 = OpTypeVoid
+%3 = OpTypeBool
+%4 = OpConstantTrue %3
+%5 = OpTypeFunction %2
+%1 = OpFunction %2 None %5
+%6 = OpLabel
+OpSelectionMerge %7 None
+OpBranchConditional %4 %7 %8
+%8 = OpLabel
+OpLoopMerge %9 %10 None
+OpBranchConditional %4 %9 %10
+%10 = OpLabel
+OpBranch %11
+%11 = OpLabel
+OpBranchConditional %4 %8 %7
+%9 = OpLabel
+OpBranch %7
+%7 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("block <ID> 11[%11] exits the continue headed by <ID> "
+                        "10[%10], but not via a structured exit"));
+}
+
+TEST_F(ValidateCFG, InvalidSelectionExitBackedge) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeVoid
+%2 = OpTypeBool
+%3 = OpUndef %2
+%4 = OpTypeFunction %1
+%5 = OpFunction %1 None %4
+%6 = OpLabel
+OpBranch %7
+%7 = OpLabel
+OpLoopMerge %8 %9 None
+OpBranchConditional %3 %8 %9
+%9 = OpLabel
+OpSelectionMerge %10 None
+OpBranchConditional %3 %11 %12
+%11 = OpLabel
+OpBranch %13
+%12 = OpLabel
+OpBranch %13
+%13 = OpLabel
+OpBranch %7
+%10 = OpLabel
+OpUnreachable
+%8 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("block <ID> 13[%13] exits the selection headed by <ID> "
+                        "9[%9], but not via a structured exit"));
+}
+
 /// TODO(umar): Nested CFG constructs
 
 }  // namespace