Add SPV_KHR_bfloat16 support (#6057)
* Add SPV_KHR_bfloat16 support
* Update DEPS to include SPIRV-Headers with bfloat16 support
* Fix unit test errors and format
* Add validation to invalid uses of bfloat16
* Add tests
* Roll back to previous commit
* Fix build error
* Add FPEncoding for opt::analysis::Float
* Address the comments
* Fix build error
* format
---------
Co-authored-by: Stu Smith <19190608+stu-s@users.noreply.github.com>
Co-authored-by: David Neto <dneto@google.com>
diff --git a/Android.mk b/Android.mk
index b2b96d9..1c150c4 100644
--- a/Android.mk
+++ b/Android.mk
@@ -76,7 +76,8 @@
source/val/validate_scopes.cpp \
source/val/validate_small_type_uses.cpp \
source/val/validate_tensor_layout.cpp \
- source/val/validate_type.cpp
+ source/val/validate_type.cpp\
+ source/val/validate_invalid_type.cpp
SPVTOOLS_OPT_SRC_FILES := \
source/opt/aggressive_dead_code_elim_pass.cpp \
diff --git a/BUILD.gn b/BUILD.gn
index 5fd00ca..82cf0e5 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -559,6 +559,7 @@
"source/val/validate_small_type_uses.cpp",
"source/val/validate_tensor_layout.cpp",
"source/val/validate_type.cpp",
+ "source/val/validate_invalid_type.cpp",
"source/val/validation_state.cpp",
"source/val/validation_state.h",
]
diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt
index b5608f0..cc5bc53 100644
--- a/source/CMakeLists.txt
+++ b/source/CMakeLists.txt
@@ -336,6 +336,7 @@
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_small_type_uses.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_tensor_layout.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_type.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_invalid_type.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/decoration.h
${CMAKE_CURRENT_SOURCE_DIR}/val/basic_block.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/construct.cpp
diff --git a/source/name_mapper.cpp b/source/name_mapper.cpp
index 24c24e2..ae3ef49 100644
--- a/source/name_mapper.cpp
+++ b/source/name_mapper.cpp
@@ -211,7 +211,12 @@
} break;
case spv::Op::OpTypeFloat: {
const auto bit_width = inst.words[2];
- // TODO: Handle optional fpencoding enum once actually used.
+ if (inst.num_words > 3) {
+ if (spv::FPEncoding(inst.words[3]) == spv::FPEncoding::BFloat16KHR) {
+ SaveName(result_id, "bfloat16");
+ break;
+ }
+ }
switch (bit_width) {
case 16:
SaveName(result_id, "half");
diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp
index e0c4b80..e37b433 100644
--- a/source/opt/type_manager.cpp
+++ b/source/opt/type_manager.cpp
@@ -792,9 +792,13 @@
type = new Integer(inst.GetSingleWordInOperand(0),
inst.GetSingleWordInOperand(1));
break;
- case spv::Op::OpTypeFloat:
- type = new Float(inst.GetSingleWordInOperand(0));
- break;
+ case spv::Op::OpTypeFloat: {
+ const spv::FPEncoding encoding =
+ inst.NumInOperands() > 1
+ ? static_cast<spv::FPEncoding>(inst.GetSingleWordInOperand(1))
+ : spv::FPEncoding::Max;
+ type = new Float(inst.GetSingleWordInOperand(0), encoding);
+ } break;
case spv::Op::OpTypeVector:
type = new Vector(GetType(inst.GetSingleWordInOperand(0)),
inst.GetSingleWordInOperand(1));
diff --git a/source/opt/types.cpp b/source/opt/types.cpp
index 2023719..a35f871 100644
--- a/source/opt/types.cpp
+++ b/source/opt/types.cpp
@@ -309,17 +309,26 @@
bool Float::IsSameImpl(const Type* that, IsSameCache*) const {
const Float* ft = that->AsFloat();
- return ft && width_ == ft->width_ && HasSameDecorations(that);
+ return ft && width_ == ft->width_ && encoding_ == ft->encoding_ &&
+ HasSameDecorations(that);
}
std::string Float::str() const {
std::ostringstream oss;
- oss << "float" << width_;
+ switch (encoding_) {
+ case spv::FPEncoding::BFloat16KHR:
+ assert(width_ == 16);
+ oss << "bfloat16";
+ break;
+ default:
+ oss << "float" << width_;
+ break;
+ }
return oss.str();
}
size_t Float::ComputeExtraStateHash(size_t hash, SeenTypes*) const {
- return hash_combine(hash, width_);
+ return hash_combine(hash, width_, encoding_);
}
Vector::Vector(const Type* type, uint32_t count)
diff --git a/source/opt/types.h b/source/opt/types.h
index 1418331..99b3cd8 100644
--- a/source/opt/types.h
+++ b/source/opt/types.h
@@ -182,9 +182,9 @@
// non-composite type.
uint64_t NumberOfComponents() const;
-// A bunch of methods for casting this type to a given type. Returns this if the
-// cast can be done, nullptr otherwise.
-// clang-format off
+ // A bunch of methods for casting this type to a given type. Returns this if
+ // the cast can be done, nullptr otherwise.
+ // clang-format off
#define DeclareCastMethod(target) \
virtual target* As##target() { return nullptr; } \
virtual const target* As##target() const { return nullptr; }
@@ -267,7 +267,8 @@
class Float : public Type {
public:
- Float(uint32_t w) : Type(kFloat), width_(w) {}
+ Float(uint32_t w, spv::FPEncoding encoding = spv::FPEncoding::Max)
+ : Type(kFloat), width_(w), encoding_(encoding) {}
Float(const Float&) = default;
std::string str() const override;
@@ -275,13 +276,15 @@
Float* AsFloat() override { return this; }
const Float* AsFloat() const override { return this; }
uint32_t width() const { return width_; }
+ spv::FPEncoding encoding() const { return encoding_; }
size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
- uint32_t width_; // bit width
+ uint32_t width_; // bit width
+ spv::FPEncoding encoding_; // FPEncoding
};
class Vector : public Type {
diff --git a/source/val/validate.cpp b/source/val/validate.cpp
index 2d10347..4c46d2b 100644
--- a/source/val/validate.cpp
+++ b/source/val/validate.cpp
@@ -367,6 +367,7 @@
if (auto error = RayReorderNVPass(*vstate, &instruction)) return error;
if (auto error = MeshShadingPass(*vstate, &instruction)) return error;
if (auto error = TensorLayoutPass(*vstate, &instruction)) return error;
+ if (auto error = InvalidTypePass(*vstate, &instruction)) return error;
}
// Validate the preconditions involving adjacent instructions. e.g.
diff --git a/source/val/validate.h b/source/val/validate.h
index 5514ff7..5d13a7b 100644
--- a/source/val/validate.h
+++ b/source/val/validate.h
@@ -223,6 +223,9 @@
/// Validates correctness of mesh shading instructions.
spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst);
+/// Validates correctness of certain special type instructions.
+spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst);
+
/// Calculates the reachability of basic blocks.
void ReachabilityPass(ValidationState_t& _);
diff --git a/source/val/validate_arithmetics.cpp b/source/val/validate_arithmetics.cpp
index d252ec9..38281be 100644
--- a/source/val/validate_arithmetics.cpp
+++ b/source/val/validate_arithmetics.cpp
@@ -224,6 +224,14 @@
<< "Expected float scalar type as Result Type: "
<< spvOpcodeString(opcode);
+ if (_.IsBfloat16ScalarType(result_type)) {
+ if (!_.HasCapability(spv::Capability::BFloat16DotProductKHR)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "OpDot Result Type <id> " << _.getIdName(result_type)
+ << "requires BFloat16DotProductKHR be declared.";
+ }
+ }
+
uint32_t first_vector_num_components = 0;
for (size_t operand_index = 2; operand_index < inst->operands().size();
diff --git a/source/val/validate_capability.cpp b/source/val/validate_capability.cpp
index 81d2ad5..303142b 100644
--- a/source/val/validate_capability.cpp
+++ b/source/val/validate_capability.cpp
@@ -100,6 +100,7 @@
case spv::Capability::GeometryStreams:
case spv::Capability::Float16:
case spv::Capability::Int8:
+ case spv::Capability::BFloat16TypeKHR:
return true;
default:
break;
diff --git a/source/val/validate_conversion.cpp b/source/val/validate_conversion.cpp
index c459ec3..8bf87ad 100644
--- a/source/val/validate_conversion.cpp
+++ b/source/val/validate_conversion.cpp
@@ -264,11 +264,18 @@
<< spvOpcodeString(opcode);
}
- if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type))
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Expected input to have different bit width from Result "
- "Type: "
- << spvOpcodeString(opcode);
+ // Scalar type
+ const uint32_t resScalarType = _.GetComponentType(result_type);
+ const uint32_t inputScalartype = _.GetComponentType(input_type);
+ if (_.GetBitWidth(resScalarType) == _.GetBitWidth(inputScalartype))
+ if ((_.IsBfloat16ScalarType(resScalarType) &&
+ _.IsBfloat16ScalarType(inputScalartype)) ||
+ (!_.IsBfloat16ScalarType(inputScalartype) &&
+ !_.IsBfloat16ScalarType(resScalarType)))
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected input to have different bit width from Result "
+ "Type: "
+ << spvOpcodeString(opcode);
break;
}
diff --git a/source/val/validate_interfaces.cpp b/source/val/validate_interfaces.cpp
index 68809b5..e01a08e 100644
--- a/source/val/validate_interfaces.cpp
+++ b/source/val/validate_interfaces.cpp
@@ -642,6 +642,28 @@
has_callable_data = true;
break;
}
+ case spv::StorageClass::Input:
+ case spv::StorageClass::Output: {
+ auto result_type = _.FindDef(interface_var->type_id());
+ if (_.ContainsType(result_type->GetOperandAs<uint32_t>(2),
+ [](const Instruction* inst) {
+ if (inst &&
+ inst->opcode() == spv::Op::OpTypeFloat) {
+ if (inst->words().size() > 3) {
+ if (inst->GetOperandAs<spv::FPEncoding>(2) ==
+ spv::FPEncoding::BFloat16KHR) {
+ return true;
+ }
+ }
+ }
+ return false;
+ })) {
+ return _.diag(SPV_ERROR_INVALID_ID, interface_var)
+ << _.VkErrorID(10370) << "Bfloat16 OpVariable <id> "
+ << _.getIdName(interface_var->id()) << " must not be declared "
+ << "with a Storage Class of Input or Output.";
+ }
+ }
default:
break;
}
diff --git a/source/val/validate_invalid_type.cpp b/source/val/validate_invalid_type.cpp
new file mode 100644
index 0000000..a9dcd29
--- /dev/null
+++ b/source/val/validate_invalid_type.cpp
@@ -0,0 +1,139 @@
+// Copyright (c) 2025 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Performs validation of invalid type instructions.
+
+#include <vector>
+
+#include "source/opcode.h"
+#include "source/val/instruction.h"
+#include "source/val/validate.h"
+#include "source/val/validation_state.h"
+
+namespace spvtools {
+namespace val {
+
+// Validates correctness of certain special type instructions.
+spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst) {
+ const spv::Op opcode = inst->opcode();
+
+ switch (opcode) {
+ // OpExtInst
+ case spv::Op::OpExtInst:
+ // Arithmetic Instructions
+ case spv::Op::OpFAdd:
+ case spv::Op::OpFSub:
+ case spv::Op::OpFMul:
+ case spv::Op::OpFDiv:
+ case spv::Op::OpFRem:
+ case spv::Op::OpFMod:
+ case spv::Op::OpFNegate:
+ // Derivative Instructions
+ case spv::Op::OpDPdx:
+ case spv::Op::OpDPdy:
+ case spv::Op::OpFwidth:
+ case spv::Op::OpDPdxFine:
+ case spv::Op::OpDPdyFine:
+ case spv::Op::OpFwidthFine:
+ case spv::Op::OpDPdxCoarse:
+ case spv::Op::OpDPdyCoarse:
+ case spv::Op::OpFwidthCoarse:
+ // Atomic Instructions
+ case spv::Op::OpAtomicFAddEXT:
+ case spv::Op::OpAtomicFMinEXT:
+ case spv::Op::OpAtomicFMaxEXT:
+ case spv::Op::OpAtomicLoad:
+ case spv::Op::OpAtomicExchange:
+ // Group and Subgroup Instructions
+ case spv::Op::OpGroupNonUniformRotateKHR:
+ case spv::Op::OpGroupNonUniformBroadcast:
+ case spv::Op::OpGroupNonUniformShuffle:
+ case spv::Op::OpGroupNonUniformShuffleXor:
+ case spv::Op::OpGroupNonUniformShuffleUp:
+ case spv::Op::OpGroupNonUniformShuffleDown:
+ case spv::Op::OpGroupNonUniformQuadBroadcast:
+ case spv::Op::OpGroupNonUniformQuadSwap:
+ case spv::Op::OpGroupNonUniformBroadcastFirst:
+ case spv::Op::OpGroupNonUniformFAdd:
+ case spv::Op::OpGroupNonUniformFMul:
+ case spv::Op::OpGroupNonUniformFMin: {
+ const uint32_t result_type = inst->type_id();
+ if (_.IsBfloat16ScalarType(result_type) ||
+ _.IsBfloat16VectorType(result_type)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode) << " doesn't support BFloat16 type.";
+ }
+ break;
+ }
+
+ case spv::Op::OpAtomicStore: {
+ uint32_t data_type =
+ _.FindDef(inst->GetOperandAs<uint32_t>(3))->type_id();
+ if (_.IsBfloat16VectorType(data_type)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode) << " doesn't support BFloat16 type.";
+ }
+ break;
+ }
+ // Relational and Logical Instructions
+ case spv::Op::OpIsNan:
+ case spv::Op::OpIsInf:
+ case spv::Op::OpIsFinite:
+ case spv::Op::OpIsNormal:
+ case spv::Op::OpSignBitSet: {
+ const uint32_t operand_type = _.GetOperandTypeId(inst, 2);
+ if (_.IsBfloat16ScalarType(operand_type) ||
+ _.IsBfloat16VectorType(operand_type)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode) << " doesn't support BFloat16 type.";
+ }
+ break;
+ }
+
+ case spv::Op::OpGroupNonUniformAllEqual: {
+ const auto value_type = _.GetOperandTypeId(inst, 3);
+ if (_.IsBfloat16ScalarType(value_type) ||
+ _.IsBfloat16VectorType(value_type)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode) << " doesn't support BFloat16 type.";
+ }
+ break;
+ }
+
+ case spv::Op::OpMatrixTimesMatrix: {
+ const uint32_t result_type = inst->type_id();
+ uint32_t res_num_rows = 0;
+ uint32_t res_num_cols = 0;
+ uint32_t res_col_type = 0;
+ uint32_t res_component_type = 0;
+ if (_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols,
+ &res_col_type, &res_component_type)) {
+ if (_.IsBfloat16ScalarType(res_component_type)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode)
+ << " doesn't support BFloat16 type.";
+ }
+ }
+ break;
+ }
+
+ default:
+ break;
+ }
+
+ return SPV_SUCCESS;
+}
+
+} // namespace val
+} // namespace spvtools
diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp
index b5fa04f..4377271 100644
--- a/source/val/validate_type.cpp
+++ b/source/val/validate_type.cpp
@@ -115,6 +115,15 @@
if (num_bits == 32) {
return SPV_SUCCESS;
}
+ auto operands = inst->words();
+ if (operands.size() > 3) {
+ if (operands[3] != 0) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Current FPEncoding only supports BFloat16KHR.";
+ }
+ return SPV_SUCCESS;
+ }
+
if (num_bits == 16) {
if (_.features().declare_float16_type) {
return SPV_SUCCESS;
@@ -643,6 +652,15 @@
<< " is not a scalar numerical type.";
}
+ if (_.IsBfloat16ScalarType(component_type_id)) {
+ if (!_.HasCapability(spv::Capability::BFloat16CooperativeMatrixKHR)) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpTypeCooperativeMatrix Component Type <id> "
+ << _.getIdName(component_type_id)
+ << "require BFloat16CooperativeMatrixKHR be declared.";
+ }
+ }
+
const auto scope_index = 2;
const auto scope_id = inst->GetOperandAs<uint32_t>(scope_index);
const auto scope = _.FindDef(scope_id);
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index db7e020..96d80f0 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -946,6 +946,32 @@
return inst && inst->opcode() == spv::Op::OpTypeVoid;
}
+bool ValidationState_t::IsBfloat16ScalarType(uint32_t id) const {
+ const Instruction* inst = FindDef(id);
+ if (inst && inst->opcode() == spv::Op::OpTypeFloat) {
+ if (inst->words().size() > 3) {
+ if (inst->GetOperandAs<spv::FPEncoding>(2) ==
+ spv::FPEncoding::BFloat16KHR) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+bool ValidationState_t::IsBfloat16VectorType(uint32_t id) const {
+ const Instruction* inst = FindDef(id);
+ if (!inst) {
+ return false;
+ }
+
+ if (inst->opcode() == spv::Op::OpTypeVector) {
+ return IsBfloat16ScalarType(GetComponentType(id));
+ }
+
+ return false;
+}
+
bool ValidationState_t::IsFloatScalarType(uint32_t id) const {
const Instruction* inst = FindDef(id);
return inst && inst->opcode() == spv::Op::OpTypeFloat;
@@ -1768,6 +1794,10 @@
const auto f = [type, width](const Instruction* inst) {
if (inst->opcode() == type) {
+ // Bfloat16 is a special type.
+ if (type == spv::Op::OpTypeFloat && inst->words().size() > 3)
+ return false;
+
return inst->GetOperandAs<uint32_t>(1u) == width;
}
return false;
@@ -2536,6 +2566,8 @@
case 10213:
// This use to be a standalone, but maintenance8 will set allow_offset_texture_operand now
return VUID_WRAP(VUID-RuntimeSpirv-Offset-10213);
+ case 10370:
+ return VUID_WRAP(VUID-StandaloneSpirv-OpTypeFloat-10370);
case 10583:
return VUID_WRAP(VUID-StandaloneSpirv-Component-10583);
case 10684:
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index e97d3d3..8cc87a4 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -633,6 +633,8 @@
// Returns true iff |id| is a type corresponding to the name of the function.
// Only works for types not for objects.
bool IsVoidType(uint32_t id) const;
+ bool IsBfloat16ScalarType(uint32_t id) const;
+ bool IsBfloat16VectorType(uint32_t id) const;
bool IsFloatScalarType(uint32_t id) const;
bool IsFloatArrayType(uint32_t id) const;
bool IsFloatVectorType(uint32_t id) const;
diff --git a/test/opt/types_test.cpp b/test/opt/types_test.cpp
index 4ceeb14..01c8e90 100644
--- a/test/opt/types_test.cpp
+++ b/test/opt/types_test.cpp
@@ -364,6 +364,20 @@
}
}
+TEST(Types, FloatFPEncoding) {
+ std::vector<spv::FPEncoding> encodings = {
+ spv::FPEncoding::BFloat16KHR,
+ spv::FPEncoding::Max,
+ };
+ std::vector<std::unique_ptr<Float>> types;
+ for (spv::FPEncoding encoding : encodings) {
+ types.emplace_back(new Float(16, encoding));
+ }
+ for (size_t i = 0; i < encodings.size(); i++) {
+ EXPECT_EQ(encodings[i], types[i]->encoding());
+ }
+}
+
TEST(Types, VectorElementCount) {
auto s32 = MakeUnique<Integer>(32, true);
for (uint32_t c : {2, 3, 4}) {
diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt
index 9d6f6ea..c356753 100644
--- a/test/val/CMakeLists.txt
+++ b/test/val/CMakeLists.txt
@@ -38,6 +38,7 @@
val_derivatives_test.cpp
val_entry_point_test.cpp
val_explicit_reserved_test.cpp
+ val_invalid_type_test.cpp
val_extensions_test.cpp
val_extension_spv_khr_expect_assume_test.cpp
val_extension_spv_khr_linkonce_odr_test.cpp
diff --git a/test/val/val_arithmetics_test.cpp b/test/val/val_arithmetics_test.cpp
index 2a15b28..85653c5 100644
--- a/test/val/val_arithmetics_test.cpp
+++ b/test/val/val_arithmetics_test.cpp
@@ -1166,6 +1166,50 @@
"vector size of the right operand: OuterProduct"));
}
+std::string GenerateBFloatCode(const std::string& main_body) {
+ const std::string prefix =
+ R"(
+OpCapability Shader
+OpCapability BFloat16TypeKHR
+OpCapability BFloat16DotProductKHR
+OpCapability BFloat16CooperativeMatrixKHR
+OpExtension "SPV_KHR_bfloat16"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpSource GLSL 450
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bfloat16 = OpTypeFloat 16 BFloat16KHR
+%_ptr_Function_bfloat16 = OpTypePointer Function %bfloat16
+%v2bfloat16 = OpTypeVector %bfloat16 2
+%_ptr_Function_v2bfloat16 = OpTypePointer Function %v2bfloat16
+%main = OpFunction %void None %func
+%main_entry = OpLabel)";
+
+ const std::string suffix =
+ R"(
+OpReturn
+OpFunctionEnd)";
+
+ return prefix + main_body + suffix;
+}
+
+TEST_F(ValidateArithmetics, DotBfloat16) {
+ const std::string body = R"(
+%v1 = OpVariable %_ptr_Function_v2bfloat16 Function
+%v2 = OpVariable %_ptr_Function_v2bfloat16 Function
+%12 = OpLoad %v2bfloat16 %v1
+%14 = OpLoad %v2bfloat16 %v2
+%15 = OpDot %bfloat16 %12 %14
+)";
+
+ CompileSuccessfully(GenerateBFloatCode(body).c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
std::string GenerateCoopMatCode(const std::string& extra_types,
const std::string& main_body) {
const std::string prefix =
diff --git a/test/val/val_conversion_test.cpp b/test/val/val_conversion_test.cpp
index 52d40d3..f0f5c46 100644
--- a/test/val/val_conversion_test.cpp
+++ b/test/val/val_conversion_test.cpp
@@ -39,6 +39,7 @@
const std::string capabilities =
R"(
OpCapability Shader
+OpCapability Float16
OpCapability Int64
OpCapability Float64)";
@@ -54,6 +55,7 @@
%func = OpTypeFunction %void
%bool = OpTypeBool
%f32 = OpTypeFloat 32
+%f16 = OpTypeFloat 16
%u32 = OpTypeInt 32 0
%s32 = OpTypeInt 32 1
%f64 = OpTypeFloat 64
@@ -84,6 +86,8 @@
%f32_3 = OpConstant %f32 3
%f32_4 = OpConstant %f32 4
+%f16_1 = OpConstant %f16 1
+
%s32_0 = OpConstant %s32 0
%s32_1 = OpConstant %s32 1
%s32_2 = OpConstant %s32 2
@@ -593,6 +597,24 @@
"Result Type: FConvert"));
}
+TEST_F(ValidateConversion, FConvertFloat16ToBFloat16) {
+ const std::string extensions = R"(
+OpCapability BFloat16TypeKHR
+OpExtension "SPV_KHR_bfloat16"
+)";
+
+ const std::string types = R"(
+%bf16 = OpTypeFloat 16 BFloat16KHR
+)";
+
+ const std::string body = R"(
+%val = OpFConvert %bf16 %f16_1
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body, extensions, "", types).c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
TEST_F(ValidateConversion, QuantizeToF16Success) {
const std::string body = R"(
%val1 = OpQuantizeToF16 %f32 %f32_1
diff --git a/test/val/val_data_test.cpp b/test/val/val_data_test.cpp
index 349e5e9..a6b90fd 100644
--- a/test/val/val_data_test.cpp
+++ b/test/val/val_data_test.cpp
@@ -71,6 +71,15 @@
OpCapability Int64
OpMemoryModel Logical GLSL450
)";
+std::string header_with_bfloat16 = R"(
+ OpCapability Shader
+ OpCapability Linkage
+ OpCapability BFloat16TypeKHR
+ OpCapability BFloat16DotProductKHR
+ OpCapability BFloat16CooperativeMatrixKHR
+ OpExtension "SPV_KHR_bfloat16"
+ OpMemoryModel Logical GLSL450
+)";
std::string header_with_float16 = R"(
OpCapability Shader
OpCapability Linkage
@@ -334,6 +343,25 @@
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
+TEST_F(ValidateData, bfloat16_good) {
+ std::string str = header_with_bfloat16 + "%2 = OpTypeFloat 16 BFloat16KHR";
+ CompileSuccessfully(str.c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateData, cooperative_matrix_bfloat16_good) {
+ std::string str = header_with_bfloat16 + R"(
+%u32 = OpTypeInt 32 0
+%u32_16 = OpConstant %u32 16
+%useA = OpConstant %u32 0
+%subgroup = OpConstant %u32 3
+%bf16 = OpTypeFloat 16 BFloat16KHR
+%bf16matA = OpTypeCooperativeMatrixKHR %bf16 %subgroup %u32_16 %u32_16 %useA
+)";
+ CompileSuccessfully(str.c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
TEST_F(ValidateData, float16_buffer_good) {
std::string str = header_with_float16_buffer + "%2 = OpTypeFloat 16";
CompileSuccessfully(str.c_str());
@@ -347,8 +375,50 @@
EXPECT_THAT(getDiagnosticString(), HasSubstr(missing_float16_cap_error));
}
+TEST_F(ValidateData, bfloat16_bad) {
+ std::string str = header + "%2 = OpTypeFloat 16 BFloat16KHR";
+ CompileSuccessfully(str.c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("requires one of these capabilities: BFloat16TypeKHR"));
+}
+
+TEST_F(ValidateData, dot_bfloat16_bad) {
+ std::string str = R"(
+ OpCapability Shader
+ OpCapability BFloat16TypeKHR
+ OpExtension "SPV_KHR_bfloat16"
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpSource GLSL 450
+ OpName %main "main"
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %bfloat16 = OpTypeFloat 16 BFloat16KHR
+%_ptr_Function_bfloat16 = OpTypePointer Function %bfloat16
+ %v2bfloat16 = OpTypeVector %bfloat16 2
+%_ptr_Function_v2bfloat16 = OpTypePointer Function %v2bfloat16
+ %main = OpFunction %void None %3
+ %5 = OpLabel
+ %v1 = OpVariable %_ptr_Function_v2bfloat16 Function
+ %v2 = OpVariable %_ptr_Function_v2bfloat16 Function
+ %12 = OpLoad %v2bfloat16 %v1
+ %14 = OpLoad %v2bfloat16 %v2
+ %15 = OpDot %bfloat16 %12 %14
+ OpReturn
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("requires BFloat16DotProductKHR be declared."));
+}
+
TEST_F(ValidateData, float64_good) {
std::string str = header_with_float64 + "%2 = OpTypeFloat 64";
+
CompileSuccessfully(str.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
@@ -627,7 +697,8 @@
OpExtension "SPV_KHR_variable_pointers"
OpExtension "SPV_KHR_16bit_storage"
OpMemoryModel Logical GLSL450
- OpDecorate %_ FPRoundingMode )" + mode + R"(
+ OpDecorate %_ FPRoundingMode )" +
+ mode + R"(
%half = OpTypeFloat 16
%float = OpTypeFloat 32
%float_1_25 = OpConstant %float 1.25
diff --git a/test/val/val_interfaces_test.cpp b/test/val/val_interfaces_test.cpp
index 2aa8033..6dff110 100644
--- a/test/val/val_interfaces_test.cpp
+++ b/test/val/val_interfaces_test.cpp
@@ -2022,6 +2022,36 @@
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_3));
}
+TEST_F(ValidateInterfacesTest, ValidInstructionType) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability BFloat16TypeKHR
+OpExtension "SPV_KHR_bfloat16"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %in %out
+OpExecutionMode %main OriginUpperLeft
+OpDecorate %in Location 0
+OpDecorate %out Location 0
+%void = OpTypeVoid
+%bfloat16 = OpTypeFloat 16 BFloat16KHR
+%in_ptr = OpTypePointer Input %bfloat16
+%out_ptr = OpTypePointer Output %bfloat16
+%in = OpVariable %in_ptr Input
+%out = OpVariable %out_ptr Output
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(text, SPV_ENV_VULKAN_1_3);
+ ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_3));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Bfloat16 OpVariable <id> '2[%2]' must not be declared "
+ "with a Storage Class of Input or Output.\n"));
+}
+
} // namespace
} // namespace val
} // namespace spvtools
diff --git a/test/val/val_invalid_type_test.cpp b/test/val/val_invalid_type_test.cpp
new file mode 100644
index 0000000..dc911e1
--- /dev/null
+++ b/test/val/val_invalid_type_test.cpp
@@ -0,0 +1,117 @@
+// Copyright (c) 2025 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests for invalid types.
+
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "source/spirv_target_env.h"
+#include "test/unit_spirv.h"
+#include "test/val/val_fixtures.h"
+
+namespace spvtools {
+namespace val {
+namespace {
+
+using ::testing::HasSubstr;
+using ::testing::Not;
+using ::testing::Values;
+
+using ValidateInvalidType = spvtest::ValidateBase<bool>;
+
+std::string GenerateBFloatCode(const std::string& main_body) {
+ const std::string prefix =
+ R"(
+OpCapability Shader
+OpCapability BFloat16TypeKHR
+OpCapability AtomicFloat16AddEXT
+OpCapability GroupNonUniformShuffle
+OpExtension "SPV_EXT_shader_atomic_float16_add"
+OpExtension "SPV_KHR_bfloat16"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpSource GLSL 450
+OpName %main "main"
+%void = OpTypeVoid
+%bfloat16 = OpTypeFloat 16 BFloat16KHR
+%func = OpTypeFunction %void
+%u32 = OpTypeInt 32 0
+%u1 = OpConstant %u32 1
+%u0 = OpConstant %u32 0
+%u3 = OpConstant %u32 3
+%bf16_1 = OpConstant %bfloat16 1
+%_ptr_Function_bfloat16 = OpTypePointer Function %bfloat16
+%v2bfloat16 = OpTypeVector %bfloat16 2
+%_ptr_Function_v2bfloat16 = OpTypePointer Function %v2bfloat16
+%bf16_ptr = OpTypePointer Workgroup %bfloat16
+%bf16_var = OpVariable %bf16_ptr Workgroup
+%main = OpFunction %void None %func
+%main_entry = OpLabel)";
+
+ const std::string suffix =
+ R"(
+OpReturn
+OpFunctionEnd)";
+
+ return prefix + main_body + suffix;
+}
+
+TEST_F(ValidateInvalidType, Bfloat16InvalidArithmeticInstruction) {
+ const std::string body = R"(
+%v1 = OpVariable %_ptr_Function_bfloat16 Function
+%v2 = OpVariable %_ptr_Function_bfloat16 Function
+%12 = OpLoad %bfloat16 %v1
+%14 = OpLoad %bfloat16 %v2
+%15 = OpFMul %bfloat16 %12 %14
+)";
+
+ CompileSuccessfully(GenerateBFloatCode(body).c_str(), SPV_ENV_VULKAN_1_3);
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA,
+ ValidateInstructions(SPV_ENV_UNIVERSAL_1_6));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("FMul doesn't support BFloat16 type."));
+}
+
+TEST_F(ValidateInvalidType, Bfloat16InvalidAtomicInstruction) {
+ const std::string body = R"(
+%val1 = OpAtomicFAddEXT %bfloat16 %bf16_var %u1 %u0 %bf16_1
+)";
+
+ CompileSuccessfully(GenerateBFloatCode(body).c_str(), SPV_ENV_VULKAN_1_3);
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA,
+ ValidateInstructions(SPV_ENV_UNIVERSAL_1_6));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("AtomicFAddEXT doesn't support BFloat16 type."));
+}
+
+TEST_F(ValidateInvalidType, Bfloat16InvalidGroupNonUniformShuffle) {
+ const std::string body = R"(
+%val1 = OpGroupNonUniformShuffle %bfloat16 %u3 %bf16_1 %u0
+)";
+
+ CompileSuccessfully(GenerateBFloatCode(body).c_str(), SPV_ENV_VULKAN_1_3);
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA,
+ ValidateInstructions(SPV_ENV_UNIVERSAL_1_6));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("GroupNonUniformShuffle doesn't support BFloat16 type."));
+}
+
+} // namespace
+} // namespace val
+} // namespace spvtools