val: add missing relational ops to InvalidTypePass (#6495)
diff --git a/source/val/validate_invalid_type.cpp b/source/val/validate_invalid_type.cpp index 6afc0f8..05c09e2 100644 --- a/source/val/validate_invalid_type.cpp +++ b/source/val/validate_invalid_type.cpp
@@ -100,6 +100,21 @@ case spv::Op::OpIsInf: case spv::Op::OpIsFinite: case spv::Op::OpIsNormal: + case spv::Op::OpFOrdEqual: + case spv::Op::OpFUnordEqual: + case spv::Op::OpFOrdNotEqual: + case spv::Op::OpFUnordNotEqual: + case spv::Op::OpFOrdLessThan: + case spv::Op::OpFUnordLessThan: + case spv::Op::OpFOrdGreaterThan: + case spv::Op::OpFUnordGreaterThan: + case spv::Op::OpFOrdLessThanEqual: + case spv::Op::OpFUnordLessThanEqual: + case spv::Op::OpFOrdGreaterThanEqual: + case spv::Op::OpFUnordGreaterThanEqual: + case spv::Op::OpLessOrGreater: + case spv::Op::OpOrdered: + case spv::Op::OpUnordered: case spv::Op::OpSignBitSet: { const uint32_t operand_type = _.GetOperandTypeId(inst, 2); if (_.IsBfloat16Type(operand_type)) {
diff --git a/test/val/val_invalid_type_test.cpp b/test/val/val_invalid_type_test.cpp index d19b274..5f8cf73 100644 --- a/test/val/val_invalid_type_test.cpp +++ b/test/val/val_invalid_type_test.cpp
@@ -48,6 +48,7 @@ OpExecutionMode %main LocalSize 1 1 1 OpSource GLSL 450 OpName %main "main" +%bool = OpTypeBool %void = OpTypeVoid %bfloat16 = OpTypeFloat 16 BFloat16KHR %func = OpTypeFunction %void @@ -88,6 +89,22 @@ HasSubstr("FMul doesn't support BFloat16 type.")); } +TEST_F(ValidateInvalidType, Bfloat16InvalidRelationalInstruction) { + const std::string body = R"( +%v1 = OpVariable %_ptr_Function_bfloat16 Function +%v2 = OpVariable %_ptr_Function_bfloat16 Function +%12 = OpLoad %bfloat16 %v1 +%14 = OpLoad %bfloat16 %v2 +%15 = OpFOrdEqual %bool %12 %14 +)"; + + CompileSuccessfully(GenerateBFloatCode(body).c_str(), SPV_ENV_VULKAN_1_3); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_6)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("FOrdEqual doesn't support BFloat16 type.")); +} + TEST_F(ValidateInvalidType, Bfloat16InvalidAtomicInstruction) { const std::string body = R"( %val1 = OpAtomicFAddEXT %bfloat16 %bf16_var %u1 %u0 %bf16_1 @@ -128,6 +145,7 @@ OpExecutionMode %main LocalSize 1 1 1 OpSource GLSL 450 OpName %main "main" +%bool = OpTypeBool %void = OpTypeVoid %fp8e4m3 = OpTypeFloat 8 Float8E4M3EXT %fp8e5m2 = OpTypeFloat 8 Float8E5M2EXT @@ -191,6 +209,38 @@ HasSubstr("FMul doesn't support FP8 E4M3/E5M2 types.")); } +TEST_F(ValidateInvalidType, FP8E4M3InvalidRelationalInstruction) { + const std::string body = R"( +%v1 = OpVariable %_ptr_Function_fp8e4m3 Function +%v2 = OpVariable %_ptr_Function_fp8e4m3 Function +%12 = OpLoad %fp8e4m3 %v1 +%14 = OpLoad %fp8e4m3 %v2 +%15 = OpFOrdEqual %bool %12 %14 +)"; + + CompileSuccessfully(GenerateFP8Code(body).c_str(), SPV_ENV_VULKAN_1_3); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_6)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("FOrdEqual doesn't support FP8 E4M3/E5M2 types.")); +} + +TEST_F(ValidateInvalidType, FP8E5M2InvalidRelationalInstruction) { + const std::string body = R"( +%v1 = OpVariable %_ptr_Function_fp8e5m2 Function +%v2 = OpVariable %_ptr_Function_fp8e5m2 Function +%12 = OpLoad %fp8e5m2 %v1 +%14 = OpLoad %fp8e5m2 %v2 +%15 = OpFOrdEqual %bool %12 %14 +)"; + + CompileSuccessfully(GenerateFP8Code(body).c_str(), SPV_ENV_VULKAN_1_3); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_6)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("FOrdEqual doesn't support FP8 E4M3/E5M2 types.")); +} + TEST_F(ValidateInvalidType, FP8E4M3InvalidAtomicInstruction) { const std::string body = R"( %val1 = OpAtomicFAddEXT %fp8e4m3 %fp8e4m3_var %u1 %u0 %fp8e4m3_1